This commit is contained in:
Felix Delattre 2026-06-02 11:03:58 +02:00
parent 2dba38af5b
commit 25cbd97662

View file

@ -1,4 +1,9 @@
"""Export 2×4 RGB panels for Tier-A gap validation (thesis appendix)."""
"""Export 2×4 RGB panels for Tier-A gap validation (thesis appendix).
Crops follow the same fusion-valid bounding box as ``postprocessing.process_cropped``
and the webapp (``processed_*`` / ``common.js``), anchored on gap-degraded fusion at the
prediction date; S2 and S3 are read from prepared stacks on that shared window.
"""
from __future__ import annotations
@ -10,6 +15,7 @@ from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import rasterio
from rasterio import windows
from rasterio.transform import rowcol
from rasterio.warp import Resampling, reproject
@ -20,9 +26,6 @@ S3_COMPOSITE_RE = re.compile(r"composite_(\d{8})\.tif$")
TRANSITIONS = ("green_up", "green_down")
COL_TITLES = ("Withheld S2", "Gap fusion", "S3 composite", "Nearest S2")
ROW_LABELS = {"green_up": "Green-up", "green_down": "Green-down"}
# Fixed pixel window around PhenoCam for comparable framing across sites (~1 km).
DISPLAY_HALF_PX = 48
# Match postprocessing / spatial_metrics positive-reflectance mask.
VALID_REFL_THRESHOLD = 0.001
NODATA_RGB = (0.15, 0.15, 0.15)
@ -38,6 +41,32 @@ def _prepared_base(data_dir: Path, site: str, season: int, strategy: str) -> Pat
return data_dir / site / str(season) / f"prepared_{strategy}"
def _s2_strategy_fallbacks(strategy: str, manifest: dict) -> tuple[str, ...]:
"""Prepared trees to try for S2 REFL (best-BtI first, then manifest calendar)."""
order: list[str] = []
for s in (strategy, manifest.get("s2_calendar_strategy")):
if isinstance(s, str) and s and s not in order:
order.append(s)
for s in ("aggressive", "nonaggressive"):
if s not in order:
order.append(s)
return tuple(order)
def _find_prepared_s2_refl(
data_dir: Path,
site: str,
season: int,
filename: str,
strategies: tuple[str, ...],
) -> Path | None:
for strat in strategies:
p = _prepared_base(data_dir, site, season, strat) / "s2" / filename
if p.is_file():
return p
return None
def _gap_spatial_fusion_dir(
data_dir: Path,
site: str,
@ -63,14 +92,9 @@ def _iso_to_date(iso_d: str) -> date:
def _exclude_ymds(entry: dict) -> set[str]:
w0 = _iso_to_date(entry["window_start"])
w1 = _iso_to_date(entry["window_end"])
withheld_fn = entry.get("withheld_s2_filename") or ""
m = REFL_DATE_RE.search(withheld_fn)
excluded = set()
if m:
excluded.add(m.group(1))
return excluded
return {m.group(1)} if m else set()
def nearest_stack_s2(
@ -79,7 +103,6 @@ def nearest_stack_s2(
*,
exclude_ymds: set[str],
) -> Path | None:
"""Nearest prepared REFL acquisition to prediction, outside excluded days."""
if not prepared_s2_dir.is_dir():
return None
target = _iso_to_date(prediction_iso)
@ -87,24 +110,16 @@ def nearest_stack_s2(
best_delta: int | None = None
for p in prepared_s2_dir.glob("S2A_MSIL2A_*_REFL.tif"):
m = REFL_DATE_RE.search(p.name)
if not m:
if not m or m.group(1) in exclude_ymds:
continue
ymd = m.group(1)
if ymd in exclude_ymds:
continue
d = datetime.strptime(ymd, "%Y%m%d").date()
delta = abs((d - target).days)
delta = abs((datetime.strptime(m.group(1), "%Y%m%d").date() - target).days)
if best_delta is None or delta < best_delta:
best_delta = delta
best_path = p
return best_path
def nearest_s3_composite(
prepared_s3_dir: Path,
prediction_iso: str,
) -> Path | None:
"""Nearest daily S3 composite to prediction date."""
def nearest_s3_composite(prepared_s3_dir: Path, prediction_iso: str) -> Path | None:
if not prepared_s3_dir.is_dir():
return None
target = _iso_to_date(prediction_iso)
@ -114,71 +129,93 @@ def nearest_s3_composite(
m = S3_COMPOSITE_RE.search(p.name)
if not m:
continue
d = datetime.strptime(m.group(1), "%Y%m%d").date()
delta = abs((d - target).days)
delta = abs((datetime.strptime(m.group(1), "%Y%m%d").date() - target).days)
if best_delta is None or delta < best_delta:
best_delta = delta
best_path = p
return best_path
def _reference_grid(path: Path) -> dict | None:
"""Grid metadata from fusion / S2 REFL (fine S2 resolution)."""
if not path.is_file():
def _crop_window_from_fusion(fusion_path: Path) -> dict | None:
"""Fusion-valid crop (``postprocessing.process_cropped``) on the full prepared grid."""
if not fusion_path.is_file():
return None
with rasterio.open(path) as src:
with rasterio.open(fusion_path) as src:
data = src.read()
valid = np.isfinite(data) & (data > VALID_REFL_THRESHOLD)
rows = np.any(valid, axis=(0, 2))
cols = np.any(valid, axis=(0, 1))
row_idx = np.where(rows)[0]
col_idx = np.where(cols)[0]
if len(row_idx) == 0 or len(col_idx) == 0:
return None
r0, r1 = int(row_idx[0]), int(row_idx[-1])
c0, c1 = int(col_idx[0]), int(col_idx[-1])
w, h = c1 - c0 + 1, r1 - r0 + 1
win = windows.Window(c0, r0, w, h)
return {
"transform": src.transform,
"window": win,
"crop_transform": windows.transform(win, src.transform),
"full_transform": src.transform,
"crs": src.crs,
"height": src.height,
"width": src.width,
"bounds": src.bounds,
"profile": src.profile.copy(),
}
def _grids_match(src: rasterio.DatasetReader, ref: dict) -> bool:
return (
src.height == ref["height"]
and src.width == ref["width"]
and src.transform == ref["transform"]
and src.crs == ref["crs"]
)
def _read_bgr_on_grid(path: Path, ref: dict) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None:
"""Read BGR on the fusion grid; reproject only when needed (S3 coarse stack)."""
if not path.is_file():
def _read_bgr_prepared_s2(prepared_refl: Path, crop: dict) -> tuple[np.ndarray, ...] | None:
if not prepared_refl.is_file():
return None
with rasterio.open(path) as src:
with rasterio.open(prepared_refl) as src:
if src.count < 3:
return None
if _grids_match(src, ref):
return (
src.read(1).astype(np.float64),
src.read(2).astype(np.float64),
src.read(3).astype(np.float64),
)
shape = (ref["height"], ref["width"])
b, g, r = src.read(indexes=(1, 2, 3), window=crop["window"])
return b.astype(np.float64), g.astype(np.float64), r.astype(np.float64)
def _read_bgr_gap_fusion(fusion_path: Path, crop: dict) -> tuple[np.ndarray, ...] | None:
if not fusion_path.is_file():
return None
with rasterio.open(fusion_path) as src:
if src.count < 3:
return None
b, g, r = src.read(indexes=(1, 2, 3), window=crop["window"])
return b.astype(np.float64), g.astype(np.float64), r.astype(np.float64)
def _read_bgr_prepared_s3(s3_path: Path, crop: dict) -> tuple[np.ndarray, ...] | None:
"""Resample S3 composite to the fusion grid, then crop (matches ``process_cropped``)."""
if not s3_path.is_file():
return None
with rasterio.open(s3_path) as src:
if src.count < 3:
return None
temp_profile = crop["profile"].copy()
temp_profile.update({"dtype": "float32", "count": src.count})
bands: list[np.ndarray] = []
for band_idx in (1, 2, 3):
dst = np.full(shape, np.nan, dtype=np.float32)
with rasterio.MemoryFile() as memfile:
with memfile.open(**temp_profile) as resampled:
for i in range(1, src.count + 1):
reproject(
source=rasterio.band(src, band_idx),
destination=dst,
source=rasterio.band(src, i),
destination=rasterio.band(resampled, i),
src_transform=src.transform,
src_crs=src.crs,
dst_transform=ref["transform"],
dst_crs=ref["crs"],
resampling=Resampling.bilinear,
dst_transform=crop["full_transform"],
dst_crs=crop["crs"],
resampling=Resampling.nearest,
)
bands.append(dst.astype(np.float64))
b, g, r = resampled.read(
indexes=(1, 2, 3), window=crop["window"]
)
bands = [
b.astype(np.float64),
g.astype(np.float64),
r.astype(np.float64),
]
return bands[0], bands[1], bands[2]
def _refl_valid(
blue: np.ndarray, green: np.ndarray, red: np.ndarray
) -> np.ndarray:
"""Positive reflectance mask (aligned with postprocessing / Tier-A metrics)."""
def _refl_valid(blue: np.ndarray, green: np.ndarray, red: np.ndarray) -> np.ndarray:
return (
np.isfinite(blue)
& np.isfinite(green)
@ -192,6 +229,7 @@ def _refl_valid(
def _panel_stretch_limits(
blue: np.ndarray, green: np.ndarray, red: np.ndarray, valid: np.ndarray
) -> tuple[float, float]:
"""Per-panel 2--98 % stretch on positive reflectance (webapp ``common.js`` style)."""
if not valid.any():
return 0.0, 1.0
vals = np.concatenate([red[valid], green[valid], blue[valid]])
@ -222,37 +260,14 @@ def _bgr_to_rgba(
return rgba
def _tight_slices(mask: np.ndarray, margin: int = 2) -> tuple[slice, slice] | None:
"""Crop to the bounding box of True pixels in *mask*."""
rows, cols = np.where(mask)
if rows.size < 8:
return None
r0 = max(0, int(rows.min()) - margin)
r1 = min(mask.shape[0], int(rows.max()) + margin + 1)
c0 = max(0, int(cols.min()) - margin)
c1 = min(mask.shape[1], int(cols.max()) + margin + 1)
if r1 - r0 < 8 or c1 - c0 < 8:
return None
return slice(r0, r1), slice(c0, c1)
def _crop_slices(
height: int, width: int, center_row: int, center_col: int, half_px: int
) -> tuple[slice, slice]:
r0 = max(0, center_row - half_px)
r1 = min(height, center_row + half_px)
c0 = max(0, center_col - half_px)
c1 = min(width, center_col + half_px)
return slice(r0, r1), slice(c0, c1)
def _phenocam_pixel(
meta: dict,
site_position_lat_lon: tuple[float, float],
def _phenocam_pixel_cropped(
crop: dict, site_position_lat_lon: tuple[float, float]
) -> tuple[int, int] | None:
lat, lon = site_position_lat_lon
try:
r, c = rowcol(meta["transform"], [lon], [lat], op=meta.get("crs"))
r, c = rowcol(
crop["crop_transform"], [lon], [lat], op=crop["crs"]
)
return int(r[0]), int(c[0])
except Exception:
return None
@ -267,14 +282,18 @@ def _resolve_row_paths(
sigma: int,
*,
gap_days: int,
manifest: dict,
) -> tuple[Path, Path, Path, Path] | None:
pred_ymd = yyyymmdd_from_iso(entry["prediction_date"])
transition = entry["transition"]
prep = _prepared_base(data_dir, site, season, strategy)
s2_strats = _s2_strategy_fallbacks(strategy, manifest)
withheld_fn = entry.get("withheld_s2_filename")
if not withheld_fn:
return None
withheld = prep / "s2" / withheld_fn
withheld = _find_prepared_s2_refl(
data_dir, site, season, withheld_fn, s2_strats
)
fusion = (
_gap_spatial_fusion_dir(data_dir, site, season, gap_days, transition, strategy, sigma)
/ f"REFL_{pred_ymd}.tif"
@ -287,29 +306,21 @@ def _resolve_row_paths(
)
w0 = _iso_to_date(entry["window_start"])
w1 = _iso_to_date(entry["window_end"])
window_ymds = acquisition_yyyymmdd_in_window(prep / "s2", w0, w1)
nearest: Path | None = None
for strat in s2_strats:
prep_s2 = _prepared_base(data_dir, site, season, strat) / "s2"
window_ymds = acquisition_yyyymmdd_in_window(prep_s2, w0, w1)
exclude = window_ymds | _exclude_ymds(entry)
nearest = nearest_stack_s2(prep / "s2", entry["prediction_date"], exclude_ymds=exclude)
if not withheld.is_file() or not fusion.is_file() or s3 is None or nearest is None:
nearest = nearest_stack_s2(
prep_s2, entry["prediction_date"], exclude_ymds=exclude
)
if nearest is not None:
break
if withheld is None or not fusion.is_file() or s3 is None or nearest is None:
return None
return withheld, fusion, s3, nearest
def _panel_date_subtitle(paths: tuple[Path, Path, Path, Path], entry: dict) -> str:
pred = entry["prediction_date"][:10]
wh = entry.get("withheld_s2_date") or ""
wh_short = wh[:10] if wh else ""
nearest_ymd = REFL_DATE_RE.search(paths[3].name)
s3_ymd = S3_COMPOSITE_RE.search(paths[2].name)
n_s2 = nearest_ymd.group(1) if nearest_ymd else "?"
n_s3 = s3_ymd.group(1) if s3_ymd else "?"
s3_note = "" if paths[2].name.endswith(f"{pred.replace('-', '')}.tif") else f" (S3 {n_s3})"
return (
f"pred. {pred}; withheld {wh_short}; "
f"nearest S2 {n_s2[:4]}-{n_s2[4:6]}-{n_s2[6:8]}{s3_note}"
)
def build_site_panel(
site: str,
season: int,
@ -340,7 +351,14 @@ def build_site_panel(
if not entry:
continue
paths = _resolve_row_paths(
data_dir, site, season, entry, strategy, sigma, gap_days=gap_days
data_dir,
site,
season,
entry,
strategy,
sigma,
gap_days=gap_days,
manifest=manifest,
)
if paths is None:
continue
@ -349,7 +367,13 @@ def build_site_panel(
if not rows:
return False
crop_px = DISPLAY_HALF_PX
readers = (
_read_bgr_prepared_s2,
_read_bgr_gap_fusion,
_read_bgr_prepared_s3,
_read_bgr_prepared_s2,
)
fig, axes = plt.subplots(
len(rows),
4,
@ -359,13 +383,15 @@ def build_site_panel(
)
for row_idx, (transition, entry, paths) in enumerate(rows):
row_title = ROW_LABELS.get(transition, transition)
subtitle = _panel_date_subtitle(paths, entry)
ref_grid = _reference_grid(paths[1])
if ref_grid is None:
crop = _crop_window_from_fusion(paths[1])
if crop is None:
for ax in axes[row_idx]:
ax.set_visible(False)
continue
layers: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = []
for path in paths:
bgr = _read_bgr_on_grid(path, ref_grid)
for path, read_fn in zip(paths, readers, strict=True):
bgr = read_fn(path, crop)
if bgr is None:
layers = []
break
@ -375,56 +401,24 @@ def build_site_panel(
ax.set_visible(False)
continue
h, w = ref_grid["height"], ref_grid["width"]
center_row, center_col = h // 2, w // 2
mark: tuple[int, int] | None = None
if site_position_lat_lon:
pix = _phenocam_pixel(ref_grid, site_position_lat_lon)
if pix:
center_row, center_col = pix
rs, cs = _crop_slices(h, w, center_row, center_col, crop_px)
if rs.stop - rs.start < 8 or cs.stop - cs.start < 8:
for ax in axes[row_idx]:
ax.set_visible(False)
continue
mark = _phenocam_pixel_cropped(crop, site_position_lat_lon)
cropped_valid: list[np.ndarray] = []
rgba_panels: list[np.ndarray] = []
for bgr in layers:
blue, green, red = (b[rs, cs] for b in bgr)
for col_idx, (col_title, bgr) in enumerate(zip(COL_TITLES, layers, strict=True)):
ax = axes[row_idx, col_idx]
blue, green, red = bgr
valid = _refl_valid(blue, green, red)
cropped_valid.append(valid)
vmin, vmax = _panel_stretch_limits(blue, green, red, valid)
rgba_panels.append(
_bgr_to_rgba(
rgba = _bgr_to_rgba(
blue, green, red, valid=valid, vmin=vmin, vmax=vmax
)
)
union = cropped_valid[0]
for v in cropped_valid[1:]:
union |= v
tight = _tight_slices(union)
if tight is not None:
tr, tc = tight
rgba_panels = [p[tr, tc] for p in rgba_panels]
union = union[tr, tc]
crop_h, crop_w = rgba_panels[0].shape[:2]
mark_r = center_row - rs.start
mark_c = center_col - cs.start
if tight is not None:
mark_r -= tr.start
mark_c -= tc.start
for col_idx, (col_title, rgba) in enumerate(
zip(COL_TITLES, rgba_panels, strict=True)
):
ax = axes[row_idx, col_idx]
ax.imshow(rgba, origin="upper", aspect="equal", interpolation="nearest")
if col_idx == 0 and 0 <= mark_r < crop_h and 0 <= mark_c < crop_w:
h, w = rgba.shape[:2]
if col_idx == 0 and mark and 0 <= mark[0] < h and 0 <= mark[1] < w:
ax.plot(
mark_c,
mark_r,
mark[1],
mark[0],
"+",
color="red",
markersize=8,
@ -436,24 +430,8 @@ def build_site_panel(
ax.set_ylabel(row_title, fontsize=9)
ax.set_xticks([])
ax.set_yticks([])
if col_idx == 3:
ax.text(
0.02,
0.02,
subtitle,
transform=ax.transAxes,
fontsize=6,
color="white",
va="bottom",
bbox=dict(boxstyle="round,pad=0.2", facecolor="black", alpha=0.55),
)
cal = manifest.get("s2_calendar_strategy", strategy)
cal_note = f"; S2 calendar {cal}" if cal != strategy else ""
fig.suptitle(
f"{site_label} ({season}) — best BtI {best_bti_scenario}{cal_note}, {gap_days}-d gap",
fontsize=10,
)
fig.suptitle(f"{site_label} ({season})", fontsize=10)
out_png.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_png, dpi=150)
plt.close(fig)