293 lines
9.9 KiB
Python
293 lines
9.9 KiB
Python
"""Full-season gap-degraded fusion → temporal NSE_PC vs PhenoCam (tier after spatial validation)."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import re
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
from metrics_indices import _get_gcc_from_original
|
||
from metrics_stats import (
|
||
WHITTAKER_LAMBDA_DAYS_SQ,
|
||
_norm_date_key,
|
||
_s2_gcc_series_from_preselection,
|
||
_whittaker_smooth_dict,
|
||
calculate_temporal_metrics,
|
||
load_timeseries,
|
||
)
|
||
|
||
from gap_validation.calendar import TRANSITIONS, load_manifest, validation_dir, write_manifest
|
||
from gap_validation.fusion_masked import run_masked_fusion_season
|
||
from gap_validation.run import (
|
||
_filter_entries,
|
||
_scenario_key,
|
||
_withheld_iso,
|
||
_yyyymmdd_from_withheld_filename,
|
||
)
|
||
from gap_validation.whittaker_compare import first_gap_where_fusion_below_whittaker
|
||
|
||
|
||
def _fusion_gcc_timeseries(
|
||
fusion_dir: Path, site_position: tuple[float, float], mode: str
|
||
) -> dict[str, float]:
|
||
"""3×3 mean GCC at site from fused REFL/GCC rasters in ``fusion_dir``."""
|
||
pattern = "REFL_*.tif" if mode == "bti" else "GCC_*.tif"
|
||
out: dict[str, float] = {}
|
||
for p in sorted(fusion_dir.glob(pattern)):
|
||
m = re.search(r"_(\d{8})\.tif$", p.name)
|
||
if not m:
|
||
continue
|
||
d = datetime.strptime(m.group(1), "%Y%m%d").date().isoformat()
|
||
gcc = _get_gcc_from_original(p, site_position)
|
||
if gcc is not None:
|
||
out[d] = float(gcc)
|
||
return out
|
||
|
||
|
||
def whittaker_timeseries_gap_degraded(
|
||
base: Path,
|
||
strategy: str,
|
||
window_start_iso: str,
|
||
window_end_iso: str,
|
||
withheld_iso: str,
|
||
lam: float = WHITTAKER_LAMBDA_DAYS_SQ,
|
||
) -> dict[str, float]:
|
||
"""Daily Whittaker GCC on S2 preselection with gap window + withheld day removed."""
|
||
all_gcc, flags = _s2_gcc_series_from_preselection(base)
|
||
if not all_gcc:
|
||
return {}
|
||
idx = 0 if strategy == "aggressive" else 1
|
||
w0 = datetime.strptime(window_start_iso[:10], "%Y-%m-%d").date()
|
||
w1 = datetime.strptime(window_end_iso[:10], "%Y-%m-%d").date()
|
||
wh_k = _norm_date_key(withheld_iso)
|
||
|
||
def in_window(dk: str) -> bool:
|
||
try:
|
||
d = datetime.strptime(dk[:10], "%Y-%m-%d").date()
|
||
except ValueError:
|
||
return False
|
||
return w0 <= d <= w1
|
||
|
||
kept = sorted(
|
||
(d, g)
|
||
for d, g in all_gcc.items()
|
||
if d in flags
|
||
and not flags[d][idx]
|
||
and _norm_date_key(d) != wh_k
|
||
and not in_window(_norm_date_key(d) or "")
|
||
)
|
||
if len(kept) < 2:
|
||
return {}
|
||
obs_d, obs_v = zip(*kept)
|
||
return _whittaker_smooth_dict(obs_d, obs_v, lam)
|
||
|
||
|
||
def run_temporal_pc(
|
||
site_name: str,
|
||
season: int,
|
||
site_position: tuple[float, float],
|
||
strategy: str,
|
||
sigma: int | None,
|
||
mode: str,
|
||
*,
|
||
skip_manifest: bool,
|
||
skip_fusion: bool,
|
||
gap_days_filter: list[int] | None,
|
||
transition_filter: list[str] | None,
|
||
s2_calendar_strategy: str,
|
||
) -> Path:
|
||
"""Run full-season gap fusion + NSE_PC; write ``gap_metrics.json``."""
|
||
base = Path(f"data/{site_name}/{season}")
|
||
vdir = validation_dir(site_name, season)
|
||
vdir.mkdir(parents=True, exist_ok=True)
|
||
|
||
if not skip_manifest:
|
||
write_manifest(
|
||
site_name,
|
||
season,
|
||
site_position,
|
||
s2_calendar_strategy=s2_calendar_strategy,
|
||
)
|
||
|
||
manifest = load_manifest(site_name, season)
|
||
entries = _filter_entries(manifest["entries"], gap_days_filter, transition_filter)
|
||
phenocam_ts_path = base / "raw" / "phenocam" / "phenocam_gcc.json"
|
||
phenocam_ts = load_timeseries(phenocam_ts_path)
|
||
|
||
nogap_metrics_path = base / "metrics.json"
|
||
nogap_nse: dict[str, float | None] = {}
|
||
if nogap_metrics_path.is_file():
|
||
m = json.loads(nogap_metrics_path.read_text(encoding="utf-8"))
|
||
sk = _scenario_key(strategy, sigma, mode)
|
||
block = (m.get("temporal") or {}).get(sk) or {}
|
||
nogap_nse["nse_pc"] = block.get("nse_pc")
|
||
|
||
results: list[dict] = []
|
||
crossover_rows: list[dict] = []
|
||
|
||
for entry in entries:
|
||
transition = entry.get("transition", "green_up")
|
||
gap_days = entry["gap_days"]
|
||
pred = entry["prediction_date"]
|
||
w0, w1 = entry["window_start"], entry["window_end"]
|
||
fn = entry.get("withheld_s2_filename")
|
||
if not fn:
|
||
results.append(
|
||
{"transition": transition, "gap_days": gap_days, "error": "no_withheld_s2"}
|
||
)
|
||
continue
|
||
wh_ymd = _yyyymmdd_from_withheld_filename(fn)
|
||
if not wh_ymd:
|
||
results.append(
|
||
{
|
||
"transition": transition,
|
||
"gap_days": gap_days,
|
||
"error": "bad_withheld_filename",
|
||
}
|
||
)
|
||
continue
|
||
withheld_iso = _withheld_iso(entry) or f"{wh_ymd[:4]}-{wh_ymd[4:6]}-{wh_ymd[6:8]}"
|
||
|
||
temporal_dir = (
|
||
vdir / "temporal" / f"gap_{gap_days}_{transition}" / _scenario_key(strategy, sigma, mode)
|
||
)
|
||
if not skip_fusion:
|
||
try:
|
||
run_masked_fusion_season(
|
||
season,
|
||
site_position,
|
||
site_name,
|
||
strategy,
|
||
sigma,
|
||
mode,
|
||
w0,
|
||
w1,
|
||
wh_ymd,
|
||
temporal_dir,
|
||
)
|
||
except RuntimeError as e:
|
||
results.append(
|
||
{
|
||
"transition": transition,
|
||
"gap_days": gap_days,
|
||
"error": str(e),
|
||
}
|
||
)
|
||
continue
|
||
fusion_ts = _fusion_gcc_timeseries(temporal_dir, site_position, mode)
|
||
else:
|
||
fusion_ts = _fusion_gcc_timeseries(temporal_dir, site_position, mode)
|
||
|
||
fused_metrics = calculate_temporal_metrics(fusion_ts, phenocam_ts)
|
||
wh_ts = whittaker_timeseries_gap_degraded(
|
||
base, strategy, w0, w1, withheld_iso
|
||
)
|
||
wh_metrics = calculate_temporal_metrics(wh_ts, phenocam_ts)
|
||
|
||
row: dict = {
|
||
"transition": transition,
|
||
"gap_days": gap_days,
|
||
"prediction_date": pred,
|
||
"window_start": w0,
|
||
"window_end": w1,
|
||
"withheld_s2_filename": fn,
|
||
"temporal": {
|
||
"fused": fused_metrics,
|
||
"whittaker": wh_metrics,
|
||
},
|
||
"fusion_dir": str(temporal_dir),
|
||
}
|
||
if fused_metrics and nogap_nse.get("nse_pc") is not None:
|
||
g_rmse = fused_metrics.get("rmse")
|
||
ng_rmse = None
|
||
if nogap_metrics_path.is_file():
|
||
sk = _scenario_key(strategy, sigma, mode)
|
||
ng_rmse = (
|
||
(json.loads(nogap_metrics_path.read_text()).get("temporal") or {})
|
||
.get(sk, {})
|
||
.get("rmse")
|
||
)
|
||
n_g = fused_metrics.get("nse_pc")
|
||
n_ng = nogap_nse["nse_pc"]
|
||
if g_rmse is not None and ng_rmse is not None:
|
||
row["delta_rmse"] = float(g_rmse - ng_rmse)
|
||
if n_g is not None and n_ng is not None:
|
||
row["delta_nse"] = float(n_ng - n_g)
|
||
|
||
fn_pc = (fused_metrics or {}).get("nse_pc")
|
||
wh_pc = (wh_metrics or {}).get("nse_pc")
|
||
row["utility_crossover_row"] = {
|
||
"transition": transition,
|
||
"gap_days": gap_days,
|
||
"nse_pc_fusion": fn_pc,
|
||
"nse_pc_whittaker": wh_pc,
|
||
}
|
||
crossover_rows.append(row["utility_crossover_row"])
|
||
results.append(row)
|
||
|
||
scenario = _scenario_key(strategy, sigma, mode)
|
||
payload = {
|
||
"site_name": site_name,
|
||
"season": season,
|
||
"scenario": scenario,
|
||
"tier": "temporal_nse_pc",
|
||
"manifest": str(vdir / "gap_manifest.json"),
|
||
"results": results,
|
||
"utility_crossover": {
|
||
scenario: {
|
||
"metric": "nse_pc_vs_phenocam_gcc90",
|
||
"first_gap_days_fusion_below_whittaker": first_gap_where_fusion_below_whittaker(
|
||
crossover_rows,
|
||
fusion_key="nse_pc_fusion",
|
||
whittaker_key="nse_pc_whittaker",
|
||
),
|
||
"by_gap": crossover_rows,
|
||
}
|
||
},
|
||
}
|
||
out_path = vdir / f"gap_metrics_{mode}.json"
|
||
out_path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8")
|
||
if mode == "bti":
|
||
# Legacy alias for backward-compatible readers.
|
||
(vdir / "gap_metrics.json").write_text(
|
||
json.dumps(payload, indent=2) + "\n", encoding="utf-8"
|
||
)
|
||
return out_path
|
||
|
||
|
||
def main() -> None:
|
||
ap = argparse.ArgumentParser(description="Gap-degraded full-season NSE_PC tier.")
|
||
ap.add_argument("--site", required=True)
|
||
ap.add_argument("--season", type=int, required=True)
|
||
ap.add_argument("--lat", type=float, required=True)
|
||
ap.add_argument("--lon", type=float, required=True)
|
||
ap.add_argument("--strategy", default="aggressive")
|
||
ap.add_argument("--sigma", type=int, default=20, choices=[20, 30])
|
||
ap.add_argument("--mode", default="bti", choices=["bti", "itb"])
|
||
ap.add_argument("--gap-days", type=int, action="append")
|
||
ap.add_argument("--transition", choices=list(TRANSITIONS), action="append")
|
||
ap.add_argument("--skip-manifest", action="store_true")
|
||
ap.add_argument("--skip-fusion", action="store_true")
|
||
ap.add_argument("--s2-calendar-strategy", default="aggressive")
|
||
args = ap.parse_args()
|
||
sigma_kw = 30 if args.sigma == 30 else None
|
||
out = run_temporal_pc(
|
||
args.site,
|
||
args.season,
|
||
(args.lat, args.lon),
|
||
args.strategy,
|
||
sigma_kw,
|
||
args.mode,
|
||
skip_manifest=args.skip_manifest,
|
||
skip_fusion=args.skip_fusion,
|
||
gap_days_filter=args.gap_days,
|
||
transition_filter=args.transition,
|
||
s2_calendar_strategy=args.s2_calendar_strategy,
|
||
)
|
||
print(out)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|