diff --git a/metrics_stats.py b/metrics_stats.py index 530652e..ea2bc83 100644 --- a/metrics_stats.py +++ b/metrics_stats.py @@ -196,6 +196,84 @@ def derived_tier1(temporal: dict) -> dict: } +MATCHED_PAIR_CONFIGS = ( + "aggressive_sigma20", + "aggressive_sigma30", + "nonaggressive_sigma20", + "nonaggressive_sigma30", +) + + +def derived_matched_pair_workflow(temporal: dict) -> dict: + """Per-config BtI vs ItB NSE_PC/RMSE pairs and site-level consistency flags.""" + per_config = [] + nse_deltas: list[float] = [] + nse_bti_wins_count = 0 + residual_bti_wins_count = 0 + + for config in MATCHED_PAIR_CONFIGS: + kb = config + ki = f"{config}_itb" + tb = temporal.get(kb) or {} + ti = temporal.get(ki) or {} + nse_bti = tb.get("nse_pc") + nse_itb = ti.get("nse_pc") + rmse_bti = tb.get("rmse") + rmse_itb = ti.get("rmse") + mb = (tb.get("residual_vs_phenocam") or {}).get("mean") + mi = (ti.get("residual_vs_phenocam") or {}).get("mean") + + delta_nse = None + delta_rmse = None + bti_wins = None + residual_bti_wins = None + + if isinstance(nse_bti, (int, float)) and isinstance(nse_itb, (int, float)): + delta_nse = float(nse_bti) - float(nse_itb) + bti_wins = delta_nse > 0 + nse_deltas.append(delta_nse) + if bti_wins: + nse_bti_wins_count += 1 + + if isinstance(rmse_bti, (int, float)) and isinstance(rmse_itb, (int, float)): + delta_rmse = float(rmse_bti) - float(rmse_itb) + + if isinstance(mb, (int, float)) and isinstance(mi, (int, float)): + if float(mb) > float(mi): + residual_bti_wins_count += 1 + residual_bti_wins = True + elif float(mb) < float(mi): + residual_bti_wins = False + else: + residual_bti_wins = None + + per_config.append( + { + "config": config, + "nse_pc_bti": float(nse_bti) if isinstance(nse_bti, (int, float)) else None, + "nse_pc_itb": float(nse_itb) if isinstance(nse_itb, (int, float)) else None, + "rmse_bti": float(rmse_bti) if isinstance(rmse_bti, (int, float)) else None, + "rmse_itb": float(rmse_itb) if isinstance(rmse_itb, (int, float)) else None, + "delta_nse_bti_minus_itb": delta_nse, + "delta_rmse_bti_minus_itb": delta_rmse, + "bti_wins": bti_wins, + "residual_bti_wins": residual_bti_wins, + } + ) + + mean_delta_nse = ( + float(sum(nse_deltas) / len(nse_deltas)) if nse_deltas else None + ) + return { + "per_config": per_config, + "consistency": nse_bti_wins_count, + "nse_bti_wins_count": nse_bti_wins_count, + "residual_bti_wins_count": residual_bti_wins_count, + "residual_nse_mismatch": residual_bti_wins_count != nse_bti_wins_count, + "mean_delta_nse": mean_delta_nse, + } + + def calculate_phenocam_stats(phenocam_ts): """Calculate phenocam summary statistics.""" values = [v for v in phenocam_ts.values() if v is not None] @@ -407,7 +485,11 @@ def calculate_all_metrics(season, site_name, site_position): results["temporal"][scenario_name] = temporal_metrics if results["temporal"]: - results["derived"] = derived_tier1(results["temporal"]) + derived = derived_tier1(results["temporal"]) + derived["matched_pair_workflow"] = derived_matched_pair_workflow( + results["temporal"] + ) + results["derived"] = derived # Save results output_path = Path(f"data/{site_name}/{season}/metrics.json")