"""Boundary-chain diagnostics for VMEC-JAX/SPECTRAX-GK gradients.
These helpers classify the scalar contractions produced by the expensive
``tools/probe_vmec_jax_boundary_chain.py`` diagnostic. The diagnostic compares
raw exact-solve finite differences, frozen-axis initial-state finite
differences, and VMEC-JAX exact-tape JVP/VJP contractions. Keeping the
classification logic in the package makes the paper-facing convention explicit
and unit-testable without launching VMEC solves.
"""
from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any
import math
def _finite_float(value: float | int | None) -> float | None:
if value is None:
return None
try:
out = float(value)
except (TypeError, ValueError):
return None
return out if math.isfinite(out) else None
def _relative_error(a: float, b: float, *, floor: float) -> float:
scale = max(abs(float(a)), abs(float(b)), float(floor))
return abs(float(a) - float(b)) / scale
def _norm_ratio(numerator: float | None, denominator: float | None) -> float | None:
num = _finite_float(numerator)
den = _finite_float(denominator)
if num is None or den is None or den == 0.0:
return None
return abs(num) / abs(den)
def _passes_error(
abs_error: float | None,
rel_error: float | None,
*,
absolute_tolerance: float,
relative_tolerance: float,
) -> bool:
if abs_error is None:
return False
return bool(
abs_error <= absolute_tolerance
or (rel_error is not None and rel_error <= relative_tolerance)
)
[docs]
def build_boundary_chain_summary(
*,
exact_fd_cost_gradient: float,
final_cot_dot_exact_final_fd: float,
frozen_axis_replay_cost_gradient: float,
frozen_axis_vjp_cost_gradient: float,
frozen_axis_linear_replay_cost_gradient: float | None = None,
frozen_axis_linear_vjp_cost_gradient: float | None = None,
frozen_axis_initial_fd_vs_linear_abs_norm: float | None = None,
frozen_axis_initial_fd_vs_linear_rel: float | None = None,
raw_initial_replay_cost_gradient: float | None = None,
raw_initial_fd_norm: float | None = None,
frozen_axis_initial_fd_norm: float | None = None,
exact_relative_tolerance: float = 1.0e-1,
internal_relative_tolerance: float = 1.0e-8,
absolute_tolerance: float = 1.0e-10,
) -> dict[str, Any]:
"""Classify a boundary-gradient chain probe.
Args:
exact_fd_cost_gradient: Central finite difference through plus/minus
exact VMEC solves.
final_cot_dot_exact_final_fd: Final-state SPECTRAX-GK cotangent dotted
into the exact final-state finite-difference direction.
frozen_axis_replay_cost_gradient: VMEC-JAX tape JVP contraction using
the frozen-axis initial-state tangent used by the optimizer.
frozen_axis_vjp_cost_gradient: VMEC-JAX tape VJP contraction projected
back through the same frozen initial-state map.
frozen_axis_linear_replay_cost_gradient: Optional contraction using
VMEC-JAX's explicit frozen-axis tangent column.
frozen_axis_linear_vjp_cost_gradient: Optional VJP contraction using
VMEC-JAX's explicit frozen-axis tangent column.
frozen_axis_initial_fd_vs_linear_abs_norm: Optional norm of the
frozen-axis finite-difference tangent minus the explicit tangent
column.
frozen_axis_initial_fd_vs_linear_rel: Optional relative norm of the
frozen-axis finite-difference tangent minus the explicit tangent
column.
raw_initial_replay_cost_gradient: Optional tape JVP contraction using
raw plus/minus initial-state finite differences. This diagnoses
magnetic-axis branch sensitivity, but it is not the optimizer's
advertised derivative.
"""
exact = _finite_float(exact_fd_cost_gradient)
final = _finite_float(final_cot_dot_exact_final_fd)
frozen_jvp = _finite_float(frozen_axis_replay_cost_gradient)
frozen_vjp = _finite_float(frozen_axis_vjp_cost_gradient)
frozen_linear_jvp = _finite_float(frozen_axis_linear_replay_cost_gradient)
frozen_linear_vjp = _finite_float(frozen_axis_linear_vjp_cost_gradient)
tangent_diff_abs = _finite_float(frozen_axis_initial_fd_vs_linear_abs_norm)
tangent_diff_rel = _finite_float(frozen_axis_initial_fd_vs_linear_rel)
raw = _finite_float(raw_initial_replay_cost_gradient)
finite = (
exact is not None
and final is not None
and frozen_jvp is not None
and frozen_vjp is not None
)
metrics: dict[str, float | None] = {
"exact_fd_cost_gradient": exact,
"final_cot_dot_exact_final_fd": final,
"frozen_axis_replay_cost_gradient": frozen_jvp,
"frozen_axis_vjp_cost_gradient": frozen_vjp,
"frozen_axis_linear_replay_cost_gradient": frozen_linear_jvp,
"frozen_axis_linear_vjp_cost_gradient": frozen_linear_vjp,
"frozen_axis_initial_fd_vs_linear_abs_norm": tangent_diff_abs,
"frozen_axis_initial_fd_vs_linear_rel": tangent_diff_rel,
"raw_initial_replay_cost_gradient": raw,
"raw_to_frozen_initial_norm_ratio": _norm_ratio(
raw_initial_fd_norm, frozen_axis_initial_fd_norm
),
}
if not finite:
return {
"kind": "vmec_jax_boundary_chain_summary",
"finite": False,
"classification": "nonfinite_boundary_chain_probe",
"metrics": metrics,
"errors": {},
"passes": {},
"next_action": "repair nonfinite VMEC/Boozer/SPECTRAX derivatives before interpreting boundary gradients",
}
assert exact is not None
assert final is not None
assert frozen_jvp is not None
assert frozen_vjp is not None
final_state_abs = abs(final - exact)
final_state_rel = _relative_error(final, exact, floor=absolute_tolerance)
frozen_axis_abs = abs(frozen_jvp - exact)
frozen_axis_rel = _relative_error(frozen_jvp, exact, floor=absolute_tolerance)
frozen_jvp_vjp_abs = abs(frozen_jvp - frozen_vjp)
frozen_jvp_vjp_rel = _relative_error(
frozen_jvp, frozen_vjp, floor=absolute_tolerance
)
frozen_fd_jvp_vs_linear_abs = (
None if frozen_linear_jvp is None else abs(frozen_jvp - frozen_linear_jvp)
)
frozen_fd_jvp_vs_linear_rel = (
None
if frozen_linear_jvp is None
else _relative_error(frozen_jvp, frozen_linear_jvp, floor=absolute_tolerance)
)
frozen_linear_jvp_vjp_abs = (
None
if frozen_linear_jvp is None or frozen_linear_vjp is None
else abs(frozen_linear_jvp - frozen_linear_vjp)
)
frozen_linear_jvp_vjp_rel = (
None
if frozen_linear_jvp is None or frozen_linear_vjp is None
else _relative_error(
frozen_linear_jvp,
frozen_linear_vjp,
floor=absolute_tolerance,
)
)
frozen_fd_vjp_vs_linear_abs = (
None if frozen_linear_vjp is None else abs(frozen_vjp - frozen_linear_vjp)
)
frozen_fd_vjp_vs_linear_rel = (
None
if frozen_linear_vjp is None
else _relative_error(frozen_vjp, frozen_linear_vjp, floor=absolute_tolerance)
)
errors: dict[str, float | None] = {
"final_state_vs_exact_fd_abs": final_state_abs,
"final_state_vs_exact_fd_rel": final_state_rel,
"frozen_axis_vs_exact_fd_abs": frozen_axis_abs,
"frozen_axis_vs_exact_fd_rel": frozen_axis_rel,
"frozen_axis_jvp_vjp_abs": frozen_jvp_vjp_abs,
"frozen_axis_jvp_vjp_rel": frozen_jvp_vjp_rel,
"frozen_axis_fd_jvp_vs_linear_jvp_abs": frozen_fd_jvp_vs_linear_abs,
"frozen_axis_fd_jvp_vs_linear_jvp_rel": frozen_fd_jvp_vs_linear_rel,
"frozen_axis_linear_jvp_vjp_abs": frozen_linear_jvp_vjp_abs,
"frozen_axis_linear_jvp_vjp_rel": frozen_linear_jvp_vjp_rel,
"frozen_axis_fd_vjp_vs_linear_vjp_abs": frozen_fd_vjp_vs_linear_abs,
"frozen_axis_fd_vjp_vs_linear_vjp_rel": frozen_fd_vjp_vs_linear_rel,
"frozen_axis_initial_fd_vs_linear_abs_norm": tangent_diff_abs,
"frozen_axis_initial_fd_vs_linear_rel": tangent_diff_rel,
"raw_initial_vs_exact_fd_abs": None if raw is None else abs(raw - exact),
"raw_initial_vs_exact_fd_rel": (
None
if raw is None
else _relative_error(raw, exact, floor=absolute_tolerance)
),
}
tangent_ok = bool(
tangent_diff_rel is not None
and (
(tangent_diff_abs is not None and tangent_diff_abs <= absolute_tolerance)
or tangent_diff_rel <= internal_relative_tolerance
)
)
fd_jvp_linear_ok = _passes_error(
frozen_fd_jvp_vs_linear_abs,
frozen_fd_jvp_vs_linear_rel,
absolute_tolerance=absolute_tolerance,
relative_tolerance=internal_relative_tolerance,
)
linear_jvp_vjp_ok = _passes_error(
frozen_linear_jvp_vjp_abs,
frozen_linear_jvp_vjp_rel,
absolute_tolerance=absolute_tolerance,
relative_tolerance=internal_relative_tolerance,
)
fd_vjp_linear_ok = _passes_error(
frozen_fd_vjp_vs_linear_abs,
frozen_fd_vjp_vs_linear_rel,
absolute_tolerance=absolute_tolerance,
relative_tolerance=internal_relative_tolerance,
)
passes = {
"final_state_matches_exact_fd": bool(
final_state_abs <= absolute_tolerance
or final_state_rel <= exact_relative_tolerance
),
"frozen_axis_matches_exact_fd": bool(
frozen_axis_abs <= absolute_tolerance
or frozen_axis_rel <= exact_relative_tolerance
),
"frozen_axis_jvp_vjp_consistent": bool(
frozen_jvp_vjp_abs <= absolute_tolerance
or frozen_jvp_vjp_rel <= internal_relative_tolerance
),
"frozen_axis_fd_matches_linear_tangent": tangent_ok,
"frozen_axis_fd_jvp_matches_linear_jvp": fd_jvp_linear_ok,
"frozen_axis_linear_jvp_vjp_consistent": linear_jvp_vjp_ok,
"frozen_axis_fd_vjp_matches_linear_vjp": fd_vjp_linear_ok,
"frozen_axis_convention_verified": bool(
tangent_ok
and fd_jvp_linear_ok
and linear_jvp_vjp_ok
and fd_vjp_linear_ok
),
"raw_initial_matches_exact_fd": bool(
raw is not None
and (
(errors["raw_initial_vs_exact_fd_abs"] or 0.0) <= absolute_tolerance
or (errors["raw_initial_vs_exact_fd_rel"] or math.inf)
<= exact_relative_tolerance
)
),
}
norm_ratio = metrics["raw_to_frozen_initial_norm_ratio"]
branch_sensitive = bool(norm_ratio is not None and norm_ratio > 10.0)
if not passes["frozen_axis_jvp_vjp_consistent"]:
classification = "frozen_axis_replay_internally_inconsistent"
next_action = (
"debug VMEC-JAX exact-tape JVP/VJP replay; the optimizer derivative "
"is not internally transposed"
)
elif passes["frozen_axis_matches_exact_fd"]:
classification = "exact_fd_and_frozen_axis_replay_consistent"
next_action = (
"use the frozen-axis derivative as an optimization diagnostic; keep "
"sparse FD checks and solved-equilibrium gates before promotion"
)
elif passes["frozen_axis_convention_verified"] and branch_sensitive:
classification = "frozen_axis_convention_verified_but_exact_fd_branch_sensitive"
next_action = (
"raw exact-solve FD is branch-sensitive, but the frozen-axis finite "
"difference, explicit tangent column, tape JVP, and tape VJP agree; "
"use only with solved-equilibrium, growth-branch, and projected "
"line-search gates"
)
elif passes["frozen_axis_convention_verified"]:
classification = "frozen_axis_convention_verified_but_exact_fd_inconsistent"
next_action = (
"raw exact-solve FD is inconsistent with the optimizer convention, "
"but the frozen-axis tangent convention is verified; require "
"projected admission and matched nonlinear audits before promotion"
)
elif branch_sensitive:
classification = "frozen_axis_replay_consistent_but_exact_fd_branch_sensitive"
next_action = (
"tighten VMEC solve convergence and compare against the frozen-axis "
"finite-difference convention; raw exact-solve FD is moving the "
"magnetic-axis initialization branch"
)
elif not passes["final_state_matches_exact_fd"]:
classification = "final_state_cotangent_mismatch"
next_action = (
"audit the SPECTRAX final-state objective cotangent or the exact "
"final-state finite-difference branch before blaming boundary replay"
)
else:
classification = "frozen_axis_replay_consistent_but_exact_fd_inconsistent"
next_action = (
"treat the raw exact-solve FD as a convergence/branch diagnostic; "
"increase VMEC iterations or reduce branch sensitivity before promotion"
)
return {
"kind": "vmec_jax_boundary_chain_summary",
"finite": True,
"classification": classification,
"exact_relative_tolerance": float(exact_relative_tolerance),
"internal_relative_tolerance": float(internal_relative_tolerance),
"absolute_tolerance": float(absolute_tolerance),
"metrics": metrics,
"errors": errors,
"passes": passes,
"next_action": next_action,
}
[docs]
def boundary_chain_summary_from_probe(
payload: Mapping[str, Any], **kwargs: Any
) -> dict[str, Any]:
"""Build a chain summary from a probe JSON payload."""
return build_boundary_chain_summary(
exact_fd_cost_gradient=float(payload["exact_fd_cost_gradient"]),
final_cot_dot_exact_final_fd=float(payload["final_cot_dot_exact_final_fd"]),
frozen_axis_replay_cost_gradient=float(
payload["final_cot_dot_tape_jvp_frozen_axis_fd"]
),
frozen_axis_vjp_cost_gradient=float(payload["initial_cot_dot_frozen_axis_fd"]),
frozen_axis_linear_replay_cost_gradient=payload.get(
"final_cot_dot_tape_jvp_frozen_axis_linear"
),
frozen_axis_linear_vjp_cost_gradient=payload.get(
"initial_cot_dot_frozen_axis_linear"
),
frozen_axis_initial_fd_vs_linear_abs_norm=payload.get(
"frozen_axis_initial_fd_vs_linear_abs_norm"
),
frozen_axis_initial_fd_vs_linear_rel=payload.get(
"frozen_axis_initial_fd_vs_linear_rel"
),
raw_initial_replay_cost_gradient=payload.get(
"final_cot_dot_tape_jvp_raw_initial_fd"
),
raw_initial_fd_norm=payload.get("raw_initial_fd_norm"),
frozen_axis_initial_fd_norm=payload.get("frozen_axis_initial_fd_norm"),
**kwargs,
)
[docs]
def build_boundary_chain_collection_summary(
probes: Sequence[Mapping[str, Any]],
*,
exact_relative_tolerance: float = 1.0e-1,
internal_relative_tolerance: float = 1.0e-8,
absolute_tolerance: float = 1.0e-10,
) -> dict[str, Any]:
"""Summarize several boundary-chain probes as one promotion gate.
A single coefficient can look well-conditioned while neighboring boundary
modes still move the raw exact-solve initialization branch. The collection
summary keeps the stricter manuscript/release decision explicit: frozen-axis
JVP/VJP replay must be internally transposed for every component, while
exact finite-difference agreement is counted separately from branch
sensitivity.
"""
if not probes:
return {
"kind": "vmec_jax_boundary_chain_collection_summary",
"finite": False,
"classification": "empty_boundary_chain_collection",
"rows": [],
"counts": {
"n_total": 0,
"n_finite": 0,
"n_frozen_axis_internal_pass": 0,
"n_frozen_axis_convention_verified": 0,
"n_exact_fd_consistent": 0,
"n_branch_sensitive": 0,
"n_growth_branch_locality_checked": 0,
"n_growth_branch_locality_passed": 0,
},
"next_action": (
"run at least one boundary-chain probe before interpreting the "
"VMEC-JAX transport-gradient convention"
),
}
rows: list[dict[str, Any]] = []
for payload in probes:
summary_payload = payload.get("summary")
summary = (
dict(summary_payload)
if isinstance(summary_payload, Mapping)
else boundary_chain_summary_from_probe(
payload,
exact_relative_tolerance=exact_relative_tolerance,
internal_relative_tolerance=internal_relative_tolerance,
absolute_tolerance=absolute_tolerance,
)
)
passes = summary.get("passes", {})
errors = summary.get("errors", {})
metrics = summary.get("metrics", {})
growth_branch = payload.get("growth_branch_locality")
growth_checked = isinstance(growth_branch, Mapping) and bool(
growth_branch.get("enabled", True)
)
growth_passed = bool(
isinstance(growth_branch, Mapping) and growth_branch.get("passed", False)
)
rows.append(
{
"index": payload.get("index"),
"name": payload.get("name"),
"classification": summary.get("classification"),
"finite": bool(summary.get("finite", False)),
"frozen_axis_jvp_vjp_consistent": bool(
isinstance(passes, Mapping)
and passes.get("frozen_axis_jvp_vjp_consistent", False)
),
"frozen_axis_matches_exact_fd": bool(
isinstance(passes, Mapping)
and passes.get("frozen_axis_matches_exact_fd", False)
),
"exact_fd_consistent": bool(
isinstance(passes, Mapping)
and passes.get("frozen_axis_matches_exact_fd", False)
),
"frozen_axis_convention_verified": bool(
isinstance(passes, Mapping)
and passes.get("frozen_axis_convention_verified", False)
),
"final_state_matches_exact_fd": bool(
isinstance(passes, Mapping)
and passes.get("final_state_matches_exact_fd", False)
),
"exact_fd_cost_gradient": (
metrics.get("exact_fd_cost_gradient")
if isinstance(metrics, Mapping)
else None
),
"frozen_axis_replay_cost_gradient": (
metrics.get("frozen_axis_replay_cost_gradient")
if isinstance(metrics, Mapping)
else None
),
"frozen_axis_vs_exact_fd_rel": (
errors.get("frozen_axis_vs_exact_fd_rel")
if isinstance(errors, Mapping)
else None
),
"frozen_axis_initial_fd_vs_linear_rel": (
errors.get("frozen_axis_initial_fd_vs_linear_rel")
if isinstance(errors, Mapping)
else None
),
"frozen_axis_linear_jvp_vjp_rel": (
errors.get("frozen_axis_linear_jvp_vjp_rel")
if isinstance(errors, Mapping)
else None
),
"raw_initial_vs_exact_fd_rel": (
errors.get("raw_initial_vs_exact_fd_rel")
if isinstance(errors, Mapping)
else None
),
"growth_branch_locality_checked": growth_checked,
"growth_branch_locality_passed": growth_passed,
"growth_branch_locality_classification": (
growth_branch.get("classification")
if isinstance(growth_branch, Mapping)
else None
),
}
)
n_total = len(rows)
n_finite = sum(1 for row in rows if row["finite"])
n_internal = sum(1 for row in rows if row["frozen_axis_jvp_vjp_consistent"])
n_exact = sum(1 for row in rows if row["frozen_axis_matches_exact_fd"])
n_convention = sum(1 for row in rows if row["frozen_axis_convention_verified"])
n_branch = sum(
1
for row in rows
if row["classification"]
in {
"frozen_axis_replay_consistent_but_exact_fd_branch_sensitive",
"frozen_axis_convention_verified_but_exact_fd_branch_sensitive",
}
)
n_growth_checked = sum(1 for row in rows if row["growth_branch_locality_checked"])
n_growth_passed = sum(1 for row in rows if row["growth_branch_locality_passed"])
finite = n_finite == n_total
all_internal = finite and n_internal == n_total
if not finite:
classification = "nonfinite_boundary_chain_collection"
next_action = (
"repair nonfinite VMEC/Boozer/SPECTRAX derivatives before using "
"the boundary-gradient collection"
)
elif not all_internal:
classification = "internal_replay_failure"
next_action = (
"debug VMEC-JAX exact-tape replay because at least one frozen-axis "
"JVP/VJP contraction is not internally transposed"
)
elif n_exact == n_total:
classification = "all_components_exact_fd_and_frozen_axis_consistent"
next_action = (
"promote the frozen-axis convention for these sparse components, "
"while retaining solved-equilibrium and sparse-FD gates"
)
elif n_exact > 0 and n_convention == 0 and n_branch > 0:
classification = "mixed_exact_fd_consistency_with_branch_sensitive_modes"
next_action = (
"use frozen-axis derivatives only as diagnostics; exclude or "
"regularize branch-sensitive modes before projected VMEC updates"
)
elif n_convention == n_total:
classification = "all_components_frozen_axis_convention_verified"
next_action = (
"raw exact-solve FD remains inconsistent, but every component "
"passes the explicit frozen-axis tangent convention; projected "
"updates may use these directions only with solved-equilibrium, "
"growth-branch, and nonlinear-audit gates"
)
elif n_exact + n_convention > 0 and n_branch > 0:
classification = "mixed_exact_or_frozen_axis_convention_verified"
next_action = (
"use only components with exact-FD consistency or explicit "
"frozen-axis convention verification; unresolved branch-sensitive "
"modes remain excluded"
)
else:
classification = "branch_sensitive_boundary_chain_collection"
next_action = (
"do not promote this boundary-gradient collection until exact-solve "
"branch sensitivity is reduced or the frozen-axis convention is "
"validated against a better-conditioned finite-difference protocol"
)
return {
"kind": "vmec_jax_boundary_chain_collection_summary",
"finite": finite,
"classification": classification,
"exact_relative_tolerance": float(exact_relative_tolerance),
"internal_relative_tolerance": float(internal_relative_tolerance),
"absolute_tolerance": float(absolute_tolerance),
"counts": {
"n_total": n_total,
"n_finite": n_finite,
"n_frozen_axis_internal_pass": n_internal,
"n_frozen_axis_convention_verified": n_convention,
"n_exact_fd_consistent": n_exact,
"n_branch_sensitive": n_branch,
"n_growth_branch_locality_checked": n_growth_checked,
"n_growth_branch_locality_passed": n_growth_passed,
},
"rows": rows,
"next_action": next_action,
}
__all__ = [
"boundary_chain_summary_from_probe",
"build_boundary_chain_collection_summary",
"build_boundary_chain_summary",
]