"""Step 7: PhenoCam GCC suitability as a fusion-accuracy reference. Inputs (``data/``, ``{year}`` = ``--evaluation-year``): - ``metrics/{year}/{site}/gcc_s2.json``, ``gcc_phenocam.json`` — Step 5 timeseries - ``metrics/manifest.json`` — site lat/lon - ``sentinel_data/{year}/{site}/prepared/s2/`` — S2 REFL/GCC/DIST_CLOUD (Steps 3–4) - ``sentinel_data/{year}/{site}/prepared/gcc_s3/``, ``prepared/s3_rgb/`` — Step 4 Outputs (``data/gcc_suitability/``): - ``{year}.json`` — representativeness (Line A), LOOCV concordance (Line B), per-site and aggregate suitability verdict CLI: - ``--evaluation-year`` (default 2025) - ``--min-cloudfree-s2`` (default 10) — minimum cloud-free S2 dates for LOOCV - ``--alpha`` (default 0.05) — reserved for future significance tests Full-sample aggregate; does not accept ``--site``. """ from __future__ import annotations import argparse import json import re import shutil import tempfile from datetime import datetime from pathlib import Path from typing import Any import numpy as np import rasterio from pyproj import Transformer from rasterio.crs import CRS from rasterio.transform import rowcol from scipy.stats import linregress, pearsonr, spearmanr from tqdm import tqdm # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- DATA_DIR = Path("data") DEFAULT_YEAR = 2025 DEFAULT_ALPHA = 0.05 MIN_CLOUDFREE_S2 = 10 REPR_R_THRESHOLD = 0.7 MATCH_TOLERANCE_DAYS = 5 RESOLUTION_RATIO = 30 MAX_DAYS = 100 MINIMUM_ACQUISITION_IMPORTANCE = 0 SMALL_SAMPLE_SITES = 6 # --------------------------------------------------------------------------- # efast import # --------------------------------------------------------------------------- def _import_efast(): try: import efast.efast as efast_module return efast_module except ImportError as exc: raise ImportError("efast not found. Install with: uv sync") from exc # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _r4(v: float | None) -> float | None: return round(v, 4) if v is not None else None def _window_mean(data: np.ndarray) -> float | None: valid = data[~np.isnan(data)] if valid.size == 0: return None return float(np.mean(valid)) def _read_center_pixel(path: Path, lat: float, lon: float) -> float | None: try: with rasterio.open(path) as src: transformer = Transformer.from_crs( CRS.from_epsg(4326), src.crs, always_xy=True ) x, y = transformer.transform(lon, lat) row, col = rowcol(src.transform, x, y) h, w = src.height, src.width r0, r1 = max(0, row - 1), min(h, row + 2) c0, c1 = max(0, col - 1), min(w, col + 2) window = rasterio.windows.Window(c0, r0, c1 - c0, r1 - r0) data = src.read(1, window=window).astype(float) nodata = src.nodata if nodata is not None: data = np.where(data == nodata, np.nan, data) data[data == 0] = np.nan return _window_mean(data) except Exception: return None def _date_from_s2_tif(path: Path) -> str | None: parts = path.stem.split("_") if len(parts) >= 3: m = re.match(r"(\d{8})", parts[2]) return m.group(1) if m else None return None def _iso_to_yyyymmdd(iso: str) -> str: return iso.replace("-", "") def _yyyymmdd_to_iso(d: str) -> str: return f"{d[:4]}-{d[4:6]}-{d[6:]}" def _day_gap(a: str, b: str) -> float: return abs((np.datetime64(a) - np.datetime64(b)) / np.timedelta64(1, "D")) def _match_series( ref: list[dict], ref_key: str, pred: list[dict], pred_key: str, tolerance_days: int = MATCH_TOLERANCE_DAYS, ) -> tuple[list[float], list[float], list[str]]: """Return paired (ref_vals, pred_vals, ref_dates) within tolerance.""" ref_lookup: dict[str, float] = { p["date"]: p[ref_key] for p in ref if p.get(ref_key) is not None } if not ref_lookup: return [], [], [] ref_dates = sorted(ref_lookup) obs, sim, matched_dates = [], [], [] for pt in pred: v = pt.get(pred_key) if v is None: continue nearest = min(ref_dates, key=lambda d: _day_gap(pt["date"], d)) if _day_gap(pt["date"], nearest) <= tolerance_days: obs.append(ref_lookup[nearest]) sim.append(v) matched_dates.append(nearest) return obs, sim, matched_dates def _accuracy_metrics(obs: list[float], sim: list[float]) -> dict[str, Any] | None: if len(obs) < 2: return None obs_arr = np.array(obs, dtype=float) sim_arr = np.array(sim, dtype=float) diff = sim_arr - obs_arr rmse = float(np.sqrt(np.mean(diff**2))) mae = float(np.mean(np.abs(diff))) bias = float(np.mean(diff)) r, _ = pearsonr(obs_arr, sim_arr) return { "n": len(obs), "rmse": _r4(rmse), "mae": _r4(mae), "bias": _r4(bias), "r": _r4(float(r)), } def _gcc_from_refl_file(refl_path: Path, gcc_path: Path) -> None: with rasterio.open(refl_path) as src: b, g, r = src.read(1), src.read(2), src.read(3) profile = src.profile total = b + g + r invalid = (b < 0) | (g < 0) | (r < 0) gcc = np.where(invalid, np.nan, g / (total + 1e-10)) gcc[total == 0] = np.nan profile.update(count=1, dtype="float32") with rasterio.open(gcc_path, "w", **profile) as dst: dst.write(gcc[np.newaxis].astype("float32")) def _load_json_series(path: Path) -> list[dict]: if not path.is_file(): return [] return json.loads(path.read_text()) def _load_site_coords(year: int) -> dict[str, tuple[float, float]]: manifest_path = DATA_DIR / "metrics" / "manifest.json" if not manifest_path.is_file(): return {} manifest = json.loads(manifest_path.read_text()) sites = manifest.get("sites", {}).get(str(year), {}) coords: dict[str, tuple[float, float]] = {} for site, meta in sites.items(): lat, lon = meta.get("lat"), meta.get("lon") if lat is not None and lon is not None: coords[site] = (float(lat), float(lon)) return coords def _discover_sites(year: int) -> list[str]: metrics_dir = DATA_DIR / "metrics" / str(year) if not metrics_dir.is_dir(): return [] return sorted( d.name for d in metrics_dir.iterdir() if d.is_dir() and (d / "gcc_s2.json").is_file() ) def _build_hr_symlink_dir(s2_dir: Path, holdout_yyyymmdd: str, dest: Path) -> None: """Symlink all S2 hr inputs except the held-out acquisition date.""" dest.mkdir(parents=True, exist_ok=True) for pattern in ("*_REFL.tif", "*_GCC.tif", "*_DIST_CLOUD.tif"): for src in sorted(s2_dir.glob(pattern)): date_token = ( src.stem.split("_")[2][:8] if len(src.stem.split("_")) >= 3 else "" ) if date_token == holdout_yyyymmdd: continue link = dest / src.name if link.exists() or link.is_symlink(): link.unlink() link.symlink_to(src.resolve()) # --------------------------------------------------------------------------- # Line A — representativeness # --------------------------------------------------------------------------- def compute_representativeness(phenocam: list[dict], s2: list[dict]) -> dict[str, Any]: """PhenoCam gcc_90 vs co-located observed S2 GCC.""" obs, sim, _ = _match_series(phenocam, "gcc_90", s2, "gcc") result: dict[str, Any] = { "n": len(obs), "r": None, "spearman": None, "slope": None, "intercept": None, "rmse": None, "bias": None, "peak_offset_days": None, "representative": False, } if len(obs) < 2: return result obs_arr = np.array(obs, dtype=float) sim_arr = np.array(sim, dtype=float) r, _ = pearsonr(obs_arr, sim_arr) sp, _ = spearmanr(obs_arr, sim_arr) reg = linregress(sim_arr, obs_arr) diff = sim_arr - obs_arr result.update( { "r": _r4(float(r)), "spearman": _r4(float(sp)), "slope": _r4(float(reg.slope)), "intercept": _r4(float(reg.intercept)), "rmse": _r4(float(np.sqrt(np.mean(diff**2)))), "bias": _r4(float(np.mean(diff))), "representative": float(r) >= REPR_R_THRESHOLD, } ) pc_dates = [p["date"] for p in phenocam if p.get("gcc_90") is not None] s2_dates = [p["date"] for p in s2 if p.get("gcc") is not None] if pc_dates and s2_dates: pc_peak = max( phenocam, key=lambda p: p["gcc_90"] if p.get("gcc_90") is not None else -1, )["date"] s2_peak = max(s2, key=lambda p: p["gcc"] if p.get("gcc") is not None else -1)[ "date" ] result["peak_offset_days"] = int(_day_gap(pc_peak, s2_peak)) return result # --------------------------------------------------------------------------- # Line B — LOOCV # --------------------------------------------------------------------------- def _phenocam_lookup(phenocam: list[dict]) -> dict[str, float]: return {p["date"]: p["gcc_90"] for p in phenocam if p.get("gcc_90") is not None} def _nearest_phenocam(iso_date: str, lookup: dict[str, float]) -> float | None: if not lookup: return None dates = sorted(lookup) nearest = min(dates, key=lambda d: _day_gap(iso_date, d)) if _day_gap(iso_date, nearest) <= MATCH_TOLERANCE_DAYS: return lookup[nearest] return None def run_loocv_site( site: str, year: int, lat: float, lon: float, s2_series: list[dict], phenocam: list[dict], efast, ) -> list[dict[str, Any]]: """Leave-one-out EFAST for each cloud-free S2 date; return per-date records.""" s2_dir = DATA_DIR / "sentinel_data" / str(year) / site / "prepared" / "s2" gcc_s3_dir = DATA_DIR / "sentinel_data" / str(year) / site / "prepared" / "gcc_s3" s3_rgb_dir = DATA_DIR / "sentinel_data" / str(year) / site / "prepared" / "s3_rgb" pc_lookup = _phenocam_lookup(phenocam) s2_truth = {p["date"]: p["gcc"] for p in s2_series} fusion_kwargs = dict( ratio=RESOLUTION_RATIO, max_days=MAX_DAYS, minimum_acquisition_importance=MINIMUM_ACQUISITION_IMPORTANCE, ) records: list[dict[str, Any]] = [] dates = [p["date"] for p in s2_series] with tempfile.TemporaryDirectory(prefix=f"loocv_{site}_") as tmp_root: tmp = Path(tmp_root) hr_dir = tmp / "hr" itb_out = tmp / "itb" bti_out = tmp / "bti" bti_gcc = tmp / "bti_gcc" itb_out.mkdir() bti_out.mkdir() bti_gcc.mkdir() for iso_date in tqdm(dates, desc=f"{site} LOOCV", leave=False): yyyymmdd = _iso_to_yyyymmdd(iso_date) truth = s2_truth.get(iso_date) if truth is None: continue if hr_dir.exists(): shutil.rmtree(hr_dir) _build_hr_symlink_dir(s2_dir, yyyymmdd, hr_dir) pred_date = datetime.strptime(yyyymmdd, "%Y%m%d") for f in itb_out.glob("*.tif"): f.unlink() for f in bti_out.glob("*.tif"): f.unlink() for f in bti_gcc.glob("*.tif"): f.unlink() efast.fusion( pred_date, gcc_s3_dir, hr_dir, itb_out, product="GCC", **fusion_kwargs, ) efast.fusion( pred_date, s3_rgb_dir, hr_dir, bti_out, product="REFL", **fusion_kwargs, ) itb_path = itb_out / f"GCC_{yyyymmdd}.tif" refl_path = bti_out / f"REFL_{yyyymmdd}.tif" bti_path = bti_gcc / f"GCC_{yyyymmdd}.tif" pred_itb = ( _read_center_pixel(itb_path, lat, lon) if itb_path.is_file() else None ) pred_bti = None if refl_path.is_file(): _gcc_from_refl_file(refl_path, bti_path) if bti_path.is_file(): pred_bti = _read_center_pixel(bti_path, lat, lon) pc_val = _nearest_phenocam(iso_date, pc_lookup) records.append( { "date": iso_date, "s2_truth": truth, "pred_bti": pred_bti, "pred_itb": pred_itb, "phenocam": pc_val, } ) return records def _method_accuracy(records: list[dict], pred_key: str, ref_key: str) -> dict | None: obs, sim = [], [] for rec in records: pred = rec.get(pred_key) ref = rec.get(ref_key) if pred is None or ref is None: continue obs.append(ref) sim.append(pred) return _accuracy_metrics(obs, sim) def _winner(rmse_bti: float | None, rmse_itb: float | None) -> str | None: if rmse_bti is None or rmse_itb is None: return None if rmse_bti < rmse_itb: return "bti" if rmse_itb < rmse_bti: return "itb" return "tie" def summarize_loocv(records: list[dict]) -> dict[str, Any]: bti_vs_s2 = _method_accuracy(records, "pred_bti", "s2_truth") itb_vs_s2 = _method_accuracy(records, "pred_itb", "s2_truth") bti_vs_pc = _method_accuracy(records, "pred_bti", "phenocam") itb_vs_pc = _method_accuracy(records, "pred_itb", "phenocam") winner_s2 = _winner( bti_vs_s2["rmse"] if bti_vs_s2 else None, itb_vs_s2["rmse"] if itb_vs_s2 else None, ) winner_pc = _winner( bti_vs_pc["rmse"] if bti_vs_pc else None, itb_vs_pc["rmse"] if itb_vs_pc else None, ) agreement = ( winner_s2 == winner_pc if winner_s2 and winner_pc and winner_s2 != "tie" and winner_pc != "tie" else None ) return { "n_dates": len(records), "bti": {"vs_s2": bti_vs_s2, "vs_phenocam": bti_vs_pc}, "itb": {"vs_s2": itb_vs_s2, "vs_phenocam": itb_vs_pc}, "winner_s2": winner_s2, "winner_phenocam": winner_pc, "winner_agreement": agreement, } # --------------------------------------------------------------------------- # Aggregate concordance # --------------------------------------------------------------------------- def _pooled_concordance( all_records: list[dict[str, Any]], ) -> dict[str, Any]: """Pooled metrics across all held-out dates.""" residual_pairs: list[tuple[float, float]] = [] vec_s2: list[float] = [] vec_pc: list[float] = [] for site_data in all_records: for rec in site_data.get("records", []): truth = rec.get("s2_truth") pc = rec.get("phenocam") for key in ("pred_bti", "pred_itb"): pred = rec.get(key) if pred is None or truth is None: continue err_s2 = abs(pred - truth) if pc is not None: err_pc = abs(pred - pc) vec_s2.append(err_s2) vec_pc.append(err_pc) residual_pairs.append((err_s2, err_pc)) pooled_spearman = None if len(vec_s2) >= 3: sp, _ = spearmanr(vec_s2, vec_pc) if not np.isnan(sp): pooled_spearman = _r4(float(sp)) residual_corr = None if len(residual_pairs) >= 3: xs = np.array([p[0] for p in residual_pairs]) ys = np.array([p[1] for p in residual_pairs]) rc, _ = pearsonr(xs, ys) residual_corr = _r4(float(rc)) agreements = [ s.get("winner_agreement") for s in all_records if s.get("eligible") and s.get("winner_agreement") is not None ] winner_agreement_rate = ( _r4(sum(1 for a in agreements if a) / len(agreements)) if agreements else None ) n_loocv_dates = sum(len(s.get("records", [])) for s in all_records) return { "pooled_spearman": pooled_spearman, "residual_corr": residual_corr, "winner_agreement_rate": winner_agreement_rate, "n_loocv_dates": n_loocv_dates, } def _suitability_verdict( n_repr_pass: int, n_eligible: int, n_total: int, pooled: dict[str, Any], ) -> str: if n_eligible == 0: return "insufficient data" repr_rate = n_repr_pass / n_total if n_total else 0 agree = pooled.get("winner_agreement_rate") sp = pooled.get("pooled_spearman") rc = pooled.get("residual_corr") strong = 0 if repr_rate >= 0.6: strong += 1 if agree is not None and agree >= 0.7: strong += 1 if sp is not None and sp >= 0.8: strong += 1 if rc is not None and rc >= 0.5: strong += 1 if strong >= 3: return "suitable" if strong >= 1 or repr_rate >= 0.4: return "partially suitable" return "not suitable" # --------------------------------------------------------------------------- # Per-site processing # --------------------------------------------------------------------------- def process_site( site: str, year: int, lat: float, lon: float, min_cloudfree: int, efast, ) -> dict[str, Any]: metrics_dir = DATA_DIR / "metrics" / str(year) / site phenocam = _load_json_series(metrics_dir / "gcc_phenocam.json") s2_series = _load_json_series(metrics_dir / "gcc_s2.json") repr_metrics = compute_representativeness(phenocam, s2_series) n_cloudfree = len(s2_series) eligible = n_cloudfree >= min_cloudfree result: dict[str, Any] = { "eligible": eligible, "n_cloudfree_s2": n_cloudfree, "representativeness": repr_metrics, "loocv": None, "winner_s2": None, "winner_phenocam": None, "winner_agreement": None, "records": [], } if not eligible: return result records = run_loocv_site(site, year, lat, lon, s2_series, phenocam, efast) loocv = summarize_loocv(records) result["loocv"] = loocv result["winner_s2"] = loocv["winner_s2"] result["winner_phenocam"] = loocv["winner_phenocam"] result["winner_agreement"] = loocv["winner_agreement"] result["records"] = records return result # --------------------------------------------------------------------------- # Output / summary # --------------------------------------------------------------------------- def _compact_site_payload(site_result: dict[str, Any]) -> dict[str, Any]: """Drop raw LOOCV records from JSON output (keep summaries only).""" out = { "eligible": site_result["eligible"], "n_cloudfree_s2": site_result["n_cloudfree_s2"], "representativeness": site_result["representativeness"], "winner_s2": site_result.get("winner_s2"), "winner_phenocam": site_result.get("winner_phenocam"), "winner_agreement": site_result.get("winner_agreement"), } if site_result.get("loocv"): out["loocv"] = site_result["loocv"] return out def _print_summary(payload: dict[str, Any]) -> None: year = payload["year"] agg = payload["aggregate"] print( f"\nPhenoCam GCC suitability — {year} " f"({payload['n_sites_total']} site(s), " f"{payload['n_sites_eligible']} LOOCV-eligible, " f"{payload['n_sites_repr_pass']} representative)" ) print(f"Verdict: {agg['suitability_verdict']}") print( f" pooled Spearman (method errors): {agg.get('pooled_spearman', '—')} " f"residual corr: {agg.get('residual_corr', '—')} " f"winner agreement: {agg.get('winner_agreement_rate', '—')} " f"LOOCV dates: {agg.get('n_loocv_dates', '—')}" ) print(f"\n{'site':<28} {'repr r':>8} {'pass':>5} {'LOOCV n':>8} {'win agree':>10}") print("-" * 65) for site, data in sorted(payload["sites"].items()): rep = data["representativeness"] loocv_n = data.get("loocv", {}).get("n_dates") if data.get("loocv") else "—" agree = data.get("winner_agreement") agree_s = "yes" if agree else ("no" if agree is False else "—") pass_s = "yes" if rep.get("representative") else "no" print( f"{site:<28} {rep.get('r') or '—':>8} {pass_s:>5} " f"{loocv_n!s:>8} {agree_s:>10}" ) if payload["n_sites_total"] < SMALL_SAMPLE_SITES: print( f"\nNote: only {payload['n_sites_total']} site(s); " "interpret cross-site aggregates cautiously." ) if payload.get("dropped_sites"): print(f"Dropped/ineligible: {', '.join(payload['dropped_sites'])}") # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--evaluation-year", type=int, default=DEFAULT_YEAR) parser.add_argument( "--min-cloudfree-s2", type=int, default=MIN_CLOUDFREE_S2, help="Minimum cloud-free S2 dates for LOOCV (default 10)", ) parser.add_argument( "--alpha", type=float, default=DEFAULT_ALPHA, help="Significance threshold (reserved; default 0.05)", ) args = parser.parse_args() year = args.evaluation_year min_cloudfree = args.min_cloudfree_s2 sites = _discover_sites(year) if not sites: raise SystemExit( f"No Step 5 metrics found under {DATA_DIR / 'metrics' / str(year)}" ) coords = _load_site_coords(year) efast = _import_efast() site_results: dict[str, dict[str, Any]] = {} dropped: list[str] = [] for site in tqdm(sites, desc="Sites"): if site not in coords: dropped.append(site) continue lat, lon = coords[site] site_results[site] = process_site(site, year, lat, lon, min_cloudfree, efast) if not site_results[site]["eligible"]: dropped.append(site) n_eligible = sum(1 for s in site_results.values() if s["eligible"]) n_repr_pass = sum( 1 for s in site_results.values() if s["representativeness"].get("representative") ) pooled = _pooled_concordance(list(site_results.values())) verdict = _suitability_verdict(n_repr_pass, n_eligible, len(sites), pooled) payload = { "year": year, "alpha": args.alpha, "repr_r_threshold": REPR_R_THRESHOLD, "min_cloudfree_s2": min_cloudfree, "n_sites_total": len(sites), "n_sites_eligible": n_eligible, "n_sites_repr_pass": n_repr_pass, "aggregate": { "suitability_verdict": verdict, **pooled, }, "sites": { site: _compact_site_payload(data) for site, data in sorted(site_results.items()) }, "dropped_sites": sorted(set(dropped)), } out_dir = DATA_DIR / "gcc_suitability" out_dir.mkdir(parents=True, exist_ok=True) out_path = out_dir / f"{year}.json" out_path.write_text(json.dumps(payload, separators=(",", ":"))) _print_summary(payload) print(f"\nWritten → {out_path}") if __name__ == "__main__": main()