Source code for spectraxgk.nonlinear_replicate_diagnostics

"""Diagnostics for replicated nonlinear transport-window spread.

The routines here are intentionally data-only. They consume already-generated
ensemble JSON artifacts and classify why a replicated nonlinear window passed
or failed before more GPU time is spent on follow-up runs.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Mapping, Sequence
import math
import re


_STATE_RE = re.compile(r"(?:^|_)(baseline|plus_delta|minus_delta)(?:_|$)")
_SEED_RE = re.compile(r"seed(\d+)")
_DT_RE = re.compile(r"dt([0-9A-Za-zp]+)")


[docs] @dataclass(frozen=True) class NonlinearReplicateSpreadConfig: """Thresholds for classifying replicated nonlinear-window spread.""" max_mean_rel_spread: float = 0.15 value_floor: float = 1.0e-12
def _finite_float(value: Any) -> float | None: try: out = float(value) except (TypeError, ValueError): return None return out if math.isfinite(out) else None def _json_number(value: float | int | None) -> float | int | None: if value is None: return None return value if math.isfinite(float(value)) else None def _state_label(ensemble: Mapping[str, Any], index: int) -> str: candidates: list[str] = [] for key in ("state", "case", "comparison"): value = ensemble.get(key) if value is not None: candidates.append(str(value)) rows = ensemble.get("rows") if isinstance(rows, Sequence): for row in rows: if isinstance(row, Mapping): for key in ("source_artifact", "summary_artifact", "case"): value = row.get(key) if value is not None: candidates.append(str(value)) for candidate in candidates: match = _STATE_RE.search(candidate) if match is not None: return match.group(1) return f"state_{index}" def _variant_label(row: Mapping[str, Any]) -> tuple[str, str]: for key in ("variant_label", "label"): value = row.get(key) if isinstance(value, str) and value: axis = str(row.get("variant_axis", "") or "") if not axis: if value.startswith("seed") and "_dt" in value: axis = "seed_timestep" elif value.startswith("seed"): axis = "seed" elif value.startswith("dt"): axis = "timestep" else: axis = "unknown" return value, axis candidates = [ str(value) for value in (row.get("source_artifact"), row.get("summary_artifact"), row.get("case")) if value is not None ] for candidate in candidates: seed = _SEED_RE.search(candidate) dt = _DT_RE.search(candidate) if seed is not None and dt is not None: return f"seed{seed.group(1)}_dt{dt.group(1)}", "seed_timestep" if seed is not None: return f"seed{seed.group(1)}", "seed" if dt is not None: return f"dt{dt.group(1)}", "timestep" return f"replicate_{row.get('index', 'unknown')}", "unknown" def _recommendation(classification: str) -> str: if classification == "passed_replicate_spread_gate": return "Replicate spread is within the configured gate; no extra replicas are indicated." if classification == "mixed_seed_timestep_spread": return ( "Do not add same-bracket replicas blindly. The high and low windows are on different " "variant axes, so first disambiguate seed sensitivity from timestep sensitivity or shrink " "the finite-difference bracket." ) if classification == "seed_spread_limited": return ( "Seed variability dominates. Add a matched seed at the same parameter state or switch to " "paired-seed finite differences before using this bracket for a gradient claim." ) if classification == "timestep_spread_limited": return ( "Timestep sensitivity dominates. Retune the timestep/window convergence before adding " "more random seeds at this state." ) return ( "Replicate spread is not classifiable from the available labels. Preserve the fail-closed " "claim boundary and add explicit seed/timestep metadata to the next run manifest." ) def _classify_state( *, passed: bool, mean_rel_spread: float | None, spread_gate: float, high_axis: str | None, low_axis: str | None, ) -> str: if passed and mean_rel_spread is not None and mean_rel_spread <= spread_gate: return "passed_replicate_spread_gate" if mean_rel_spread is not None and mean_rel_spread <= spread_gate: return "passed_replicate_spread_gate" if high_axis is not None and low_axis is not None and high_axis != low_axis: return "mixed_seed_timestep_spread" if high_axis == "seed" and low_axis == "seed": return "seed_spread_limited" if high_axis == "timestep" and low_axis == "timestep": return "timestep_spread_limited" return "spread_limited_unknown_axis"
[docs] def nonlinear_replicate_spread_report( ensembles: Sequence[Mapping[str, Any]], *, case: str = "nonlinear_replicate_spread_diagnostic", config: NonlinearReplicateSpreadConfig | None = None, ) -> dict[str, Any]: """Classify which replicate/state drives nonlinear-window ensemble spread. Parameters ---------- ensembles: Sequence of ensemble JSON payloads, typically produced by ``tools/check_nonlinear_window_ensemble.py``. case: Human-readable label for the diagnostic artifact. config: Spread threshold and numerical floor used for relative deviations. """ cfg = config or NonlinearReplicateSpreadConfig() if cfg.max_mean_rel_spread < 0.0: raise ValueError("max_mean_rel_spread must be non-negative") if cfg.value_floor <= 0.0: raise ValueError("value_floor must be positive") state_rows: list[dict[str, Any]] = [] replicate_rows: list[dict[str, Any]] = [] failed_states: list[str] = [] classifications: dict[str, int] = {} for state_index, ensemble in enumerate(ensembles): state = _state_label(ensemble, state_index) statistics = ensemble.get("statistics") stats = statistics if isinstance(statistics, Mapping) else {} rows_obj = ensemble.get("rows") rows = [row for row in rows_obj if isinstance(row, Mapping)] if isinstance(rows_obj, Sequence) else [] means = [_finite_float(row.get("late_mean")) for row in rows] finite_means = [mean for mean in means if mean is not None] ensemble_mean = _finite_float(stats.get("ensemble_mean")) if ensemble_mean is None and finite_means: ensemble_mean = sum(finite_means) / len(finite_means) scale = max(abs(float(ensemble_mean or 0.0)), float(cfg.value_floor)) high_row: Mapping[str, Any] | None = None low_row: Mapping[str, Any] | None = None if rows and finite_means: row_mean_pairs: list[tuple[Mapping[str, Any], float]] = [] for row in rows: row_mean = _finite_float(row.get("late_mean")) if row_mean is not None: row_mean_pairs.append((row, row_mean)) high_row = max(row_mean_pairs, key=lambda item: item[1])[0] low_row = min(row_mean_pairs, key=lambda item: item[1])[0] high_label, high_axis = _variant_label(high_row) if high_row is not None else (None, None) low_label, low_axis = _variant_label(low_row) if low_row is not None else (None, None) mean_rel_spread = _finite_float(stats.get("mean_rel_spread")) if mean_rel_spread is None and finite_means: mean_rel_spread = (max(finite_means) - min(finite_means)) / scale spread_gate = _finite_float(stats.get("max_mean_rel_spread")) if spread_gate is None: raw_config = ensemble.get("config") if isinstance(raw_config, Mapping): spread_gate = _finite_float(raw_config.get("max_mean_rel_spread")) if spread_gate is None: spread_gate = float(cfg.max_mean_rel_spread) passed = bool(ensemble.get("passed", False)) classification = _classify_state( passed=passed, mean_rel_spread=mean_rel_spread, spread_gate=spread_gate, high_axis=high_axis, low_axis=low_axis, ) classifications[classification] = classifications.get(classification, 0) + 1 if classification != "passed_replicate_spread_gate": failed_states.append(state) state_rows.append( { "state": state, "passed": passed, "classification": classification, "recommendation": _recommendation(classification), "ensemble_mean": _json_number(ensemble_mean), "mean_rel_spread": _json_number(mean_rel_spread), "mean_rel_spread_gate": _json_number(spread_gate), "combined_sem_rel": _json_number(_finite_float(stats.get("combined_sem_rel"))), "high_variant_label": high_label, "high_variant_axis": high_axis, "low_variant_label": low_label, "low_variant_axis": low_axis, } ) for row in rows: mean = _finite_float(row.get("late_mean")) sem = _finite_float(row.get("sem")) label, axis = _variant_label(row) rel_delta = None if mean is None or ensemble_mean is None else (mean - ensemble_mean) / scale window_stats = row.get("window_statistics") window = window_stats if isinstance(window_stats, Mapping) else {} replicate_rows.append( { "state": state, "index": int(row.get("index", len(replicate_rows))), "variant_label": label, "variant_axis": axis, "late_mean": _json_number(mean), "sem": _json_number(sem), "ensemble_mean": _json_number(ensemble_mean), "relative_delta": _json_number(rel_delta), "passed": bool(row.get("passed", False)), "promotion_ready": bool(row.get("promotion_ready", False)), "source_artifact": row.get("source_artifact"), "summary_artifact": row.get("summary_artifact"), "running_mean_rel_drift": _json_number( _finite_float(window.get("running_mean_rel_drift")) ), "terminal_mean_rel_delta": _json_number( _finite_float(window.get("terminal_mean_rel_delta")) ), "sem_rel": _json_number(_finite_float(window.get("sem_rel"))), "n_blocks": _json_number(_finite_float(window.get("n_blocks"))), } ) passed = not failed_states return { "kind": "nonlinear_replicate_spread_diagnostic", "claim_level": "replicate_spread_diagnostic_not_simulation_claim", "case": str(case), "passed": passed, "summary": { "n_states": len(ensembles), "n_replicates": len(replicate_rows), "failed_states": failed_states, "classifications": classifications, "recommendation": ( "All replicated ensembles are within spread gates." if passed else "Keep the nonlinear-gradient claim fail-closed and target the failed states first." ), }, "state_rows": state_rows, "replicate_rows": replicate_rows, "config": { "max_mean_rel_spread": float(cfg.max_mean_rel_spread), "value_floor": float(cfg.value_floor), }, }
__all__ = [ "NonlinearReplicateSpreadConfig", "nonlinear_replicate_spread_report", ]