efast-phenocam-validation/gap_validation/export_rasters.py
Felix Delattre 2dba38af5b foo
2026-06-02 10:34:20 +02:00

460 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Export 2×4 RGB panels for Tier-A gap validation (thesis appendix)."""
from __future__ import annotations
import json
import re
from datetime import date, datetime
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import rasterio
from rasterio.transform import rowcol
from rasterio.warp import Resampling, reproject
from gap_validation.s2_mask_dir import acquisition_yyyymmdd_in_window, yyyymmdd_from_iso
REFL_DATE_RE = re.compile(r"S2A_MSIL2A_(\d{8})_REFL\.tif$")
S3_COMPOSITE_RE = re.compile(r"composite_(\d{8})\.tif$")
TRANSITIONS = ("green_up", "green_down")
COL_TITLES = ("Withheld S2", "Gap fusion", "S3 composite", "Nearest S2")
ROW_LABELS = {"green_up": "Green-up", "green_down": "Green-down"}
# Fixed pixel window around PhenoCam for comparable framing across sites (~1 km).
DISPLAY_HALF_PX = 48
# Match postprocessing / spatial_metrics positive-reflectance mask.
VALID_REFL_THRESHOLD = 0.001
NODATA_RGB = (0.15, 0.15, 0.15)
def _parse_bti_scenario(scenario: str) -> tuple[str, int]:
m = re.match(r"^(aggressive|nonaggressive)_sigma(20|30)$", scenario)
if not m:
raise ValueError(f"expected BtI scenario key, got {scenario!r}")
return m.group(1), int(m.group(2))
def _prepared_base(data_dir: Path, site: str, season: int, strategy: str) -> Path:
return data_dir / site / str(season) / f"prepared_{strategy}"
def _gap_spatial_fusion_dir(
data_dir: Path,
site: str,
season: int,
gap_days: int,
transition: str,
strategy: str,
sigma: int,
) -> Path:
return (
data_dir
/ site
/ str(season)
/ "validation"
/ "fusion"
/ f"gap_{gap_days}_{transition}"
/ f"{strategy}_sigma{sigma}_bti"
)
def _iso_to_date(iso_d: str) -> date:
return datetime.strptime(iso_d[:10], "%Y-%m-%d").date()
def _exclude_ymds(entry: dict) -> set[str]:
w0 = _iso_to_date(entry["window_start"])
w1 = _iso_to_date(entry["window_end"])
withheld_fn = entry.get("withheld_s2_filename") or ""
m = REFL_DATE_RE.search(withheld_fn)
excluded = set()
if m:
excluded.add(m.group(1))
return excluded
def nearest_stack_s2(
prepared_s2_dir: Path,
prediction_iso: str,
*,
exclude_ymds: set[str],
) -> Path | None:
"""Nearest prepared REFL acquisition to prediction, outside excluded days."""
if not prepared_s2_dir.is_dir():
return None
target = _iso_to_date(prediction_iso)
best_path: Path | None = None
best_delta: int | None = None
for p in prepared_s2_dir.glob("S2A_MSIL2A_*_REFL.tif"):
m = REFL_DATE_RE.search(p.name)
if not m:
continue
ymd = m.group(1)
if ymd in exclude_ymds:
continue
d = datetime.strptime(ymd, "%Y%m%d").date()
delta = abs((d - target).days)
if best_delta is None or delta < best_delta:
best_delta = delta
best_path = p
return best_path
def nearest_s3_composite(
prepared_s3_dir: Path,
prediction_iso: str,
) -> Path | None:
"""Nearest daily S3 composite to prediction date."""
if not prepared_s3_dir.is_dir():
return None
target = _iso_to_date(prediction_iso)
best_path: Path | None = None
best_delta: int | None = None
for p in prepared_s3_dir.glob("composite_*.tif"):
m = S3_COMPOSITE_RE.search(p.name)
if not m:
continue
d = datetime.strptime(m.group(1), "%Y%m%d").date()
delta = abs((d - target).days)
if best_delta is None or delta < best_delta:
best_delta = delta
best_path = p
return best_path
def _reference_grid(path: Path) -> dict | None:
"""Grid metadata from fusion / S2 REFL (fine S2 resolution)."""
if not path.is_file():
return None
with rasterio.open(path) as src:
return {
"transform": src.transform,
"crs": src.crs,
"height": src.height,
"width": src.width,
"bounds": src.bounds,
}
def _grids_match(src: rasterio.DatasetReader, ref: dict) -> bool:
return (
src.height == ref["height"]
and src.width == ref["width"]
and src.transform == ref["transform"]
and src.crs == ref["crs"]
)
def _read_bgr_on_grid(path: Path, ref: dict) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None:
"""Read BGR on the fusion grid; reproject only when needed (S3 coarse stack)."""
if not path.is_file():
return None
with rasterio.open(path) as src:
if src.count < 3:
return None
if _grids_match(src, ref):
return (
src.read(1).astype(np.float64),
src.read(2).astype(np.float64),
src.read(3).astype(np.float64),
)
shape = (ref["height"], ref["width"])
bands: list[np.ndarray] = []
for band_idx in (1, 2, 3):
dst = np.full(shape, np.nan, dtype=np.float32)
reproject(
source=rasterio.band(src, band_idx),
destination=dst,
src_transform=src.transform,
src_crs=src.crs,
dst_transform=ref["transform"],
dst_crs=ref["crs"],
resampling=Resampling.bilinear,
)
bands.append(dst.astype(np.float64))
return bands[0], bands[1], bands[2]
def _refl_valid(
blue: np.ndarray, green: np.ndarray, red: np.ndarray
) -> np.ndarray:
"""Positive reflectance mask (aligned with postprocessing / Tier-A metrics)."""
return (
np.isfinite(blue)
& np.isfinite(green)
& np.isfinite(red)
& (blue > VALID_REFL_THRESHOLD)
& (green > VALID_REFL_THRESHOLD)
& (red > VALID_REFL_THRESHOLD)
)
def _panel_stretch_limits(
blue: np.ndarray, green: np.ndarray, red: np.ndarray, valid: np.ndarray
) -> tuple[float, float]:
if not valid.any():
return 0.0, 1.0
vals = np.concatenate([red[valid], green[valid], blue[valid]])
lo, hi = np.percentile(vals, (2, 98))
if hi <= lo:
return 0.0, 1.0
return float(lo), float(hi)
def _bgr_to_rgba(
blue: np.ndarray,
green: np.ndarray,
red: np.ndarray,
*,
valid: np.ndarray,
vmin: float,
vmax: float,
) -> np.ndarray:
rgba = np.zeros((*blue.shape, 4), dtype=np.float32)
rgba[..., 3] = 1.0
rgba[~valid, 0] = NODATA_RGB[0]
rgba[~valid, 1] = NODATA_RGB[1]
rgba[~valid, 2] = NODATA_RGB[2]
span = vmax - vmin or 1.0
for band, idx in ((red, 0), (green, 1), (blue, 2)):
norm = np.clip((band - vmin) / span, 0.0, 1.0)
rgba[..., idx] = np.where(valid, norm, rgba[..., idx])
return rgba
def _tight_slices(mask: np.ndarray, margin: int = 2) -> tuple[slice, slice] | None:
"""Crop to the bounding box of True pixels in *mask*."""
rows, cols = np.where(mask)
if rows.size < 8:
return None
r0 = max(0, int(rows.min()) - margin)
r1 = min(mask.shape[0], int(rows.max()) + margin + 1)
c0 = max(0, int(cols.min()) - margin)
c1 = min(mask.shape[1], int(cols.max()) + margin + 1)
if r1 - r0 < 8 or c1 - c0 < 8:
return None
return slice(r0, r1), slice(c0, c1)
def _crop_slices(
height: int, width: int, center_row: int, center_col: int, half_px: int
) -> tuple[slice, slice]:
r0 = max(0, center_row - half_px)
r1 = min(height, center_row + half_px)
c0 = max(0, center_col - half_px)
c1 = min(width, center_col + half_px)
return slice(r0, r1), slice(c0, c1)
def _phenocam_pixel(
meta: dict,
site_position_lat_lon: tuple[float, float],
) -> tuple[int, int] | None:
lat, lon = site_position_lat_lon
try:
r, c = rowcol(meta["transform"], [lon], [lat], op=meta.get("crs"))
return int(r[0]), int(c[0])
except Exception:
return None
def _resolve_row_paths(
data_dir: Path,
site: str,
season: int,
entry: dict,
strategy: str,
sigma: int,
*,
gap_days: int,
) -> tuple[Path, Path, Path, Path] | None:
pred_ymd = yyyymmdd_from_iso(entry["prediction_date"])
transition = entry["transition"]
prep = _prepared_base(data_dir, site, season, strategy)
withheld_fn = entry.get("withheld_s2_filename")
if not withheld_fn:
return None
withheld = prep / "s2" / withheld_fn
fusion = (
_gap_spatial_fusion_dir(data_dir, site, season, gap_days, transition, strategy, sigma)
/ f"REFL_{pred_ymd}.tif"
)
s3_exact = prep / "s3" / f"composite_{pred_ymd}.tif"
s3 = (
s3_exact
if s3_exact.is_file()
else nearest_s3_composite(prep / "s3", entry["prediction_date"])
)
w0 = _iso_to_date(entry["window_start"])
w1 = _iso_to_date(entry["window_end"])
window_ymds = acquisition_yyyymmdd_in_window(prep / "s2", w0, w1)
exclude = window_ymds | _exclude_ymds(entry)
nearest = nearest_stack_s2(prep / "s2", entry["prediction_date"], exclude_ymds=exclude)
if not withheld.is_file() or not fusion.is_file() or s3 is None or nearest is None:
return None
return withheld, fusion, s3, nearest
def _panel_date_subtitle(paths: tuple[Path, Path, Path, Path], entry: dict) -> str:
pred = entry["prediction_date"][:10]
wh = entry.get("withheld_s2_date") or ""
wh_short = wh[:10] if wh else ""
nearest_ymd = REFL_DATE_RE.search(paths[3].name)
s3_ymd = S3_COMPOSITE_RE.search(paths[2].name)
n_s2 = nearest_ymd.group(1) if nearest_ymd else "?"
n_s3 = s3_ymd.group(1) if s3_ymd else "?"
s3_note = "" if paths[2].name.endswith(f"{pred.replace('-', '')}.tif") else f" (S3 {n_s3})"
return (
f"pred. {pred}; withheld {wh_short}; "
f"nearest S2 {n_s2[:4]}-{n_s2[4:6]}-{n_s2[6:8]}{s3_note}"
)
def build_site_panel(
site: str,
season: int,
data_dir: Path,
out_png: Path,
*,
best_bti_scenario: str,
site_label: str,
site_position_lat_lon: tuple[float, float] | None = None,
gap_days: int = 30,
) -> bool:
"""Build 2×4 RGB figure; return False if manifest or any transition row is incomplete."""
manifest_path = data_dir / site / str(season) / "validation" / "gap_manifest.json"
if not manifest_path.is_file():
return False
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
strategy, sigma = _parse_bti_scenario(best_bti_scenario)
rows: list[tuple[str, dict, tuple[Path, Path, Path, Path]]] = []
for transition in TRANSITIONS:
entry = next(
(
e
for e in manifest["entries"]
if e.get("gap_days") == gap_days and e.get("transition") == transition
),
None,
)
if not entry:
continue
paths = _resolve_row_paths(
data_dir, site, season, entry, strategy, sigma, gap_days=gap_days
)
if paths is None:
continue
rows.append((transition, entry, paths))
if not rows:
return False
crop_px = DISPLAY_HALF_PX
fig, axes = plt.subplots(
len(rows),
4,
figsize=(12.0, 2.8 * len(rows)),
squeeze=False,
constrained_layout=True,
)
for row_idx, (transition, entry, paths) in enumerate(rows):
row_title = ROW_LABELS.get(transition, transition)
subtitle = _panel_date_subtitle(paths, entry)
ref_grid = _reference_grid(paths[1])
if ref_grid is None:
continue
layers: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = []
for path in paths:
bgr = _read_bgr_on_grid(path, ref_grid)
if bgr is None:
layers = []
break
layers.append(bgr)
if len(layers) != 4:
for ax in axes[row_idx]:
ax.set_visible(False)
continue
h, w = ref_grid["height"], ref_grid["width"]
center_row, center_col = h // 2, w // 2
if site_position_lat_lon:
pix = _phenocam_pixel(ref_grid, site_position_lat_lon)
if pix:
center_row, center_col = pix
rs, cs = _crop_slices(h, w, center_row, center_col, crop_px)
if rs.stop - rs.start < 8 or cs.stop - cs.start < 8:
for ax in axes[row_idx]:
ax.set_visible(False)
continue
cropped_valid: list[np.ndarray] = []
rgba_panels: list[np.ndarray] = []
for bgr in layers:
blue, green, red = (b[rs, cs] for b in bgr)
valid = _refl_valid(blue, green, red)
cropped_valid.append(valid)
vmin, vmax = _panel_stretch_limits(blue, green, red, valid)
rgba_panels.append(
_bgr_to_rgba(
blue, green, red, valid=valid, vmin=vmin, vmax=vmax
)
)
union = cropped_valid[0]
for v in cropped_valid[1:]:
union |= v
tight = _tight_slices(union)
if tight is not None:
tr, tc = tight
rgba_panels = [p[tr, tc] for p in rgba_panels]
union = union[tr, tc]
crop_h, crop_w = rgba_panels[0].shape[:2]
mark_r = center_row - rs.start
mark_c = center_col - cs.start
if tight is not None:
mark_r -= tr.start
mark_c -= tc.start
for col_idx, (col_title, rgba) in enumerate(
zip(COL_TITLES, rgba_panels, strict=True)
):
ax = axes[row_idx, col_idx]
ax.imshow(rgba, origin="upper", aspect="equal", interpolation="nearest")
if col_idx == 0 and 0 <= mark_r < crop_h and 0 <= mark_c < crop_w:
ax.plot(
mark_c,
mark_r,
"+",
color="red",
markersize=8,
markeredgewidth=1.2,
)
if row_idx == 0:
ax.set_title(col_title, fontsize=9)
if col_idx == 0:
ax.set_ylabel(row_title, fontsize=9)
ax.set_xticks([])
ax.set_yticks([])
if col_idx == 3:
ax.text(
0.02,
0.02,
subtitle,
transform=ax.transAxes,
fontsize=6,
color="white",
va="bottom",
bbox=dict(boxstyle="round,pad=0.2", facecolor="black", alpha=0.55),
)
cal = manifest.get("s2_calendar_strategy", strategy)
cal_note = f"; S2 calendar {cal}" if cal != strategy else ""
fig.suptitle(
f"{site_label} ({season}) — best BtI {best_bti_scenario}{cal_note}, {gap_days}-d gap",
fontsize=10,
)
out_png.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_png, dpi=150)
plt.close(fig)
return True