252 lines
7.5 KiB
Python
252 lines
7.5 KiB
Python
"""Step 6: Paired ItB-vs-BtI significance test across the full sample.
|
|
|
|
Inputs (``data/``, ``{year}`` = ``--evaluation-year``):
|
|
|
|
- ``metrics/{year}/{site}/metrics.json`` — per-site validation metrics (Step 5)
|
|
|
|
Outputs (``data/statistics_fusion_order/``):
|
|
|
|
- ``{year}.json`` — paired Wilcoxon + t-test summary for NSE, RMSE, nRMSE, r
|
|
|
|
CLI:
|
|
|
|
- ``--evaluation-year`` (default 2025)
|
|
- ``--alpha`` (default 0.05) — significance threshold for ``better_order``
|
|
|
|
This step aggregates across all sites with Step 5 output; it does not accept
|
|
``--site`` (a single-site filter would not support a sample-level test).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
from scipy.stats import ttest_rel, wilcoxon
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Constants
|
|
# ---------------------------------------------------------------------------
|
|
|
|
DATA_DIR = Path("data")
|
|
DEFAULT_YEAR = 2025
|
|
DEFAULT_ALPHA = 0.05
|
|
|
|
METRICS = ["nse", "rmse", "nrmse", "r"]
|
|
LOWER_IS_BETTER = {"rmse", "nrmse"}
|
|
|
|
MIN_PAIRS_WARNING = 6
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _r4(v: float | None) -> float | None:
|
|
return round(v, 4) if v is not None else None
|
|
|
|
|
|
def _load_site_metrics(year: int) -> list[dict[str, Any]]:
|
|
"""Return parsed ``metrics.json`` payloads for every site under ``{year}``."""
|
|
metrics_dir = DATA_DIR / "metrics" / str(year)
|
|
if not metrics_dir.is_dir():
|
|
return []
|
|
|
|
payloads: list[dict[str, Any]] = []
|
|
for site_dir in sorted(metrics_dir.iterdir()):
|
|
if not site_dir.is_dir():
|
|
continue
|
|
path = site_dir / "metrics.json"
|
|
if not path.is_file():
|
|
continue
|
|
payloads.append(json.loads(path.read_text()))
|
|
return payloads
|
|
|
|
|
|
def collect_pairs(
|
|
site_metrics: list[dict[str, Any]], metric: str
|
|
) -> tuple[list[float], list[float], int]:
|
|
"""Return paired BtI / ItB values for ``metric`` and count of dropped sites."""
|
|
bti_vals: list[float] = []
|
|
itb_vals: list[float] = []
|
|
n_dropped = 0
|
|
|
|
for payload in site_metrics:
|
|
bti = payload.get("bti")
|
|
itb = payload.get("itb")
|
|
if not isinstance(bti, dict) or not isinstance(itb, dict):
|
|
n_dropped += 1
|
|
continue
|
|
|
|
bti_v = bti.get(metric)
|
|
itb_v = itb.get(metric)
|
|
if bti_v is None or itb_v is None:
|
|
n_dropped += 1
|
|
continue
|
|
|
|
bti_vals.append(float(bti_v))
|
|
itb_vals.append(float(itb_v))
|
|
|
|
return bti_vals, itb_vals, n_dropped
|
|
|
|
|
|
def _better_order(
|
|
bti_vals: list[float],
|
|
itb_vals: list[float],
|
|
metric: str,
|
|
p_value: float | None,
|
|
alpha: float,
|
|
) -> str:
|
|
"""Name the better fusion order when Wilcoxon p < alpha, else no difference."""
|
|
if p_value is None or p_value >= alpha:
|
|
return "no significant difference"
|
|
|
|
mean_diff = float(np.mean(itb_vals) - np.mean(bti_vals))
|
|
if metric in LOWER_IS_BETTER:
|
|
return "itb" if mean_diff < 0 else "bti"
|
|
return "itb" if mean_diff > 0 else "bti"
|
|
|
|
|
|
def paired_test(
|
|
bti_vals: list[float],
|
|
itb_vals: list[float],
|
|
metric: str,
|
|
alpha: float,
|
|
) -> dict[str, Any]:
|
|
"""Run paired Wilcoxon (primary) and t-test; return summary dict."""
|
|
n_pairs = len(bti_vals)
|
|
bti_arr = np.array(bti_vals, dtype=float)
|
|
itb_arr = np.array(itb_vals, dtype=float)
|
|
diffs = itb_arr - bti_arr
|
|
|
|
result: dict[str, Any] = {
|
|
"n_pairs": n_pairs,
|
|
"bti_mean": _r4(float(bti_arr.mean())) if n_pairs else None,
|
|
"bti_median": _r4(float(np.median(bti_arr))) if n_pairs else None,
|
|
"itb_mean": _r4(float(itb_arr.mean())) if n_pairs else None,
|
|
"itb_median": _r4(float(np.median(itb_arr))) if n_pairs else None,
|
|
"mean_diff": _r4(float(diffs.mean())) if n_pairs else None,
|
|
"median_diff": _r4(float(np.median(diffs))) if n_pairs else None,
|
|
"wilcoxon": {"statistic": None, "p_value": None},
|
|
"ttest": {"statistic": None, "p_value": None},
|
|
"better_order": "insufficient data",
|
|
}
|
|
|
|
if n_pairs < 2:
|
|
return result
|
|
|
|
wilcoxon_stat: float | None = None
|
|
wilcoxon_p: float | None = None
|
|
if np.any(diffs != 0):
|
|
try:
|
|
w_stat, w_p = wilcoxon(itb_arr, bti_arr)
|
|
wilcoxon_stat = float(w_stat)
|
|
wilcoxon_p = float(w_p)
|
|
except ValueError:
|
|
pass
|
|
|
|
t_stat, t_p = ttest_rel(itb_arr, bti_arr)
|
|
|
|
result["wilcoxon"] = {
|
|
"statistic": _r4(wilcoxon_stat),
|
|
"p_value": _r4(wilcoxon_p),
|
|
}
|
|
result["ttest"] = {
|
|
"statistic": _r4(float(t_stat)),
|
|
"p_value": _r4(float(t_p)),
|
|
}
|
|
result["better_order"] = _better_order(
|
|
bti_vals, itb_vals, metric, wilcoxon_p, alpha
|
|
)
|
|
return result
|
|
|
|
|
|
def _print_summary(
|
|
year: int, alpha: float, n_sites_total: int, metrics_out: dict
|
|
) -> None:
|
|
print(f"\nPaired ItB vs BtI test — {year} (alpha={alpha}, sites={n_sites_total})")
|
|
print(
|
|
f"{'metric':<8} {'n':>4} {'BtI mean':>10} {'ItB mean':>10} "
|
|
f"{'diff':>10} {'W p':>8} {'t p':>8} better"
|
|
)
|
|
print("-" * 78)
|
|
for metric in METRICS:
|
|
m = metrics_out[metric]
|
|
bti_mean = m["bti_mean"]
|
|
itb_mean = m["itb_mean"]
|
|
mean_diff = m["mean_diff"]
|
|
w_p = m["wilcoxon"]["p_value"]
|
|
t_p = m["ttest"]["p_value"]
|
|
better = m["better_order"]
|
|
|
|
def _fmt(v: float | None) -> str:
|
|
return f"{v:10.4f}" if v is not None else f"{'—':>10}"
|
|
|
|
print(
|
|
f"{metric:<8} {m['n_pairs']:>4} {_fmt(bti_mean)} {_fmt(itb_mean)} "
|
|
f"{_fmt(mean_diff)} "
|
|
f"{w_p if w_p is not None else '—':>8} "
|
|
f"{t_p if t_p is not None else '—':>8} {better}"
|
|
)
|
|
if 0 < m["n_pairs"] < MIN_PAIRS_WARNING:
|
|
print(
|
|
f" warning: only {m['n_pairs']} pair(s) for {metric}; "
|
|
"interpret p-values cautiously"
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CLI
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument("--evaluation-year", type=int, default=DEFAULT_YEAR)
|
|
parser.add_argument(
|
|
"--alpha",
|
|
type=float,
|
|
default=DEFAULT_ALPHA,
|
|
help="Significance threshold for better_order (default 0.05)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
year = args.evaluation_year
|
|
alpha = args.alpha
|
|
|
|
site_metrics = _load_site_metrics(year)
|
|
n_sites_total = len(site_metrics)
|
|
if n_sites_total == 0:
|
|
raise SystemExit(
|
|
f"No Step 5 metrics found under {DATA_DIR / 'metrics' / str(year)}"
|
|
)
|
|
|
|
metrics_out: dict[str, Any] = {}
|
|
for metric in METRICS:
|
|
bti_vals, itb_vals, n_dropped = collect_pairs(site_metrics, metric)
|
|
summary = paired_test(bti_vals, itb_vals, metric, alpha)
|
|
summary["n_dropped"] = n_dropped
|
|
metrics_out[metric] = summary
|
|
|
|
payload = {
|
|
"year": year,
|
|
"alpha": alpha,
|
|
"n_sites_total": n_sites_total,
|
|
"metrics": metrics_out,
|
|
}
|
|
|
|
out_dir = DATA_DIR / "statistics_fusion_order"
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
out_path = out_dir / f"{year}.json"
|
|
out_path.write_text(json.dumps(payload, separators=(",", ":")))
|
|
|
|
_print_summary(year, alpha, n_sites_total, metrics_out)
|
|
print(f"\nWritten → {out_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|