diff --git a/gap_validation/export_rasters.py b/gap_validation/export_rasters.py new file mode 100644 index 0000000..682d2d3 --- /dev/null +++ b/gap_validation/export_rasters.py @@ -0,0 +1,460 @@ +"""Export 2×4 RGB panels for Tier-A gap validation (thesis appendix).""" + +from __future__ import annotations + +import json +import re +from datetime import date, datetime +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import rasterio +from rasterio.transform import rowcol +from rasterio.warp import Resampling, reproject + +from gap_validation.s2_mask_dir import acquisition_yyyymmdd_in_window, yyyymmdd_from_iso + +REFL_DATE_RE = re.compile(r"S2A_MSIL2A_(\d{8})_REFL\.tif$") +S3_COMPOSITE_RE = re.compile(r"composite_(\d{8})\.tif$") +TRANSITIONS = ("green_up", "green_down") +COL_TITLES = ("Withheld S2", "Gap fusion", "S3 composite", "Nearest S2") +ROW_LABELS = {"green_up": "Green-up", "green_down": "Green-down"} +# Fixed pixel window around PhenoCam for comparable framing across sites (~1 km). +DISPLAY_HALF_PX = 48 +# Match postprocessing / spatial_metrics positive-reflectance mask. +VALID_REFL_THRESHOLD = 0.001 +NODATA_RGB = (0.15, 0.15, 0.15) + + +def _parse_bti_scenario(scenario: str) -> tuple[str, int]: + m = re.match(r"^(aggressive|nonaggressive)_sigma(20|30)$", scenario) + if not m: + raise ValueError(f"expected BtI scenario key, got {scenario!r}") + return m.group(1), int(m.group(2)) + + +def _prepared_base(data_dir: Path, site: str, season: int, strategy: str) -> Path: + return data_dir / site / str(season) / f"prepared_{strategy}" + + +def _gap_spatial_fusion_dir( + data_dir: Path, + site: str, + season: int, + gap_days: int, + transition: str, + strategy: str, + sigma: int, +) -> Path: + return ( + data_dir + / site + / str(season) + / "validation" + / "fusion" + / f"gap_{gap_days}_{transition}" + / f"{strategy}_sigma{sigma}_bti" + ) + + +def _iso_to_date(iso_d: str) -> date: + return datetime.strptime(iso_d[:10], "%Y-%m-%d").date() + + +def _exclude_ymds(entry: dict) -> set[str]: + w0 = _iso_to_date(entry["window_start"]) + w1 = _iso_to_date(entry["window_end"]) + withheld_fn = entry.get("withheld_s2_filename") or "" + m = REFL_DATE_RE.search(withheld_fn) + excluded = set() + if m: + excluded.add(m.group(1)) + return excluded + + +def nearest_stack_s2( + prepared_s2_dir: Path, + prediction_iso: str, + *, + exclude_ymds: set[str], +) -> Path | None: + """Nearest prepared REFL acquisition to prediction, outside excluded days.""" + if not prepared_s2_dir.is_dir(): + return None + target = _iso_to_date(prediction_iso) + best_path: Path | None = None + best_delta: int | None = None + for p in prepared_s2_dir.glob("S2A_MSIL2A_*_REFL.tif"): + m = REFL_DATE_RE.search(p.name) + if not m: + continue + ymd = m.group(1) + if ymd in exclude_ymds: + continue + d = datetime.strptime(ymd, "%Y%m%d").date() + delta = abs((d - target).days) + if best_delta is None or delta < best_delta: + best_delta = delta + best_path = p + return best_path + + +def nearest_s3_composite( + prepared_s3_dir: Path, + prediction_iso: str, +) -> Path | None: + """Nearest daily S3 composite to prediction date.""" + if not prepared_s3_dir.is_dir(): + return None + target = _iso_to_date(prediction_iso) + best_path: Path | None = None + best_delta: int | None = None + for p in prepared_s3_dir.glob("composite_*.tif"): + m = S3_COMPOSITE_RE.search(p.name) + if not m: + continue + d = datetime.strptime(m.group(1), "%Y%m%d").date() + delta = abs((d - target).days) + if best_delta is None or delta < best_delta: + best_delta = delta + best_path = p + return best_path + + +def _reference_grid(path: Path) -> dict | None: + """Grid metadata from fusion / S2 REFL (fine S2 resolution).""" + if not path.is_file(): + return None + with rasterio.open(path) as src: + return { + "transform": src.transform, + "crs": src.crs, + "height": src.height, + "width": src.width, + "bounds": src.bounds, + } + + +def _grids_match(src: rasterio.DatasetReader, ref: dict) -> bool: + return ( + src.height == ref["height"] + and src.width == ref["width"] + and src.transform == ref["transform"] + and src.crs == ref["crs"] + ) + + +def _read_bgr_on_grid(path: Path, ref: dict) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None: + """Read BGR on the fusion grid; reproject only when needed (S3 coarse stack).""" + if not path.is_file(): + return None + with rasterio.open(path) as src: + if src.count < 3: + return None + if _grids_match(src, ref): + return ( + src.read(1).astype(np.float64), + src.read(2).astype(np.float64), + src.read(3).astype(np.float64), + ) + shape = (ref["height"], ref["width"]) + bands: list[np.ndarray] = [] + for band_idx in (1, 2, 3): + dst = np.full(shape, np.nan, dtype=np.float32) + reproject( + source=rasterio.band(src, band_idx), + destination=dst, + src_transform=src.transform, + src_crs=src.crs, + dst_transform=ref["transform"], + dst_crs=ref["crs"], + resampling=Resampling.bilinear, + ) + bands.append(dst.astype(np.float64)) + return bands[0], bands[1], bands[2] + + +def _refl_valid( + blue: np.ndarray, green: np.ndarray, red: np.ndarray +) -> np.ndarray: + """Positive reflectance mask (aligned with postprocessing / Tier-A metrics).""" + return ( + np.isfinite(blue) + & np.isfinite(green) + & np.isfinite(red) + & (blue > VALID_REFL_THRESHOLD) + & (green > VALID_REFL_THRESHOLD) + & (red > VALID_REFL_THRESHOLD) + ) + + +def _panel_stretch_limits( + blue: np.ndarray, green: np.ndarray, red: np.ndarray, valid: np.ndarray +) -> tuple[float, float]: + if not valid.any(): + return 0.0, 1.0 + vals = np.concatenate([red[valid], green[valid], blue[valid]]) + lo, hi = np.percentile(vals, (2, 98)) + if hi <= lo: + return 0.0, 1.0 + return float(lo), float(hi) + + +def _bgr_to_rgba( + blue: np.ndarray, + green: np.ndarray, + red: np.ndarray, + *, + valid: np.ndarray, + vmin: float, + vmax: float, +) -> np.ndarray: + rgba = np.zeros((*blue.shape, 4), dtype=np.float32) + rgba[..., 3] = 1.0 + rgba[~valid, 0] = NODATA_RGB[0] + rgba[~valid, 1] = NODATA_RGB[1] + rgba[~valid, 2] = NODATA_RGB[2] + span = vmax - vmin or 1.0 + for band, idx in ((red, 0), (green, 1), (blue, 2)): + norm = np.clip((band - vmin) / span, 0.0, 1.0) + rgba[..., idx] = np.where(valid, norm, rgba[..., idx]) + return rgba + + +def _tight_slices(mask: np.ndarray, margin: int = 2) -> tuple[slice, slice] | None: + """Crop to the bounding box of True pixels in *mask*.""" + rows, cols = np.where(mask) + if rows.size < 8: + return None + r0 = max(0, int(rows.min()) - margin) + r1 = min(mask.shape[0], int(rows.max()) + margin + 1) + c0 = max(0, int(cols.min()) - margin) + c1 = min(mask.shape[1], int(cols.max()) + margin + 1) + if r1 - r0 < 8 or c1 - c0 < 8: + return None + return slice(r0, r1), slice(c0, c1) + + +def _crop_slices( + height: int, width: int, center_row: int, center_col: int, half_px: int +) -> tuple[slice, slice]: + r0 = max(0, center_row - half_px) + r1 = min(height, center_row + half_px) + c0 = max(0, center_col - half_px) + c1 = min(width, center_col + half_px) + return slice(r0, r1), slice(c0, c1) + + +def _phenocam_pixel( + meta: dict, + site_position_lat_lon: tuple[float, float], +) -> tuple[int, int] | None: + lat, lon = site_position_lat_lon + try: + r, c = rowcol(meta["transform"], [lon], [lat], op=meta.get("crs")) + return int(r[0]), int(c[0]) + except Exception: + return None + + +def _resolve_row_paths( + data_dir: Path, + site: str, + season: int, + entry: dict, + strategy: str, + sigma: int, + *, + gap_days: int, +) -> tuple[Path, Path, Path, Path] | None: + pred_ymd = yyyymmdd_from_iso(entry["prediction_date"]) + transition = entry["transition"] + prep = _prepared_base(data_dir, site, season, strategy) + withheld_fn = entry.get("withheld_s2_filename") + if not withheld_fn: + return None + withheld = prep / "s2" / withheld_fn + fusion = ( + _gap_spatial_fusion_dir(data_dir, site, season, gap_days, transition, strategy, sigma) + / f"REFL_{pred_ymd}.tif" + ) + s3_exact = prep / "s3" / f"composite_{pred_ymd}.tif" + s3 = ( + s3_exact + if s3_exact.is_file() + else nearest_s3_composite(prep / "s3", entry["prediction_date"]) + ) + w0 = _iso_to_date(entry["window_start"]) + w1 = _iso_to_date(entry["window_end"]) + window_ymds = acquisition_yyyymmdd_in_window(prep / "s2", w0, w1) + exclude = window_ymds | _exclude_ymds(entry) + nearest = nearest_stack_s2(prep / "s2", entry["prediction_date"], exclude_ymds=exclude) + if not withheld.is_file() or not fusion.is_file() or s3 is None or nearest is None: + return None + return withheld, fusion, s3, nearest + + +def _panel_date_subtitle(paths: tuple[Path, Path, Path, Path], entry: dict) -> str: + pred = entry["prediction_date"][:10] + wh = entry.get("withheld_s2_date") or "" + wh_short = wh[:10] if wh else "—" + nearest_ymd = REFL_DATE_RE.search(paths[3].name) + s3_ymd = S3_COMPOSITE_RE.search(paths[2].name) + n_s2 = nearest_ymd.group(1) if nearest_ymd else "?" + n_s3 = s3_ymd.group(1) if s3_ymd else "?" + s3_note = "" if paths[2].name.endswith(f"{pred.replace('-', '')}.tif") else f" (S3 {n_s3})" + return ( + f"pred. {pred}; withheld {wh_short}; " + f"nearest S2 {n_s2[:4]}-{n_s2[4:6]}-{n_s2[6:8]}{s3_note}" + ) + + +def build_site_panel( + site: str, + season: int, + data_dir: Path, + out_png: Path, + *, + best_bti_scenario: str, + site_label: str, + site_position_lat_lon: tuple[float, float] | None = None, + gap_days: int = 30, +) -> bool: + """Build 2×4 RGB figure; return False if manifest or any transition row is incomplete.""" + manifest_path = data_dir / site / str(season) / "validation" / "gap_manifest.json" + if not manifest_path.is_file(): + return False + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + strategy, sigma = _parse_bti_scenario(best_bti_scenario) + rows: list[tuple[str, dict, tuple[Path, Path, Path, Path]]] = [] + for transition in TRANSITIONS: + entry = next( + ( + e + for e in manifest["entries"] + if e.get("gap_days") == gap_days and e.get("transition") == transition + ), + None, + ) + if not entry: + continue + paths = _resolve_row_paths( + data_dir, site, season, entry, strategy, sigma, gap_days=gap_days + ) + if paths is None: + continue + rows.append((transition, entry, paths)) + + if not rows: + return False + + crop_px = DISPLAY_HALF_PX + fig, axes = plt.subplots( + len(rows), + 4, + figsize=(12.0, 2.8 * len(rows)), + squeeze=False, + constrained_layout=True, + ) + for row_idx, (transition, entry, paths) in enumerate(rows): + row_title = ROW_LABELS.get(transition, transition) + subtitle = _panel_date_subtitle(paths, entry) + ref_grid = _reference_grid(paths[1]) + if ref_grid is None: + continue + layers: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = [] + for path in paths: + bgr = _read_bgr_on_grid(path, ref_grid) + if bgr is None: + layers = [] + break + layers.append(bgr) + if len(layers) != 4: + for ax in axes[row_idx]: + ax.set_visible(False) + continue + + h, w = ref_grid["height"], ref_grid["width"] + center_row, center_col = h // 2, w // 2 + if site_position_lat_lon: + pix = _phenocam_pixel(ref_grid, site_position_lat_lon) + if pix: + center_row, center_col = pix + rs, cs = _crop_slices(h, w, center_row, center_col, crop_px) + if rs.stop - rs.start < 8 or cs.stop - cs.start < 8: + for ax in axes[row_idx]: + ax.set_visible(False) + continue + + cropped_valid: list[np.ndarray] = [] + rgba_panels: list[np.ndarray] = [] + for bgr in layers: + blue, green, red = (b[rs, cs] for b in bgr) + valid = _refl_valid(blue, green, red) + cropped_valid.append(valid) + vmin, vmax = _panel_stretch_limits(blue, green, red, valid) + rgba_panels.append( + _bgr_to_rgba( + blue, green, red, valid=valid, vmin=vmin, vmax=vmax + ) + ) + + union = cropped_valid[0] + for v in cropped_valid[1:]: + union |= v + tight = _tight_slices(union) + if tight is not None: + tr, tc = tight + rgba_panels = [p[tr, tc] for p in rgba_panels] + union = union[tr, tc] + + crop_h, crop_w = rgba_panels[0].shape[:2] + mark_r = center_row - rs.start + mark_c = center_col - cs.start + if tight is not None: + mark_r -= tr.start + mark_c -= tc.start + + for col_idx, (col_title, rgba) in enumerate( + zip(COL_TITLES, rgba_panels, strict=True) + ): + ax = axes[row_idx, col_idx] + ax.imshow(rgba, origin="upper", aspect="equal", interpolation="nearest") + if col_idx == 0 and 0 <= mark_r < crop_h and 0 <= mark_c < crop_w: + ax.plot( + mark_c, + mark_r, + "+", + color="red", + markersize=8, + markeredgewidth=1.2, + ) + if row_idx == 0: + ax.set_title(col_title, fontsize=9) + if col_idx == 0: + ax.set_ylabel(row_title, fontsize=9) + ax.set_xticks([]) + ax.set_yticks([]) + if col_idx == 3: + ax.text( + 0.02, + 0.02, + subtitle, + transform=ax.transAxes, + fontsize=6, + color="white", + va="bottom", + bbox=dict(boxstyle="round,pad=0.2", facecolor="black", alpha=0.55), + ) + + cal = manifest.get("s2_calendar_strategy", strategy) + cal_note = f"; S2 calendar {cal}" if cal != strategy else "" + fig.suptitle( + f"{site_label} ({season}) — best BtI {best_bti_scenario}{cal_note}, {gap_days}-d gap", + fontsize=10, + ) + out_png.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(out_png, dpi=150) + plt.close(fig) + return True