efast-phenocam-validation/6-statistics-fusion-order.py
2026-06-17 12:04:27 +02:00

257 lines
7.9 KiB
Python

"""Step 6: Paired ItB-vs-BtI significance test across the full sample.
Inputs (``data/``, ``{year}`` = ``--evaluation-year``):
- ``metrics/{year}/{site}/metrics.json`` — per-site validation metrics (Step 5)
Outputs (``data/statistics_fusion_order/``):
- ``{year}.json`` — paired Wilcoxon + t-test summary for NSE, RMSE, nRMSE, r;
includes ``dropped_sites`` (union) and per-metric ``dropped_sites`` lists
CLI:
- ``--evaluation-year`` (default 2025)
- ``--alpha`` (default 0.05) — significance threshold for ``better_order``
This step aggregates across all sites with Step 5 output; it does not accept
``--site`` (a single-site filter would not support a sample-level test).
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
import numpy as np
from scipy.stats import ttest_rel, wilcoxon
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
DATA_DIR = Path("data")
DEFAULT_YEAR = 2025
DEFAULT_ALPHA = 0.05
METRICS = ["nse", "rmse", "nrmse", "r"]
LOWER_IS_BETTER = {"rmse", "nrmse"}
MIN_PAIRS_WARNING = 6
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _r4(v: float | None) -> float | None:
return round(v, 4) if v is not None else None
def _load_site_metrics(year: int) -> list[tuple[str, dict[str, Any]]]:
"""Return ``(sitename, metrics.json payload)`` for every site under ``{year}``."""
metrics_dir = DATA_DIR / "metrics" / str(year)
if not metrics_dir.is_dir():
return []
payloads: list[tuple[str, dict[str, Any]]] = []
for site_dir in sorted(metrics_dir.iterdir()):
if not site_dir.is_dir():
continue
path = site_dir / "metrics.json"
if not path.is_file():
continue
payloads.append((site_dir.name, json.loads(path.read_text())))
return payloads
def collect_pairs(
site_metrics: list[tuple[str, dict[str, Any]]], metric: str
) -> tuple[list[float], list[float], list[str]]:
"""Return paired BtI / ItB values for ``metric`` and dropped site names."""
bti_vals: list[float] = []
itb_vals: list[float] = []
dropped_sites: list[str] = []
for site, payload in site_metrics:
bti = payload.get("bti")
itb = payload.get("itb")
if not isinstance(bti, dict) or not isinstance(itb, dict):
dropped_sites.append(site)
continue
bti_v = bti.get(metric)
itb_v = itb.get(metric)
if bti_v is None or itb_v is None:
dropped_sites.append(site)
continue
bti_vals.append(float(bti_v))
itb_vals.append(float(itb_v))
return bti_vals, itb_vals, dropped_sites
def _better_order(
bti_vals: list[float],
itb_vals: list[float],
metric: str,
p_value: float | None,
alpha: float,
) -> str:
"""Name the better fusion order when Wilcoxon p < alpha, else no difference."""
if p_value is None or p_value >= alpha:
return "no significant difference"
mean_diff = float(np.mean(itb_vals) - np.mean(bti_vals))
if metric in LOWER_IS_BETTER:
return "itb" if mean_diff < 0 else "bti"
return "itb" if mean_diff > 0 else "bti"
def paired_test(
bti_vals: list[float],
itb_vals: list[float],
metric: str,
alpha: float,
) -> dict[str, Any]:
"""Run paired Wilcoxon (primary) and t-test; return summary dict."""
n_pairs = len(bti_vals)
bti_arr = np.array(bti_vals, dtype=float)
itb_arr = np.array(itb_vals, dtype=float)
diffs = itb_arr - bti_arr
result: dict[str, Any] = {
"n_pairs": n_pairs,
"bti_mean": _r4(float(bti_arr.mean())) if n_pairs else None,
"bti_median": _r4(float(np.median(bti_arr))) if n_pairs else None,
"itb_mean": _r4(float(itb_arr.mean())) if n_pairs else None,
"itb_median": _r4(float(np.median(itb_arr))) if n_pairs else None,
"mean_diff": _r4(float(diffs.mean())) if n_pairs else None,
"median_diff": _r4(float(np.median(diffs))) if n_pairs else None,
"wilcoxon": {"statistic": None, "p_value": None},
"ttest": {"statistic": None, "p_value": None},
"better_order": "insufficient data",
}
if n_pairs < 2:
return result
wilcoxon_stat: float | None = None
wilcoxon_p: float | None = None
if np.any(diffs != 0):
try:
w_stat, w_p = wilcoxon(itb_arr, bti_arr)
wilcoxon_stat = float(w_stat)
wilcoxon_p = float(w_p)
except ValueError:
pass
t_stat, t_p = ttest_rel(itb_arr, bti_arr)
result["wilcoxon"] = {
"statistic": _r4(wilcoxon_stat),
"p_value": _r4(wilcoxon_p),
}
result["ttest"] = {
"statistic": _r4(float(t_stat)),
"p_value": _r4(float(t_p)),
}
result["better_order"] = _better_order(
bti_vals, itb_vals, metric, wilcoxon_p, alpha
)
return result
def _print_summary(
year: int, alpha: float, n_sites_total: int, metrics_out: dict
) -> None:
print(f"\nPaired ItB vs BtI test — {year} (alpha={alpha}, sites={n_sites_total})")
print(
f"{'metric':<8} {'n':>4} {'BtI mean':>10} {'ItB mean':>10} "
f"{'diff':>10} {'W p':>8} {'t p':>8} better"
)
print("-" * 78)
for metric in METRICS:
m = metrics_out[metric]
bti_mean = m["bti_mean"]
itb_mean = m["itb_mean"]
mean_diff = m["mean_diff"]
w_p = m["wilcoxon"]["p_value"]
t_p = m["ttest"]["p_value"]
better = m["better_order"]
def _fmt(v: float | None) -> str:
return f"{v:10.4f}" if v is not None else f"{'':>10}"
print(
f"{metric:<8} {m['n_pairs']:>4} {_fmt(bti_mean)} {_fmt(itb_mean)} "
f"{_fmt(mean_diff)} "
f"{w_p if w_p is not None else '':>8} "
f"{t_p if t_p is not None else '':>8} {better}"
)
if 0 < m["n_pairs"] < MIN_PAIRS_WARNING:
print(
f" warning: only {m['n_pairs']} pair(s) for {metric}; "
"interpret p-values cautiously"
)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--evaluation-year", type=int, default=DEFAULT_YEAR)
parser.add_argument(
"--alpha",
type=float,
default=DEFAULT_ALPHA,
help="Significance threshold for better_order (default 0.05)",
)
args = parser.parse_args()
year = args.evaluation_year
alpha = args.alpha
site_metrics = _load_site_metrics(year)
n_sites_total = len(site_metrics)
if n_sites_total == 0:
raise SystemExit(
f"No Step 5 metrics found under {DATA_DIR / 'metrics' / str(year)}"
)
metrics_out: dict[str, Any] = {}
all_dropped: set[str] = set()
for metric in METRICS:
bti_vals, itb_vals, dropped_sites = collect_pairs(site_metrics, metric)
summary = paired_test(bti_vals, itb_vals, metric, alpha)
summary["n_dropped"] = len(dropped_sites)
summary["dropped_sites"] = dropped_sites
all_dropped.update(dropped_sites)
metrics_out[metric] = summary
payload = {
"year": year,
"alpha": alpha,
"n_sites_total": n_sites_total,
"dropped_sites": sorted(all_dropped),
"metrics": metrics_out,
}
out_dir = DATA_DIR / "statistics_fusion_order"
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / f"{year}.json"
out_path.write_text(json.dumps(payload, separators=(",", ":")))
_print_summary(year, alpha, n_sites_total, metrics_out)
print(f"\nWritten → {out_path}")
if __name__ == "__main__":
main()