"""Export 2×4 RGB panels for Tier-A gap validation (thesis appendix).""" from __future__ import annotations import json import re from datetime import date, datetime from pathlib import Path import matplotlib.pyplot as plt import numpy as np import rasterio from rasterio.transform import rowcol from rasterio.warp import Resampling, reproject from gap_validation.s2_mask_dir import acquisition_yyyymmdd_in_window, yyyymmdd_from_iso REFL_DATE_RE = re.compile(r"S2A_MSIL2A_(\d{8})_REFL\.tif$") S3_COMPOSITE_RE = re.compile(r"composite_(\d{8})\.tif$") TRANSITIONS = ("green_up", "green_down") COL_TITLES = ("Withheld S2", "Gap fusion", "S3 composite", "Nearest S2") ROW_LABELS = {"green_up": "Green-up", "green_down": "Green-down"} # Fixed pixel window around PhenoCam for comparable framing across sites (~1 km). DISPLAY_HALF_PX = 48 # Match postprocessing / spatial_metrics positive-reflectance mask. VALID_REFL_THRESHOLD = 0.001 NODATA_RGB = (0.15, 0.15, 0.15) def _parse_bti_scenario(scenario: str) -> tuple[str, int]: m = re.match(r"^(aggressive|nonaggressive)_sigma(20|30)$", scenario) if not m: raise ValueError(f"expected BtI scenario key, got {scenario!r}") return m.group(1), int(m.group(2)) def _prepared_base(data_dir: Path, site: str, season: int, strategy: str) -> Path: return data_dir / site / str(season) / f"prepared_{strategy}" def _gap_spatial_fusion_dir( data_dir: Path, site: str, season: int, gap_days: int, transition: str, strategy: str, sigma: int, ) -> Path: return ( data_dir / site / str(season) / "validation" / "fusion" / f"gap_{gap_days}_{transition}" / f"{strategy}_sigma{sigma}_bti" ) def _iso_to_date(iso_d: str) -> date: return datetime.strptime(iso_d[:10], "%Y-%m-%d").date() def _exclude_ymds(entry: dict) -> set[str]: w0 = _iso_to_date(entry["window_start"]) w1 = _iso_to_date(entry["window_end"]) withheld_fn = entry.get("withheld_s2_filename") or "" m = REFL_DATE_RE.search(withheld_fn) excluded = set() if m: excluded.add(m.group(1)) return excluded def nearest_stack_s2( prepared_s2_dir: Path, prediction_iso: str, *, exclude_ymds: set[str], ) -> Path | None: """Nearest prepared REFL acquisition to prediction, outside excluded days.""" if not prepared_s2_dir.is_dir(): return None target = _iso_to_date(prediction_iso) best_path: Path | None = None best_delta: int | None = None for p in prepared_s2_dir.glob("S2A_MSIL2A_*_REFL.tif"): m = REFL_DATE_RE.search(p.name) if not m: continue ymd = m.group(1) if ymd in exclude_ymds: continue d = datetime.strptime(ymd, "%Y%m%d").date() delta = abs((d - target).days) if best_delta is None or delta < best_delta: best_delta = delta best_path = p return best_path def nearest_s3_composite( prepared_s3_dir: Path, prediction_iso: str, ) -> Path | None: """Nearest daily S3 composite to prediction date.""" if not prepared_s3_dir.is_dir(): return None target = _iso_to_date(prediction_iso) best_path: Path | None = None best_delta: int | None = None for p in prepared_s3_dir.glob("composite_*.tif"): m = S3_COMPOSITE_RE.search(p.name) if not m: continue d = datetime.strptime(m.group(1), "%Y%m%d").date() delta = abs((d - target).days) if best_delta is None or delta < best_delta: best_delta = delta best_path = p return best_path def _reference_grid(path: Path) -> dict | None: """Grid metadata from fusion / S2 REFL (fine S2 resolution).""" if not path.is_file(): return None with rasterio.open(path) as src: return { "transform": src.transform, "crs": src.crs, "height": src.height, "width": src.width, "bounds": src.bounds, } def _grids_match(src: rasterio.DatasetReader, ref: dict) -> bool: return ( src.height == ref["height"] and src.width == ref["width"] and src.transform == ref["transform"] and src.crs == ref["crs"] ) def _read_bgr_on_grid(path: Path, ref: dict) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None: """Read BGR on the fusion grid; reproject only when needed (S3 coarse stack).""" if not path.is_file(): return None with rasterio.open(path) as src: if src.count < 3: return None if _grids_match(src, ref): return ( src.read(1).astype(np.float64), src.read(2).astype(np.float64), src.read(3).astype(np.float64), ) shape = (ref["height"], ref["width"]) bands: list[np.ndarray] = [] for band_idx in (1, 2, 3): dst = np.full(shape, np.nan, dtype=np.float32) reproject( source=rasterio.band(src, band_idx), destination=dst, src_transform=src.transform, src_crs=src.crs, dst_transform=ref["transform"], dst_crs=ref["crs"], resampling=Resampling.bilinear, ) bands.append(dst.astype(np.float64)) return bands[0], bands[1], bands[2] def _refl_valid( blue: np.ndarray, green: np.ndarray, red: np.ndarray ) -> np.ndarray: """Positive reflectance mask (aligned with postprocessing / Tier-A metrics).""" return ( np.isfinite(blue) & np.isfinite(green) & np.isfinite(red) & (blue > VALID_REFL_THRESHOLD) & (green > VALID_REFL_THRESHOLD) & (red > VALID_REFL_THRESHOLD) ) def _panel_stretch_limits( blue: np.ndarray, green: np.ndarray, red: np.ndarray, valid: np.ndarray ) -> tuple[float, float]: if not valid.any(): return 0.0, 1.0 vals = np.concatenate([red[valid], green[valid], blue[valid]]) lo, hi = np.percentile(vals, (2, 98)) if hi <= lo: return 0.0, 1.0 return float(lo), float(hi) def _bgr_to_rgba( blue: np.ndarray, green: np.ndarray, red: np.ndarray, *, valid: np.ndarray, vmin: float, vmax: float, ) -> np.ndarray: rgba = np.zeros((*blue.shape, 4), dtype=np.float32) rgba[..., 3] = 1.0 rgba[~valid, 0] = NODATA_RGB[0] rgba[~valid, 1] = NODATA_RGB[1] rgba[~valid, 2] = NODATA_RGB[2] span = vmax - vmin or 1.0 for band, idx in ((red, 0), (green, 1), (blue, 2)): norm = np.clip((band - vmin) / span, 0.0, 1.0) rgba[..., idx] = np.where(valid, norm, rgba[..., idx]) return rgba def _tight_slices(mask: np.ndarray, margin: int = 2) -> tuple[slice, slice] | None: """Crop to the bounding box of True pixels in *mask*.""" rows, cols = np.where(mask) if rows.size < 8: return None r0 = max(0, int(rows.min()) - margin) r1 = min(mask.shape[0], int(rows.max()) + margin + 1) c0 = max(0, int(cols.min()) - margin) c1 = min(mask.shape[1], int(cols.max()) + margin + 1) if r1 - r0 < 8 or c1 - c0 < 8: return None return slice(r0, r1), slice(c0, c1) def _crop_slices( height: int, width: int, center_row: int, center_col: int, half_px: int ) -> tuple[slice, slice]: r0 = max(0, center_row - half_px) r1 = min(height, center_row + half_px) c0 = max(0, center_col - half_px) c1 = min(width, center_col + half_px) return slice(r0, r1), slice(c0, c1) def _phenocam_pixel( meta: dict, site_position_lat_lon: tuple[float, float], ) -> tuple[int, int] | None: lat, lon = site_position_lat_lon try: r, c = rowcol(meta["transform"], [lon], [lat], op=meta.get("crs")) return int(r[0]), int(c[0]) except Exception: return None def _resolve_row_paths( data_dir: Path, site: str, season: int, entry: dict, strategy: str, sigma: int, *, gap_days: int, ) -> tuple[Path, Path, Path, Path] | None: pred_ymd = yyyymmdd_from_iso(entry["prediction_date"]) transition = entry["transition"] prep = _prepared_base(data_dir, site, season, strategy) withheld_fn = entry.get("withheld_s2_filename") if not withheld_fn: return None withheld = prep / "s2" / withheld_fn fusion = ( _gap_spatial_fusion_dir(data_dir, site, season, gap_days, transition, strategy, sigma) / f"REFL_{pred_ymd}.tif" ) s3_exact = prep / "s3" / f"composite_{pred_ymd}.tif" s3 = ( s3_exact if s3_exact.is_file() else nearest_s3_composite(prep / "s3", entry["prediction_date"]) ) w0 = _iso_to_date(entry["window_start"]) w1 = _iso_to_date(entry["window_end"]) window_ymds = acquisition_yyyymmdd_in_window(prep / "s2", w0, w1) exclude = window_ymds | _exclude_ymds(entry) nearest = nearest_stack_s2(prep / "s2", entry["prediction_date"], exclude_ymds=exclude) if not withheld.is_file() or not fusion.is_file() or s3 is None or nearest is None: return None return withheld, fusion, s3, nearest def _panel_date_subtitle(paths: tuple[Path, Path, Path, Path], entry: dict) -> str: pred = entry["prediction_date"][:10] wh = entry.get("withheld_s2_date") or "" wh_short = wh[:10] if wh else "—" nearest_ymd = REFL_DATE_RE.search(paths[3].name) s3_ymd = S3_COMPOSITE_RE.search(paths[2].name) n_s2 = nearest_ymd.group(1) if nearest_ymd else "?" n_s3 = s3_ymd.group(1) if s3_ymd else "?" s3_note = "" if paths[2].name.endswith(f"{pred.replace('-', '')}.tif") else f" (S3 {n_s3})" return ( f"pred. {pred}; withheld {wh_short}; " f"nearest S2 {n_s2[:4]}-{n_s2[4:6]}-{n_s2[6:8]}{s3_note}" ) def build_site_panel( site: str, season: int, data_dir: Path, out_png: Path, *, best_bti_scenario: str, site_label: str, site_position_lat_lon: tuple[float, float] | None = None, gap_days: int = 30, ) -> bool: """Build 2×4 RGB figure; return False if manifest or any transition row is incomplete.""" manifest_path = data_dir / site / str(season) / "validation" / "gap_manifest.json" if not manifest_path.is_file(): return False manifest = json.loads(manifest_path.read_text(encoding="utf-8")) strategy, sigma = _parse_bti_scenario(best_bti_scenario) rows: list[tuple[str, dict, tuple[Path, Path, Path, Path]]] = [] for transition in TRANSITIONS: entry = next( ( e for e in manifest["entries"] if e.get("gap_days") == gap_days and e.get("transition") == transition ), None, ) if not entry: continue paths = _resolve_row_paths( data_dir, site, season, entry, strategy, sigma, gap_days=gap_days ) if paths is None: continue rows.append((transition, entry, paths)) if not rows: return False crop_px = DISPLAY_HALF_PX fig, axes = plt.subplots( len(rows), 4, figsize=(12.0, 2.8 * len(rows)), squeeze=False, constrained_layout=True, ) for row_idx, (transition, entry, paths) in enumerate(rows): row_title = ROW_LABELS.get(transition, transition) subtitle = _panel_date_subtitle(paths, entry) ref_grid = _reference_grid(paths[1]) if ref_grid is None: continue layers: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = [] for path in paths: bgr = _read_bgr_on_grid(path, ref_grid) if bgr is None: layers = [] break layers.append(bgr) if len(layers) != 4: for ax in axes[row_idx]: ax.set_visible(False) continue h, w = ref_grid["height"], ref_grid["width"] center_row, center_col = h // 2, w // 2 if site_position_lat_lon: pix = _phenocam_pixel(ref_grid, site_position_lat_lon) if pix: center_row, center_col = pix rs, cs = _crop_slices(h, w, center_row, center_col, crop_px) if rs.stop - rs.start < 8 or cs.stop - cs.start < 8: for ax in axes[row_idx]: ax.set_visible(False) continue cropped_valid: list[np.ndarray] = [] rgba_panels: list[np.ndarray] = [] for bgr in layers: blue, green, red = (b[rs, cs] for b in bgr) valid = _refl_valid(blue, green, red) cropped_valid.append(valid) vmin, vmax = _panel_stretch_limits(blue, green, red, valid) rgba_panels.append( _bgr_to_rgba( blue, green, red, valid=valid, vmin=vmin, vmax=vmax ) ) union = cropped_valid[0] for v in cropped_valid[1:]: union |= v tight = _tight_slices(union) if tight is not None: tr, tc = tight rgba_panels = [p[tr, tc] for p in rgba_panels] union = union[tr, tc] crop_h, crop_w = rgba_panels[0].shape[:2] mark_r = center_row - rs.start mark_c = center_col - cs.start if tight is not None: mark_r -= tr.start mark_c -= tc.start for col_idx, (col_title, rgba) in enumerate( zip(COL_TITLES, rgba_panels, strict=True) ): ax = axes[row_idx, col_idx] ax.imshow(rgba, origin="upper", aspect="equal", interpolation="nearest") if col_idx == 0 and 0 <= mark_r < crop_h and 0 <= mark_c < crop_w: ax.plot( mark_c, mark_r, "+", color="red", markersize=8, markeredgewidth=1.2, ) if row_idx == 0: ax.set_title(col_title, fontsize=9) if col_idx == 0: ax.set_ylabel(row_title, fontsize=9) ax.set_xticks([]) ax.set_yticks([]) if col_idx == 3: ax.text( 0.02, 0.02, subtitle, transform=ax.transAxes, fontsize=6, color="white", va="bottom", bbox=dict(boxstyle="round,pad=0.2", facecolor="black", alpha=0.55), ) cal = manifest.get("s2_calendar_strategy", strategy) cal_note = f"; S2 calendar {cal}" if cal != strategy else "" fig.suptitle( f"{site_label} ({season}) — best BtI {best_bti_scenario}{cal_note}, {gap_days}-d gap", fontsize=10, ) out_png.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_png, dpi=150) plt.close(fig) return True