efast-phenocam-validation/7-gcc-suitability.py
2026-06-17 12:29:35 +02:00

742 lines
23 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Step 7: PhenoCam GCC suitability as a fusion-accuracy reference.
Inputs (``data/``, ``{year}`` = ``--evaluation-year``):
- ``metrics/{year}/{site}/gcc_s2.json``, ``gcc_phenocam.json`` — Step 5 timeseries
- ``metrics/manifest.json`` — site lat/lon
- ``sentinel_data/{year}/{site}/prepared/s2/`` — S2 REFL/GCC/DIST_CLOUD (Steps 34)
- ``sentinel_data/{year}/{site}/prepared/gcc_s3/``, ``prepared/s3_rgb/`` — Step 4
Outputs (``data/gcc_suitability/``):
- ``{year}.json`` — representativeness (Line A), LOOCV concordance (Line B),
per-site and aggregate suitability verdict
CLI:
- ``--evaluation-year`` (default 2025)
- ``--min-cloudfree-s2`` (default 10) — minimum cloud-free S2 dates for LOOCV
- ``--alpha`` (default 0.05) — reserved for future significance tests
Full-sample aggregate; does not accept ``--site``.
"""
from __future__ import annotations
import argparse
import json
import re
import shutil
import tempfile
from datetime import datetime
from pathlib import Path
from typing import Any
import numpy as np
import rasterio
from pyproj import Transformer
from rasterio.crs import CRS
from rasterio.transform import rowcol
from scipy.stats import linregress, pearsonr, spearmanr
from tqdm import tqdm
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
DATA_DIR = Path("data")
DEFAULT_YEAR = 2025
DEFAULT_ALPHA = 0.05
MIN_CLOUDFREE_S2 = 10
REPR_R_THRESHOLD = 0.7
MATCH_TOLERANCE_DAYS = 5
RESOLUTION_RATIO = 30
MAX_DAYS = 100
MINIMUM_ACQUISITION_IMPORTANCE = 0
SMALL_SAMPLE_SITES = 6
# ---------------------------------------------------------------------------
# efast import
# ---------------------------------------------------------------------------
def _import_efast():
try:
import efast.efast as efast_module
return efast_module
except ImportError as exc:
raise ImportError("efast not found. Install with: uv sync") from exc
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _r4(v: float | None) -> float | None:
return round(v, 4) if v is not None else None
def _window_mean(data: np.ndarray) -> float | None:
valid = data[~np.isnan(data)]
if valid.size == 0:
return None
return float(np.mean(valid))
def _read_center_pixel(path: Path, lat: float, lon: float) -> float | None:
try:
with rasterio.open(path) as src:
transformer = Transformer.from_crs(
CRS.from_epsg(4326), src.crs, always_xy=True
)
x, y = transformer.transform(lon, lat)
row, col = rowcol(src.transform, x, y)
h, w = src.height, src.width
r0, r1 = max(0, row - 1), min(h, row + 2)
c0, c1 = max(0, col - 1), min(w, col + 2)
window = rasterio.windows.Window(c0, r0, c1 - c0, r1 - r0)
data = src.read(1, window=window).astype(float)
nodata = src.nodata
if nodata is not None:
data = np.where(data == nodata, np.nan, data)
data[data == 0] = np.nan
return _window_mean(data)
except Exception:
return None
def _date_from_s2_tif(path: Path) -> str | None:
parts = path.stem.split("_")
if len(parts) >= 3:
m = re.match(r"(\d{8})", parts[2])
return m.group(1) if m else None
return None
def _iso_to_yyyymmdd(iso: str) -> str:
return iso.replace("-", "")
def _yyyymmdd_to_iso(d: str) -> str:
return f"{d[:4]}-{d[4:6]}-{d[6:]}"
def _day_gap(a: str, b: str) -> float:
return abs((np.datetime64(a) - np.datetime64(b)) / np.timedelta64(1, "D"))
def _match_series(
ref: list[dict],
ref_key: str,
pred: list[dict],
pred_key: str,
tolerance_days: int = MATCH_TOLERANCE_DAYS,
) -> tuple[list[float], list[float], list[str]]:
"""Return paired (ref_vals, pred_vals, ref_dates) within tolerance."""
ref_lookup: dict[str, float] = {
p["date"]: p[ref_key] for p in ref if p.get(ref_key) is not None
}
if not ref_lookup:
return [], [], []
ref_dates = sorted(ref_lookup)
obs, sim, matched_dates = [], [], []
for pt in pred:
v = pt.get(pred_key)
if v is None:
continue
nearest = min(ref_dates, key=lambda d: _day_gap(pt["date"], d))
if _day_gap(pt["date"], nearest) <= tolerance_days:
obs.append(ref_lookup[nearest])
sim.append(v)
matched_dates.append(nearest)
return obs, sim, matched_dates
def _accuracy_metrics(obs: list[float], sim: list[float]) -> dict[str, Any] | None:
if len(obs) < 2:
return None
obs_arr = np.array(obs, dtype=float)
sim_arr = np.array(sim, dtype=float)
diff = sim_arr - obs_arr
rmse = float(np.sqrt(np.mean(diff**2)))
mae = float(np.mean(np.abs(diff)))
bias = float(np.mean(diff))
r, _ = pearsonr(obs_arr, sim_arr)
return {
"n": len(obs),
"rmse": _r4(rmse),
"mae": _r4(mae),
"bias": _r4(bias),
"r": _r4(float(r)),
}
def _gcc_from_refl_file(refl_path: Path, gcc_path: Path) -> None:
with rasterio.open(refl_path) as src:
b, g, r = src.read(1), src.read(2), src.read(3)
profile = src.profile
total = b + g + r
invalid = (b < 0) | (g < 0) | (r < 0)
gcc = np.where(invalid, np.nan, g / (total + 1e-10))
gcc[total == 0] = np.nan
profile.update(count=1, dtype="float32")
with rasterio.open(gcc_path, "w", **profile) as dst:
dst.write(gcc[np.newaxis].astype("float32"))
def _load_json_series(path: Path) -> list[dict]:
if not path.is_file():
return []
return json.loads(path.read_text())
def _load_site_coords(year: int) -> dict[str, tuple[float, float]]:
manifest_path = DATA_DIR / "metrics" / "manifest.json"
if not manifest_path.is_file():
return {}
manifest = json.loads(manifest_path.read_text())
sites = manifest.get("sites", {}).get(str(year), {})
coords: dict[str, tuple[float, float]] = {}
for site, meta in sites.items():
lat, lon = meta.get("lat"), meta.get("lon")
if lat is not None and lon is not None:
coords[site] = (float(lat), float(lon))
return coords
def _discover_sites(year: int) -> list[str]:
metrics_dir = DATA_DIR / "metrics" / str(year)
if not metrics_dir.is_dir():
return []
return sorted(
d.name
for d in metrics_dir.iterdir()
if d.is_dir() and (d / "gcc_s2.json").is_file()
)
def _build_hr_symlink_dir(s2_dir: Path, holdout_yyyymmdd: str, dest: Path) -> None:
"""Symlink all S2 hr inputs except the held-out acquisition date."""
dest.mkdir(parents=True, exist_ok=True)
for pattern in ("*_REFL.tif", "*_GCC.tif", "*_DIST_CLOUD.tif"):
for src in sorted(s2_dir.glob(pattern)):
date_token = (
src.stem.split("_")[2][:8] if len(src.stem.split("_")) >= 3 else ""
)
if date_token == holdout_yyyymmdd:
continue
link = dest / src.name
if link.exists() or link.is_symlink():
link.unlink()
link.symlink_to(src.resolve())
# ---------------------------------------------------------------------------
# Line A — representativeness
# ---------------------------------------------------------------------------
def compute_representativeness(phenocam: list[dict], s2: list[dict]) -> dict[str, Any]:
"""PhenoCam gcc_90 vs co-located observed S2 GCC."""
obs, sim, _ = _match_series(phenocam, "gcc_90", s2, "gcc")
result: dict[str, Any] = {
"n": len(obs),
"r": None,
"spearman": None,
"slope": None,
"intercept": None,
"rmse": None,
"bias": None,
"peak_offset_days": None,
"representative": False,
}
if len(obs) < 2:
return result
obs_arr = np.array(obs, dtype=float)
sim_arr = np.array(sim, dtype=float)
r, _ = pearsonr(obs_arr, sim_arr)
sp, _ = spearmanr(obs_arr, sim_arr)
reg = linregress(sim_arr, obs_arr)
diff = sim_arr - obs_arr
result.update(
{
"r": _r4(float(r)),
"spearman": _r4(float(sp)),
"slope": _r4(float(reg.slope)),
"intercept": _r4(float(reg.intercept)),
"rmse": _r4(float(np.sqrt(np.mean(diff**2)))),
"bias": _r4(float(np.mean(diff))),
"representative": float(r) >= REPR_R_THRESHOLD,
}
)
pc_dates = [p["date"] for p in phenocam if p.get("gcc_90") is not None]
s2_dates = [p["date"] for p in s2 if p.get("gcc") is not None]
if pc_dates and s2_dates:
pc_peak = max(
phenocam,
key=lambda p: p["gcc_90"] if p.get("gcc_90") is not None else -1,
)["date"]
s2_peak = max(s2, key=lambda p: p["gcc"] if p.get("gcc") is not None else -1)[
"date"
]
result["peak_offset_days"] = int(_day_gap(pc_peak, s2_peak))
return result
# ---------------------------------------------------------------------------
# Line B — LOOCV
# ---------------------------------------------------------------------------
def _phenocam_lookup(phenocam: list[dict]) -> dict[str, float]:
return {p["date"]: p["gcc_90"] for p in phenocam if p.get("gcc_90") is not None}
def _nearest_phenocam(iso_date: str, lookup: dict[str, float]) -> float | None:
if not lookup:
return None
dates = sorted(lookup)
nearest = min(dates, key=lambda d: _day_gap(iso_date, d))
if _day_gap(iso_date, nearest) <= MATCH_TOLERANCE_DAYS:
return lookup[nearest]
return None
def run_loocv_site(
site: str,
year: int,
lat: float,
lon: float,
s2_series: list[dict],
phenocam: list[dict],
efast,
) -> list[dict[str, Any]]:
"""Leave-one-out EFAST for each cloud-free S2 date; return per-date records."""
s2_dir = DATA_DIR / "sentinel_data" / str(year) / site / "prepared" / "s2"
gcc_s3_dir = DATA_DIR / "sentinel_data" / str(year) / site / "prepared" / "gcc_s3"
s3_rgb_dir = DATA_DIR / "sentinel_data" / str(year) / site / "prepared" / "s3_rgb"
pc_lookup = _phenocam_lookup(phenocam)
s2_truth = {p["date"]: p["gcc"] for p in s2_series}
fusion_kwargs = dict(
ratio=RESOLUTION_RATIO,
max_days=MAX_DAYS,
minimum_acquisition_importance=MINIMUM_ACQUISITION_IMPORTANCE,
)
records: list[dict[str, Any]] = []
dates = [p["date"] for p in s2_series]
with tempfile.TemporaryDirectory(prefix=f"loocv_{site}_") as tmp_root:
tmp = Path(tmp_root)
hr_dir = tmp / "hr"
itb_out = tmp / "itb"
bti_out = tmp / "bti"
bti_gcc = tmp / "bti_gcc"
itb_out.mkdir()
bti_out.mkdir()
bti_gcc.mkdir()
for iso_date in tqdm(dates, desc=f"{site} LOOCV", leave=False):
yyyymmdd = _iso_to_yyyymmdd(iso_date)
truth = s2_truth.get(iso_date)
if truth is None:
continue
if hr_dir.exists():
shutil.rmtree(hr_dir)
_build_hr_symlink_dir(s2_dir, yyyymmdd, hr_dir)
pred_date = datetime.strptime(yyyymmdd, "%Y%m%d")
for f in itb_out.glob("*.tif"):
f.unlink()
for f in bti_out.glob("*.tif"):
f.unlink()
for f in bti_gcc.glob("*.tif"):
f.unlink()
efast.fusion(
pred_date,
gcc_s3_dir,
hr_dir,
itb_out,
product="GCC",
**fusion_kwargs,
)
efast.fusion(
pred_date,
s3_rgb_dir,
hr_dir,
bti_out,
product="REFL",
**fusion_kwargs,
)
itb_path = itb_out / f"GCC_{yyyymmdd}.tif"
refl_path = bti_out / f"REFL_{yyyymmdd}.tif"
bti_path = bti_gcc / f"GCC_{yyyymmdd}.tif"
pred_itb = (
_read_center_pixel(itb_path, lat, lon) if itb_path.is_file() else None
)
pred_bti = None
if refl_path.is_file():
_gcc_from_refl_file(refl_path, bti_path)
if bti_path.is_file():
pred_bti = _read_center_pixel(bti_path, lat, lon)
pc_val = _nearest_phenocam(iso_date, pc_lookup)
records.append(
{
"date": iso_date,
"s2_truth": truth,
"pred_bti": pred_bti,
"pred_itb": pred_itb,
"phenocam": pc_val,
}
)
return records
def _method_accuracy(records: list[dict], pred_key: str, ref_key: str) -> dict | None:
obs, sim = [], []
for rec in records:
pred = rec.get(pred_key)
ref = rec.get(ref_key)
if pred is None or ref is None:
continue
obs.append(ref)
sim.append(pred)
return _accuracy_metrics(obs, sim)
def _winner(rmse_bti: float | None, rmse_itb: float | None) -> str | None:
if rmse_bti is None or rmse_itb is None:
return None
if rmse_bti < rmse_itb:
return "bti"
if rmse_itb < rmse_bti:
return "itb"
return "tie"
def summarize_loocv(records: list[dict]) -> dict[str, Any]:
bti_vs_s2 = _method_accuracy(records, "pred_bti", "s2_truth")
itb_vs_s2 = _method_accuracy(records, "pred_itb", "s2_truth")
bti_vs_pc = _method_accuracy(records, "pred_bti", "phenocam")
itb_vs_pc = _method_accuracy(records, "pred_itb", "phenocam")
winner_s2 = _winner(
bti_vs_s2["rmse"] if bti_vs_s2 else None,
itb_vs_s2["rmse"] if itb_vs_s2 else None,
)
winner_pc = _winner(
bti_vs_pc["rmse"] if bti_vs_pc else None,
itb_vs_pc["rmse"] if itb_vs_pc else None,
)
agreement = (
winner_s2 == winner_pc
if winner_s2 and winner_pc and winner_s2 != "tie" and winner_pc != "tie"
else None
)
return {
"n_dates": len(records),
"bti": {"vs_s2": bti_vs_s2, "vs_phenocam": bti_vs_pc},
"itb": {"vs_s2": itb_vs_s2, "vs_phenocam": itb_vs_pc},
"winner_s2": winner_s2,
"winner_phenocam": winner_pc,
"winner_agreement": agreement,
}
# ---------------------------------------------------------------------------
# Aggregate concordance
# ---------------------------------------------------------------------------
def _pooled_concordance(
all_records: list[dict[str, Any]],
) -> dict[str, Any]:
"""Pooled metrics across all held-out dates."""
residual_pairs: list[tuple[float, float]] = []
vec_s2: list[float] = []
vec_pc: list[float] = []
for site_data in all_records:
for rec in site_data.get("records", []):
truth = rec.get("s2_truth")
pc = rec.get("phenocam")
for key in ("pred_bti", "pred_itb"):
pred = rec.get(key)
if pred is None or truth is None:
continue
err_s2 = abs(pred - truth)
if pc is not None:
err_pc = abs(pred - pc)
vec_s2.append(err_s2)
vec_pc.append(err_pc)
residual_pairs.append((err_s2, err_pc))
pooled_spearman = None
if len(vec_s2) >= 3:
sp, _ = spearmanr(vec_s2, vec_pc)
if not np.isnan(sp):
pooled_spearman = _r4(float(sp))
residual_corr = None
if len(residual_pairs) >= 3:
xs = np.array([p[0] for p in residual_pairs])
ys = np.array([p[1] for p in residual_pairs])
rc, _ = pearsonr(xs, ys)
residual_corr = _r4(float(rc))
agreements = [
s.get("winner_agreement")
for s in all_records
if s.get("eligible") and s.get("winner_agreement") is not None
]
winner_agreement_rate = (
_r4(sum(1 for a in agreements if a) / len(agreements)) if agreements else None
)
n_loocv_dates = sum(len(s.get("records", [])) for s in all_records)
return {
"pooled_spearman": pooled_spearman,
"residual_corr": residual_corr,
"winner_agreement_rate": winner_agreement_rate,
"n_loocv_dates": n_loocv_dates,
}
def _suitability_verdict(
n_repr_pass: int,
n_eligible: int,
n_total: int,
pooled: dict[str, Any],
) -> str:
if n_eligible == 0:
return "insufficient data"
repr_rate = n_repr_pass / n_total if n_total else 0
agree = pooled.get("winner_agreement_rate")
sp = pooled.get("pooled_spearman")
rc = pooled.get("residual_corr")
strong = 0
if repr_rate >= 0.6:
strong += 1
if agree is not None and agree >= 0.7:
strong += 1
if sp is not None and sp >= 0.8:
strong += 1
if rc is not None and rc >= 0.5:
strong += 1
if strong >= 3:
return "suitable"
if strong >= 1 or repr_rate >= 0.4:
return "partially suitable"
return "not suitable"
# ---------------------------------------------------------------------------
# Per-site processing
# ---------------------------------------------------------------------------
def process_site(
site: str,
year: int,
lat: float,
lon: float,
min_cloudfree: int,
efast,
) -> dict[str, Any]:
metrics_dir = DATA_DIR / "metrics" / str(year) / site
phenocam = _load_json_series(metrics_dir / "gcc_phenocam.json")
s2_series = _load_json_series(metrics_dir / "gcc_s2.json")
repr_metrics = compute_representativeness(phenocam, s2_series)
n_cloudfree = len(s2_series)
eligible = n_cloudfree >= min_cloudfree
result: dict[str, Any] = {
"eligible": eligible,
"n_cloudfree_s2": n_cloudfree,
"representativeness": repr_metrics,
"loocv": None,
"winner_s2": None,
"winner_phenocam": None,
"winner_agreement": None,
"records": [],
}
if not eligible:
return result
records = run_loocv_site(site, year, lat, lon, s2_series, phenocam, efast)
loocv = summarize_loocv(records)
result["loocv"] = loocv
result["winner_s2"] = loocv["winner_s2"]
result["winner_phenocam"] = loocv["winner_phenocam"]
result["winner_agreement"] = loocv["winner_agreement"]
result["records"] = records
return result
# ---------------------------------------------------------------------------
# Output / summary
# ---------------------------------------------------------------------------
def _compact_site_payload(site_result: dict[str, Any]) -> dict[str, Any]:
"""Drop raw LOOCV records from JSON output (keep summaries only)."""
out = {
"eligible": site_result["eligible"],
"n_cloudfree_s2": site_result["n_cloudfree_s2"],
"representativeness": site_result["representativeness"],
"winner_s2": site_result.get("winner_s2"),
"winner_phenocam": site_result.get("winner_phenocam"),
"winner_agreement": site_result.get("winner_agreement"),
}
if site_result.get("loocv"):
out["loocv"] = site_result["loocv"]
return out
def _print_summary(payload: dict[str, Any]) -> None:
year = payload["year"]
agg = payload["aggregate"]
print(
f"\nPhenoCam GCC suitability — {year} "
f"({payload['n_sites_total']} site(s), "
f"{payload['n_sites_eligible']} LOOCV-eligible, "
f"{payload['n_sites_repr_pass']} representative)"
)
print(f"Verdict: {agg['suitability_verdict']}")
print(
f" pooled Spearman (method errors): {agg.get('pooled_spearman', '')} "
f"residual corr: {agg.get('residual_corr', '')} "
f"winner agreement: {agg.get('winner_agreement_rate', '')} "
f"LOOCV dates: {agg.get('n_loocv_dates', '')}"
)
print(f"\n{'site':<28} {'repr r':>8} {'pass':>5} {'LOOCV n':>8} {'win agree':>10}")
print("-" * 65)
for site, data in sorted(payload["sites"].items()):
rep = data["representativeness"]
loocv_n = data.get("loocv", {}).get("n_dates") if data.get("loocv") else ""
agree = data.get("winner_agreement")
agree_s = "yes" if agree else ("no" if agree is False else "")
pass_s = "yes" if rep.get("representative") else "no"
print(
f"{site:<28} {rep.get('r') or '':>8} {pass_s:>5} "
f"{loocv_n!s:>8} {agree_s:>10}"
)
if payload["n_sites_total"] < SMALL_SAMPLE_SITES:
print(
f"\nNote: only {payload['n_sites_total']} site(s); "
"interpret cross-site aggregates cautiously."
)
if payload.get("dropped_sites"):
print(f"Dropped/ineligible: {', '.join(payload['dropped_sites'])}")
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--evaluation-year", type=int, default=DEFAULT_YEAR)
parser.add_argument(
"--min-cloudfree-s2",
type=int,
default=MIN_CLOUDFREE_S2,
help="Minimum cloud-free S2 dates for LOOCV (default 10)",
)
parser.add_argument(
"--alpha",
type=float,
default=DEFAULT_ALPHA,
help="Significance threshold (reserved; default 0.05)",
)
args = parser.parse_args()
year = args.evaluation_year
min_cloudfree = args.min_cloudfree_s2
sites = _discover_sites(year)
if not sites:
raise SystemExit(
f"No Step 5 metrics found under {DATA_DIR / 'metrics' / str(year)}"
)
coords = _load_site_coords(year)
efast = _import_efast()
site_results: dict[str, dict[str, Any]] = {}
dropped: list[str] = []
for site in tqdm(sites, desc="Sites"):
if site not in coords:
dropped.append(site)
continue
lat, lon = coords[site]
site_results[site] = process_site(site, year, lat, lon, min_cloudfree, efast)
if not site_results[site]["eligible"]:
dropped.append(site)
n_eligible = sum(1 for s in site_results.values() if s["eligible"])
n_repr_pass = sum(
1
for s in site_results.values()
if s["representativeness"].get("representative")
)
pooled = _pooled_concordance(list(site_results.values()))
verdict = _suitability_verdict(n_repr_pass, n_eligible, len(sites), pooled)
payload = {
"year": year,
"alpha": args.alpha,
"repr_r_threshold": REPR_R_THRESHOLD,
"min_cloudfree_s2": min_cloudfree,
"n_sites_total": len(sites),
"n_sites_eligible": n_eligible,
"n_sites_repr_pass": n_repr_pass,
"aggregate": {
"suitability_verdict": verdict,
**pooled,
},
"sites": {
site: _compact_site_payload(data)
for site, data in sorted(site_results.items())
},
"dropped_sites": sorted(set(dropped)),
}
out_dir = DATA_DIR / "gcc_suitability"
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / f"{year}.json"
out_path.write_text(json.dumps(payload, separators=(",", ":")))
_print_summary(payload)
print(f"\nWritten → {out_path}")
if __name__ == "__main__":
main()