efast-phenocam-validation/gap_validation/export_rasters.py
Felix Delattre 25cbd97662 foo
2026-06-02 11:03:58 +02:00

438 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Export 2×4 RGB panels for Tier-A gap validation (thesis appendix).
Crops follow the same fusion-valid bounding box as ``postprocessing.process_cropped``
and the webapp (``processed_*`` / ``common.js``), anchored on gap-degraded fusion at the
prediction date; S2 and S3 are read from prepared stacks on that shared window.
"""
from __future__ import annotations
import json
import re
from datetime import date, datetime
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import rasterio
from rasterio import windows
from rasterio.transform import rowcol
from rasterio.warp import Resampling, reproject
from gap_validation.s2_mask_dir import acquisition_yyyymmdd_in_window, yyyymmdd_from_iso
REFL_DATE_RE = re.compile(r"S2A_MSIL2A_(\d{8})_REFL\.tif$")
S3_COMPOSITE_RE = re.compile(r"composite_(\d{8})\.tif$")
TRANSITIONS = ("green_up", "green_down")
COL_TITLES = ("Withheld S2", "Gap fusion", "S3 composite", "Nearest S2")
ROW_LABELS = {"green_up": "Green-up", "green_down": "Green-down"}
VALID_REFL_THRESHOLD = 0.001
NODATA_RGB = (0.15, 0.15, 0.15)
def _parse_bti_scenario(scenario: str) -> tuple[str, int]:
m = re.match(r"^(aggressive|nonaggressive)_sigma(20|30)$", scenario)
if not m:
raise ValueError(f"expected BtI scenario key, got {scenario!r}")
return m.group(1), int(m.group(2))
def _prepared_base(data_dir: Path, site: str, season: int, strategy: str) -> Path:
return data_dir / site / str(season) / f"prepared_{strategy}"
def _s2_strategy_fallbacks(strategy: str, manifest: dict) -> tuple[str, ...]:
"""Prepared trees to try for S2 REFL (best-BtI first, then manifest calendar)."""
order: list[str] = []
for s in (strategy, manifest.get("s2_calendar_strategy")):
if isinstance(s, str) and s and s not in order:
order.append(s)
for s in ("aggressive", "nonaggressive"):
if s not in order:
order.append(s)
return tuple(order)
def _find_prepared_s2_refl(
data_dir: Path,
site: str,
season: int,
filename: str,
strategies: tuple[str, ...],
) -> Path | None:
for strat in strategies:
p = _prepared_base(data_dir, site, season, strat) / "s2" / filename
if p.is_file():
return p
return None
def _gap_spatial_fusion_dir(
data_dir: Path,
site: str,
season: int,
gap_days: int,
transition: str,
strategy: str,
sigma: int,
) -> Path:
return (
data_dir
/ site
/ str(season)
/ "validation"
/ "fusion"
/ f"gap_{gap_days}_{transition}"
/ f"{strategy}_sigma{sigma}_bti"
)
def _iso_to_date(iso_d: str) -> date:
return datetime.strptime(iso_d[:10], "%Y-%m-%d").date()
def _exclude_ymds(entry: dict) -> set[str]:
withheld_fn = entry.get("withheld_s2_filename") or ""
m = REFL_DATE_RE.search(withheld_fn)
return {m.group(1)} if m else set()
def nearest_stack_s2(
prepared_s2_dir: Path,
prediction_iso: str,
*,
exclude_ymds: set[str],
) -> Path | None:
if not prepared_s2_dir.is_dir():
return None
target = _iso_to_date(prediction_iso)
best_path: Path | None = None
best_delta: int | None = None
for p in prepared_s2_dir.glob("S2A_MSIL2A_*_REFL.tif"):
m = REFL_DATE_RE.search(p.name)
if not m or m.group(1) in exclude_ymds:
continue
delta = abs((datetime.strptime(m.group(1), "%Y%m%d").date() - target).days)
if best_delta is None or delta < best_delta:
best_delta = delta
best_path = p
return best_path
def nearest_s3_composite(prepared_s3_dir: Path, prediction_iso: str) -> Path | None:
if not prepared_s3_dir.is_dir():
return None
target = _iso_to_date(prediction_iso)
best_path: Path | None = None
best_delta: int | None = None
for p in prepared_s3_dir.glob("composite_*.tif"):
m = S3_COMPOSITE_RE.search(p.name)
if not m:
continue
delta = abs((datetime.strptime(m.group(1), "%Y%m%d").date() - target).days)
if best_delta is None or delta < best_delta:
best_delta = delta
best_path = p
return best_path
def _crop_window_from_fusion(fusion_path: Path) -> dict | None:
"""Fusion-valid crop (``postprocessing.process_cropped``) on the full prepared grid."""
if not fusion_path.is_file():
return None
with rasterio.open(fusion_path) as src:
data = src.read()
valid = np.isfinite(data) & (data > VALID_REFL_THRESHOLD)
rows = np.any(valid, axis=(0, 2))
cols = np.any(valid, axis=(0, 1))
row_idx = np.where(rows)[0]
col_idx = np.where(cols)[0]
if len(row_idx) == 0 or len(col_idx) == 0:
return None
r0, r1 = int(row_idx[0]), int(row_idx[-1])
c0, c1 = int(col_idx[0]), int(col_idx[-1])
w, h = c1 - c0 + 1, r1 - r0 + 1
win = windows.Window(c0, r0, w, h)
return {
"window": win,
"crop_transform": windows.transform(win, src.transform),
"full_transform": src.transform,
"crs": src.crs,
"profile": src.profile.copy(),
}
def _read_bgr_prepared_s2(prepared_refl: Path, crop: dict) -> tuple[np.ndarray, ...] | None:
if not prepared_refl.is_file():
return None
with rasterio.open(prepared_refl) as src:
if src.count < 3:
return None
b, g, r = src.read(indexes=(1, 2, 3), window=crop["window"])
return b.astype(np.float64), g.astype(np.float64), r.astype(np.float64)
def _read_bgr_gap_fusion(fusion_path: Path, crop: dict) -> tuple[np.ndarray, ...] | None:
if not fusion_path.is_file():
return None
with rasterio.open(fusion_path) as src:
if src.count < 3:
return None
b, g, r = src.read(indexes=(1, 2, 3), window=crop["window"])
return b.astype(np.float64), g.astype(np.float64), r.astype(np.float64)
def _read_bgr_prepared_s3(s3_path: Path, crop: dict) -> tuple[np.ndarray, ...] | None:
"""Resample S3 composite to the fusion grid, then crop (matches ``process_cropped``)."""
if not s3_path.is_file():
return None
with rasterio.open(s3_path) as src:
if src.count < 3:
return None
temp_profile = crop["profile"].copy()
temp_profile.update({"dtype": "float32", "count": src.count})
bands: list[np.ndarray] = []
with rasterio.MemoryFile() as memfile:
with memfile.open(**temp_profile) as resampled:
for i in range(1, src.count + 1):
reproject(
source=rasterio.band(src, i),
destination=rasterio.band(resampled, i),
src_transform=src.transform,
src_crs=src.crs,
dst_transform=crop["full_transform"],
dst_crs=crop["crs"],
resampling=Resampling.nearest,
)
b, g, r = resampled.read(
indexes=(1, 2, 3), window=crop["window"]
)
bands = [
b.astype(np.float64),
g.astype(np.float64),
r.astype(np.float64),
]
return bands[0], bands[1], bands[2]
def _refl_valid(blue: np.ndarray, green: np.ndarray, red: np.ndarray) -> np.ndarray:
return (
np.isfinite(blue)
& np.isfinite(green)
& np.isfinite(red)
& (blue > VALID_REFL_THRESHOLD)
& (green > VALID_REFL_THRESHOLD)
& (red > VALID_REFL_THRESHOLD)
)
def _panel_stretch_limits(
blue: np.ndarray, green: np.ndarray, red: np.ndarray, valid: np.ndarray
) -> tuple[float, float]:
"""Per-panel 2--98 % stretch on positive reflectance (webapp ``common.js`` style)."""
if not valid.any():
return 0.0, 1.0
vals = np.concatenate([red[valid], green[valid], blue[valid]])
lo, hi = np.percentile(vals, (2, 98))
if hi <= lo:
return 0.0, 1.0
return float(lo), float(hi)
def _bgr_to_rgba(
blue: np.ndarray,
green: np.ndarray,
red: np.ndarray,
*,
valid: np.ndarray,
vmin: float,
vmax: float,
) -> np.ndarray:
rgba = np.zeros((*blue.shape, 4), dtype=np.float32)
rgba[..., 3] = 1.0
rgba[~valid, 0] = NODATA_RGB[0]
rgba[~valid, 1] = NODATA_RGB[1]
rgba[~valid, 2] = NODATA_RGB[2]
span = vmax - vmin or 1.0
for band, idx in ((red, 0), (green, 1), (blue, 2)):
norm = np.clip((band - vmin) / span, 0.0, 1.0)
rgba[..., idx] = np.where(valid, norm, rgba[..., idx])
return rgba
def _phenocam_pixel_cropped(
crop: dict, site_position_lat_lon: tuple[float, float]
) -> tuple[int, int] | None:
lat, lon = site_position_lat_lon
try:
r, c = rowcol(
crop["crop_transform"], [lon], [lat], op=crop["crs"]
)
return int(r[0]), int(c[0])
except Exception:
return None
def _resolve_row_paths(
data_dir: Path,
site: str,
season: int,
entry: dict,
strategy: str,
sigma: int,
*,
gap_days: int,
manifest: dict,
) -> tuple[Path, Path, Path, Path] | None:
pred_ymd = yyyymmdd_from_iso(entry["prediction_date"])
transition = entry["transition"]
prep = _prepared_base(data_dir, site, season, strategy)
s2_strats = _s2_strategy_fallbacks(strategy, manifest)
withheld_fn = entry.get("withheld_s2_filename")
if not withheld_fn:
return None
withheld = _find_prepared_s2_refl(
data_dir, site, season, withheld_fn, s2_strats
)
fusion = (
_gap_spatial_fusion_dir(data_dir, site, season, gap_days, transition, strategy, sigma)
/ f"REFL_{pred_ymd}.tif"
)
s3_exact = prep / "s3" / f"composite_{pred_ymd}.tif"
s3 = (
s3_exact
if s3_exact.is_file()
else nearest_s3_composite(prep / "s3", entry["prediction_date"])
)
w0 = _iso_to_date(entry["window_start"])
w1 = _iso_to_date(entry["window_end"])
nearest: Path | None = None
for strat in s2_strats:
prep_s2 = _prepared_base(data_dir, site, season, strat) / "s2"
window_ymds = acquisition_yyyymmdd_in_window(prep_s2, w0, w1)
exclude = window_ymds | _exclude_ymds(entry)
nearest = nearest_stack_s2(
prep_s2, entry["prediction_date"], exclude_ymds=exclude
)
if nearest is not None:
break
if withheld is None or not fusion.is_file() or s3 is None or nearest is None:
return None
return withheld, fusion, s3, nearest
def build_site_panel(
site: str,
season: int,
data_dir: Path,
out_png: Path,
*,
best_bti_scenario: str,
site_label: str,
site_position_lat_lon: tuple[float, float] | None = None,
gap_days: int = 30,
) -> bool:
"""Build 2×4 RGB figure; return False if manifest or any transition row is incomplete."""
manifest_path = data_dir / site / str(season) / "validation" / "gap_manifest.json"
if not manifest_path.is_file():
return False
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
strategy, sigma = _parse_bti_scenario(best_bti_scenario)
rows: list[tuple[str, dict, tuple[Path, Path, Path, Path]]] = []
for transition in TRANSITIONS:
entry = next(
(
e
for e in manifest["entries"]
if e.get("gap_days") == gap_days and e.get("transition") == transition
),
None,
)
if not entry:
continue
paths = _resolve_row_paths(
data_dir,
site,
season,
entry,
strategy,
sigma,
gap_days=gap_days,
manifest=manifest,
)
if paths is None:
continue
rows.append((transition, entry, paths))
if not rows:
return False
readers = (
_read_bgr_prepared_s2,
_read_bgr_gap_fusion,
_read_bgr_prepared_s3,
_read_bgr_prepared_s2,
)
fig, axes = plt.subplots(
len(rows),
4,
figsize=(12.0, 2.8 * len(rows)),
squeeze=False,
constrained_layout=True,
)
for row_idx, (transition, entry, paths) in enumerate(rows):
row_title = ROW_LABELS.get(transition, transition)
crop = _crop_window_from_fusion(paths[1])
if crop is None:
for ax in axes[row_idx]:
ax.set_visible(False)
continue
layers: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = []
for path, read_fn in zip(paths, readers, strict=True):
bgr = read_fn(path, crop)
if bgr is None:
layers = []
break
layers.append(bgr)
if len(layers) != 4:
for ax in axes[row_idx]:
ax.set_visible(False)
continue
mark: tuple[int, int] | None = None
if site_position_lat_lon:
mark = _phenocam_pixel_cropped(crop, site_position_lat_lon)
for col_idx, (col_title, bgr) in enumerate(zip(COL_TITLES, layers, strict=True)):
ax = axes[row_idx, col_idx]
blue, green, red = bgr
valid = _refl_valid(blue, green, red)
vmin, vmax = _panel_stretch_limits(blue, green, red, valid)
rgba = _bgr_to_rgba(
blue, green, red, valid=valid, vmin=vmin, vmax=vmax
)
ax.imshow(rgba, origin="upper", aspect="equal", interpolation="nearest")
h, w = rgba.shape[:2]
if col_idx == 0 and mark and 0 <= mark[0] < h and 0 <= mark[1] < w:
ax.plot(
mark[1],
mark[0],
"+",
color="red",
markersize=8,
markeredgewidth=1.2,
)
if row_idx == 0:
ax.set_title(col_title, fontsize=9)
if col_idx == 0:
ax.set_ylabel(row_title, fontsize=9)
ax.set_xticks([])
ax.set_yticks([])
fig.suptitle(f"{site_label} ({season})", fontsize=10)
out_png.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_png, dpi=150)
plt.close(fig)
return True