added itb bti comparison.
This commit is contained in:
parent
a8852bc997
commit
f188dd38ab
3 changed files with 444 additions and 14 deletions
252
6-statistics-fusion-order.py
Normal file
252
6-statistics-fusion-order.py
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
"""Step 6: Paired ItB-vs-BtI significance test across the full sample.
|
||||
|
||||
Inputs (``data/``, ``{year}`` = ``--evaluation-year``):
|
||||
|
||||
- ``metrics/{year}/{site}/metrics.json`` — per-site validation metrics (Step 5)
|
||||
|
||||
Outputs (``data/statistics_fusion_order/``):
|
||||
|
||||
- ``{year}.json`` — paired Wilcoxon + t-test summary for NSE, RMSE, nRMSE, r
|
||||
|
||||
CLI:
|
||||
|
||||
- ``--evaluation-year`` (default 2025)
|
||||
- ``--alpha`` (default 0.05) — significance threshold for ``better_order``
|
||||
|
||||
This step aggregates across all sites with Step 5 output; it does not accept
|
||||
``--site`` (a single-site filter would not support a sample-level test).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from scipy.stats import ttest_rel, wilcoxon
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DATA_DIR = Path("data")
|
||||
DEFAULT_YEAR = 2025
|
||||
DEFAULT_ALPHA = 0.05
|
||||
|
||||
METRICS = ["nse", "rmse", "nrmse", "r"]
|
||||
LOWER_IS_BETTER = {"rmse", "nrmse"}
|
||||
|
||||
MIN_PAIRS_WARNING = 6
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _r4(v: float | None) -> float | None:
|
||||
return round(v, 4) if v is not None else None
|
||||
|
||||
|
||||
def _load_site_metrics(year: int) -> list[dict[str, Any]]:
|
||||
"""Return parsed ``metrics.json`` payloads for every site under ``{year}``."""
|
||||
metrics_dir = DATA_DIR / "metrics" / str(year)
|
||||
if not metrics_dir.is_dir():
|
||||
return []
|
||||
|
||||
payloads: list[dict[str, Any]] = []
|
||||
for site_dir in sorted(metrics_dir.iterdir()):
|
||||
if not site_dir.is_dir():
|
||||
continue
|
||||
path = site_dir / "metrics.json"
|
||||
if not path.is_file():
|
||||
continue
|
||||
payloads.append(json.loads(path.read_text()))
|
||||
return payloads
|
||||
|
||||
|
||||
def collect_pairs(
|
||||
site_metrics: list[dict[str, Any]], metric: str
|
||||
) -> tuple[list[float], list[float], int]:
|
||||
"""Return paired BtI / ItB values for ``metric`` and count of dropped sites."""
|
||||
bti_vals: list[float] = []
|
||||
itb_vals: list[float] = []
|
||||
n_dropped = 0
|
||||
|
||||
for payload in site_metrics:
|
||||
bti = payload.get("bti")
|
||||
itb = payload.get("itb")
|
||||
if not isinstance(bti, dict) or not isinstance(itb, dict):
|
||||
n_dropped += 1
|
||||
continue
|
||||
|
||||
bti_v = bti.get(metric)
|
||||
itb_v = itb.get(metric)
|
||||
if bti_v is None or itb_v is None:
|
||||
n_dropped += 1
|
||||
continue
|
||||
|
||||
bti_vals.append(float(bti_v))
|
||||
itb_vals.append(float(itb_v))
|
||||
|
||||
return bti_vals, itb_vals, n_dropped
|
||||
|
||||
|
||||
def _better_order(
|
||||
bti_vals: list[float],
|
||||
itb_vals: list[float],
|
||||
metric: str,
|
||||
p_value: float | None,
|
||||
alpha: float,
|
||||
) -> str:
|
||||
"""Name the better fusion order when Wilcoxon p < alpha, else no difference."""
|
||||
if p_value is None or p_value >= alpha:
|
||||
return "no significant difference"
|
||||
|
||||
mean_diff = float(np.mean(itb_vals) - np.mean(bti_vals))
|
||||
if metric in LOWER_IS_BETTER:
|
||||
return "itb" if mean_diff < 0 else "bti"
|
||||
return "itb" if mean_diff > 0 else "bti"
|
||||
|
||||
|
||||
def paired_test(
|
||||
bti_vals: list[float],
|
||||
itb_vals: list[float],
|
||||
metric: str,
|
||||
alpha: float,
|
||||
) -> dict[str, Any]:
|
||||
"""Run paired Wilcoxon (primary) and t-test; return summary dict."""
|
||||
n_pairs = len(bti_vals)
|
||||
bti_arr = np.array(bti_vals, dtype=float)
|
||||
itb_arr = np.array(itb_vals, dtype=float)
|
||||
diffs = itb_arr - bti_arr
|
||||
|
||||
result: dict[str, Any] = {
|
||||
"n_pairs": n_pairs,
|
||||
"bti_mean": _r4(float(bti_arr.mean())) if n_pairs else None,
|
||||
"bti_median": _r4(float(np.median(bti_arr))) if n_pairs else None,
|
||||
"itb_mean": _r4(float(itb_arr.mean())) if n_pairs else None,
|
||||
"itb_median": _r4(float(np.median(itb_arr))) if n_pairs else None,
|
||||
"mean_diff": _r4(float(diffs.mean())) if n_pairs else None,
|
||||
"median_diff": _r4(float(np.median(diffs))) if n_pairs else None,
|
||||
"wilcoxon": {"statistic": None, "p_value": None},
|
||||
"ttest": {"statistic": None, "p_value": None},
|
||||
"better_order": "insufficient data",
|
||||
}
|
||||
|
||||
if n_pairs < 2:
|
||||
return result
|
||||
|
||||
wilcoxon_stat: float | None = None
|
||||
wilcoxon_p: float | None = None
|
||||
if np.any(diffs != 0):
|
||||
try:
|
||||
w_stat, w_p = wilcoxon(itb_arr, bti_arr)
|
||||
wilcoxon_stat = float(w_stat)
|
||||
wilcoxon_p = float(w_p)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
t_stat, t_p = ttest_rel(itb_arr, bti_arr)
|
||||
|
||||
result["wilcoxon"] = {
|
||||
"statistic": _r4(wilcoxon_stat),
|
||||
"p_value": _r4(wilcoxon_p),
|
||||
}
|
||||
result["ttest"] = {
|
||||
"statistic": _r4(float(t_stat)),
|
||||
"p_value": _r4(float(t_p)),
|
||||
}
|
||||
result["better_order"] = _better_order(
|
||||
bti_vals, itb_vals, metric, wilcoxon_p, alpha
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _print_summary(
|
||||
year: int, alpha: float, n_sites_total: int, metrics_out: dict
|
||||
) -> None:
|
||||
print(f"\nPaired ItB vs BtI test — {year} (alpha={alpha}, sites={n_sites_total})")
|
||||
print(
|
||||
f"{'metric':<8} {'n':>4} {'BtI mean':>10} {'ItB mean':>10} "
|
||||
f"{'diff':>10} {'W p':>8} {'t p':>8} better"
|
||||
)
|
||||
print("-" * 78)
|
||||
for metric in METRICS:
|
||||
m = metrics_out[metric]
|
||||
bti_mean = m["bti_mean"]
|
||||
itb_mean = m["itb_mean"]
|
||||
mean_diff = m["mean_diff"]
|
||||
w_p = m["wilcoxon"]["p_value"]
|
||||
t_p = m["ttest"]["p_value"]
|
||||
better = m["better_order"]
|
||||
|
||||
def _fmt(v: float | None) -> str:
|
||||
return f"{v:10.4f}" if v is not None else f"{'—':>10}"
|
||||
|
||||
print(
|
||||
f"{metric:<8} {m['n_pairs']:>4} {_fmt(bti_mean)} {_fmt(itb_mean)} "
|
||||
f"{_fmt(mean_diff)} "
|
||||
f"{w_p if w_p is not None else '—':>8} "
|
||||
f"{t_p if t_p is not None else '—':>8} {better}"
|
||||
)
|
||||
if 0 < m["n_pairs"] < MIN_PAIRS_WARNING:
|
||||
print(
|
||||
f" warning: only {m['n_pairs']} pair(s) for {metric}; "
|
||||
"interpret p-values cautiously"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--evaluation-year", type=int, default=DEFAULT_YEAR)
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
type=float,
|
||||
default=DEFAULT_ALPHA,
|
||||
help="Significance threshold for better_order (default 0.05)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
year = args.evaluation_year
|
||||
alpha = args.alpha
|
||||
|
||||
site_metrics = _load_site_metrics(year)
|
||||
n_sites_total = len(site_metrics)
|
||||
if n_sites_total == 0:
|
||||
raise SystemExit(
|
||||
f"No Step 5 metrics found under {DATA_DIR / 'metrics' / str(year)}"
|
||||
)
|
||||
|
||||
metrics_out: dict[str, Any] = {}
|
||||
for metric in METRICS:
|
||||
bti_vals, itb_vals, n_dropped = collect_pairs(site_metrics, metric)
|
||||
summary = paired_test(bti_vals, itb_vals, metric, alpha)
|
||||
summary["n_dropped"] = n_dropped
|
||||
metrics_out[metric] = summary
|
||||
|
||||
payload = {
|
||||
"year": year,
|
||||
"alpha": alpha,
|
||||
"n_sites_total": n_sites_total,
|
||||
"metrics": metrics_out,
|
||||
}
|
||||
|
||||
out_dir = DATA_DIR / "statistics_fusion_order"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = out_dir / f"{year}.json"
|
||||
out_path.write_text(json.dumps(payload, separators=(",", ":")))
|
||||
|
||||
_print_summary(year, alpha, n_sites_total, metrics_out)
|
||||
print(f"\nWritten → {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue