460 lines
15 KiB
Python
460 lines
15 KiB
Python
"""Export 2×4 RGB panels for Tier-A gap validation (thesis appendix)."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import re
|
||
from datetime import date, datetime
|
||
from pathlib import Path
|
||
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import rasterio
|
||
from rasterio.transform import rowcol
|
||
from rasterio.warp import Resampling, reproject
|
||
|
||
from gap_validation.s2_mask_dir import acquisition_yyyymmdd_in_window, yyyymmdd_from_iso
|
||
|
||
REFL_DATE_RE = re.compile(r"S2A_MSIL2A_(\d{8})_REFL\.tif$")
|
||
S3_COMPOSITE_RE = re.compile(r"composite_(\d{8})\.tif$")
|
||
TRANSITIONS = ("green_up", "green_down")
|
||
COL_TITLES = ("Withheld S2", "Gap fusion", "S3 composite", "Nearest S2")
|
||
ROW_LABELS = {"green_up": "Green-up", "green_down": "Green-down"}
|
||
# Fixed pixel window around PhenoCam for comparable framing across sites (~1 km).
|
||
DISPLAY_HALF_PX = 48
|
||
# Match postprocessing / spatial_metrics positive-reflectance mask.
|
||
VALID_REFL_THRESHOLD = 0.001
|
||
NODATA_RGB = (0.15, 0.15, 0.15)
|
||
|
||
|
||
def _parse_bti_scenario(scenario: str) -> tuple[str, int]:
|
||
m = re.match(r"^(aggressive|nonaggressive)_sigma(20|30)$", scenario)
|
||
if not m:
|
||
raise ValueError(f"expected BtI scenario key, got {scenario!r}")
|
||
return m.group(1), int(m.group(2))
|
||
|
||
|
||
def _prepared_base(data_dir: Path, site: str, season: int, strategy: str) -> Path:
|
||
return data_dir / site / str(season) / f"prepared_{strategy}"
|
||
|
||
|
||
def _gap_spatial_fusion_dir(
|
||
data_dir: Path,
|
||
site: str,
|
||
season: int,
|
||
gap_days: int,
|
||
transition: str,
|
||
strategy: str,
|
||
sigma: int,
|
||
) -> Path:
|
||
return (
|
||
data_dir
|
||
/ site
|
||
/ str(season)
|
||
/ "validation"
|
||
/ "fusion"
|
||
/ f"gap_{gap_days}_{transition}"
|
||
/ f"{strategy}_sigma{sigma}_bti"
|
||
)
|
||
|
||
|
||
def _iso_to_date(iso_d: str) -> date:
|
||
return datetime.strptime(iso_d[:10], "%Y-%m-%d").date()
|
||
|
||
|
||
def _exclude_ymds(entry: dict) -> set[str]:
|
||
w0 = _iso_to_date(entry["window_start"])
|
||
w1 = _iso_to_date(entry["window_end"])
|
||
withheld_fn = entry.get("withheld_s2_filename") or ""
|
||
m = REFL_DATE_RE.search(withheld_fn)
|
||
excluded = set()
|
||
if m:
|
||
excluded.add(m.group(1))
|
||
return excluded
|
||
|
||
|
||
def nearest_stack_s2(
|
||
prepared_s2_dir: Path,
|
||
prediction_iso: str,
|
||
*,
|
||
exclude_ymds: set[str],
|
||
) -> Path | None:
|
||
"""Nearest prepared REFL acquisition to prediction, outside excluded days."""
|
||
if not prepared_s2_dir.is_dir():
|
||
return None
|
||
target = _iso_to_date(prediction_iso)
|
||
best_path: Path | None = None
|
||
best_delta: int | None = None
|
||
for p in prepared_s2_dir.glob("S2A_MSIL2A_*_REFL.tif"):
|
||
m = REFL_DATE_RE.search(p.name)
|
||
if not m:
|
||
continue
|
||
ymd = m.group(1)
|
||
if ymd in exclude_ymds:
|
||
continue
|
||
d = datetime.strptime(ymd, "%Y%m%d").date()
|
||
delta = abs((d - target).days)
|
||
if best_delta is None or delta < best_delta:
|
||
best_delta = delta
|
||
best_path = p
|
||
return best_path
|
||
|
||
|
||
def nearest_s3_composite(
|
||
prepared_s3_dir: Path,
|
||
prediction_iso: str,
|
||
) -> Path | None:
|
||
"""Nearest daily S3 composite to prediction date."""
|
||
if not prepared_s3_dir.is_dir():
|
||
return None
|
||
target = _iso_to_date(prediction_iso)
|
||
best_path: Path | None = None
|
||
best_delta: int | None = None
|
||
for p in prepared_s3_dir.glob("composite_*.tif"):
|
||
m = S3_COMPOSITE_RE.search(p.name)
|
||
if not m:
|
||
continue
|
||
d = datetime.strptime(m.group(1), "%Y%m%d").date()
|
||
delta = abs((d - target).days)
|
||
if best_delta is None or delta < best_delta:
|
||
best_delta = delta
|
||
best_path = p
|
||
return best_path
|
||
|
||
|
||
def _reference_grid(path: Path) -> dict | None:
|
||
"""Grid metadata from fusion / S2 REFL (fine S2 resolution)."""
|
||
if not path.is_file():
|
||
return None
|
||
with rasterio.open(path) as src:
|
||
return {
|
||
"transform": src.transform,
|
||
"crs": src.crs,
|
||
"height": src.height,
|
||
"width": src.width,
|
||
"bounds": src.bounds,
|
||
}
|
||
|
||
|
||
def _grids_match(src: rasterio.DatasetReader, ref: dict) -> bool:
|
||
return (
|
||
src.height == ref["height"]
|
||
and src.width == ref["width"]
|
||
and src.transform == ref["transform"]
|
||
and src.crs == ref["crs"]
|
||
)
|
||
|
||
|
||
def _read_bgr_on_grid(path: Path, ref: dict) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None:
|
||
"""Read BGR on the fusion grid; reproject only when needed (S3 coarse stack)."""
|
||
if not path.is_file():
|
||
return None
|
||
with rasterio.open(path) as src:
|
||
if src.count < 3:
|
||
return None
|
||
if _grids_match(src, ref):
|
||
return (
|
||
src.read(1).astype(np.float64),
|
||
src.read(2).astype(np.float64),
|
||
src.read(3).astype(np.float64),
|
||
)
|
||
shape = (ref["height"], ref["width"])
|
||
bands: list[np.ndarray] = []
|
||
for band_idx in (1, 2, 3):
|
||
dst = np.full(shape, np.nan, dtype=np.float32)
|
||
reproject(
|
||
source=rasterio.band(src, band_idx),
|
||
destination=dst,
|
||
src_transform=src.transform,
|
||
src_crs=src.crs,
|
||
dst_transform=ref["transform"],
|
||
dst_crs=ref["crs"],
|
||
resampling=Resampling.bilinear,
|
||
)
|
||
bands.append(dst.astype(np.float64))
|
||
return bands[0], bands[1], bands[2]
|
||
|
||
|
||
def _refl_valid(
|
||
blue: np.ndarray, green: np.ndarray, red: np.ndarray
|
||
) -> np.ndarray:
|
||
"""Positive reflectance mask (aligned with postprocessing / Tier-A metrics)."""
|
||
return (
|
||
np.isfinite(blue)
|
||
& np.isfinite(green)
|
||
& np.isfinite(red)
|
||
& (blue > VALID_REFL_THRESHOLD)
|
||
& (green > VALID_REFL_THRESHOLD)
|
||
& (red > VALID_REFL_THRESHOLD)
|
||
)
|
||
|
||
|
||
def _panel_stretch_limits(
|
||
blue: np.ndarray, green: np.ndarray, red: np.ndarray, valid: np.ndarray
|
||
) -> tuple[float, float]:
|
||
if not valid.any():
|
||
return 0.0, 1.0
|
||
vals = np.concatenate([red[valid], green[valid], blue[valid]])
|
||
lo, hi = np.percentile(vals, (2, 98))
|
||
if hi <= lo:
|
||
return 0.0, 1.0
|
||
return float(lo), float(hi)
|
||
|
||
|
||
def _bgr_to_rgba(
|
||
blue: np.ndarray,
|
||
green: np.ndarray,
|
||
red: np.ndarray,
|
||
*,
|
||
valid: np.ndarray,
|
||
vmin: float,
|
||
vmax: float,
|
||
) -> np.ndarray:
|
||
rgba = np.zeros((*blue.shape, 4), dtype=np.float32)
|
||
rgba[..., 3] = 1.0
|
||
rgba[~valid, 0] = NODATA_RGB[0]
|
||
rgba[~valid, 1] = NODATA_RGB[1]
|
||
rgba[~valid, 2] = NODATA_RGB[2]
|
||
span = vmax - vmin or 1.0
|
||
for band, idx in ((red, 0), (green, 1), (blue, 2)):
|
||
norm = np.clip((band - vmin) / span, 0.0, 1.0)
|
||
rgba[..., idx] = np.where(valid, norm, rgba[..., idx])
|
||
return rgba
|
||
|
||
|
||
def _tight_slices(mask: np.ndarray, margin: int = 2) -> tuple[slice, slice] | None:
|
||
"""Crop to the bounding box of True pixels in *mask*."""
|
||
rows, cols = np.where(mask)
|
||
if rows.size < 8:
|
||
return None
|
||
r0 = max(0, int(rows.min()) - margin)
|
||
r1 = min(mask.shape[0], int(rows.max()) + margin + 1)
|
||
c0 = max(0, int(cols.min()) - margin)
|
||
c1 = min(mask.shape[1], int(cols.max()) + margin + 1)
|
||
if r1 - r0 < 8 or c1 - c0 < 8:
|
||
return None
|
||
return slice(r0, r1), slice(c0, c1)
|
||
|
||
|
||
def _crop_slices(
|
||
height: int, width: int, center_row: int, center_col: int, half_px: int
|
||
) -> tuple[slice, slice]:
|
||
r0 = max(0, center_row - half_px)
|
||
r1 = min(height, center_row + half_px)
|
||
c0 = max(0, center_col - half_px)
|
||
c1 = min(width, center_col + half_px)
|
||
return slice(r0, r1), slice(c0, c1)
|
||
|
||
|
||
def _phenocam_pixel(
|
||
meta: dict,
|
||
site_position_lat_lon: tuple[float, float],
|
||
) -> tuple[int, int] | None:
|
||
lat, lon = site_position_lat_lon
|
||
try:
|
||
r, c = rowcol(meta["transform"], [lon], [lat], op=meta.get("crs"))
|
||
return int(r[0]), int(c[0])
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
def _resolve_row_paths(
|
||
data_dir: Path,
|
||
site: str,
|
||
season: int,
|
||
entry: dict,
|
||
strategy: str,
|
||
sigma: int,
|
||
*,
|
||
gap_days: int,
|
||
) -> tuple[Path, Path, Path, Path] | None:
|
||
pred_ymd = yyyymmdd_from_iso(entry["prediction_date"])
|
||
transition = entry["transition"]
|
||
prep = _prepared_base(data_dir, site, season, strategy)
|
||
withheld_fn = entry.get("withheld_s2_filename")
|
||
if not withheld_fn:
|
||
return None
|
||
withheld = prep / "s2" / withheld_fn
|
||
fusion = (
|
||
_gap_spatial_fusion_dir(data_dir, site, season, gap_days, transition, strategy, sigma)
|
||
/ f"REFL_{pred_ymd}.tif"
|
||
)
|
||
s3_exact = prep / "s3" / f"composite_{pred_ymd}.tif"
|
||
s3 = (
|
||
s3_exact
|
||
if s3_exact.is_file()
|
||
else nearest_s3_composite(prep / "s3", entry["prediction_date"])
|
||
)
|
||
w0 = _iso_to_date(entry["window_start"])
|
||
w1 = _iso_to_date(entry["window_end"])
|
||
window_ymds = acquisition_yyyymmdd_in_window(prep / "s2", w0, w1)
|
||
exclude = window_ymds | _exclude_ymds(entry)
|
||
nearest = nearest_stack_s2(prep / "s2", entry["prediction_date"], exclude_ymds=exclude)
|
||
if not withheld.is_file() or not fusion.is_file() or s3 is None or nearest is None:
|
||
return None
|
||
return withheld, fusion, s3, nearest
|
||
|
||
|
||
def _panel_date_subtitle(paths: tuple[Path, Path, Path, Path], entry: dict) -> str:
|
||
pred = entry["prediction_date"][:10]
|
||
wh = entry.get("withheld_s2_date") or ""
|
||
wh_short = wh[:10] if wh else "—"
|
||
nearest_ymd = REFL_DATE_RE.search(paths[3].name)
|
||
s3_ymd = S3_COMPOSITE_RE.search(paths[2].name)
|
||
n_s2 = nearest_ymd.group(1) if nearest_ymd else "?"
|
||
n_s3 = s3_ymd.group(1) if s3_ymd else "?"
|
||
s3_note = "" if paths[2].name.endswith(f"{pred.replace('-', '')}.tif") else f" (S3 {n_s3})"
|
||
return (
|
||
f"pred. {pred}; withheld {wh_short}; "
|
||
f"nearest S2 {n_s2[:4]}-{n_s2[4:6]}-{n_s2[6:8]}{s3_note}"
|
||
)
|
||
|
||
|
||
def build_site_panel(
|
||
site: str,
|
||
season: int,
|
||
data_dir: Path,
|
||
out_png: Path,
|
||
*,
|
||
best_bti_scenario: str,
|
||
site_label: str,
|
||
site_position_lat_lon: tuple[float, float] | None = None,
|
||
gap_days: int = 30,
|
||
) -> bool:
|
||
"""Build 2×4 RGB figure; return False if manifest or any transition row is incomplete."""
|
||
manifest_path = data_dir / site / str(season) / "validation" / "gap_manifest.json"
|
||
if not manifest_path.is_file():
|
||
return False
|
||
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||
strategy, sigma = _parse_bti_scenario(best_bti_scenario)
|
||
rows: list[tuple[str, dict, tuple[Path, Path, Path, Path]]] = []
|
||
for transition in TRANSITIONS:
|
||
entry = next(
|
||
(
|
||
e
|
||
for e in manifest["entries"]
|
||
if e.get("gap_days") == gap_days and e.get("transition") == transition
|
||
),
|
||
None,
|
||
)
|
||
if not entry:
|
||
continue
|
||
paths = _resolve_row_paths(
|
||
data_dir, site, season, entry, strategy, sigma, gap_days=gap_days
|
||
)
|
||
if paths is None:
|
||
continue
|
||
rows.append((transition, entry, paths))
|
||
|
||
if not rows:
|
||
return False
|
||
|
||
crop_px = DISPLAY_HALF_PX
|
||
fig, axes = plt.subplots(
|
||
len(rows),
|
||
4,
|
||
figsize=(12.0, 2.8 * len(rows)),
|
||
squeeze=False,
|
||
constrained_layout=True,
|
||
)
|
||
for row_idx, (transition, entry, paths) in enumerate(rows):
|
||
row_title = ROW_LABELS.get(transition, transition)
|
||
subtitle = _panel_date_subtitle(paths, entry)
|
||
ref_grid = _reference_grid(paths[1])
|
||
if ref_grid is None:
|
||
continue
|
||
layers: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = []
|
||
for path in paths:
|
||
bgr = _read_bgr_on_grid(path, ref_grid)
|
||
if bgr is None:
|
||
layers = []
|
||
break
|
||
layers.append(bgr)
|
||
if len(layers) != 4:
|
||
for ax in axes[row_idx]:
|
||
ax.set_visible(False)
|
||
continue
|
||
|
||
h, w = ref_grid["height"], ref_grid["width"]
|
||
center_row, center_col = h // 2, w // 2
|
||
if site_position_lat_lon:
|
||
pix = _phenocam_pixel(ref_grid, site_position_lat_lon)
|
||
if pix:
|
||
center_row, center_col = pix
|
||
rs, cs = _crop_slices(h, w, center_row, center_col, crop_px)
|
||
if rs.stop - rs.start < 8 or cs.stop - cs.start < 8:
|
||
for ax in axes[row_idx]:
|
||
ax.set_visible(False)
|
||
continue
|
||
|
||
cropped_valid: list[np.ndarray] = []
|
||
rgba_panels: list[np.ndarray] = []
|
||
for bgr in layers:
|
||
blue, green, red = (b[rs, cs] for b in bgr)
|
||
valid = _refl_valid(blue, green, red)
|
||
cropped_valid.append(valid)
|
||
vmin, vmax = _panel_stretch_limits(blue, green, red, valid)
|
||
rgba_panels.append(
|
||
_bgr_to_rgba(
|
||
blue, green, red, valid=valid, vmin=vmin, vmax=vmax
|
||
)
|
||
)
|
||
|
||
union = cropped_valid[0]
|
||
for v in cropped_valid[1:]:
|
||
union |= v
|
||
tight = _tight_slices(union)
|
||
if tight is not None:
|
||
tr, tc = tight
|
||
rgba_panels = [p[tr, tc] for p in rgba_panels]
|
||
union = union[tr, tc]
|
||
|
||
crop_h, crop_w = rgba_panels[0].shape[:2]
|
||
mark_r = center_row - rs.start
|
||
mark_c = center_col - cs.start
|
||
if tight is not None:
|
||
mark_r -= tr.start
|
||
mark_c -= tc.start
|
||
|
||
for col_idx, (col_title, rgba) in enumerate(
|
||
zip(COL_TITLES, rgba_panels, strict=True)
|
||
):
|
||
ax = axes[row_idx, col_idx]
|
||
ax.imshow(rgba, origin="upper", aspect="equal", interpolation="nearest")
|
||
if col_idx == 0 and 0 <= mark_r < crop_h and 0 <= mark_c < crop_w:
|
||
ax.plot(
|
||
mark_c,
|
||
mark_r,
|
||
"+",
|
||
color="red",
|
||
markersize=8,
|
||
markeredgewidth=1.2,
|
||
)
|
||
if row_idx == 0:
|
||
ax.set_title(col_title, fontsize=9)
|
||
if col_idx == 0:
|
||
ax.set_ylabel(row_title, fontsize=9)
|
||
ax.set_xticks([])
|
||
ax.set_yticks([])
|
||
if col_idx == 3:
|
||
ax.text(
|
||
0.02,
|
||
0.02,
|
||
subtitle,
|
||
transform=ax.transAxes,
|
||
fontsize=6,
|
||
color="white",
|
||
va="bottom",
|
||
bbox=dict(boxstyle="round,pad=0.2", facecolor="black", alpha=0.55),
|
||
)
|
||
|
||
cal = manifest.get("s2_calendar_strategy", strategy)
|
||
cal_note = f"; S2 calendar {cal}" if cal != strategy else ""
|
||
fig.suptitle(
|
||
f"{site_label} ({season}) — best BtI {best_bti_scenario}{cal_note}, {gap_days}-d gap",
|
||
fontsize=10,
|
||
)
|
||
out_png.parent.mkdir(parents=True, exist_ok=True)
|
||
fig.savefig(out_png, dpi=150)
|
||
plt.close(fig)
|
||
return True
|