This commit is contained in:
Felix Delattre 2026-06-02 11:03:58 +02:00
parent 2dba38af5b
commit 25cbd97662

View file

@ -1,4 +1,9 @@
"""Export 2×4 RGB panels for Tier-A gap validation (thesis appendix).""" """Export 2×4 RGB panels for Tier-A gap validation (thesis appendix).
Crops follow the same fusion-valid bounding box as ``postprocessing.process_cropped``
and the webapp (``processed_*`` / ``common.js``), anchored on gap-degraded fusion at the
prediction date; S2 and S3 are read from prepared stacks on that shared window.
"""
from __future__ import annotations from __future__ import annotations
@ -10,6 +15,7 @@ from pathlib import Path
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import rasterio import rasterio
from rasterio import windows
from rasterio.transform import rowcol from rasterio.transform import rowcol
from rasterio.warp import Resampling, reproject from rasterio.warp import Resampling, reproject
@ -20,9 +26,6 @@ S3_COMPOSITE_RE = re.compile(r"composite_(\d{8})\.tif$")
TRANSITIONS = ("green_up", "green_down") TRANSITIONS = ("green_up", "green_down")
COL_TITLES = ("Withheld S2", "Gap fusion", "S3 composite", "Nearest S2") COL_TITLES = ("Withheld S2", "Gap fusion", "S3 composite", "Nearest S2")
ROW_LABELS = {"green_up": "Green-up", "green_down": "Green-down"} ROW_LABELS = {"green_up": "Green-up", "green_down": "Green-down"}
# Fixed pixel window around PhenoCam for comparable framing across sites (~1 km).
DISPLAY_HALF_PX = 48
# Match postprocessing / spatial_metrics positive-reflectance mask.
VALID_REFL_THRESHOLD = 0.001 VALID_REFL_THRESHOLD = 0.001
NODATA_RGB = (0.15, 0.15, 0.15) NODATA_RGB = (0.15, 0.15, 0.15)
@ -38,6 +41,32 @@ def _prepared_base(data_dir: Path, site: str, season: int, strategy: str) -> Pat
return data_dir / site / str(season) / f"prepared_{strategy}" return data_dir / site / str(season) / f"prepared_{strategy}"
def _s2_strategy_fallbacks(strategy: str, manifest: dict) -> tuple[str, ...]:
"""Prepared trees to try for S2 REFL (best-BtI first, then manifest calendar)."""
order: list[str] = []
for s in (strategy, manifest.get("s2_calendar_strategy")):
if isinstance(s, str) and s and s not in order:
order.append(s)
for s in ("aggressive", "nonaggressive"):
if s not in order:
order.append(s)
return tuple(order)
def _find_prepared_s2_refl(
data_dir: Path,
site: str,
season: int,
filename: str,
strategies: tuple[str, ...],
) -> Path | None:
for strat in strategies:
p = _prepared_base(data_dir, site, season, strat) / "s2" / filename
if p.is_file():
return p
return None
def _gap_spatial_fusion_dir( def _gap_spatial_fusion_dir(
data_dir: Path, data_dir: Path,
site: str, site: str,
@ -63,14 +92,9 @@ def _iso_to_date(iso_d: str) -> date:
def _exclude_ymds(entry: dict) -> set[str]: def _exclude_ymds(entry: dict) -> set[str]:
w0 = _iso_to_date(entry["window_start"])
w1 = _iso_to_date(entry["window_end"])
withheld_fn = entry.get("withheld_s2_filename") or "" withheld_fn = entry.get("withheld_s2_filename") or ""
m = REFL_DATE_RE.search(withheld_fn) m = REFL_DATE_RE.search(withheld_fn)
excluded = set() return {m.group(1)} if m else set()
if m:
excluded.add(m.group(1))
return excluded
def nearest_stack_s2( def nearest_stack_s2(
@ -79,7 +103,6 @@ def nearest_stack_s2(
*, *,
exclude_ymds: set[str], exclude_ymds: set[str],
) -> Path | None: ) -> Path | None:
"""Nearest prepared REFL acquisition to prediction, outside excluded days."""
if not prepared_s2_dir.is_dir(): if not prepared_s2_dir.is_dir():
return None return None
target = _iso_to_date(prediction_iso) target = _iso_to_date(prediction_iso)
@ -87,24 +110,16 @@ def nearest_stack_s2(
best_delta: int | None = None best_delta: int | None = None
for p in prepared_s2_dir.glob("S2A_MSIL2A_*_REFL.tif"): for p in prepared_s2_dir.glob("S2A_MSIL2A_*_REFL.tif"):
m = REFL_DATE_RE.search(p.name) m = REFL_DATE_RE.search(p.name)
if not m: if not m or m.group(1) in exclude_ymds:
continue continue
ymd = m.group(1) delta = abs((datetime.strptime(m.group(1), "%Y%m%d").date() - target).days)
if ymd in exclude_ymds:
continue
d = datetime.strptime(ymd, "%Y%m%d").date()
delta = abs((d - target).days)
if best_delta is None or delta < best_delta: if best_delta is None or delta < best_delta:
best_delta = delta best_delta = delta
best_path = p best_path = p
return best_path return best_path
def nearest_s3_composite( def nearest_s3_composite(prepared_s3_dir: Path, prediction_iso: str) -> Path | None:
prepared_s3_dir: Path,
prediction_iso: str,
) -> Path | None:
"""Nearest daily S3 composite to prediction date."""
if not prepared_s3_dir.is_dir(): if not prepared_s3_dir.is_dir():
return None return None
target = _iso_to_date(prediction_iso) target = _iso_to_date(prediction_iso)
@ -114,71 +129,93 @@ def nearest_s3_composite(
m = S3_COMPOSITE_RE.search(p.name) m = S3_COMPOSITE_RE.search(p.name)
if not m: if not m:
continue continue
d = datetime.strptime(m.group(1), "%Y%m%d").date() delta = abs((datetime.strptime(m.group(1), "%Y%m%d").date() - target).days)
delta = abs((d - target).days)
if best_delta is None or delta < best_delta: if best_delta is None or delta < best_delta:
best_delta = delta best_delta = delta
best_path = p best_path = p
return best_path return best_path
def _reference_grid(path: Path) -> dict | None: def _crop_window_from_fusion(fusion_path: Path) -> dict | None:
"""Grid metadata from fusion / S2 REFL (fine S2 resolution).""" """Fusion-valid crop (``postprocessing.process_cropped``) on the full prepared grid."""
if not path.is_file(): if not fusion_path.is_file():
return None return None
with rasterio.open(path) as src: with rasterio.open(fusion_path) as src:
data = src.read()
valid = np.isfinite(data) & (data > VALID_REFL_THRESHOLD)
rows = np.any(valid, axis=(0, 2))
cols = np.any(valid, axis=(0, 1))
row_idx = np.where(rows)[0]
col_idx = np.where(cols)[0]
if len(row_idx) == 0 or len(col_idx) == 0:
return None
r0, r1 = int(row_idx[0]), int(row_idx[-1])
c0, c1 = int(col_idx[0]), int(col_idx[-1])
w, h = c1 - c0 + 1, r1 - r0 + 1
win = windows.Window(c0, r0, w, h)
return { return {
"transform": src.transform, "window": win,
"crop_transform": windows.transform(win, src.transform),
"full_transform": src.transform,
"crs": src.crs, "crs": src.crs,
"height": src.height, "profile": src.profile.copy(),
"width": src.width,
"bounds": src.bounds,
} }
def _grids_match(src: rasterio.DatasetReader, ref: dict) -> bool: def _read_bgr_prepared_s2(prepared_refl: Path, crop: dict) -> tuple[np.ndarray, ...] | None:
return ( if not prepared_refl.is_file():
src.height == ref["height"]
and src.width == ref["width"]
and src.transform == ref["transform"]
and src.crs == ref["crs"]
)
def _read_bgr_on_grid(path: Path, ref: dict) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None:
"""Read BGR on the fusion grid; reproject only when needed (S3 coarse stack)."""
if not path.is_file():
return None return None
with rasterio.open(path) as src: with rasterio.open(prepared_refl) as src:
if src.count < 3: if src.count < 3:
return None return None
if _grids_match(src, ref): b, g, r = src.read(indexes=(1, 2, 3), window=crop["window"])
return ( return b.astype(np.float64), g.astype(np.float64), r.astype(np.float64)
src.read(1).astype(np.float64),
src.read(2).astype(np.float64),
src.read(3).astype(np.float64), def _read_bgr_gap_fusion(fusion_path: Path, crop: dict) -> tuple[np.ndarray, ...] | None:
) if not fusion_path.is_file():
shape = (ref["height"], ref["width"]) return None
with rasterio.open(fusion_path) as src:
if src.count < 3:
return None
b, g, r = src.read(indexes=(1, 2, 3), window=crop["window"])
return b.astype(np.float64), g.astype(np.float64), r.astype(np.float64)
def _read_bgr_prepared_s3(s3_path: Path, crop: dict) -> tuple[np.ndarray, ...] | None:
"""Resample S3 composite to the fusion grid, then crop (matches ``process_cropped``)."""
if not s3_path.is_file():
return None
with rasterio.open(s3_path) as src:
if src.count < 3:
return None
temp_profile = crop["profile"].copy()
temp_profile.update({"dtype": "float32", "count": src.count})
bands: list[np.ndarray] = [] bands: list[np.ndarray] = []
for band_idx in (1, 2, 3): with rasterio.MemoryFile() as memfile:
dst = np.full(shape, np.nan, dtype=np.float32) with memfile.open(**temp_profile) as resampled:
for i in range(1, src.count + 1):
reproject( reproject(
source=rasterio.band(src, band_idx), source=rasterio.band(src, i),
destination=dst, destination=rasterio.band(resampled, i),
src_transform=src.transform, src_transform=src.transform,
src_crs=src.crs, src_crs=src.crs,
dst_transform=ref["transform"], dst_transform=crop["full_transform"],
dst_crs=ref["crs"], dst_crs=crop["crs"],
resampling=Resampling.bilinear, resampling=Resampling.nearest,
) )
bands.append(dst.astype(np.float64)) b, g, r = resampled.read(
indexes=(1, 2, 3), window=crop["window"]
)
bands = [
b.astype(np.float64),
g.astype(np.float64),
r.astype(np.float64),
]
return bands[0], bands[1], bands[2] return bands[0], bands[1], bands[2]
def _refl_valid( def _refl_valid(blue: np.ndarray, green: np.ndarray, red: np.ndarray) -> np.ndarray:
blue: np.ndarray, green: np.ndarray, red: np.ndarray
) -> np.ndarray:
"""Positive reflectance mask (aligned with postprocessing / Tier-A metrics)."""
return ( return (
np.isfinite(blue) np.isfinite(blue)
& np.isfinite(green) & np.isfinite(green)
@ -192,6 +229,7 @@ def _refl_valid(
def _panel_stretch_limits( def _panel_stretch_limits(
blue: np.ndarray, green: np.ndarray, red: np.ndarray, valid: np.ndarray blue: np.ndarray, green: np.ndarray, red: np.ndarray, valid: np.ndarray
) -> tuple[float, float]: ) -> tuple[float, float]:
"""Per-panel 2--98 % stretch on positive reflectance (webapp ``common.js`` style)."""
if not valid.any(): if not valid.any():
return 0.0, 1.0 return 0.0, 1.0
vals = np.concatenate([red[valid], green[valid], blue[valid]]) vals = np.concatenate([red[valid], green[valid], blue[valid]])
@ -222,37 +260,14 @@ def _bgr_to_rgba(
return rgba return rgba
def _tight_slices(mask: np.ndarray, margin: int = 2) -> tuple[slice, slice] | None: def _phenocam_pixel_cropped(
"""Crop to the bounding box of True pixels in *mask*.""" crop: dict, site_position_lat_lon: tuple[float, float]
rows, cols = np.where(mask)
if rows.size < 8:
return None
r0 = max(0, int(rows.min()) - margin)
r1 = min(mask.shape[0], int(rows.max()) + margin + 1)
c0 = max(0, int(cols.min()) - margin)
c1 = min(mask.shape[1], int(cols.max()) + margin + 1)
if r1 - r0 < 8 or c1 - c0 < 8:
return None
return slice(r0, r1), slice(c0, c1)
def _crop_slices(
height: int, width: int, center_row: int, center_col: int, half_px: int
) -> tuple[slice, slice]:
r0 = max(0, center_row - half_px)
r1 = min(height, center_row + half_px)
c0 = max(0, center_col - half_px)
c1 = min(width, center_col + half_px)
return slice(r0, r1), slice(c0, c1)
def _phenocam_pixel(
meta: dict,
site_position_lat_lon: tuple[float, float],
) -> tuple[int, int] | None: ) -> tuple[int, int] | None:
lat, lon = site_position_lat_lon lat, lon = site_position_lat_lon
try: try:
r, c = rowcol(meta["transform"], [lon], [lat], op=meta.get("crs")) r, c = rowcol(
crop["crop_transform"], [lon], [lat], op=crop["crs"]
)
return int(r[0]), int(c[0]) return int(r[0]), int(c[0])
except Exception: except Exception:
return None return None
@ -267,14 +282,18 @@ def _resolve_row_paths(
sigma: int, sigma: int,
*, *,
gap_days: int, gap_days: int,
manifest: dict,
) -> tuple[Path, Path, Path, Path] | None: ) -> tuple[Path, Path, Path, Path] | None:
pred_ymd = yyyymmdd_from_iso(entry["prediction_date"]) pred_ymd = yyyymmdd_from_iso(entry["prediction_date"])
transition = entry["transition"] transition = entry["transition"]
prep = _prepared_base(data_dir, site, season, strategy) prep = _prepared_base(data_dir, site, season, strategy)
s2_strats = _s2_strategy_fallbacks(strategy, manifest)
withheld_fn = entry.get("withheld_s2_filename") withheld_fn = entry.get("withheld_s2_filename")
if not withheld_fn: if not withheld_fn:
return None return None
withheld = prep / "s2" / withheld_fn withheld = _find_prepared_s2_refl(
data_dir, site, season, withheld_fn, s2_strats
)
fusion = ( fusion = (
_gap_spatial_fusion_dir(data_dir, site, season, gap_days, transition, strategy, sigma) _gap_spatial_fusion_dir(data_dir, site, season, gap_days, transition, strategy, sigma)
/ f"REFL_{pred_ymd}.tif" / f"REFL_{pred_ymd}.tif"
@ -287,29 +306,21 @@ def _resolve_row_paths(
) )
w0 = _iso_to_date(entry["window_start"]) w0 = _iso_to_date(entry["window_start"])
w1 = _iso_to_date(entry["window_end"]) w1 = _iso_to_date(entry["window_end"])
window_ymds = acquisition_yyyymmdd_in_window(prep / "s2", w0, w1) nearest: Path | None = None
for strat in s2_strats:
prep_s2 = _prepared_base(data_dir, site, season, strat) / "s2"
window_ymds = acquisition_yyyymmdd_in_window(prep_s2, w0, w1)
exclude = window_ymds | _exclude_ymds(entry) exclude = window_ymds | _exclude_ymds(entry)
nearest = nearest_stack_s2(prep / "s2", entry["prediction_date"], exclude_ymds=exclude) nearest = nearest_stack_s2(
if not withheld.is_file() or not fusion.is_file() or s3 is None or nearest is None: prep_s2, entry["prediction_date"], exclude_ymds=exclude
)
if nearest is not None:
break
if withheld is None or not fusion.is_file() or s3 is None or nearest is None:
return None return None
return withheld, fusion, s3, nearest return withheld, fusion, s3, nearest
def _panel_date_subtitle(paths: tuple[Path, Path, Path, Path], entry: dict) -> str:
pred = entry["prediction_date"][:10]
wh = entry.get("withheld_s2_date") or ""
wh_short = wh[:10] if wh else ""
nearest_ymd = REFL_DATE_RE.search(paths[3].name)
s3_ymd = S3_COMPOSITE_RE.search(paths[2].name)
n_s2 = nearest_ymd.group(1) if nearest_ymd else "?"
n_s3 = s3_ymd.group(1) if s3_ymd else "?"
s3_note = "" if paths[2].name.endswith(f"{pred.replace('-', '')}.tif") else f" (S3 {n_s3})"
return (
f"pred. {pred}; withheld {wh_short}; "
f"nearest S2 {n_s2[:4]}-{n_s2[4:6]}-{n_s2[6:8]}{s3_note}"
)
def build_site_panel( def build_site_panel(
site: str, site: str,
season: int, season: int,
@ -340,7 +351,14 @@ def build_site_panel(
if not entry: if not entry:
continue continue
paths = _resolve_row_paths( paths = _resolve_row_paths(
data_dir, site, season, entry, strategy, sigma, gap_days=gap_days data_dir,
site,
season,
entry,
strategy,
sigma,
gap_days=gap_days,
manifest=manifest,
) )
if paths is None: if paths is None:
continue continue
@ -349,7 +367,13 @@ def build_site_panel(
if not rows: if not rows:
return False return False
crop_px = DISPLAY_HALF_PX readers = (
_read_bgr_prepared_s2,
_read_bgr_gap_fusion,
_read_bgr_prepared_s3,
_read_bgr_prepared_s2,
)
fig, axes = plt.subplots( fig, axes = plt.subplots(
len(rows), len(rows),
4, 4,
@ -359,13 +383,15 @@ def build_site_panel(
) )
for row_idx, (transition, entry, paths) in enumerate(rows): for row_idx, (transition, entry, paths) in enumerate(rows):
row_title = ROW_LABELS.get(transition, transition) row_title = ROW_LABELS.get(transition, transition)
subtitle = _panel_date_subtitle(paths, entry) crop = _crop_window_from_fusion(paths[1])
ref_grid = _reference_grid(paths[1]) if crop is None:
if ref_grid is None: for ax in axes[row_idx]:
ax.set_visible(False)
continue continue
layers: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = [] layers: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = []
for path in paths: for path, read_fn in zip(paths, readers, strict=True):
bgr = _read_bgr_on_grid(path, ref_grid) bgr = read_fn(path, crop)
if bgr is None: if bgr is None:
layers = [] layers = []
break break
@ -375,56 +401,24 @@ def build_site_panel(
ax.set_visible(False) ax.set_visible(False)
continue continue
h, w = ref_grid["height"], ref_grid["width"] mark: tuple[int, int] | None = None
center_row, center_col = h // 2, w // 2
if site_position_lat_lon: if site_position_lat_lon:
pix = _phenocam_pixel(ref_grid, site_position_lat_lon) mark = _phenocam_pixel_cropped(crop, site_position_lat_lon)
if pix:
center_row, center_col = pix
rs, cs = _crop_slices(h, w, center_row, center_col, crop_px)
if rs.stop - rs.start < 8 or cs.stop - cs.start < 8:
for ax in axes[row_idx]:
ax.set_visible(False)
continue
cropped_valid: list[np.ndarray] = [] for col_idx, (col_title, bgr) in enumerate(zip(COL_TITLES, layers, strict=True)):
rgba_panels: list[np.ndarray] = [] ax = axes[row_idx, col_idx]
for bgr in layers: blue, green, red = bgr
blue, green, red = (b[rs, cs] for b in bgr)
valid = _refl_valid(blue, green, red) valid = _refl_valid(blue, green, red)
cropped_valid.append(valid)
vmin, vmax = _panel_stretch_limits(blue, green, red, valid) vmin, vmax = _panel_stretch_limits(blue, green, red, valid)
rgba_panels.append( rgba = _bgr_to_rgba(
_bgr_to_rgba(
blue, green, red, valid=valid, vmin=vmin, vmax=vmax blue, green, red, valid=valid, vmin=vmin, vmax=vmax
) )
)
union = cropped_valid[0]
for v in cropped_valid[1:]:
union |= v
tight = _tight_slices(union)
if tight is not None:
tr, tc = tight
rgba_panels = [p[tr, tc] for p in rgba_panels]
union = union[tr, tc]
crop_h, crop_w = rgba_panels[0].shape[:2]
mark_r = center_row - rs.start
mark_c = center_col - cs.start
if tight is not None:
mark_r -= tr.start
mark_c -= tc.start
for col_idx, (col_title, rgba) in enumerate(
zip(COL_TITLES, rgba_panels, strict=True)
):
ax = axes[row_idx, col_idx]
ax.imshow(rgba, origin="upper", aspect="equal", interpolation="nearest") ax.imshow(rgba, origin="upper", aspect="equal", interpolation="nearest")
if col_idx == 0 and 0 <= mark_r < crop_h and 0 <= mark_c < crop_w: h, w = rgba.shape[:2]
if col_idx == 0 and mark and 0 <= mark[0] < h and 0 <= mark[1] < w:
ax.plot( ax.plot(
mark_c, mark[1],
mark_r, mark[0],
"+", "+",
color="red", color="red",
markersize=8, markersize=8,
@ -436,24 +430,8 @@ def build_site_panel(
ax.set_ylabel(row_title, fontsize=9) ax.set_ylabel(row_title, fontsize=9)
ax.set_xticks([]) ax.set_xticks([])
ax.set_yticks([]) ax.set_yticks([])
if col_idx == 3:
ax.text(
0.02,
0.02,
subtitle,
transform=ax.transAxes,
fontsize=6,
color="white",
va="bottom",
bbox=dict(boxstyle="round,pad=0.2", facecolor="black", alpha=0.55),
)
cal = manifest.get("s2_calendar_strategy", strategy) fig.suptitle(f"{site_label} ({season})", fontsize=10)
cal_note = f"; S2 calendar {cal}" if cal != strategy else ""
fig.suptitle(
f"{site_label} ({season}) — best BtI {best_bti_scenario}{cal_note}, {gap_days}-d gap",
fontsize=10,
)
out_png.parent.mkdir(parents=True, exist_ok=True) out_png.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_png, dpi=150) fig.savefig(out_png, dpi=150)
plt.close(fig) plt.close(fig)