"""Acceptance gates for VMEC-JAX stellarator-optimization candidates.
The helpers here are intentionally independent of the SPECTRAX-GK time
integrator. They answer a narrower question: is a solved VMEC-JAX equilibrium
candidate physically acceptable enough to spend expensive nonlinear GK audit
time on it?
"""
from __future__ import annotations
from collections.abc import Mapping
from pathlib import Path
from types import SimpleNamespace
from typing import Any, cast
import numpy as np
DEFAULT_QS_WOUT_SURFACES = tuple(float(x) for x in np.linspace(0.0, 1.0, 11))
def _finite_float_or_none(value: Any) -> float | None:
try:
result = float(value)
except Exception:
return None
return result if np.isfinite(result) else None
def _finite_gate(value: float | None, *, lower: float | None = None, upper: float | None = None) -> bool:
if value is None:
return False
if lower is not None and value < float(lower):
return False
if upper is not None and value > float(upper):
return False
return True
[docs]
def final_iota_profiles_from_vmec_result(result: Any) -> tuple[np.ndarray, np.ndarray] | None:
"""Return final solved iota profiles from a VMEC-JAX result if available."""
state = getattr(result, "final_state", None)
optimizer = getattr(result, "final_optimizer", None)
if state is None or optimizer is None:
return None
try:
import vmec_jax as vj # type: ignore[import-not-found]
_chips, iotas, iotaf = vj.equilibrium_iota_profiles_from_state(
state=state,
static=getattr(optimizer, "_static"),
indata=getattr(optimizer, "_indata"),
signgs=int(getattr(optimizer, "_signgs")),
)
except Exception:
return None
return np.asarray(iotas, dtype=float), np.asarray(iotaf, dtype=float)
def _final_quasisymmetry_from_vmec_result(result: Any) -> float | None:
"""Return an independent final QS residual from a VMEC-JAX result."""
optimizer = getattr(result, "final_optimizer", None)
if optimizer is None:
return None
state = getattr(result, "final_state", None)
if state is not None:
try:
import jax.numpy as jnp
import vmec_jax as vj # type: ignore[import-not-found]
static = getattr(optimizer, "_static")
qs = vj.QuasisymmetryRatioResidual(
helicity_m=int(getattr(optimizer, "_helicity_m", 1) or 1),
helicity_n=int(getattr(optimizer, "_helicity_n", 0) or 0),
surfaces=np.arange(0.0, 1.01, 0.1),
)
ctx = SimpleNamespace(
static=static,
indata=getattr(optimizer, "_indata"),
signgs=int(getattr(optimizer, "_signgs")),
flux=getattr(optimizer, "_flux"),
pressure=jnp.zeros_like(jnp.asarray(getattr(static, "s"))),
)
value = _finite_float_or_none(qs.total(ctx, state))
if value is not None:
return value
except Exception:
pass
try:
residuals = getattr(optimizer, "_evaluate_residuals_from_state")(state)
qs_total = getattr(optimizer, "_qs_total_from_state")(state, residuals)
value = _finite_float_or_none(qs_total)
if value is not None:
return value
except Exception:
pass
params = getattr(result, "final_params", None)
if params is not None:
try:
return _finite_float_or_none(getattr(optimizer, "quasisymmetry_objective")(params))
except Exception:
return None
return None
def _history_from_candidate(candidate: Any) -> Mapping[str, Any]:
if isinstance(candidate, Mapping):
return candidate
history = getattr(candidate, "history", None)
return history if isinstance(history, Mapping) else {}
def _wout_summary(source: str | Path | Mapping[str, Any]) -> dict[str, Any]:
if isinstance(source, Mapping):
return {
"source": str(source.get("source", "mapping")),
"aspect": _finite_float_or_none(source.get("aspect")),
"mean_iota": _finite_float_or_none(source.get("mean_iota")),
"min_iotas_excluding_axis": _finite_float_or_none(source.get("min_iotas_excluding_axis")),
"min_iotaf": _finite_float_or_none(source.get("min_iotaf")),
}
path = Path(source)
try:
import netCDF4 # type: ignore[import-not-found]
with netCDF4.Dataset(path) as dataset:
aspect = _finite_float_or_none(np.asarray(dataset.variables["aspect"][:]))
iotas = np.asarray(dataset.variables["iotas"][:], dtype=float)
iotaf = np.asarray(dataset.variables["iotaf"][:], dtype=float)
except Exception as exc:
return {
"source": str(path),
"error": f"{type(exc).__name__}: {exc}",
"aspect": None,
"mean_iota": None,
"min_iotas_excluding_axis": None,
"min_iotaf": None,
}
finite_iotas = iotas[np.isfinite(iotas)]
finite_iotaf = iotaf[np.isfinite(iotaf)]
profile = finite_iotas[1:] if finite_iotas.size > 1 else finite_iotas
return {
"source": str(path),
"aspect": aspect,
"mean_iota": _finite_float_or_none(np.nanmean(profile)) if profile.size else None,
"min_iotas_excluding_axis": _finite_float_or_none(np.nanmin(profile)) if profile.size else None,
"min_iotaf": _finite_float_or_none(np.nanmin(finite_iotaf)) if finite_iotaf.size else None,
}
def _wout_quasisymmetry(
source: str | Path | Mapping[str, Any],
*,
helicity_m: int,
helicity_n: int,
surfaces: tuple[float, ...],
ntheta: int,
nphi: int,
) -> tuple[float | None, str, str | None]:
if isinstance(source, Mapping):
value = _finite_float_or_none(source.get("qs_residual", source.get("quasisymmetry")))
return value, str(source.get("qs_source", "mapping")), None if value is not None else "missing_qs_residual"
try:
import vmec_jax as vj # type: ignore[import-not-found]
wout = vj.load_wout(source)
qs = vj.quasisymmetry_ratio_residual_from_wout(
wout,
surfaces=np.asarray(surfaces, dtype=float),
helicity_m=int(helicity_m),
helicity_n=int(helicity_n),
ntheta=int(ntheta),
nphi=int(nphi),
)
if isinstance(qs, Mapping):
value = _finite_float_or_none(qs.get("total"))
else:
value = _finite_float_or_none(qs)
return value, "vmec_jax_wout", None if value is not None else "nonfinite_qs_residual"
except Exception as exc:
return None, "vmec_jax_wout_error", f"{type(exc).__name__}: {exc}"
[docs]
def build_authoritative_wout_candidate_gate(
wout: str | Path | Mapping[str, Any],
*,
target_aspect: float,
aspect_atol: float,
min_abs_mean_iota: float,
qs_residual_max: float,
iota_profile_floor: float | None,
helicity_m: int = 1,
helicity_n: int = 0,
qs_surfaces: tuple[float, ...] = DEFAULT_QS_WOUT_SURFACES,
qs_ntheta: int = 63,
qs_nphi: int = 64,
) -> dict[str, Any]:
"""Build a solved-equilibrium gate directly from a WOUT artifact.
Use this when the deterministic replayed WOUT is the publication-facing
equilibrium. It does not assert that the replayed WOUT matches an
optimizer-state WOUT; that remains the role of
:func:`build_wout_reproducibility_gate`.
"""
summary = _wout_summary(wout)
aspect = _finite_float_or_none(summary.get("aspect"))
mean_iota = _finite_float_or_none(summary.get("mean_iota"))
min_iotas = _finite_float_or_none(summary.get("min_iotas_excluding_axis"))
min_iotaf = _finite_float_or_none(summary.get("min_iotaf"))
profile_floor_passed = True
if iota_profile_floor is not None:
profile_floor_passed = _finite_gate(min_iotas, lower=float(iota_profile_floor)) and _finite_gate(
min_iotaf,
lower=float(iota_profile_floor),
)
qs_value, qs_source, qs_error = _wout_quasisymmetry(
wout,
helicity_m=int(helicity_m),
helicity_n=int(helicity_n),
surfaces=tuple(float(x) for x in qs_surfaces),
ntheta=int(qs_ntheta),
nphi=int(qs_nphi),
)
checks = {
"aspect": {
"value": aspect,
"target": float(target_aspect),
"absolute_error": None if aspect is None else abs(aspect - float(target_aspect)),
"absolute_tolerance": float(aspect_atol),
"passed": _finite_gate(
None if aspect is None else abs(aspect - float(target_aspect)),
upper=float(aspect_atol),
),
},
"mean_iota": {
"value": None if mean_iota is None else abs(mean_iota),
"minimum_abs": float(min_abs_mean_iota),
"margin": None if mean_iota is None else abs(mean_iota) - float(min_abs_mean_iota),
"passed": _finite_gate(None if mean_iota is None else abs(mean_iota), lower=float(min_abs_mean_iota)),
},
"iota_profile": {
"minimum_iotas_excluding_axis": min_iotas,
"minimum_iotaf": min_iotaf,
"floor": None if iota_profile_floor is None else float(iota_profile_floor),
"source": "wout",
"passed": bool(profile_floor_passed),
},
"quasisymmetry": {
"value": qs_value,
"maximum": float(qs_residual_max),
"margin": None if qs_value is None else float(qs_residual_max) - qs_value,
"source": qs_source,
"error": qs_error,
"helicity_m": int(helicity_m),
"helicity_n": int(helicity_n),
"surfaces": [float(x) for x in qs_surfaces],
"ntheta": int(qs_ntheta),
"nphi": int(qs_nphi),
"passed": _finite_gate(qs_value, upper=float(qs_residual_max)),
},
}
passed = all(bool(cast(Mapping[str, Any], check).get("passed")) for check in checks.values())
return {
"kind": "vmec_jax_authoritative_wout_candidate_gate",
"passed": bool(passed),
"authoritative_wout": summary,
"checks": checks,
"claim_level": (
"deterministic WOUT artifact passes solved-equilibrium admission; "
"optimizer-state reproducibility must be reported separately"
),
"next_action": (
"this WOUT may be used as the authoritative equilibrium for downstream SPECTRAX-GK audits"
if passed
else "do not use this WOUT for downstream SPECTRAX-GK transport promotion"
),
}
[docs]
def build_wout_reproducibility_gate(
reference_wout: str | Path | Mapping[str, Any],
rerun_wout: str | Path | Mapping[str, Any],
*,
target_aspect: float,
aspect_atol: float,
min_abs_mean_iota: float,
iota_profile_floor: float | None,
mean_iota_repro_atol: float = 5.0e-4,
aspect_repro_atol: float = 1.0e-6,
profile_repro_atol: float = 5.0e-4,
) -> dict[str, Any]:
"""Check that a saved VMEC input reproduces the optimizer-state WOUT.
VMEC-JAX can write both an optimizer-state ``wout_final.nc`` and an
``input.final`` deck. For publication-facing transport claims, the deck
must reproduce the WOUT when rerun; otherwise downstream SPECTRAX-GK
metrics may be attached to a different equilibrium than the optimized
state. This gate compares the original WOUT against a fresh rerun WOUT and
also applies the solved-equilibrium aspect/iota/profile admission checks to
the rerun.
"""
reference = _wout_summary(reference_wout)
rerun = _wout_summary(rerun_wout)
ref_aspect = _finite_float_or_none(reference.get("aspect"))
rerun_aspect = _finite_float_or_none(rerun.get("aspect"))
ref_iota = _finite_float_or_none(reference.get("mean_iota"))
rerun_iota = _finite_float_or_none(rerun.get("mean_iota"))
ref_min_iotas = _finite_float_or_none(reference.get("min_iotas_excluding_axis"))
rerun_min_iotas = _finite_float_or_none(rerun.get("min_iotas_excluding_axis"))
ref_min_iotaf = _finite_float_or_none(reference.get("min_iotaf"))
rerun_min_iotaf = _finite_float_or_none(rerun.get("min_iotaf"))
aspect_drift = None if ref_aspect is None or rerun_aspect is None else abs(rerun_aspect - ref_aspect)
iota_drift = None if ref_iota is None or rerun_iota is None else abs(abs(rerun_iota) - abs(ref_iota))
min_iotas_drift = (
None
if ref_min_iotas is None or rerun_min_iotas is None
else abs(rerun_min_iotas - ref_min_iotas)
)
min_iotaf_drift = (
None
if ref_min_iotaf is None or rerun_min_iotaf is None
else abs(rerun_min_iotaf - ref_min_iotaf)
)
profile_floor_passed = True
if iota_profile_floor is not None:
profile_floor_passed = _finite_gate(rerun_min_iotas, lower=float(iota_profile_floor)) and _finite_gate(
rerun_min_iotaf,
lower=float(iota_profile_floor),
)
checks = {
"rerun_aspect_admission": {
"value": rerun_aspect,
"target": float(target_aspect),
"absolute_error": None if rerun_aspect is None else abs(rerun_aspect - float(target_aspect)),
"absolute_tolerance": float(aspect_atol),
"passed": _finite_gate(
None if rerun_aspect is None else abs(rerun_aspect - float(target_aspect)),
upper=float(aspect_atol),
),
},
"rerun_mean_iota_admission": {
"value": None if rerun_iota is None else abs(rerun_iota),
"minimum_abs": float(min_abs_mean_iota),
"margin": None if rerun_iota is None else abs(rerun_iota) - float(min_abs_mean_iota),
"passed": _finite_gate(None if rerun_iota is None else abs(rerun_iota), lower=float(min_abs_mean_iota)),
},
"rerun_iota_profile_admission": {
"minimum_iotas_excluding_axis": rerun_min_iotas,
"minimum_iotaf": rerun_min_iotaf,
"floor": None if iota_profile_floor is None else float(iota_profile_floor),
"passed": bool(profile_floor_passed),
},
"aspect_reproducibility": {
"reference": ref_aspect,
"rerun": rerun_aspect,
"absolute_drift": aspect_drift,
"absolute_tolerance": float(aspect_repro_atol),
"passed": _finite_gate(aspect_drift, upper=float(aspect_repro_atol)),
},
"mean_iota_reproducibility": {
"reference": None if ref_iota is None else abs(ref_iota),
"rerun": None if rerun_iota is None else abs(rerun_iota),
"absolute_drift": iota_drift,
"absolute_tolerance": float(mean_iota_repro_atol),
"passed": _finite_gate(iota_drift, upper=float(mean_iota_repro_atol)),
},
"iota_profile_reproducibility": {
"min_iotas_drift": min_iotas_drift,
"min_iotaf_drift": min_iotaf_drift,
"absolute_tolerance": float(profile_repro_atol),
"passed": _finite_gate(min_iotas_drift, upper=float(profile_repro_atol))
and _finite_gate(min_iotaf_drift, upper=float(profile_repro_atol)),
},
}
passed = all(bool(cast(Mapping[str, Any], check).get("passed")) for check in checks.values())
return {
"kind": "vmec_jax_wout_reproducibility_gate",
"passed": bool(passed),
"reference_wout": reference,
"rerun_wout": rerun,
"checks": checks,
"claim_level": "saved VMEC input must reproduce optimizer-state WOUT before SPECTRAX-GK transport promotion",
"next_action": (
"saved input/WOUT pair is reproducible enough for transport admission"
if passed
else "do not promote this saved input; rerun/refine the VMEC-JAX solve until input.final reproduces wout_final.nc"
),
}
[docs]
def build_solved_vmec_candidate_gate(
candidate: Any,
*,
target_aspect: float,
aspect_atol: float,
min_abs_mean_iota: float,
qs_residual_max: float,
iota_profile_floor: float | None,
iota_profiles: tuple[np.ndarray, np.ndarray] | None = None,
profile_source: str = "provided",
) -> dict[str, Any]:
"""Build a JSON-safe solved-equilibrium gate report.
``candidate`` may be a VMEC-JAX optimization result with a ``history``
property or a history mapping loaded from ``history.json``.
"""
history = _history_from_candidate(candidate)
aspect = _finite_float_or_none(history.get("aspect_final"))
mean_iota = _finite_float_or_none(history.get("iota_final"))
mean_abs_iota = None if mean_iota is None else abs(mean_iota)
qs_residual = None
qs_source = "history"
if not isinstance(candidate, Mapping):
qs_residual = _final_quasisymmetry_from_vmec_result(candidate)
if qs_residual is not None:
qs_source = "vmec_jax_state"
if qs_residual is None:
qs_residual = _finite_float_or_none(history.get("qs_final"))
aspect_error = None if aspect is None else abs(aspect - float(target_aspect))
if iota_profiles is None and not isinstance(candidate, Mapping):
profile_source = "vmec_jax_state"
iota_profiles = final_iota_profiles_from_vmec_result(candidate)
min_iota_profile: float | None = None
min_iotaf_profile: float | None = None
profile_passed = iota_profile_floor is None
if iota_profiles is not None:
iotas, iotaf = iota_profiles
iotas = np.asarray(iotas, dtype=float)
iotaf = np.asarray(iotaf, dtype=float)
profile = iotas[1:] if iotas.size > 1 else iotas
full_profile = iotaf[np.isfinite(iotaf)]
min_iota_profile = _finite_float_or_none(np.nanmin(profile)) if profile.size else None
min_iotaf_profile = _finite_float_or_none(np.nanmin(full_profile)) if full_profile.size else None
if iota_profile_floor is not None:
profile_passed = _finite_gate(min_iota_profile, lower=float(iota_profile_floor)) and _finite_gate(
min_iotaf_profile,
lower=float(iota_profile_floor),
)
elif iota_profile_floor is not None:
profile_source = "missing"
profile_passed = False
checks = {
"aspect": {
"value": aspect,
"target": float(target_aspect),
"absolute_error": aspect_error,
"absolute_tolerance": float(aspect_atol),
"passed": _finite_gate(aspect_error, upper=float(aspect_atol)),
},
"mean_iota": {
"value": mean_abs_iota,
"minimum_abs": float(min_abs_mean_iota),
"margin": None if mean_abs_iota is None else mean_abs_iota - float(min_abs_mean_iota),
"passed": _finite_gate(mean_abs_iota, lower=float(min_abs_mean_iota)),
},
"quasisymmetry": {
"value": qs_residual,
"maximum": float(qs_residual_max),
"margin": None if qs_residual is None else float(qs_residual_max) - qs_residual,
"source": qs_source,
"passed": _finite_gate(qs_residual, upper=float(qs_residual_max)),
},
"iota_profile": {
"minimum_iotas_excluding_axis": min_iota_profile,
"minimum_iotaf": min_iotaf_profile,
"floor": None if iota_profile_floor is None else float(iota_profile_floor),
"source": profile_source,
"passed": bool(profile_passed),
},
}
passed = all(bool(cast(Mapping[str, Any], check).get("passed")) for check in checks.values())
return {
"kind": "vmec_jax_solved_wout_candidate_gate",
"passed": bool(passed),
"checks": checks,
"claim_level": "solved VMEC candidate gate before expensive SPECTRAX-GK nonlinear transport audit",
"next_action": (
"candidate may proceed to matched long-window nonlinear transport audits"
if passed
else "do not promote this candidate; refine constraints or reduce/re-scale the transport residual"
),
}
__all__ = [
"DEFAULT_QS_WOUT_SURFACES",
"build_authoritative_wout_candidate_gate",
"build_solved_vmec_candidate_gate",
"build_wout_reproducibility_gate",
"final_iota_profiles_from_vmec_result",
]