"""Export 2×4 RGB panels for Tier-A gap validation (thesis appendix). Crops follow the same fusion-valid bounding box as ``postprocessing.process_cropped`` and the webapp (``processed_*`` / ``common.js``), anchored on gap-degraded fusion at the prediction date; S2 and S3 are read from prepared stacks on that shared window. """ from __future__ import annotations import json import re from datetime import date, datetime from pathlib import Path import matplotlib.pyplot as plt import numpy as np import rasterio from rasterio import windows from rasterio.transform import rowcol from rasterio.warp import Resampling, reproject from gap_validation.s2_mask_dir import acquisition_yyyymmdd_in_window, yyyymmdd_from_iso REFL_DATE_RE = re.compile(r"S2A_MSIL2A_(\d{8})_REFL\.tif$") S3_COMPOSITE_RE = re.compile(r"composite_(\d{8})\.tif$") TRANSITIONS = ("green_up", "green_down") COL_TITLES = ("Withheld S2", "Gap fusion", "S3 composite", "Nearest S2") ROW_LABELS = {"green_up": "Green-up", "green_down": "Green-down"} VALID_REFL_THRESHOLD = 0.001 NODATA_RGB = (0.15, 0.15, 0.15) def _parse_bti_scenario(scenario: str) -> tuple[str, int]: m = re.match(r"^(aggressive|nonaggressive)_sigma(20|30)$", scenario) if not m: raise ValueError(f"expected BtI scenario key, got {scenario!r}") return m.group(1), int(m.group(2)) def _prepared_base(data_dir: Path, site: str, season: int, strategy: str) -> Path: return data_dir / site / str(season) / f"prepared_{strategy}" def _s2_strategy_fallbacks(strategy: str, manifest: dict) -> tuple[str, ...]: """Prepared trees to try for S2 REFL (best-BtI first, then manifest calendar).""" order: list[str] = [] for s in (strategy, manifest.get("s2_calendar_strategy")): if isinstance(s, str) and s and s not in order: order.append(s) for s in ("aggressive", "nonaggressive"): if s not in order: order.append(s) return tuple(order) def _find_prepared_s2_refl( data_dir: Path, site: str, season: int, filename: str, strategies: tuple[str, ...], ) -> Path | None: for strat in strategies: p = _prepared_base(data_dir, site, season, strat) / "s2" / filename if p.is_file(): return p return None def _gap_spatial_fusion_dir( data_dir: Path, site: str, season: int, gap_days: int, transition: str, strategy: str, sigma: int, ) -> Path: return ( data_dir / site / str(season) / "validation" / "fusion" / f"gap_{gap_days}_{transition}" / f"{strategy}_sigma{sigma}_bti" ) def _iso_to_date(iso_d: str) -> date: return datetime.strptime(iso_d[:10], "%Y-%m-%d").date() def _exclude_ymds(entry: dict) -> set[str]: withheld_fn = entry.get("withheld_s2_filename") or "" m = REFL_DATE_RE.search(withheld_fn) return {m.group(1)} if m else set() def nearest_stack_s2( prepared_s2_dir: Path, prediction_iso: str, *, exclude_ymds: set[str], ) -> Path | None: if not prepared_s2_dir.is_dir(): return None target = _iso_to_date(prediction_iso) best_path: Path | None = None best_delta: int | None = None for p in prepared_s2_dir.glob("S2A_MSIL2A_*_REFL.tif"): m = REFL_DATE_RE.search(p.name) if not m or m.group(1) in exclude_ymds: continue delta = abs((datetime.strptime(m.group(1), "%Y%m%d").date() - target).days) if best_delta is None or delta < best_delta: best_delta = delta best_path = p return best_path def nearest_s3_composite(prepared_s3_dir: Path, prediction_iso: str) -> Path | None: if not prepared_s3_dir.is_dir(): return None target = _iso_to_date(prediction_iso) best_path: Path | None = None best_delta: int | None = None for p in prepared_s3_dir.glob("composite_*.tif"): m = S3_COMPOSITE_RE.search(p.name) if not m: continue delta = abs((datetime.strptime(m.group(1), "%Y%m%d").date() - target).days) if best_delta is None or delta < best_delta: best_delta = delta best_path = p return best_path def _crop_window_from_fusion(fusion_path: Path) -> dict | None: """Fusion-valid crop (``postprocessing.process_cropped``) on the full prepared grid.""" if not fusion_path.is_file(): return None with rasterio.open(fusion_path) as src: data = src.read() valid = np.isfinite(data) & (data > VALID_REFL_THRESHOLD) rows = np.any(valid, axis=(0, 2)) cols = np.any(valid, axis=(0, 1)) row_idx = np.where(rows)[0] col_idx = np.where(cols)[0] if len(row_idx) == 0 or len(col_idx) == 0: return None r0, r1 = int(row_idx[0]), int(row_idx[-1]) c0, c1 = int(col_idx[0]), int(col_idx[-1]) w, h = c1 - c0 + 1, r1 - r0 + 1 win = windows.Window(c0, r0, w, h) return { "window": win, "crop_transform": windows.transform(win, src.transform), "full_transform": src.transform, "crs": src.crs, "profile": src.profile.copy(), } def _read_bgr_prepared_s2(prepared_refl: Path, crop: dict) -> tuple[np.ndarray, ...] | None: if not prepared_refl.is_file(): return None with rasterio.open(prepared_refl) as src: if src.count < 3: return None b, g, r = src.read(indexes=(1, 2, 3), window=crop["window"]) return b.astype(np.float64), g.astype(np.float64), r.astype(np.float64) def _read_bgr_gap_fusion(fusion_path: Path, crop: dict) -> tuple[np.ndarray, ...] | None: if not fusion_path.is_file(): return None with rasterio.open(fusion_path) as src: if src.count < 3: return None b, g, r = src.read(indexes=(1, 2, 3), window=crop["window"]) return b.astype(np.float64), g.astype(np.float64), r.astype(np.float64) def _read_bgr_prepared_s3(s3_path: Path, crop: dict) -> tuple[np.ndarray, ...] | None: """Resample S3 composite to the fusion grid, then crop (matches ``process_cropped``).""" if not s3_path.is_file(): return None with rasterio.open(s3_path) as src: if src.count < 3: return None temp_profile = crop["profile"].copy() temp_profile.update({"dtype": "float32", "count": src.count}) bands: list[np.ndarray] = [] with rasterio.MemoryFile() as memfile: with memfile.open(**temp_profile) as resampled: for i in range(1, src.count + 1): reproject( source=rasterio.band(src, i), destination=rasterio.band(resampled, i), src_transform=src.transform, src_crs=src.crs, dst_transform=crop["full_transform"], dst_crs=crop["crs"], resampling=Resampling.nearest, ) b, g, r = resampled.read( indexes=(1, 2, 3), window=crop["window"] ) bands = [ b.astype(np.float64), g.astype(np.float64), r.astype(np.float64), ] return bands[0], bands[1], bands[2] def _refl_valid(blue: np.ndarray, green: np.ndarray, red: np.ndarray) -> np.ndarray: return ( np.isfinite(blue) & np.isfinite(green) & np.isfinite(red) & (blue > VALID_REFL_THRESHOLD) & (green > VALID_REFL_THRESHOLD) & (red > VALID_REFL_THRESHOLD) ) def _panel_stretch_limits( blue: np.ndarray, green: np.ndarray, red: np.ndarray, valid: np.ndarray ) -> tuple[float, float]: """Per-panel 2--98 % stretch on positive reflectance (webapp ``common.js`` style).""" if not valid.any(): return 0.0, 1.0 vals = np.concatenate([red[valid], green[valid], blue[valid]]) lo, hi = np.percentile(vals, (2, 98)) if hi <= lo: return 0.0, 1.0 return float(lo), float(hi) def _bgr_to_rgba( blue: np.ndarray, green: np.ndarray, red: np.ndarray, *, valid: np.ndarray, vmin: float, vmax: float, ) -> np.ndarray: rgba = np.zeros((*blue.shape, 4), dtype=np.float32) rgba[..., 3] = 1.0 rgba[~valid, 0] = NODATA_RGB[0] rgba[~valid, 1] = NODATA_RGB[1] rgba[~valid, 2] = NODATA_RGB[2] span = vmax - vmin or 1.0 for band, idx in ((red, 0), (green, 1), (blue, 2)): norm = np.clip((band - vmin) / span, 0.0, 1.0) rgba[..., idx] = np.where(valid, norm, rgba[..., idx]) return rgba def _phenocam_pixel_cropped( crop: dict, site_position_lat_lon: tuple[float, float] ) -> tuple[int, int] | None: lat, lon = site_position_lat_lon try: r, c = rowcol( crop["crop_transform"], [lon], [lat], op=crop["crs"] ) return int(r[0]), int(c[0]) except Exception: return None def _resolve_row_paths( data_dir: Path, site: str, season: int, entry: dict, strategy: str, sigma: int, *, gap_days: int, manifest: dict, ) -> tuple[Path, Path, Path, Path] | None: pred_ymd = yyyymmdd_from_iso(entry["prediction_date"]) transition = entry["transition"] prep = _prepared_base(data_dir, site, season, strategy) s2_strats = _s2_strategy_fallbacks(strategy, manifest) withheld_fn = entry.get("withheld_s2_filename") if not withheld_fn: return None withheld = _find_prepared_s2_refl( data_dir, site, season, withheld_fn, s2_strats ) fusion = ( _gap_spatial_fusion_dir(data_dir, site, season, gap_days, transition, strategy, sigma) / f"REFL_{pred_ymd}.tif" ) s3_exact = prep / "s3" / f"composite_{pred_ymd}.tif" s3 = ( s3_exact if s3_exact.is_file() else nearest_s3_composite(prep / "s3", entry["prediction_date"]) ) w0 = _iso_to_date(entry["window_start"]) w1 = _iso_to_date(entry["window_end"]) nearest: Path | None = None for strat in s2_strats: prep_s2 = _prepared_base(data_dir, site, season, strat) / "s2" window_ymds = acquisition_yyyymmdd_in_window(prep_s2, w0, w1) exclude = window_ymds | _exclude_ymds(entry) nearest = nearest_stack_s2( prep_s2, entry["prediction_date"], exclude_ymds=exclude ) if nearest is not None: break if withheld is None or not fusion.is_file() or s3 is None or nearest is None: return None return withheld, fusion, s3, nearest def build_site_panel( site: str, season: int, data_dir: Path, out_png: Path, *, best_bti_scenario: str, site_label: str, site_position_lat_lon: tuple[float, float] | None = None, gap_days: int = 30, ) -> bool: """Build 2×4 RGB figure; return False if manifest or any transition row is incomplete.""" manifest_path = data_dir / site / str(season) / "validation" / "gap_manifest.json" if not manifest_path.is_file(): return False manifest = json.loads(manifest_path.read_text(encoding="utf-8")) strategy, sigma = _parse_bti_scenario(best_bti_scenario) rows: list[tuple[str, dict, tuple[Path, Path, Path, Path]]] = [] for transition in TRANSITIONS: entry = next( ( e for e in manifest["entries"] if e.get("gap_days") == gap_days and e.get("transition") == transition ), None, ) if not entry: continue paths = _resolve_row_paths( data_dir, site, season, entry, strategy, sigma, gap_days=gap_days, manifest=manifest, ) if paths is None: continue rows.append((transition, entry, paths)) if not rows: return False readers = ( _read_bgr_prepared_s2, _read_bgr_gap_fusion, _read_bgr_prepared_s3, _read_bgr_prepared_s2, ) fig, axes = plt.subplots( len(rows), 4, figsize=(12.0, 2.8 * len(rows)), squeeze=False, constrained_layout=True, ) for row_idx, (transition, entry, paths) in enumerate(rows): row_title = ROW_LABELS.get(transition, transition) crop = _crop_window_from_fusion(paths[1]) if crop is None: for ax in axes[row_idx]: ax.set_visible(False) continue layers: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = [] for path, read_fn in zip(paths, readers, strict=True): bgr = read_fn(path, crop) if bgr is None: layers = [] break layers.append(bgr) if len(layers) != 4: for ax in axes[row_idx]: ax.set_visible(False) continue mark: tuple[int, int] | None = None if site_position_lat_lon: mark = _phenocam_pixel_cropped(crop, site_position_lat_lon) for col_idx, (col_title, bgr) in enumerate(zip(COL_TITLES, layers, strict=True)): ax = axes[row_idx, col_idx] blue, green, red = bgr valid = _refl_valid(blue, green, red) vmin, vmax = _panel_stretch_limits(blue, green, red, valid) rgba = _bgr_to_rgba( blue, green, red, valid=valid, vmin=vmin, vmax=vmax ) ax.imshow(rgba, origin="upper", aspect="equal", interpolation="nearest") h, w = rgba.shape[:2] if col_idx == 0 and mark and 0 <= mark[0] < h and 0 <= mark[1] < w: ax.plot( mark[1], mark[0], "+", color="red", markersize=8, markeredgewidth=1.2, ) if row_idx == 0: ax.set_title(col_title, fontsize=9) if col_idx == 0: ax.set_ylabel(row_title, fontsize=9) ax.set_xticks([]) ax.set_yticks([]) fig.suptitle(f"{site_label} ({season})", fontsize=10) out_png.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_png, dpi=150) plt.close(fig) return True