438 lines
14 KiB
Python
438 lines
14 KiB
Python
"""Export 2×4 RGB panels for Tier-A gap validation (thesis appendix).
|
||
|
||
Crops follow the same fusion-valid bounding box as ``postprocessing.process_cropped``
|
||
and the webapp (``processed_*`` / ``common.js``), anchored on gap-degraded fusion at the
|
||
prediction date; S2 and S3 are read from prepared stacks on that shared window.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import re
|
||
from datetime import date, datetime
|
||
from pathlib import Path
|
||
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import rasterio
|
||
from rasterio import windows
|
||
from rasterio.transform import rowcol
|
||
from rasterio.warp import Resampling, reproject
|
||
|
||
from gap_validation.s2_mask_dir import acquisition_yyyymmdd_in_window, yyyymmdd_from_iso
|
||
|
||
REFL_DATE_RE = re.compile(r"S2A_MSIL2A_(\d{8})_REFL\.tif$")
|
||
S3_COMPOSITE_RE = re.compile(r"composite_(\d{8})\.tif$")
|
||
TRANSITIONS = ("green_up", "green_down")
|
||
COL_TITLES = ("Withheld S2", "Gap fusion", "S3 composite", "Nearest S2")
|
||
ROW_LABELS = {"green_up": "Green-up", "green_down": "Green-down"}
|
||
VALID_REFL_THRESHOLD = 0.001
|
||
NODATA_RGB = (0.15, 0.15, 0.15)
|
||
|
||
|
||
def _parse_bti_scenario(scenario: str) -> tuple[str, int]:
|
||
m = re.match(r"^(aggressive|nonaggressive)_sigma(20|30)$", scenario)
|
||
if not m:
|
||
raise ValueError(f"expected BtI scenario key, got {scenario!r}")
|
||
return m.group(1), int(m.group(2))
|
||
|
||
|
||
def _prepared_base(data_dir: Path, site: str, season: int, strategy: str) -> Path:
|
||
return data_dir / site / str(season) / f"prepared_{strategy}"
|
||
|
||
|
||
def _s2_strategy_fallbacks(strategy: str, manifest: dict) -> tuple[str, ...]:
|
||
"""Prepared trees to try for S2 REFL (best-BtI first, then manifest calendar)."""
|
||
order: list[str] = []
|
||
for s in (strategy, manifest.get("s2_calendar_strategy")):
|
||
if isinstance(s, str) and s and s not in order:
|
||
order.append(s)
|
||
for s in ("aggressive", "nonaggressive"):
|
||
if s not in order:
|
||
order.append(s)
|
||
return tuple(order)
|
||
|
||
|
||
def _find_prepared_s2_refl(
|
||
data_dir: Path,
|
||
site: str,
|
||
season: int,
|
||
filename: str,
|
||
strategies: tuple[str, ...],
|
||
) -> Path | None:
|
||
for strat in strategies:
|
||
p = _prepared_base(data_dir, site, season, strat) / "s2" / filename
|
||
if p.is_file():
|
||
return p
|
||
return None
|
||
|
||
|
||
def _gap_spatial_fusion_dir(
|
||
data_dir: Path,
|
||
site: str,
|
||
season: int,
|
||
gap_days: int,
|
||
transition: str,
|
||
strategy: str,
|
||
sigma: int,
|
||
) -> Path:
|
||
return (
|
||
data_dir
|
||
/ site
|
||
/ str(season)
|
||
/ "validation"
|
||
/ "fusion"
|
||
/ f"gap_{gap_days}_{transition}"
|
||
/ f"{strategy}_sigma{sigma}_bti"
|
||
)
|
||
|
||
|
||
def _iso_to_date(iso_d: str) -> date:
|
||
return datetime.strptime(iso_d[:10], "%Y-%m-%d").date()
|
||
|
||
|
||
def _exclude_ymds(entry: dict) -> set[str]:
|
||
withheld_fn = entry.get("withheld_s2_filename") or ""
|
||
m = REFL_DATE_RE.search(withheld_fn)
|
||
return {m.group(1)} if m else set()
|
||
|
||
|
||
def nearest_stack_s2(
|
||
prepared_s2_dir: Path,
|
||
prediction_iso: str,
|
||
*,
|
||
exclude_ymds: set[str],
|
||
) -> Path | None:
|
||
if not prepared_s2_dir.is_dir():
|
||
return None
|
||
target = _iso_to_date(prediction_iso)
|
||
best_path: Path | None = None
|
||
best_delta: int | None = None
|
||
for p in prepared_s2_dir.glob("S2A_MSIL2A_*_REFL.tif"):
|
||
m = REFL_DATE_RE.search(p.name)
|
||
if not m or m.group(1) in exclude_ymds:
|
||
continue
|
||
delta = abs((datetime.strptime(m.group(1), "%Y%m%d").date() - target).days)
|
||
if best_delta is None or delta < best_delta:
|
||
best_delta = delta
|
||
best_path = p
|
||
return best_path
|
||
|
||
|
||
def nearest_s3_composite(prepared_s3_dir: Path, prediction_iso: str) -> Path | None:
|
||
if not prepared_s3_dir.is_dir():
|
||
return None
|
||
target = _iso_to_date(prediction_iso)
|
||
best_path: Path | None = None
|
||
best_delta: int | None = None
|
||
for p in prepared_s3_dir.glob("composite_*.tif"):
|
||
m = S3_COMPOSITE_RE.search(p.name)
|
||
if not m:
|
||
continue
|
||
delta = abs((datetime.strptime(m.group(1), "%Y%m%d").date() - target).days)
|
||
if best_delta is None or delta < best_delta:
|
||
best_delta = delta
|
||
best_path = p
|
||
return best_path
|
||
|
||
|
||
def _crop_window_from_fusion(fusion_path: Path) -> dict | None:
|
||
"""Fusion-valid crop (``postprocessing.process_cropped``) on the full prepared grid."""
|
||
if not fusion_path.is_file():
|
||
return None
|
||
with rasterio.open(fusion_path) as src:
|
||
data = src.read()
|
||
valid = np.isfinite(data) & (data > VALID_REFL_THRESHOLD)
|
||
rows = np.any(valid, axis=(0, 2))
|
||
cols = np.any(valid, axis=(0, 1))
|
||
row_idx = np.where(rows)[0]
|
||
col_idx = np.where(cols)[0]
|
||
if len(row_idx) == 0 or len(col_idx) == 0:
|
||
return None
|
||
r0, r1 = int(row_idx[0]), int(row_idx[-1])
|
||
c0, c1 = int(col_idx[0]), int(col_idx[-1])
|
||
w, h = c1 - c0 + 1, r1 - r0 + 1
|
||
win = windows.Window(c0, r0, w, h)
|
||
return {
|
||
"window": win,
|
||
"crop_transform": windows.transform(win, src.transform),
|
||
"full_transform": src.transform,
|
||
"crs": src.crs,
|
||
"profile": src.profile.copy(),
|
||
}
|
||
|
||
|
||
def _read_bgr_prepared_s2(prepared_refl: Path, crop: dict) -> tuple[np.ndarray, ...] | None:
|
||
if not prepared_refl.is_file():
|
||
return None
|
||
with rasterio.open(prepared_refl) as src:
|
||
if src.count < 3:
|
||
return None
|
||
b, g, r = src.read(indexes=(1, 2, 3), window=crop["window"])
|
||
return b.astype(np.float64), g.astype(np.float64), r.astype(np.float64)
|
||
|
||
|
||
def _read_bgr_gap_fusion(fusion_path: Path, crop: dict) -> tuple[np.ndarray, ...] | None:
|
||
if not fusion_path.is_file():
|
||
return None
|
||
with rasterio.open(fusion_path) as src:
|
||
if src.count < 3:
|
||
return None
|
||
b, g, r = src.read(indexes=(1, 2, 3), window=crop["window"])
|
||
return b.astype(np.float64), g.astype(np.float64), r.astype(np.float64)
|
||
|
||
|
||
def _read_bgr_prepared_s3(s3_path: Path, crop: dict) -> tuple[np.ndarray, ...] | None:
|
||
"""Resample S3 composite to the fusion grid, then crop (matches ``process_cropped``)."""
|
||
if not s3_path.is_file():
|
||
return None
|
||
with rasterio.open(s3_path) as src:
|
||
if src.count < 3:
|
||
return None
|
||
temp_profile = crop["profile"].copy()
|
||
temp_profile.update({"dtype": "float32", "count": src.count})
|
||
bands: list[np.ndarray] = []
|
||
with rasterio.MemoryFile() as memfile:
|
||
with memfile.open(**temp_profile) as resampled:
|
||
for i in range(1, src.count + 1):
|
||
reproject(
|
||
source=rasterio.band(src, i),
|
||
destination=rasterio.band(resampled, i),
|
||
src_transform=src.transform,
|
||
src_crs=src.crs,
|
||
dst_transform=crop["full_transform"],
|
||
dst_crs=crop["crs"],
|
||
resampling=Resampling.nearest,
|
||
)
|
||
b, g, r = resampled.read(
|
||
indexes=(1, 2, 3), window=crop["window"]
|
||
)
|
||
bands = [
|
||
b.astype(np.float64),
|
||
g.astype(np.float64),
|
||
r.astype(np.float64),
|
||
]
|
||
return bands[0], bands[1], bands[2]
|
||
|
||
|
||
def _refl_valid(blue: np.ndarray, green: np.ndarray, red: np.ndarray) -> np.ndarray:
|
||
return (
|
||
np.isfinite(blue)
|
||
& np.isfinite(green)
|
||
& np.isfinite(red)
|
||
& (blue > VALID_REFL_THRESHOLD)
|
||
& (green > VALID_REFL_THRESHOLD)
|
||
& (red > VALID_REFL_THRESHOLD)
|
||
)
|
||
|
||
|
||
def _panel_stretch_limits(
|
||
blue: np.ndarray, green: np.ndarray, red: np.ndarray, valid: np.ndarray
|
||
) -> tuple[float, float]:
|
||
"""Per-panel 2--98 % stretch on positive reflectance (webapp ``common.js`` style)."""
|
||
if not valid.any():
|
||
return 0.0, 1.0
|
||
vals = np.concatenate([red[valid], green[valid], blue[valid]])
|
||
lo, hi = np.percentile(vals, (2, 98))
|
||
if hi <= lo:
|
||
return 0.0, 1.0
|
||
return float(lo), float(hi)
|
||
|
||
|
||
def _bgr_to_rgba(
|
||
blue: np.ndarray,
|
||
green: np.ndarray,
|
||
red: np.ndarray,
|
||
*,
|
||
valid: np.ndarray,
|
||
vmin: float,
|
||
vmax: float,
|
||
) -> np.ndarray:
|
||
rgba = np.zeros((*blue.shape, 4), dtype=np.float32)
|
||
rgba[..., 3] = 1.0
|
||
rgba[~valid, 0] = NODATA_RGB[0]
|
||
rgba[~valid, 1] = NODATA_RGB[1]
|
||
rgba[~valid, 2] = NODATA_RGB[2]
|
||
span = vmax - vmin or 1.0
|
||
for band, idx in ((red, 0), (green, 1), (blue, 2)):
|
||
norm = np.clip((band - vmin) / span, 0.0, 1.0)
|
||
rgba[..., idx] = np.where(valid, norm, rgba[..., idx])
|
||
return rgba
|
||
|
||
|
||
def _phenocam_pixel_cropped(
|
||
crop: dict, site_position_lat_lon: tuple[float, float]
|
||
) -> tuple[int, int] | None:
|
||
lat, lon = site_position_lat_lon
|
||
try:
|
||
r, c = rowcol(
|
||
crop["crop_transform"], [lon], [lat], op=crop["crs"]
|
||
)
|
||
return int(r[0]), int(c[0])
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
def _resolve_row_paths(
|
||
data_dir: Path,
|
||
site: str,
|
||
season: int,
|
||
entry: dict,
|
||
strategy: str,
|
||
sigma: int,
|
||
*,
|
||
gap_days: int,
|
||
manifest: dict,
|
||
) -> tuple[Path, Path, Path, Path] | None:
|
||
pred_ymd = yyyymmdd_from_iso(entry["prediction_date"])
|
||
transition = entry["transition"]
|
||
prep = _prepared_base(data_dir, site, season, strategy)
|
||
s2_strats = _s2_strategy_fallbacks(strategy, manifest)
|
||
withheld_fn = entry.get("withheld_s2_filename")
|
||
if not withheld_fn:
|
||
return None
|
||
withheld = _find_prepared_s2_refl(
|
||
data_dir, site, season, withheld_fn, s2_strats
|
||
)
|
||
fusion = (
|
||
_gap_spatial_fusion_dir(data_dir, site, season, gap_days, transition, strategy, sigma)
|
||
/ f"REFL_{pred_ymd}.tif"
|
||
)
|
||
s3_exact = prep / "s3" / f"composite_{pred_ymd}.tif"
|
||
s3 = (
|
||
s3_exact
|
||
if s3_exact.is_file()
|
||
else nearest_s3_composite(prep / "s3", entry["prediction_date"])
|
||
)
|
||
w0 = _iso_to_date(entry["window_start"])
|
||
w1 = _iso_to_date(entry["window_end"])
|
||
nearest: Path | None = None
|
||
for strat in s2_strats:
|
||
prep_s2 = _prepared_base(data_dir, site, season, strat) / "s2"
|
||
window_ymds = acquisition_yyyymmdd_in_window(prep_s2, w0, w1)
|
||
exclude = window_ymds | _exclude_ymds(entry)
|
||
nearest = nearest_stack_s2(
|
||
prep_s2, entry["prediction_date"], exclude_ymds=exclude
|
||
)
|
||
if nearest is not None:
|
||
break
|
||
if withheld is None or not fusion.is_file() or s3 is None or nearest is None:
|
||
return None
|
||
return withheld, fusion, s3, nearest
|
||
|
||
|
||
def build_site_panel(
|
||
site: str,
|
||
season: int,
|
||
data_dir: Path,
|
||
out_png: Path,
|
||
*,
|
||
best_bti_scenario: str,
|
||
site_label: str,
|
||
site_position_lat_lon: tuple[float, float] | None = None,
|
||
gap_days: int = 30,
|
||
) -> bool:
|
||
"""Build 2×4 RGB figure; return False if manifest or any transition row is incomplete."""
|
||
manifest_path = data_dir / site / str(season) / "validation" / "gap_manifest.json"
|
||
if not manifest_path.is_file():
|
||
return False
|
||
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||
strategy, sigma = _parse_bti_scenario(best_bti_scenario)
|
||
rows: list[tuple[str, dict, tuple[Path, Path, Path, Path]]] = []
|
||
for transition in TRANSITIONS:
|
||
entry = next(
|
||
(
|
||
e
|
||
for e in manifest["entries"]
|
||
if e.get("gap_days") == gap_days and e.get("transition") == transition
|
||
),
|
||
None,
|
||
)
|
||
if not entry:
|
||
continue
|
||
paths = _resolve_row_paths(
|
||
data_dir,
|
||
site,
|
||
season,
|
||
entry,
|
||
strategy,
|
||
sigma,
|
||
gap_days=gap_days,
|
||
manifest=manifest,
|
||
)
|
||
if paths is None:
|
||
continue
|
||
rows.append((transition, entry, paths))
|
||
|
||
if not rows:
|
||
return False
|
||
|
||
readers = (
|
||
_read_bgr_prepared_s2,
|
||
_read_bgr_gap_fusion,
|
||
_read_bgr_prepared_s3,
|
||
_read_bgr_prepared_s2,
|
||
)
|
||
|
||
fig, axes = plt.subplots(
|
||
len(rows),
|
||
4,
|
||
figsize=(12.0, 2.8 * len(rows)),
|
||
squeeze=False,
|
||
constrained_layout=True,
|
||
)
|
||
for row_idx, (transition, entry, paths) in enumerate(rows):
|
||
row_title = ROW_LABELS.get(transition, transition)
|
||
crop = _crop_window_from_fusion(paths[1])
|
||
if crop is None:
|
||
for ax in axes[row_idx]:
|
||
ax.set_visible(False)
|
||
continue
|
||
|
||
layers: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = []
|
||
for path, read_fn in zip(paths, readers, strict=True):
|
||
bgr = read_fn(path, crop)
|
||
if bgr is None:
|
||
layers = []
|
||
break
|
||
layers.append(bgr)
|
||
if len(layers) != 4:
|
||
for ax in axes[row_idx]:
|
||
ax.set_visible(False)
|
||
continue
|
||
|
||
mark: tuple[int, int] | None = None
|
||
if site_position_lat_lon:
|
||
mark = _phenocam_pixel_cropped(crop, site_position_lat_lon)
|
||
|
||
for col_idx, (col_title, bgr) in enumerate(zip(COL_TITLES, layers, strict=True)):
|
||
ax = axes[row_idx, col_idx]
|
||
blue, green, red = bgr
|
||
valid = _refl_valid(blue, green, red)
|
||
vmin, vmax = _panel_stretch_limits(blue, green, red, valid)
|
||
rgba = _bgr_to_rgba(
|
||
blue, green, red, valid=valid, vmin=vmin, vmax=vmax
|
||
)
|
||
ax.imshow(rgba, origin="upper", aspect="equal", interpolation="nearest")
|
||
h, w = rgba.shape[:2]
|
||
if col_idx == 0 and mark and 0 <= mark[0] < h and 0 <= mark[1] < w:
|
||
ax.plot(
|
||
mark[1],
|
||
mark[0],
|
||
"+",
|
||
color="red",
|
||
markersize=8,
|
||
markeredgewidth=1.2,
|
||
)
|
||
if row_idx == 0:
|
||
ax.set_title(col_title, fontsize=9)
|
||
if col_idx == 0:
|
||
ax.set_ylabel(row_title, fontsize=9)
|
||
ax.set_xticks([])
|
||
ax.set_yticks([])
|
||
|
||
fig.suptitle(f"{site_label} ({season})", fontsize=10)
|
||
out_png.parent.mkdir(parents=True, exist_ok=True)
|
||
fig.savefig(out_png, dpi=150)
|
||
plt.close(fig)
|
||
return True
|