"""Differentiable geometry bridge contracts for VMEC/JAX pipelines."""
from __future__ import annotations
import importlib
import os
from collections.abc import Mapping, Sequence
from dataclasses import replace as dc_replace
from functools import lru_cache
from pathlib import Path
import tempfile
from types import SimpleNamespace
from typing import Any, Callable, cast
import jax
import jax.numpy as jnp
import numpy as np
from spectraxgk.autodiff_validation import covariance_diagnostics
from spectraxgk.geometry import FluxTubeGeometryData
_ARRAY_FIELDS = (
"theta",
"gradpar",
"bmag",
"bgrad",
"gds2",
"gds21",
"gds22",
"cvdrift",
"gbdrift",
"cvdrift0",
"gbdrift0",
)
_GEOMETRY_OBSERVABLE_NAMES = (
"mean_bmag",
"relative_bmag_ripple",
"metric_frobenius_rms",
"drift_rms",
"mean_jacobian",
"mean_gradpar",
)
_VMEC_METRIC_OBSERVABLE_NAMES = (
"sqrtg_rms",
"mean_g_ss",
"mean_g_tt",
"mean_g_pp",
"g_st_rms",
"g_sp_rms",
"g_tp_rms",
)
_VMEC_FIELD_LINE_OBSERVABLE_NAMES = (
"mean_bmag",
"relative_bmag_ripple",
"sqrtg_rms",
"mean_g_tt",
"mean_g_pp",
"g_tp_rms",
"mean_g_ss",
)
_VMEC_BOOZER_PARITY_MIN_MODE_COUNT = 21
def _jax_float_dtype() -> Any:
return jnp.float64 if bool(jax.config.read("jax_enable_x64")) else jnp.float32
def _candidate_paths(env_names: Sequence[str], defaults: Sequence[Path]) -> list[Path]:
paths: list[Path] = []
for name in env_names:
raw = os.environ.get(name)
if raw:
paths.append(Path(os.path.expandvars(raw)).expanduser())
paths.extend(defaults)
out: list[Path] = []
seen: set[Path] = set()
for base in paths:
for candidate in (base, base / "src"):
resolved = candidate.resolve(strict=False)
if resolved in seen or not resolved.exists():
continue
seen.add(resolved)
out.append(resolved)
return out
def _find_importable_module(name: str, paths: Sequence[Path]) -> Any | None:
import sys
def _module_file(module_name: str) -> Path | None:
module = sys.modules.get(module_name)
raw = None if module is None else getattr(module, "__file__", None)
if raw is None:
return None
return Path(str(raw)).resolve(strict=False)
def _inside_candidates(path: Path | None) -> bool:
if path is None:
return False
for root in paths:
try:
path.relative_to(root.resolve(strict=False))
return True
except ValueError:
continue
return False
for path in reversed(paths):
path_str = str(path)
if path_str not in sys.path:
sys.path.insert(0, path_str)
# Prefer explicitly configured/local differentiable-geometry checkouts over
# globally installed packages. Some editable VMEC checkouts carry example
# data that the site package does not ship, so importing the site package
# first can silently disable the real validation gates.
root_name = name.split(".", maxsplit=1)[0]
root_path = _module_file(root_name)
if root_path is not None and not _inside_candidates(root_path):
for module_name in list(sys.modules):
if module_name == root_name or module_name.startswith(f"{root_name}."):
sys.modules.pop(module_name, None)
try:
return importlib.import_module(name)
except Exception:
pass
return None
def _is_traced(value: Any) -> bool:
"""Return true when host NumPy validation would break JAX tracing."""
if isinstance(value, jax.core.Tracer):
return True
if isinstance(value, (tuple, list)):
return any(_is_traced(item) for item in value)
if isinstance(value, Mapping):
return any(_is_traced(item) for item in value.values())
return False
[docs]
def discover_differentiable_geometry_backends() -> dict[str, object]:
"""Discover optional ``vmec_jax`` and ``booz_xform_jax`` bridge APIs."""
repo_parent = Path(__file__).resolve().parents[3].parent
home = Path.home()
vmec_paths = _candidate_paths(
("SPECTRAX_VMEC_JAX_PATH", "VMEC_JAX_PATH"),
(
repo_parent / "vmec_jax",
home / "vmec_jax",
home / "local" / "vmec_jax",
),
)
booz_paths = _candidate_paths(
("SPECTRAX_BOOZ_XFORM_JAX_PATH", "BOOZ_XFORM_JAX_PATH"),
(
repo_parent / "booz_xform_jax",
home / "booz_xform_jax",
home / "local" / "booz_xform_jax",
),
)
vmec = _find_importable_module("vmec_jax", vmec_paths)
booz = _find_importable_module("booz_xform_jax", booz_paths)
booz_jax_api = (
None
if booz is None
else _find_importable_module("booz_xform_jax.jax_api", booz_paths)
)
vmec_boundary_api = vmec is not None and all(
hasattr(vmec, name)
for name in (
"BoundaryCoeffs",
"boundary_aspect_ratio",
"build_helical_basis",
"make_angle_grid",
"vmec_mode_table",
)
)
booz_api = (
booz_jax_api is not None
and hasattr(booz_jax_api, "prepare_booz_xform_constants_from_inputs")
and hasattr(booz_jax_api, "booz_xform_from_inputs")
and hasattr(booz_jax_api, "booz_xform_jax_impl")
)
return {
"vmec_jax_available": vmec is not None,
"vmec_jax_boundary_api_available": vmec_boundary_api,
"booz_xform_jax_available": booz is not None,
"booz_xform_jax_api_available": booz_api,
"vmec_jax_paths": [str(path) for path in vmec_paths],
"booz_xform_jax_paths": [str(path) for path in booz_paths],
}
def _array(
mapping: Mapping[str, Any],
key: str,
ntheta: int | None = None,
*,
validate_finite: bool = True,
) -> jnp.ndarray:
if key not in mapping:
raise ValueError(f"missing differentiable geometry field {key!r}")
arr = jnp.asarray(mapping[key])
if arr.ndim != 1:
raise ValueError(f"{key} must be one-dimensional")
if ntheta is not None and int(arr.shape[0]) != int(ntheta):
raise ValueError(
f"{key} length {arr.shape[0]} does not match theta length {ntheta}"
)
if (
validate_finite
and not _is_traced(arr)
and not bool(np.all(np.isfinite(np.asarray(arr))))
):
raise ValueError(f"{key} contains non-finite values")
return arr
def _scalar(
mapping: Mapping[str, Any],
key: str,
default: float,
*,
validate_finite: bool = True,
) -> Any:
value = mapping.get(key, default)
arr = jnp.asarray(value)
if arr.ndim != 0:
raise ValueError(f"{key} must be scalar")
if _is_traced(arr):
return arr
out = float(np.asarray(arr))
if validate_finite and not np.isfinite(out):
raise ValueError(f"{key} contains a non-finite value")
return out
[docs]
def flux_tube_geometry_from_mapping(
data: Mapping[str, Any],
*,
source_model: str = "vmec_jax",
validate_finite: bool = True,
) -> FluxTubeGeometryData:
"""Build ``FluxTubeGeometryData`` from an in-memory differentiable backend.
The input is intentionally the solver-ready flux-tube contract, not a fake
equilibrium. ``vmec_jax`` / ``booz_xform_jax`` pipelines should first
produce the sampled field-line arrays named here, then this function
validates shapes/finite values and hands them to the existing solver.
"""
theta = _array(data, "theta", validate_finite=validate_finite)
ntheta = int(theta.shape[0])
if ntheta < 1:
raise ValueError("theta must contain at least one sample")
arrays = {
name: _array(data, name, ntheta, validate_finite=validate_finite)
for name in _ARRAY_FIELDS
if name != "theta"
}
jacobian = (
_array(data, "jacobian", ntheta, validate_finite=validate_finite)
if "jacobian" in data
else 1.0 / arrays["gradpar"] / arrays["bmag"]
)
grho = (
_array(data, "grho", ntheta, validate_finite=validate_finite)
if "grho" in data
else jnp.ones_like(theta)
)
gradpar_value: Any
if _is_traced(arrays["gradpar"]):
gradpar_value = jnp.mean(arrays["gradpar"])
else:
gradpar_values = np.asarray(arrays["gradpar"])
gradpar_value = float(np.mean(gradpar_values))
if validate_finite and not np.allclose(
gradpar_values, gradpar_value, rtol=1.0e-5, atol=1.0e-7
):
raise ValueError("gradpar must be constant along the sampled field line")
nfp = int(data.get("nfp", 1))
if nfp < 1:
raise ValueError("nfp must be a positive integer")
return FluxTubeGeometryData(
theta=theta,
gradpar_value=gradpar_value,
bmag_profile=arrays["bmag"],
bgrad_profile=arrays["bgrad"],
gds2_profile=arrays["gds2"],
gds21_profile=arrays["gds21"],
gds22_profile=arrays["gds22"],
cv_profile=arrays["cvdrift"],
gb_profile=arrays["gbdrift"],
cv0_profile=arrays["cvdrift0"],
gb0_profile=arrays["gbdrift0"],
jacobian_profile=jacobian,
grho_profile=grho,
q=_scalar(data, "q", 1.0, validate_finite=validate_finite),
s_hat=_scalar(
data, "s_hat", data.get("shat", 0.0), validate_finite=validate_finite
),
epsilon=_scalar(data, "epsilon", 0.0, validate_finite=validate_finite),
R0=_scalar(data, "R0", 1.0, validate_finite=validate_finite),
B0=_scalar(data, "B0", 1.0, validate_finite=validate_finite),
alpha=_scalar(data, "alpha", 0.0, validate_finite=validate_finite),
drift_scale=_scalar(data, "drift_scale", 1.0, validate_finite=validate_finite),
kxfac=_scalar(data, "kxfac", 1.0, validate_finite=validate_finite),
theta_scale=_scalar(data, "theta_scale", 1.0, validate_finite=validate_finite),
nfp=nfp,
kperp2_bmag=bool(data.get("kperp2_bmag", True)),
bessel_bmag_power=float(data["bessel_bmag_power"]) if "bessel_bmag_power" in data else 0.0,
source_model=str(source_model),
theta_closed_interval=bool(data.get("theta_closed_interval", False)),
)
[docs]
def geometry_observable_names() -> tuple[str, ...]:
"""Return the ordered geometry observables used by bridge AD checks."""
return _GEOMETRY_OBSERVABLE_NAMES
[docs]
def vmec_metric_tensor_observable_names() -> tuple[str, ...]:
"""Return the ordered observables used by the VMEC metric-tensor gate."""
return _VMEC_METRIC_OBSERVABLE_NAMES
[docs]
def vmec_field_line_tensor_observable_names() -> tuple[str, ...]:
"""Return the ordered observables used by the VMEC field-line tensor gate."""
return _VMEC_FIELD_LINE_OBSERVABLE_NAMES
[docs]
def flux_tube_geometry_observables(geom: FluxTubeGeometryData) -> jnp.ndarray:
"""Return differentiable scalar observables from solver-ready geometry.
The observables are intentionally geometry-level quantities: mean field
strength, relative ripple, metric norm, drift norm, mean Jacobian, and mean
parallel-gradient factor. They are used to validate the differentiable
``vmec_jax`` / ``booz_xform_jax`` bridge before any turbulence observable
is promoted into an optimization claim.
"""
bmag = jnp.asarray(geom.bmag_profile)
jac = jnp.abs(jnp.asarray(geom.jacobian_profile))
weights = jac / jnp.maximum(jnp.sum(jac), jnp.asarray(1.0e-300, dtype=jac.dtype))
mean_b = jnp.sum(weights * bmag)
ripple = jnp.sqrt(
jnp.sum(weights * (bmag / jnp.maximum(jnp.abs(mean_b), 1.0e-300) - 1.0) ** 2)
)
metric = jnp.sqrt(
jnp.sum(
weights
* (
jnp.asarray(geom.gds2_profile) ** 2
+ 2.0 * jnp.asarray(geom.gds21_profile) ** 2
+ jnp.asarray(geom.gds22_profile) ** 2
)
)
)
drift = jnp.sqrt(
jnp.sum(
weights
* (
jnp.asarray(geom.cv_profile) ** 2
+ jnp.asarray(geom.gb_profile) ** 2
+ jnp.asarray(geom.cv0_profile) ** 2
+ jnp.asarray(geom.gb0_profile) ** 2
)
)
)
return jnp.asarray(
[
mean_b,
ripple,
metric,
drift,
jnp.mean(jnp.asarray(geom.jacobian_profile)),
jnp.mean(jnp.asarray(geom.gradpar_value)),
]
)
[docs]
def finite_difference_jacobian(
fn: Any, params: jnp.ndarray, *, step: float = 1.0e-4
) -> jnp.ndarray:
"""Central finite-difference Jacobian for small validation problems."""
p = jnp.asarray(params, dtype=_jax_float_dtype())
h = float(step)
if p.ndim != 1:
raise ValueError("params must be one-dimensional")
if h <= 0.0:
raise ValueError("step must be positive")
columns = []
for idx in range(int(p.shape[0])):
basis = jnp.zeros_like(p).at[idx].set(h)
columns.append(
(jnp.asarray(fn(p + basis)) - jnp.asarray(fn(p - basis))) / (2.0 * h)
)
return jnp.stack(columns, axis=1)
def _json_ready(value: Any) -> Any:
"""Return a strict JSON-compatible copy, replacing nonfinite floats by null."""
if isinstance(value, Mapping):
return {str(key): _json_ready(item) for key, item in value.items()}
if isinstance(value, (list, tuple)):
return [_json_ready(item) for item in value]
if isinstance(value, (np.bool_, bool)):
return bool(value)
if isinstance(value, (np.integer, int)):
return int(value)
if isinstance(value, (np.floating, float)):
out = float(value)
return out if np.isfinite(out) else None
if isinstance(value, np.ndarray):
return _json_ready(value.tolist())
return value
def _sensitivity_conditioning_metadata(
jacobian_ad: Any,
jacobian_fd: Any,
params: Any,
*,
fd_step: float,
observable_names: Sequence[str] | None = None,
param_names: Sequence[str] | None = None,
relative_floor: float = 1.0e-12,
) -> dict[str, object]:
"""Return JSON-friendly conditioning metadata for AD/FD Jacobian gates."""
jac_ad = np.asarray(jacobian_ad, dtype=float)
jac_fd = np.asarray(jacobian_fd, dtype=float)
p = np.asarray(params, dtype=float).reshape(-1)
if jac_ad.ndim != 2 or jac_fd.ndim != 2:
raise ValueError("jacobians must be two-dimensional")
if jac_ad.shape != jac_fd.shape:
raise ValueError("AD and finite-difference jacobians must have matching shapes")
if jac_ad.shape[1] != p.size:
raise ValueError("parameter length must match jacobian columns")
obs_names = (
[str(name) for name in observable_names]
if observable_names is not None
else [f"observable_{idx}" for idx in range(jac_ad.shape[0])]
)
par_names = (
[str(name) for name in param_names]
if param_names is not None
else [f"param_{idx}" for idx in range(jac_ad.shape[1])]
)
if len(obs_names) != jac_ad.shape[0]:
raise ValueError("observable_names length must match jacobian rows")
if len(par_names) != jac_ad.shape[1]:
raise ValueError("param_names length must match jacobian columns")
finite_ad = bool(np.all(np.isfinite(jac_ad)))
finite_fd = bool(np.all(np.isfinite(jac_fd)))
finite_params = bool(np.all(np.isfinite(p)))
if finite_ad and jac_ad.size:
singular_values = np.linalg.svd(jac_ad, compute_uv=False)
rank = int(np.linalg.matrix_rank(jac_ad))
if singular_values.size == 0 or float(singular_values[-1]) <= 0.0:
condition_number = float("inf")
else:
condition_number = float(singular_values[0] / singular_values[-1])
column_norms = np.linalg.norm(jac_ad, axis=0)
row_norms = np.linalg.norm(jac_ad, axis=1)
else:
singular_values = np.asarray([], dtype=float)
rank = 0
condition_number = float("inf")
column_norms = np.full((jac_ad.shape[1],), np.nan)
row_norms = np.full((jac_ad.shape[0],), np.nan)
floor = float(relative_floor)
diff = jac_ad - jac_fd
abs_error = np.abs(diff)
rel_error = abs_error / np.maximum(np.abs(jac_fd), floor)
def _worst_entry(values: np.ndarray) -> dict[str, object] | None:
if values.size == 0 or not np.any(np.isfinite(values)):
return None
flat_idx = int(np.nanargmax(values))
row, col = np.unravel_index(flat_idx, values.shape)
return {
"observable_index": int(row),
"observable_name": obs_names[int(row)],
"parameter_index": int(col),
"parameter_name": par_names[int(col)],
"value": float(values[row, col]),
"ad": float(jac_ad[row, col]),
"finite_difference": float(jac_fd[row, col]),
}
h = abs(float(fd_step))
fd_step_by_parameter = [
{
"parameter_index": int(idx),
"parameter_name": par_names[int(idx)],
"parameter_value": float(p[idx]) if idx < p.size else float("nan"),
"absolute_step": h,
"relative_step": float(h / max(abs(float(p[idx])), 1.0)),
}
for idx in range(p.size)
]
finite_column_norms = column_norms[np.isfinite(column_norms)]
max_column_norm = (
float(np.max(finite_column_norms)) if finite_column_norms.size else float("nan")
)
min_column_norm = (
float(np.min(finite_column_norms)) if finite_column_norms.size else float("nan")
)
return {
"jacobian_shape": [int(jac_ad.shape[0]), int(jac_ad.shape[1])],
"finite_ad_jacobian": finite_ad,
"finite_fd_jacobian": finite_fd,
"finite_parameters": finite_params,
"sensitivity_map_rank": rank,
"jacobian_condition_number": condition_number,
"jacobian_singular_values": singular_values.tolist(),
"ad_column_norms": column_norms.tolist(),
"ad_row_norms": row_norms.tolist(),
"max_ad_column_norm": max_column_norm,
"min_ad_column_norm": min_column_norm,
"fd_step": float(fd_step),
"relative_error_floor": floor,
"finite_difference_step_by_parameter": fd_step_by_parameter,
"fd_near_zero_reference_count": int(np.sum(np.abs(jac_fd) < floor)),
"worst_abs_error": _worst_entry(abs_error),
"worst_rel_error": _worst_entry(rel_error),
}
[docs]
def observable_gradient_validation_report(
observable_fn: Callable[[jnp.ndarray], Any],
params: jnp.ndarray | np.ndarray,
*,
fd_step: float = 1.0e-4,
rtol: float = 1.0e-4,
atol: float = 1.0e-6,
observable_names: Sequence[str] | None = None,
param_names: Sequence[str] | None = None,
tangent: jnp.ndarray | np.ndarray | None = None,
relative_floor: float = 1.0e-12,
min_rank: int | None = None,
condition_number_max: float | None = 1.0e12,
report_kind: str = "observable_gradient_validation",
) -> dict[str, object]:
"""Validate observable gradients by AD, finite differences, and conditioning.
``observable_fn(params)`` may return any array-like observable vector. The
returned report is strict JSON-compatible: nonfinite diagnostic numbers are
represented as ``None`` while finite flags and failure reasons preserve why
the gate failed.
"""
p = jnp.asarray(params, dtype=_jax_float_dtype())
if p.ndim != 1:
raise ValueError("params must be one-dimensional")
if int(p.size) == 0:
raise ValueError("params must contain at least one parameter")
step = float(fd_step)
if step <= 0.0:
raise ValueError("fd_step must be positive")
rel_tol = float(rtol)
abs_tol = float(atol)
if rel_tol < 0.0 or abs_tol < 0.0:
raise ValueError("rtol and atol must be non-negative")
floor = float(relative_floor)
if floor <= 0.0:
raise ValueError("relative_floor must be positive")
if condition_number_max is not None and float(condition_number_max) <= 0.0:
raise ValueError("condition_number_max must be positive or None")
def flat_fn(x: jnp.ndarray) -> jnp.ndarray:
out = jnp.ravel(jnp.asarray(observable_fn(x)))
if jnp.iscomplexobj(out):
return jnp.concatenate([jnp.real(out), jnp.imag(out)])
return out
observables = flat_fn(p)
if int(observables.size) == 0:
raise ValueError("observable_fn must return at least one observable")
jac_ad = jax.jacfwd(flat_fn)(p)
jac_fd = finite_difference_jacobian(flat_fn, p, step=step)
if jac_ad.ndim != 2 or jac_fd.ndim != 2 or jac_ad.shape != jac_fd.shape:
raise ValueError("AD and finite-difference jacobians must be two-dimensional and shape-aligned")
n_obs, n_params = int(jac_ad.shape[0]), int(jac_ad.shape[1])
obs_names = (
[str(name) for name in observable_names]
if observable_names is not None
else [f"observable_{idx}" for idx in range(n_obs)]
)
par_names = (
[str(name) for name in param_names]
if param_names is not None
else [f"param_{idx}" for idx in range(n_params)]
)
if len(obs_names) != n_obs:
raise ValueError("observable_names length must match jacobian rows")
if len(par_names) != n_params:
raise ValueError("param_names length must match jacobian columns")
if tangent is None:
direction = jnp.ones_like(p)
direction = direction / jnp.maximum(
jnp.linalg.norm(direction), jnp.asarray(1.0, dtype=p.dtype)
)
else:
direction = jnp.asarray(tangent, dtype=p.dtype)
if direction.shape != p.shape:
raise ValueError("tangent must have the same shape as params")
tangent_ad = jac_ad @ direction
tangent_fd = (flat_fn(p + step * direction) - flat_fn(p - step * direction)) / (
2.0 * step
)
ad_np = np.asarray(jac_ad, dtype=float)
fd_np = np.asarray(jac_fd, dtype=float)
diff = ad_np - fd_np
abs_error = np.abs(diff)
rel_error = abs_error / np.maximum(np.abs(fd_np), floor)
tangent_ad_np = np.asarray(tangent_ad, dtype=float)
tangent_fd_np = np.asarray(tangent_fd, dtype=float)
tangent_diff = tangent_ad_np - tangent_fd_np
tangent_abs_error = np.abs(tangent_diff)
tangent_rel_error = tangent_abs_error / np.maximum(np.abs(tangent_fd_np), floor)
finite_flags = {
"params": bool(np.all(np.isfinite(np.asarray(p, dtype=float)))),
"observables": bool(np.all(np.isfinite(np.asarray(observables, dtype=float)))),
"autodiff_jacobian": bool(np.all(np.isfinite(ad_np))),
"finite_difference_jacobian": bool(np.all(np.isfinite(fd_np))),
"abs_error": bool(np.all(np.isfinite(abs_error))),
"rel_error": bool(np.all(np.isfinite(rel_error))),
"tangent_direction": bool(np.all(np.isfinite(np.asarray(direction, dtype=float)))),
"tangent_autodiff": bool(np.all(np.isfinite(tangent_ad_np))),
"tangent_finite_difference": bool(np.all(np.isfinite(tangent_fd_np))),
}
finite_passed = bool(all(finite_flags.values()))
gradient_checks: list[dict[str, object]] = []
for i, observable_name in enumerate(obs_names):
for j, parameter_name in enumerate(par_names):
entry_abs = float(abs_error[i, j])
entry_rel = float(rel_error[i, j])
gradient_checks.append(
{
"observable": observable_name,
"parameter": parameter_name,
"autodiff": float(ad_np[i, j]),
"finite_difference": float(fd_np[i, j]),
"abs_error": entry_abs,
"rel_error": entry_rel,
"atol": abs_tol,
"rtol": rel_tol,
"passed": bool(entry_abs <= abs_tol or entry_rel <= rel_tol),
}
)
derivative_passed = bool(
gradient_checks and all(bool(row["passed"]) for row in gradient_checks)
)
tangent_max_abs = (
float(np.nanmax(tangent_abs_error)) if tangent_abs_error.size else 0.0
)
tangent_max_rel = (
float(np.nanmax(tangent_rel_error)) if tangent_rel_error.size else 0.0
)
tangent_passed = bool(tangent_max_abs <= abs_tol or tangent_max_rel <= rel_tol)
conditioning = _sensitivity_conditioning_metadata(
jac_ad,
jac_fd,
p,
fd_step=step,
observable_names=obs_names,
param_names=par_names,
relative_floor=floor,
)
required_rank = (
min(n_obs, n_params) if min_rank is None else int(min_rank)
)
if required_rank < 0:
raise ValueError("min_rank must be non-negative")
rank = int(cast(int | float, conditioning["sensitivity_map_rank"]))
condition_number = float(cast(float, conditioning["jacobian_condition_number"]))
condition_number_finite = bool(np.isfinite(condition_number))
condition_number_passed = bool(
condition_number_max is None
or (
condition_number_finite
and condition_number <= float(condition_number_max)
)
)
rank_passed = bool(rank >= required_rank)
conditioning_passed = bool(
finite_flags["autodiff_jacobian"]
and finite_flags["finite_difference_jacobian"]
and rank_passed
and condition_number_passed
)
conditioning_gate = {
"passed": conditioning_passed,
"required_rank": int(required_rank),
"rank_passed": rank_passed,
"condition_number_max": (
None if condition_number_max is None else float(condition_number_max)
),
"condition_number_finite": condition_number_finite,
"condition_number_passed": condition_number_passed,
}
failure_reasons: list[str] = []
if not finite_passed:
failed = [name for name, flag in finite_flags.items() if not flag]
failure_reasons.append("nonfinite:" + ",".join(failed))
if not derivative_passed:
failure_reasons.append("ad_fd_tolerance")
if not tangent_passed:
failure_reasons.append("tangent_ad_fd_tolerance")
if not conditioning_passed:
if not rank_passed:
failure_reasons.append("rank_below_required")
if not condition_number_passed:
failure_reasons.append("ill_conditioned")
report = {
"kind": str(report_kind),
"passed": bool(
finite_passed and derivative_passed and tangent_passed and conditioning_passed
),
"finite_passed": finite_passed,
"derivative_tolerance_passed": derivative_passed,
"tangent_tolerance_passed": tangent_passed,
"conditioning_passed": conditioning_passed,
"failure_reasons": failure_reasons,
"fd_step": step,
"rtol": rel_tol,
"atol": abs_tol,
"relative_error_floor": floor,
"observable_names": obs_names,
"parameter_names": par_names,
"params": np.asarray(p, dtype=float).tolist(),
"observables": np.asarray(observables, dtype=float).tolist(),
"jacobian_ad": ad_np.tolist(),
"jacobian_fd": fd_np.tolist(),
"abs_error": abs_error.tolist(),
"rel_error": rel_error.tolist(),
"max_abs_ad_fd_error": float(np.nanmax(abs_error)) if abs_error.size else 0.0,
"max_rel_ad_fd_error": float(np.nanmax(rel_error)) if rel_error.size else 0.0,
"gradient_checks": gradient_checks,
"finite_flags": finite_flags,
"tangent_direction": np.asarray(direction, dtype=float).tolist(),
"tangent_direction_norm": float(
np.linalg.norm(np.asarray(direction, dtype=float))
),
"tangent_ad": tangent_ad_np.tolist(),
"tangent_fd": tangent_fd_np.tolist(),
"tangent_abs_error": tangent_abs_error.tolist(),
"tangent_rel_error": tangent_rel_error.tolist(),
"tangent_max_abs_error": tangent_max_abs,
"tangent_max_rel_error": tangent_max_rel,
"tangent_ad_norm": float(np.linalg.norm(tangent_ad_np)),
"tangent_fd_norm": float(np.linalg.norm(tangent_fd_np)),
"conditioning_gate": conditioning_gate,
"conditioning": conditioning,
}
return _json_ready(report)
def _array_parity_metrics(
candidate: Any, reference: Any, *, floor: float = 1.0e-12
) -> dict[str, object]:
cand = np.asarray(candidate, dtype=float)
ref = np.asarray(reference, dtype=float)
metrics: dict[str, object] = {
"candidate_shape": [int(v) for v in cand.shape],
"reference_shape": [int(v) for v in ref.shape],
"shape_match": bool(cand.shape == ref.shape),
}
if cand.shape != ref.shape:
return metrics
diff = cand - ref
ref_scale = max(float(np.nanmax(np.abs(ref))) if ref.size else 0.0, float(floor))
local_scale = np.maximum(np.abs(ref), float(floor))
metrics.update(
{
"max_abs": float(np.nanmax(np.abs(diff))) if diff.size else 0.0,
"rms_abs": float(np.sqrt(np.nanmean(diff * diff))) if diff.size else 0.0,
"max_rel_pointwise": float(np.nanmax(np.abs(diff) / local_scale))
if diff.size
else 0.0,
"rms_rel_pointwise": float(np.sqrt(np.nanmean((diff / local_scale) ** 2)))
if diff.size
else 0.0,
"reference_scale": ref_scale,
"normalized_max_abs": float(np.nanmax(np.abs(diff)) / ref_scale)
if diff.size
else 0.0,
"candidate_min": float(np.nanmin(cand)) if cand.size else 0.0,
"candidate_max": float(np.nanmax(cand)) if cand.size else 0.0,
"reference_min": float(np.nanmin(ref)) if ref.size else 0.0,
"reference_max": float(np.nanmax(ref)) if ref.size else 0.0,
}
)
return metrics
def _scalar_parity_metrics(
candidate: Any, reference: Any, *, floor: float = 1.0e-12
) -> dict[str, float]:
cand = float(np.asarray(candidate))
ref = float(np.asarray(reference))
diff = cand - ref
scale = max(abs(ref), float(floor))
return {
"candidate": cand,
"reference": ref,
"abs": abs(diff),
"rel": abs(diff) / scale,
}
def _interp_radial(
values: jnp.ndarray, s_grid: jnp.ndarray, s_value: float
) -> jnp.ndarray:
arr = jnp.asarray(values)
s = jnp.asarray(s_grid, dtype=arr.dtype)
s_val = jnp.asarray(float(s_value), dtype=arr.dtype)
if arr.ndim == 1:
return jnp.interp(s_val, s, arr)
if arr.ndim == 2:
return jax.vmap(lambda col: jnp.interp(s_val, s, col), in_axes=1, out_axes=0)(
arr
)
raise ValueError("radial interpolation expects a one- or two-dimensional array")
def _interp_equal_arc_profile(
theta_uniform_closed: jnp.ndarray,
theta_equal_arc_closed: jnp.ndarray,
values_closed: jnp.ndarray,
) -> jnp.ndarray:
"""Interpolate onto the equal-arc grid through the coordinate map.
The Boozer transform feeding this remap must use safe divisions in inactive
Fourier branches; otherwise nonfinite upstream cotangents contaminate the
moving equal-arc coordinate sensitivity. Sparse boundary finite-difference
gates remain the promotion criterion for VMEC-JAX transport gradients.
"""
values = jnp.asarray(values_closed)
theta_uniform = jnp.asarray(theta_uniform_closed, dtype=values.dtype)
theta_equal_arc = jnp.asarray(theta_equal_arc_closed, dtype=values.dtype)
return jnp.interp(theta_uniform, theta_equal_arc, values)
def _boozer_half_mesh_s_grid(
raw_jlist: Any | None,
*,
ns_b: int,
ns_b_full: int,
dtype: Any,
) -> jnp.ndarray:
"""Return normalized Boozer half-mesh coordinates from API surface indices.
``booz_xform_jax`` exposes ``jlist`` using VMEC/Fortran half-mesh indexing
where the first interior half mesh is ``j=2``. The corresponding normalized
radial coordinate is therefore ``(j - 1.5) / ns_b_full``, matching the
imported VMEC/EIK half mesh ``0.5 * (s_full[:-1] + s_full[1:])``.
"""
if raw_jlist is None:
return (jnp.arange(int(ns_b), dtype=dtype) + 0.5) / float(
max(int(ns_b_full), 1)
)
jlist = jnp.asarray(raw_jlist, dtype=dtype)
return (jlist - 1.5) / float(max(int(ns_b_full), 1))
def _radial_derivative_profile(values: jnp.ndarray, spacing: float) -> jnp.ndarray:
arr = jnp.asarray(values)
if arr.ndim != 1:
raise ValueError("radial derivative expects a one-dimensional profile")
if int(arr.shape[0]) < 2:
return jnp.zeros_like(arr)
ds = jnp.asarray(float(spacing), dtype=arr.dtype)
deriv = jnp.empty_like(arr)
deriv = deriv.at[0].set((arr[1] - arr[0]) / ds)
deriv = deriv.at[-1].set((arr[-1] - arr[-2]) / ds)
if int(arr.shape[0]) > 2:
deriv = deriv.at[1:-1].set((arr[2:] - arr[:-2]) / (2.0 * ds))
return deriv
def _radial_derivative_array(values: jnp.ndarray, spacing: float) -> jnp.ndarray:
arr = jnp.asarray(values)
if arr.ndim != 2:
raise ValueError(
"radial derivative expects a two-dimensional surface-mode array"
)
if int(arr.shape[0]) < 2:
return jnp.zeros_like(arr)
ds = jnp.asarray(float(spacing), dtype=arr.dtype)
deriv = jnp.empty_like(arr)
deriv = deriv.at[0].set((arr[1] - arr[0]) / ds)
deriv = deriv.at[-1].set((arr[-1] - arr[-2]) / ds)
if int(arr.shape[0]) > 2:
deriv = deriv.at[1:-1].set((arr[2:] - arr[:-2]) / (2.0 * ds))
return deriv
def _evaluate_boozer_cosine_series_on_field_line(
theta: jnp.ndarray,
*,
coeffs: jnp.ndarray,
ixm_b: jnp.ndarray,
ixn_b: jnp.ndarray,
iota: jnp.ndarray,
alpha: float,
) -> tuple[jnp.ndarray, jnp.ndarray]:
theta_arr = jnp.asarray(theta)
coeff_arr = jnp.asarray(coeffs, dtype=theta_arr.dtype)
m = jnp.asarray(ixm_b, dtype=theta_arr.dtype)
n = jnp.asarray(ixn_b, dtype=theta_arr.dtype)
iota_safe = jnp.where(
jnp.abs(iota) < 1.0e-12, jnp.sign(iota + 1.0e-30) * 1.0e-12, iota
)
zeta = (theta_arr - jnp.asarray(float(alpha), dtype=theta_arr.dtype)) / iota_safe
phase = m[:, None] * theta_arr[None, :] - n[:, None] * zeta[None, :]
dphase_dtheta = m[:, None] - n[:, None] / iota_safe
values = jnp.sum(coeff_arr[:, None] * jnp.cos(phase), axis=0)
dvalues_dtheta = jnp.sum(
-coeff_arr[:, None] * dphase_dtheta * jnp.sin(phase), axis=0
)
return values, dvalues_dtheta
def _cumulative_trapezoid(y: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray:
values = jnp.asarray(y)
grid = jnp.asarray(x, dtype=values.dtype)
if values.ndim != 1 or grid.ndim != 1 or int(values.shape[0]) != int(grid.shape[0]):
raise ValueError("cumulative trapezoid expects matching one-dimensional arrays")
if int(values.shape[0]) < 2:
return jnp.zeros_like(values)
areas = 0.5 * (values[1:] + values[:-1]) * (grid[1:] - grid[:-1])
return jnp.concatenate([jnp.zeros((1,), dtype=values.dtype), jnp.cumsum(areas)])
def _periodic_bilinear_sample_2d(
values: jnp.ndarray, theta: jnp.ndarray, zeta: jnp.ndarray
) -> jnp.ndarray:
"""Sample a uniform periodic ``(theta,zeta)`` grid with fixed bilinear weights."""
arr = jnp.asarray(values)
if arr.ndim != 2:
raise ValueError("values must be a two-dimensional theta-zeta array")
theta_arr = jnp.asarray(theta)
zeta_arr = jnp.asarray(zeta)
if theta_arr.shape != zeta_arr.shape:
raise ValueError("theta and zeta samples must have the same shape")
ntheta = int(arr.shape[0])
nzeta = int(arr.shape[1])
if ntheta < 1 or nzeta < 1:
raise ValueError("values must have non-empty theta and zeta dimensions")
twopi = jnp.asarray(2.0 * np.pi, dtype=theta_arr.dtype)
theta_index = jnp.mod(theta_arr, twopi) * (float(ntheta) / twopi)
zeta_index = jnp.mod(zeta_arr, twopi) * (float(nzeta) / twopi)
i0 = jnp.floor(theta_index).astype(jnp.int32)
j0 = jnp.floor(zeta_index).astype(jnp.int32)
wi = theta_index - i0
wj = zeta_index - j0
i1 = (i0 + 1) % int(ntheta)
j1 = (j0 + 1) % int(nzeta)
v00 = arr[i0, j0]
v10 = arr[i1, j0]
v01 = arr[i0, j1]
v11 = arr[i1, j1]
return (
(1.0 - wi) * (1.0 - wj) * v00
+ wi * (1.0 - wj) * v10
+ (1.0 - wi) * wj * v01
+ wi * wj * v11
)
[docs]
def vmec_boundary_aspect_sensitivity_report(
params: jnp.ndarray,
*,
fd_step: float = 2.0e-5,
mpol: int = 2,
ntor: int = 0,
ntheta: int = 96,
nphi: int = 1,
nfp: int = 1,
) -> dict[str, object]:
"""Validate a real ``vmec_jax`` boundary-aspect derivative when available.
The check intentionally stops at the boundary Fourier API. Full VMEC solves
are too expensive and environment-sensitive for the default package tests,
but the boundary-aspect path verifies that SPECTRAX-GK can discover a
``vmec_jax`` checkout and differentiate through its JAX-native boundary
data structures before higher-cost optimization workflows are promoted.
"""
p = jnp.asarray(params, dtype=_jax_float_dtype())
if p.ndim != 1 or int(p.shape[0]) != 2:
raise ValueError("params must be a one-dimensional length-2 vector")
info = discover_differentiable_geometry_backends()
if not info.get("vmec_jax_boundary_api_available", False):
return {
"available": False,
"backend_info": info,
"aspect": None,
"grad_ad": None,
"grad_fd": None,
"max_abs_ad_fd_error": None,
"fd_step": float(fd_step),
}
import vmec_jax as vj # type: ignore[import-untyped, import-not-found]
modes = vj.vmec_mode_table(int(mpol), int(ntor))
grid = vj.make_angle_grid(int(ntheta), int(nphi), int(nfp))
basis = vj.build_helical_basis(modes, grid)
def aspect_fn(x: jnp.ndarray) -> jnp.ndarray:
ripple, elongation = x
r0 = 1.0
minor = 0.22 * (1.0 + 0.5 * ripple)
r_cos = jnp.zeros(modes.K, dtype=p.dtype).at[0].set(r0).at[1].set(minor)
z_sin = jnp.zeros(modes.K, dtype=p.dtype).at[1].set(minor * (1.0 + elongation))
zeros = jnp.zeros_like(r_cos)
boundary = vj.BoundaryCoeffs(R_cos=r_cos, R_sin=zeros, Z_cos=zeros, Z_sin=z_sin)
return vj.boundary_aspect_ratio(boundary, basis)
grad_ad = jax.grad(aspect_fn)(p)
grad_fd = finite_difference_jacobian(
lambda x: jnp.asarray([aspect_fn(x)]), p, step=fd_step
)[0]
diff = grad_ad - grad_fd
conditioning = _sensitivity_conditioning_metadata(
jnp.asarray(grad_ad)[None, :],
jnp.asarray(grad_fd)[None, :],
p,
fd_step=float(fd_step),
observable_names=("aspect_ratio",),
param_names=("ripple", "elongation"),
)
return {
"available": True,
"backend_info": info,
"aspect": float(aspect_fn(p)),
"grad_ad": np.asarray(grad_ad).tolist(),
"grad_fd": np.asarray(grad_fd).tolist(),
"max_abs_ad_fd_error": float(np.max(np.abs(np.asarray(diff)))),
"conditioning": conditioning,
"fd_step": float(fd_step),
"mpol": int(mpol),
"ntor": int(ntor),
"ntheta": int(ntheta),
"nphi": int(nphi),
"nfp": int(nfp),
}
[docs]
def evaluate_boozer_bmag_on_field_line(
theta: jnp.ndarray,
*,
bmnc_b: jnp.ndarray,
ixm_b: jnp.ndarray,
ixn_b: jnp.ndarray,
iota: jnp.ndarray | float,
alpha: float = 0.0,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Evaluate a Boozer ``|B|`` spectrum and theta derivative on a field line.
The field-line label convention is :math:`\\alpha = \\theta - \\iota\\zeta`.
This helper is intentionally small and JAX-native so that the
``booz_xform_jax`` spectral output can be differentiated all the way into
the sampled SPECTRAX-GK geometry contract.
"""
theta_arr = jnp.asarray(theta)
modes_m = jnp.asarray(ixm_b, dtype=theta_arr.dtype)
modes_n = jnp.asarray(ixn_b, dtype=theta_arr.dtype)
coeffs = jnp.asarray(bmnc_b, dtype=theta_arr.dtype)
iota_arr = jnp.asarray(iota, dtype=theta_arr.dtype)
iota_safe = jnp.where(
jnp.abs(iota_arr) < 1.0e-12, jnp.sign(iota_arr + 1.0e-30) * 1.0e-12, iota_arr
)
zeta = (theta_arr - jnp.asarray(float(alpha), dtype=theta_arr.dtype)) / iota_safe
phase = theta_arr[:, None] * modes_m[None, :] - zeta[:, None] * modes_n[None, :]
dphase_dtheta = modes_m[None, :] - modes_n[None, :] / iota_safe
bmag = jnp.sum(coeffs[None, :] * jnp.cos(phase), axis=1)
dbmag_dtheta = jnp.sum(-coeffs[None, :] * dphase_dtheta * jnp.sin(phase), axis=1)
return bmag, dbmag_dtheta
[docs]
def vmec_jax_boozer_flux_tube_sensitivity_report( # pragma: no cover
*,
params: jnp.ndarray | None = None,
case_name: str = "circular_tokamak",
radial_index: int | None = None,
mode_index: int = 1,
surface_index: int | None = None,
fd_step: float = 1.0e-5,
mboz: int = 2,
nboz: int = 0,
ntheta: int = 32,
) -> dict[str, object]:
"""AD/FD-check ``vmec_jax`` state coefficients through the Boozer bridge.
This is the first end-to-end optional-backend gate that starts from a real
``vmec_jax`` ``VMECState`` instead of a hand-built Boozer input bundle. It
loads a small bundled VMEC example, perturbs two VMEC Fourier coefficients
``[Rcos(radial_index, mode_index), Zsin(radial_index, mode_index)]``,
converts the perturbed state to ``booz_xform_jax`` inputs, samples the
resulting Boozer ``|B|`` spectrum on a field line, and checks
SPECTRAX-GK geometry-observable derivatives against central finite
differences.
The current metric/drift closure is still intentionally smooth and local to
SPECTRAX-GK. Full production promotion requires replacing it with sampled
VMEC/Boozer metric tensors and parity-checking those arrays against the
imported VMEC/EIK path.
"""
p = jnp.asarray([1.0e-3, 1.0e-3] if params is None else params, dtype=jnp.float64)
if p.ndim != 1 or int(p.shape[0]) != 2:
raise ValueError("params must be a length-2 vector")
info = discover_differentiable_geometry_backends()
if not (
info.get("vmec_jax_available", False)
and info.get("booz_xform_jax_api_available", False)
):
return {
"available": False,
"backend_info": info,
"sensitivity": None,
"fd_step": float(fd_step),
"case_name": str(case_name),
"reason": "vmec_jax or booz_xform_jax functional API is not available",
}
try:
driver = importlib.import_module("vmec_jax.driver")
config_mod = importlib.import_module("vmec_jax.config")
static_mod = importlib.import_module("vmec_jax.static")
wout_mod = importlib.import_module("vmec_jax.wout")
booz_input_mod = importlib.import_module("vmec_jax.booz_input")
input_path, wout_path = driver.example_paths(str(case_name))
if wout_path is None:
raise RuntimeError(
f"vmec_jax example {case_name!r} has no bundled wout reference"
)
cfg, indata = config_mod.load_config(str(input_path))
static = static_mod.build_static(cfg)
wout = wout_mod.read_wout(wout_path)
state = wout_mod.state_from_wout(wout)
base_Rcos = jnp.asarray(state.Rcos)
base_Zsin = jnp.asarray(state.Zsin)
if base_Rcos.ndim != 2 or base_Zsin.ndim != 2:
raise RuntimeError(
"vmec_jax state Rcos/Zsin arrays must be two-dimensional"
)
ridx = (
int(base_Rcos.shape[0] // 2) if radial_index is None else int(radial_index)
)
midx = int(mode_index)
if not (0 <= ridx < int(base_Rcos.shape[0])):
raise ValueError("radial_index is outside the VMEC state radial grid")
if not (0 <= midx < int(base_Rcos.shape[1])):
raise ValueError("mode_index is outside the VMEC state mode table")
sidx = (
max(0, min(ridx - 1, int(base_Rcos.shape[0]) - 2))
if surface_index is None
else int(surface_index)
)
if not (0 <= sidx < int(base_Rcos.shape[0]) - 1):
raise ValueError(
"surface_index is outside the VMEC half-mesh Boozer surface grid"
)
def mapping_fn(x: jnp.ndarray) -> dict[str, Any]:
traced_state = dc_replace(
state,
Rcos=base_Rcos.at[ridx, midx].add(x[0]),
Zsin=base_Zsin.at[ridx, midx].add(x[1]),
)
inputs = booz_input_mod.booz_xform_inputs_from_state(
state=traced_state,
static=static,
indata=indata,
signgs=wout.signgs,
)
return booz_xform_flux_tube_mapping_from_inputs(
inputs,
mboz=int(mboz),
nboz=int(nboz),
ntheta=int(ntheta),
surface_index=int(sidx),
magnetic_shear=0.35,
jit=False,
)
sensitivity = geometry_sensitivity_report(
mapping_fn,
p,
fd_step=float(fd_step),
source_model="vmec_jax:state->booz_xform_jax:field-line-bmag",
)
mapping = mapping_fn(p)
booz_meta = mapping["booz_xform"]
except Exception as exc:
return {
"available": False,
"backend_info": info,
"sensitivity": None,
"fd_step": float(fd_step),
"case_name": str(case_name),
"error": f"{type(exc).__name__}: {exc}",
}
return {
"available": True,
"backend_info": info,
"case_name": str(case_name),
"input_path": str(input_path),
"wout_path": str(wout_path),
"param_names": ["delta_Rcos", "delta_Zsin"],
"params": np.asarray(p).tolist(),
"radial_index": int(ridx),
"mode_index": int(midx),
"surface_index": int(sidx),
"state_shape": [int(base_Rcos.shape[0]), int(base_Rcos.shape[1])],
"sensitivity": sensitivity,
"fd_step": float(fd_step),
"mboz": int(mboz),
"nboz": int(nboz),
"ntheta": int(ntheta),
"bmnc_b": np.asarray(booz_meta["bmnc_b"]).tolist(),
"ixm_b": np.asarray(booz_meta["ixm_b"]).tolist(),
"ixn_b": np.asarray(booz_meta["ixn_b"]).tolist(),
"iota_b": float(np.asarray(booz_meta["iota_b"])),
}
[docs]
def vmec_jax_metric_tensor_sensitivity_report( # pragma: no cover
*,
params: jnp.ndarray | None = None,
case_name: str = "circular_tokamak",
radial_index: int | None = None,
mode_index: int = 1,
surface_index: int | None = None,
fd_step: float = 1.0e-5,
rms_epsilon: float = 1.0e-24,
) -> dict[str, object]:
"""AD/FD-check real ``vmec_jax`` metric tensors from a ``VMECState``.
The Boozer bridge validates the straight-field-line ``|B|`` spectrum, but
SPECTRAX-GK's production geometry contract also needs sampled metric and
drift tensors. This gate stays upstream of any reduced closure: it loads a
real ``vmec_jax`` example state, perturbs two VMEC Fourier coefficients,
evaluates ``vmec_jax.geom.eval_geom``, and checks metric-tensor observable
derivatives against central finite differences.
This is a prerequisite for replacing the smooth metric/drift closure in
:func:`booz_xform_flux_tube_mapping_from_inputs`; it is not by itself the
final Boozer-field-line metric parity gate.
"""
p = jnp.asarray([1.0e-3, 1.0e-3] if params is None else params, dtype=jnp.float64)
if p.ndim != 1 or int(p.shape[0]) != 2:
raise ValueError("params must be a length-2 vector")
info = discover_differentiable_geometry_backends()
if not info.get("vmec_jax_available", False):
return {
"available": False,
"backend_info": info,
"sensitivity": None,
"fd_step": float(fd_step),
"case_name": str(case_name),
"reason": "vmec_jax is not available",
}
try:
driver = importlib.import_module("vmec_jax.driver")
config_mod = importlib.import_module("vmec_jax.config")
static_mod = importlib.import_module("vmec_jax.static")
wout_mod = importlib.import_module("vmec_jax.wout")
geom_mod = importlib.import_module("vmec_jax.geom")
input_path, wout_path = driver.example_paths(str(case_name))
if wout_path is None:
raise RuntimeError(
f"vmec_jax example {case_name!r} has no bundled wout reference"
)
cfg, _indata = config_mod.load_config(str(input_path))
static = static_mod.build_static(cfg)
wout = wout_mod.read_wout(wout_path)
state = wout_mod.state_from_wout(wout)
base_Rcos = jnp.asarray(state.Rcos)
base_Zsin = jnp.asarray(state.Zsin)
if base_Rcos.ndim != 2 or base_Zsin.ndim != 2:
raise RuntimeError(
"vmec_jax state Rcos/Zsin arrays must be two-dimensional"
)
ridx = (
int(base_Rcos.shape[0] // 2) if radial_index is None else int(radial_index)
)
midx = int(mode_index)
if not (0 <= ridx < int(base_Rcos.shape[0])):
raise ValueError("radial_index is outside the VMEC state radial grid")
if not (0 <= midx < int(base_Rcos.shape[1])):
raise ValueError("mode_index is outside the VMEC state mode table")
sidx = (
max(0, min(ridx - 1, int(base_Rcos.shape[0]) - 1))
if surface_index is None
else int(surface_index)
)
if not (0 <= sidx < int(base_Rcos.shape[0])):
raise ValueError("surface_index is outside the VMEC metric radial grid")
eps = jnp.asarray(float(rms_epsilon), dtype=p.dtype)
def _rms(arr: jnp.ndarray) -> jnp.ndarray:
arr = jnp.asarray(arr)
return jnp.sqrt(jnp.mean(arr * arr) + eps)
def metric_observables(x: jnp.ndarray) -> jnp.ndarray:
traced_state = dc_replace(
state,
Rcos=base_Rcos.at[ridx, midx].add(x[0]),
Zsin=base_Zsin.at[ridx, midx].add(x[1]),
)
geom = geom_mod.eval_geom(traced_state, static)
sqrtg = jnp.asarray(geom.sqrtg)[sidx]
g_ss = jnp.asarray(geom.g_ss)[sidx]
g_st = jnp.asarray(geom.g_st)[sidx]
g_sp = jnp.asarray(geom.g_sp)[sidx]
g_tt = jnp.asarray(geom.g_tt)[sidx]
g_tp = jnp.asarray(geom.g_tp)[sidx]
g_pp = jnp.asarray(geom.g_pp)[sidx]
return jnp.asarray(
[
_rms(sqrtg),
jnp.mean(g_ss),
jnp.mean(g_tt),
jnp.mean(g_pp),
_rms(g_st),
_rms(g_sp),
_rms(g_tp),
]
)
observables = metric_observables(p)
jac_ad = jax.jacfwd(metric_observables)(p)
jac_fd = finite_difference_jacobian(metric_observables, p, step=float(fd_step))
diff = jac_ad - jac_fd
max_abs = jnp.max(jnp.abs(diff))
max_rel = jnp.max(jnp.abs(diff) / (jnp.abs(jac_fd) + 1.0e-12))
geom0 = geom_mod.eval_geom(state, static)
except Exception as exc:
return {
"available": False,
"backend_info": info,
"sensitivity": None,
"fd_step": float(fd_step),
"case_name": str(case_name),
"error": f"{type(exc).__name__}: {exc}",
}
return {
"available": True,
"backend_info": info,
"case_name": str(case_name),
"input_path": str(input_path),
"wout_path": str(wout_path),
"source_model": "vmec_jax:state->metric-tensors",
"param_names": ["delta_Rcos", "delta_Zsin"],
"observable_names": list(_VMEC_METRIC_OBSERVABLE_NAMES),
"params": np.asarray(p).tolist(),
"observables": np.asarray(observables).tolist(),
"jacobian_ad": np.asarray(jac_ad).tolist(),
"jacobian_fd": np.asarray(jac_fd).tolist(),
"max_abs_ad_fd_error": float(np.asarray(max_abs)),
"max_rel_ad_fd_error": float(np.asarray(max_rel)),
"conditioning": _sensitivity_conditioning_metadata(
jac_ad,
jac_fd,
p,
fd_step=float(fd_step),
observable_names=_VMEC_METRIC_OBSERVABLE_NAMES,
param_names=("delta_Rcos", "delta_Zsin"),
relative_floor=1.0e-12,
),
"radial_index": int(ridx),
"mode_index": int(midx),
"surface_index": int(sidx),
"state_shape": [int(base_Rcos.shape[0]), int(base_Rcos.shape[1])],
"metric_grid_shape": [int(v) for v in np.asarray(geom0.sqrtg).shape],
"fd_step": float(fd_step),
"rms_epsilon": float(rms_epsilon),
}
[docs]
def vmec_jax_field_line_tensor_sensitivity_report( # pragma: no cover
*,
params: jnp.ndarray | None = None,
case_name: str = "nfp4_QH_warm_start",
radial_index: int | None = None,
mode_index: int = 1,
surface_index: int | None = None,
alpha: float = 0.0,
ntheta: int = 32,
fd_step: float = 1.0e-6,
b2_floor: float = 1.0e-24,
rms_epsilon: float = 1.0e-24,
) -> dict[str, object]:
"""AD/FD-check VMEC field-line metric and ``|B|`` tensors from ``vmec_jax``.
This optional-backend gate is deliberately upstream of the production
SPECTRAX-GK metric/drift closure. It loads a real stellarator
``vmec_jax`` example state, perturbs two VMEC Fourier coefficients,
evaluates ``vmec_jax.geom.eval_geom`` and ``vmec_jax.vmec_bcovar``, samples
raw metric/``|B|`` tensors on a fixed VMEC field-line convention, and checks
those observable derivatives against central finite differences.
The gate proves differentiability from ``VMECState`` through real VMEC
metric and magnetic-field tensors. The later production gate must still
convert those tensors into the exact SPECTRAX-GK ``gds*``/drift contract and
compare against the imported VMEC/EIK path.
"""
p = jnp.asarray([1.0e-4, 1.0e-4] if params is None else params, dtype=jnp.float64)
if p.ndim != 1 or int(p.shape[0]) != 2:
raise ValueError("params must be a length-2 vector")
ntheta_int = int(ntheta)
if ntheta_int < 4:
raise ValueError("ntheta must be >= 4")
info = discover_differentiable_geometry_backends()
if not info.get("vmec_jax_available", False):
return {
"available": False,
"backend_info": info,
"sensitivity": None,
"fd_step": float(fd_step),
"case_name": str(case_name),
"reason": "vmec_jax is not available",
}
try:
driver = importlib.import_module("vmec_jax.driver")
config_mod = importlib.import_module("vmec_jax.config")
static_mod = importlib.import_module("vmec_jax.static")
wout_mod = importlib.import_module("vmec_jax.wout")
geom_mod = importlib.import_module("vmec_jax.geom")
bcovar_mod = importlib.import_module("vmec_jax.vmec_bcovar")
field_mod = importlib.import_module("vmec_jax.field")
input_path, wout_path = driver.example_paths(str(case_name))
if wout_path is None:
raise RuntimeError(
f"vmec_jax example {case_name!r} has no bundled wout reference"
)
cfg, _indata = config_mod.load_config(str(input_path))
static = static_mod.build_static(cfg)
wout = wout_mod.read_wout(wout_path)
state = wout_mod.state_from_wout(wout)
base_Rcos = jnp.asarray(state.Rcos)
base_Zsin = jnp.asarray(state.Zsin)
if base_Rcos.ndim != 2 or base_Zsin.ndim != 2:
raise RuntimeError(
"vmec_jax state Rcos/Zsin arrays must be two-dimensional"
)
ridx = (
int(base_Rcos.shape[0] // 2) if radial_index is None else int(radial_index)
)
midx = int(mode_index)
if not (0 <= ridx < int(base_Rcos.shape[0])):
raise ValueError("radial_index is outside the VMEC state radial grid")
if not (0 <= midx < int(base_Rcos.shape[1])):
raise ValueError("mode_index is outside the VMEC state mode table")
sidx = (
max(1, min(ridx, int(base_Rcos.shape[0]) - 2))
if surface_index is None
else int(surface_index)
)
if not (0 <= sidx < int(base_Rcos.shape[0])):
raise ValueError("surface_index is outside the VMEC metric radial grid")
iota_profile = jnp.asarray(getattr(wout, "iotas"))
if iota_profile.ndim != 1 or int(iota_profile.shape[0]) <= sidx:
raise RuntimeError(
"vmec_jax wout iotas profile is missing or incompatible with the state grid"
)
iota_line = iota_profile[sidx]
iota_safe = jnp.where(
jnp.abs(iota_line) < 1.0e-12,
jnp.sign(iota_line + 1.0e-30) * 1.0e-12,
iota_line,
)
theta_line = jnp.linspace(
-jnp.pi, jnp.pi, ntheta_int, endpoint=False, dtype=p.dtype
)
theta_vmec = jnp.mod(theta_line + jnp.pi, 2.0 * jnp.pi)
zeta_line = jnp.mod(
(theta_vmec - jnp.asarray(float(alpha), dtype=p.dtype)) / iota_safe,
2.0 * jnp.pi,
)
b2_floor_arr = jnp.asarray(float(b2_floor), dtype=p.dtype)
eps = jnp.asarray(float(rms_epsilon), dtype=p.dtype)
def _rms(arr: jnp.ndarray) -> jnp.ndarray:
arr = jnp.asarray(arr)
return jnp.sqrt(jnp.mean(arr * arr) + eps)
def field_line_observables(x: jnp.ndarray) -> jnp.ndarray:
traced_state = dc_replace(
state,
Rcos=base_Rcos.at[ridx, midx].add(x[0]),
Zsin=base_Zsin.at[ridx, midx].add(x[1]),
)
geom = geom_mod.eval_geom(traced_state, static)
bcovar = bcovar_mod.vmec_bcovar_half_mesh_from_wout(
state=traced_state,
static=static,
wout=wout,
pres=getattr(wout, "pres", None),
)
b2 = field_mod.b2_from_bsup(geom, bcovar.bsupu, bcovar.bsupv)
bmag = jnp.sqrt(
jnp.maximum(
_periodic_bilinear_sample_2d(b2[sidx], theta_vmec, zeta_line),
b2_floor_arr,
)
)
sqrtg = _periodic_bilinear_sample_2d(
geom.sqrtg[sidx], theta_vmec, zeta_line
)
g_tt = _periodic_bilinear_sample_2d(geom.g_tt[sidx], theta_vmec, zeta_line)
g_tp = _periodic_bilinear_sample_2d(geom.g_tp[sidx], theta_vmec, zeta_line)
g_pp = _periodic_bilinear_sample_2d(geom.g_pp[sidx], theta_vmec, zeta_line)
g_ss = _periodic_bilinear_sample_2d(geom.g_ss[sidx], theta_vmec, zeta_line)
mean_b = jnp.mean(bmag)
ripple = jnp.std(bmag) / jnp.maximum(
jnp.abs(mean_b), jnp.asarray(1.0e-30, dtype=bmag.dtype)
)
return jnp.asarray(
[
mean_b,
ripple,
_rms(sqrtg),
jnp.mean(g_tt),
jnp.mean(g_pp),
_rms(g_tp),
jnp.mean(g_ss),
]
)
observables = field_line_observables(p)
jac_ad = jax.jacfwd(field_line_observables)(p)
jac_fd = finite_difference_jacobian(
field_line_observables, p, step=float(fd_step)
)
diff = jac_ad - jac_fd
max_abs = jnp.max(jnp.abs(diff))
max_rel = jnp.max(jnp.abs(diff) / (jnp.abs(jac_fd) + 1.0e-10))
geom0 = geom_mod.eval_geom(state, static)
except Exception as exc:
return {
"available": False,
"backend_info": info,
"sensitivity": None,
"fd_step": float(fd_step),
"case_name": str(case_name),
"error": f"{type(exc).__name__}: {exc}",
}
return {
"available": True,
"backend_info": info,
"case_name": str(case_name),
"input_path": str(input_path),
"wout_path": str(wout_path),
"source_model": "vmec_jax:state->field-line-metric-and-b",
"field_line_convention": "VMEC theta, zeta=(theta-alpha)/iota with periodic bilinear sampling",
"param_names": ["delta_Rcos", "delta_Zsin"],
"observable_names": list(_VMEC_FIELD_LINE_OBSERVABLE_NAMES),
"params": np.asarray(p).tolist(),
"observables": np.asarray(observables).tolist(),
"jacobian_ad": np.asarray(jac_ad).tolist(),
"jacobian_fd": np.asarray(jac_fd).tolist(),
"max_abs_ad_fd_error": float(np.asarray(max_abs)),
"max_rel_ad_fd_error": float(np.asarray(max_rel)),
"conditioning": _sensitivity_conditioning_metadata(
jac_ad,
jac_fd,
p,
fd_step=float(fd_step),
observable_names=_VMEC_FIELD_LINE_OBSERVABLE_NAMES,
param_names=("delta_Rcos", "delta_Zsin"),
relative_floor=1.0e-10,
),
"radial_index": int(ridx),
"mode_index": int(midx),
"surface_index": int(sidx),
"iota": float(np.asarray(iota_line)),
"alpha": float(alpha),
"ntheta": int(ntheta_int),
"state_shape": [int(base_Rcos.shape[0]), int(base_Rcos.shape[1])],
"metric_grid_shape": [int(v) for v in np.asarray(geom0.sqrtg).shape],
"fd_step": float(fd_step),
"b2_floor": float(b2_floor),
"rms_epsilon": float(rms_epsilon),
}
[docs]
def vmec_jax_flux_tube_mapping_from_state( # pragma: no cover
state: Any,
static: Any,
wout: Any,
*,
surface_index: int | None = None,
alpha: float = 0.0,
ntheta: int = 32,
b2_floor: float = 1.0e-24,
metric_floor: float = 1.0e-24,
reference_length: float | None = None,
reference_b: float | None = None,
drift_scale: float = 1.0,
) -> dict[str, Any]:
"""Build a solver-ready flux-tube mapping directly from ``vmec_jax`` tensors.
This is the VMEC-native bridge step: it evaluates ``vmec_jax.geom`` and
``vmec_jax.vmec_bcovar``, samples the covariant metric and ``|B|`` on a
fixed field line, inverts the sampled metric to construct perpendicular
flux-tube metric coefficients, and emits the
:func:`flux_tube_geometry_from_mapping` contract.
The metric and magnetic-field derivatives are differentiable with respect
to the VMEC state. The drift coefficients use a local grad-:math:`B`
projection closure so that downstream solver contracts remain populated;
the full Hegna-Nakajima/imported-VMEC drift parity gate remains a separate
production promotion step.
"""
ntheta_int = int(ntheta)
if ntheta_int < 4:
raise ValueError("ntheta must be >= 4")
geom_mod = importlib.import_module("vmec_jax.geom")
bcovar_mod = importlib.import_module("vmec_jax.vmec_bcovar")
field_mod = importlib.import_module("vmec_jax.field")
base_Rcos = jnp.asarray(state.Rcos)
if base_Rcos.ndim != 2:
raise RuntimeError("vmec_jax state Rcos array must be two-dimensional")
ns = int(base_Rcos.shape[0])
if ns < 3:
raise RuntimeError("vmec_jax state needs at least three radial surfaces")
sidx = max(1, min(ns // 2, ns - 2)) if surface_index is None else int(surface_index)
if not (0 < sidx < ns - 1):
raise ValueError("surface_index must be an interior VMEC radial index")
iota_profile = jnp.asarray(getattr(wout, "iotas"))
if iota_profile.ndim != 1 or int(iota_profile.shape[0]) <= sidx + 1:
raise RuntimeError(
"vmec_jax wout iotas profile is missing or incompatible with the state grid"
)
iota_line = iota_profile[sidx]
iota_safe = jnp.where(
jnp.abs(iota_line) < 1.0e-12, jnp.sign(iota_line + 1.0e-30) * 1.0e-12, iota_line
)
s_grid = jnp.linspace(0.0, 1.0, ns, dtype=base_Rcos.dtype)
s_val = jnp.maximum(s_grid[sidx], jnp.asarray(1.0e-12, dtype=base_Rcos.dtype))
ds = s_grid[1] - s_grid[0]
d_iota_ds = (iota_profile[sidx + 1] - iota_profile[sidx - 1]) / (2.0 * ds)
s_hat = -2.0 * s_val * d_iota_ds / iota_safe
raw_length = (
float(getattr(wout, "Aminor_p", 1.0))
if reference_length is None
else float(reference_length)
)
L_reference = (
raw_length if np.isfinite(raw_length) and abs(raw_length) > 0.0 else 1.0
)
if reference_b is None:
phi_profile = np.asarray(getattr(wout, "phi", [0.0, np.pi]), dtype=float)
edge_toroidal_flux_over_2pi = abs(float(phi_profile[-1]) / (2.0 * np.pi))
raw_b = 2.0 * edge_toroidal_flux_over_2pi / (L_reference * L_reference)
B_reference = raw_b if np.isfinite(raw_b) and abs(raw_b) > 0.0 else 1.0
else:
B_reference = float(reference_b)
B_reference = (
B_reference if np.isfinite(B_reference) and abs(B_reference) > 0.0 else 1.0
)
theta_line = jnp.linspace(
-jnp.pi, jnp.pi, ntheta_int, endpoint=False, dtype=base_Rcos.dtype
)
theta_vmec = jnp.mod(theta_line + jnp.pi, 2.0 * jnp.pi)
zeta_line = jnp.mod(
(theta_vmec - jnp.asarray(float(alpha), dtype=base_Rcos.dtype)) / iota_safe,
2.0 * jnp.pi,
)
geom = geom_mod.eval_geom(state, static)
bcovar = bcovar_mod.vmec_bcovar_half_mesh_from_wout(
state=state,
static=static,
wout=wout,
pres=getattr(wout, "pres", None),
)
b2 = field_mod.b2_from_bsup(geom, bcovar.bsupu, bcovar.bsupv)
b2_floor_arr = jnp.asarray(float(b2_floor), dtype=base_Rcos.dtype)
metric_floor_arr = jnp.asarray(float(metric_floor), dtype=base_Rcos.dtype)
bmag_all = jnp.sqrt(jnp.maximum(jnp.asarray(b2), b2_floor_arr)) / float(B_reference)
bmag_grid = bmag_all[sidx]
dtheta_grid = 2.0 * jnp.pi / float(bmag_grid.shape[0])
dzeta_grid = 2.0 * jnp.pi / float(bmag_grid.shape[1])
db_dtheta_grid = (
jnp.roll(bmag_grid, -1, axis=0) - jnp.roll(bmag_grid, 1, axis=0)
) / (2.0 * dtheta_grid)
db_dzeta_grid = (
jnp.roll(bmag_grid, -1, axis=1) - jnp.roll(bmag_grid, 1, axis=1)
) / (2.0 * dzeta_grid)
db_ds_grid = (bmag_all[sidx + 1] - bmag_all[sidx - 1]) / (2.0 * ds)
def _line(values: jnp.ndarray) -> jnp.ndarray:
return _periodic_bilinear_sample_2d(values, theta_vmec, zeta_line)
bmag = _line(bmag_grid)
sqrtg = _line(jnp.asarray(geom.sqrtg)[sidx])
g_ss = _line(jnp.asarray(geom.g_ss)[sidx])
g_st = _line(jnp.asarray(geom.g_st)[sidx])
g_sp = _line(jnp.asarray(geom.g_sp)[sidx])
g_tt = _line(jnp.asarray(geom.g_tt)[sidx])
g_tp = _line(jnp.asarray(geom.g_tp)[sidx])
g_pp = _line(jnp.asarray(geom.g_pp)[sidx])
db_ds = _line(db_ds_grid)
db_dt = _line(db_dtheta_grid)
db_dp = _line(db_dzeta_grid)
cov_metric = jnp.stack(
(
jnp.stack((g_ss, g_st, g_sp), axis=-1),
jnp.stack((g_st, g_tt, g_tp), axis=-1),
jnp.stack((g_sp, g_tp, g_pp), axis=-1),
),
axis=-2,
)
contra = jnp.linalg.inv(cov_metric)
alpha_coeff = (
jnp.asarray([0.0, 1.0, -1.0], dtype=base_Rcos.dtype).at[2].set(-iota_safe)
)
grad_alpha_metric = jnp.einsum("i,zij,j->z", alpha_coeff, contra, alpha_coeff)
grad_s_metric = contra[:, 0, 0]
grad_s_dot_alpha = contra[:, 0, 1] - iota_safe * contra[:, 0, 2]
grad_alpha_metric = jnp.maximum(grad_alpha_metric, metric_floor_arr)
grad_s_metric = jnp.maximum(grad_s_metric, metric_floor_arr)
sqrt_s = jnp.sqrt(s_val)
L = jnp.asarray(float(L_reference), dtype=base_Rcos.dtype)
Bref = jnp.asarray(float(B_reference), dtype=base_Rcos.dtype)
shat_safe = jnp.where(
jnp.abs(s_hat) < 1.0e-12, jnp.asarray(0.0, dtype=base_Rcos.dtype), s_hat
)
gds2 = L * L * s_val * grad_alpha_metric
gds21 = (
shat_safe
* grad_s_dot_alpha
/ jnp.maximum(Bref, jnp.asarray(1.0e-30, dtype=base_Rcos.dtype))
)
gds22 = (
shat_safe
* shat_safe
* grad_s_metric
/ jnp.maximum(
L * L * Bref * Bref * s_val, jnp.asarray(1.0e-30, dtype=base_Rcos.dtype)
)
)
grho = jnp.sqrt(grad_s_metric) / jnp.maximum(
L * Bref * sqrt_s, jnp.asarray(1.0e-30, dtype=base_Rcos.dtype)
)
gradpar_profile = jnp.abs(
L
* iota_safe
/ jnp.maximum(
jnp.abs(bmag * sqrtg), jnp.asarray(1.0e-30, dtype=base_Rcos.dtype)
)
)
gradpar_value = jnp.mean(gradpar_profile)
gradpar = gradpar_value * jnp.ones_like(theta_line)
dlogb_dtheta = (
jnp.roll(jnp.log(jnp.maximum(bmag, 1.0e-30)), -1)
- jnp.roll(jnp.log(jnp.maximum(bmag, 1.0e-30)), 1)
) / (2.0 * (2.0 * jnp.pi / float(ntheta_int)))
bgrad = gradpar_value * dlogb_dtheta
grad_b_cov = jnp.stack((db_ds, db_dt, db_dp), axis=-1)
grad_b_dot_alpha = jnp.einsum("zi,zij,j->z", grad_b_cov, contra, alpha_coeff)
grad_b_dot_s = jnp.einsum("zi,zi->z", grad_b_cov, contra[:, :, 0])
bmag_sq = jnp.maximum(bmag * bmag, jnp.asarray(1.0e-30, dtype=base_Rcos.dtype))
gbdrift = -2.0 * float(drift_scale) * L * L * sqrt_s * grad_b_dot_alpha / bmag_sq
gbdrift0 = (
-2.0
* float(drift_scale)
* shat_safe
* grad_b_dot_s
/ jnp.maximum(
Bref * bmag_sq * sqrt_s, jnp.asarray(1.0e-30, dtype=base_Rcos.dtype)
)
)
return {
"theta": theta_line,
"gradpar": gradpar,
"bmag": bmag,
"bgrad": bgrad,
"gds2": gds2,
"gds21": gds21,
"gds22": gds22,
"cvdrift": gbdrift,
"gbdrift": gbdrift,
"cvdrift0": gbdrift0,
"gbdrift0": gbdrift0,
"jacobian": jnp.abs(sqrtg)
/ jnp.maximum(
jnp.mean(jnp.abs(sqrtg)), jnp.asarray(1.0e-30, dtype=base_Rcos.dtype)
),
"grho": grho,
"q": 1.0
/ jnp.maximum(jnp.abs(iota_safe), jnp.asarray(1.0e-30, dtype=base_Rcos.dtype)),
"s_hat": shat_safe,
"epsilon": jnp.std(bmag)
/ jnp.maximum(jnp.mean(bmag), jnp.asarray(1.0e-30, dtype=base_Rcos.dtype)),
"R0": float(L_reference),
"B0": float(B_reference),
"alpha": float(alpha),
"drift_scale": float(drift_scale),
"nfp": int(getattr(wout, "nfp", 1)),
"vmec_jax": {
"surface_index": int(sidx),
"iota": iota_line,
"reference_length": float(L_reference),
"reference_b": float(B_reference),
"field_line_convention": "VMEC theta, zeta=(theta-alpha)/iota with periodic bilinear sampling",
},
}
@lru_cache(maxsize=32)
def _cached_booz_xform_constants(
*,
nfp: int,
mpol: int,
ntor: int,
ntheta: int,
nzeta: int,
mboz: int,
nboz: int,
asym: bool,
) -> tuple[Any, Any]:
"""Prepare Boozer constants outside traced VMEC-JAX residual callbacks."""
modes_mod = importlib.import_module("vmec_jax.modes")
bx = importlib.import_module("booz_xform_jax.jax_api")
main_modes = modes_mod.vmec_mode_table(int(mpol), int(ntor))
nyq_modes = modes_mod.nyquist_mode_table_from_grid(
mpol=int(mpol),
ntor=int(ntor),
ntheta=int(ntheta),
nzeta=int(nzeta),
)
return bx.prepare_booz_xform_constants(
nfp=int(nfp),
mboz=int(mboz),
nboz=int(nboz),
asym=bool(asym),
xm=np.asarray(main_modes.m, dtype=np.int32),
xn=np.asarray(main_modes.n * int(nfp), dtype=np.int32),
xm_nyq=np.asarray(nyq_modes.m, dtype=np.int32),
xn_nyq=np.asarray(nyq_modes.n * int(nfp), dtype=np.int32),
)
def prewarm_vmec_boozer_equal_arc_cache(
static: Any,
wout: Any,
*,
mboz: int = _VMEC_BOOZER_PARITY_MIN_MODE_COUNT,
nboz: int = _VMEC_BOOZER_PARITY_MIN_MODE_COUNT,
asym: bool | None = None,
) -> None: # pragma: no cover - exercised by optional VMEC-JAX optimizer smoke tests.
"""Precompute Boozer constants before VMEC-JAX jits residual callbacks."""
cfg = static.cfg
nfp_raw = getattr(wout, "nfp", None)
if nfp_raw is None:
nfp_raw = getattr(cfg, "nfp", 1)
nfp_int = 1 if nfp_raw is None else int(nfp_raw)
_cached_booz_xform_constants(
nfp=nfp_int,
mpol=int(cfg.mpol),
ntor=int(cfg.ntor),
ntheta=int(cfg.ntheta),
nzeta=int(cfg.nzeta),
mboz=int(mboz),
nboz=int(nboz),
asym=bool(getattr(cfg, "lasym", False) if asym is None else asym),
)
[docs]
def vmec_jax_boozer_equal_arc_core_profiles_from_state( # pragma: no cover
state: Any,
static: Any,
indata: Any,
wout: Any,
*,
surface_index: int | None = None,
torflux: float | None = None,
alpha: float = 0.0,
ntheta: int = 32,
mboz: int = _VMEC_BOOZER_PARITY_MIN_MODE_COUNT,
nboz: int = _VMEC_BOOZER_PARITY_MIN_MODE_COUNT,
jit: bool = False,
surface_stencil_width: int | None = None,
reference_length: float | None = None,
reference_b: float | None = None,
) -> dict[str, Any]:
"""Return Boozer equal-arc core profiles from a real ``vmec_jax`` state.
This bridge follows the same high-level convention as the imported VMEC/EIK
runtime path for scalar/core field-line quantities and the zero-beta Boozer
metric/drift terms that can be reconstructed directly from
``booz_xform_jax`` output: Boozer ``|B|``, equal-arc constant ``gradpar``,
``q``, magnetic shear, solver Jacobian normalization, ``gds*``/``grho``,
and loaded-convention ``cvdrift``/``gbdrift`` coefficients. General
finite-beta pressure corrections and broader-equilibrium drift gates remain
separate promotion steps.
"""
ntheta_int = int(ntheta)
if ntheta_int < 4:
raise ValueError("ntheta must be >= 4")
mboz_int = int(mboz)
nboz_int = int(nboz)
if (
mboz_int < _VMEC_BOOZER_PARITY_MIN_MODE_COUNT
or nboz_int < _VMEC_BOOZER_PARITY_MIN_MODE_COUNT
):
raise ValueError(
"mboz and nboz must both be >= "
f"{_VMEC_BOOZER_PARITY_MIN_MODE_COUNT} for VMEC/Boozer parity gates"
)
info = discover_differentiable_geometry_backends()
if not (
info.get("vmec_jax_available", False)
and info.get("booz_xform_jax_api_available", False)
):
raise RuntimeError("vmec_jax and booz_xform_jax functional APIs are required")
booz_input_mod = importlib.import_module("vmec_jax.booz_input")
bx = importlib.import_module("booz_xform_jax.jax_api")
base_Rcos = jnp.asarray(state.Rcos)
if base_Rcos.ndim != 2:
raise RuntimeError("vmec_jax state Rcos array must be two-dimensional")
ns_full = int(base_Rcos.shape[0])
if ns_full < 3:
raise RuntimeError("vmec_jax state needs at least three radial surfaces")
sidx = (
max(1, min(ns_full // 2, ns_full - 2))
if surface_index is None
else int(surface_index)
)
if not (0 < sidx < ns_full - 1):
raise ValueError("surface_index must be an interior VMEC radial index")
s_value = (
float(sidx) / float(max(ns_full - 1, 1)) if torflux is None else float(torflux)
)
if not (0.0 < s_value < 1.0):
raise ValueError("torflux must lie inside (0, 1)")
if surface_stencil_width is not None and int(surface_stencil_width) < 3:
raise ValueError("surface_stencil_width must be >= 3 when provided")
raw_length = (
float(getattr(wout, "Aminor_p", 1.0))
if reference_length is None
else float(reference_length)
)
L_reference = (
raw_length if np.isfinite(raw_length) and abs(raw_length) > 0.0 else 1.0
)
if reference_b is None:
phi_profile = np.asarray(getattr(wout, "phi", [0.0, np.pi]), dtype=float)
edge_toroidal_flux_over_2pi = -float(phi_profile[-1]) / (2.0 * np.pi)
raw_b = 2.0 * abs(edge_toroidal_flux_over_2pi) / (L_reference * L_reference)
B_reference = raw_b if np.isfinite(raw_b) and abs(raw_b) > 0.0 else 1.0
else:
edge_toroidal_flux_over_2pi = -float(
np.asarray(getattr(wout, "phi", [0.0, np.pi]))[-1]
) / (2.0 * np.pi)
B_reference = float(reference_b)
B_reference = (
B_reference if np.isfinite(B_reference) and abs(B_reference) > 0.0 else 1.0
)
inputs = booz_input_mod.booz_xform_inputs_from_state(
state=state,
static=static,
indata=indata,
signgs=getattr(wout, "signgs", 1),
)
asym = bool(getattr(inputs, "bmns", None) is not None)
cfg = getattr(static, "cfg", SimpleNamespace())
nfp_raw = getattr(wout, "nfp", None)
if nfp_raw is None:
nfp_raw = getattr(cfg, "nfp", 1)
nfp_int = 1 if nfp_raw is None else int(nfp_raw)
try:
constants, grids = _cached_booz_xform_constants(
nfp=nfp_int,
mpol=int(getattr(cfg, "mpol", max(2, base_Rcos.shape[1]))),
ntor=int(getattr(cfg, "ntor", max(1, base_Rcos.shape[1] - 1))),
ntheta=int(getattr(cfg, "ntheta", max(16, ntheta_int))),
nzeta=int(getattr(cfg, "nzeta", max(16, 2 * ntheta_int))),
mboz=mboz_int,
nboz=nboz_int,
asym=asym,
)
except (AttributeError, ModuleNotFoundError):
constants, grids = bx.prepare_booz_xform_constants_from_inputs(
inputs=inputs,
mboz=mboz_int,
nboz=nboz_int,
asym=asym,
)
surface_indices = None
if surface_stencil_width is not None:
ns_b_est = max(1, ns_full - 1)
width = min(int(surface_stencil_width), ns_b_est)
center = int(round(s_value * float(ns_b_est) - 0.5))
half_width = width // 2
start = max(0, min(center - half_width, ns_b_est - width))
surface_indices = jnp.arange(start, start + width, dtype=jnp.int32)
out = bx.booz_xform_from_inputs(
inputs=inputs,
constants=constants,
grids=grids,
surface_indices=surface_indices,
jit=bool(jit),
)
bmnc_b_all = jnp.asarray(out["bmnc_b"], dtype=base_Rcos.dtype)
if bmnc_b_all.ndim != 2:
raise RuntimeError(
"booz_xform_jax bmnc_b output must have shape (surface, mode)"
)
ns_b = int(bmnc_b_all.shape[0])
if ns_b < 2:
raise RuntimeError("booz_xform_jax output needs at least two radial surfaces")
ns_b_full = max(int(ns_full) - 1, int(ns_b))
s_half = _boozer_half_mesh_s_grid(
out.get("jlist"),
ns_b=ns_b,
ns_b_full=ns_b_full,
dtype=base_Rcos.dtype,
)
radial_spacing = 1.0 / float(max(ns_b_full, 1))
bmnc_b = _interp_radial(bmnc_b_all, s_half, s_value)
rmnc_b = _interp_radial(
jnp.asarray(out["rmnc_b"], dtype=base_Rcos.dtype), s_half, s_value
)
zmns_b = _interp_radial(
jnp.asarray(out["zmns_b"], dtype=base_Rcos.dtype), s_half, s_value
)
numns_b = -_interp_radial(
jnp.asarray(out["pmns_b"], dtype=base_Rcos.dtype), s_half, s_value
)
d_bmnc_b_d_s = _interp_radial(
_radial_derivative_array(
jnp.asarray(out["bmnc_b"], dtype=base_Rcos.dtype), radial_spacing
),
s_half,
s_value,
)
d_rmnc_b_d_s = _interp_radial(
_radial_derivative_array(
jnp.asarray(out["rmnc_b"], dtype=base_Rcos.dtype), radial_spacing
),
s_half,
s_value,
)
d_zmns_b_d_s = _interp_radial(
_radial_derivative_array(
jnp.asarray(out["zmns_b"], dtype=base_Rcos.dtype), radial_spacing
),
s_half,
s_value,
)
d_numns_b_d_s = -_interp_radial(
_radial_derivative_array(
jnp.asarray(out["pmns_b"], dtype=base_Rcos.dtype), radial_spacing
),
s_half,
s_value,
)
iota_profile = jnp.asarray(out["iota_b"], dtype=base_Rcos.dtype)
iota = _interp_radial(iota_profile, s_half, s_value)
d_iota_ds = _interp_radial(
_radial_derivative_profile(iota_profile, radial_spacing), s_half, s_value
)
iota_safe = jnp.where(
jnp.abs(iota) < 1.0e-12, jnp.sign(iota + 1.0e-30) * 1.0e-12, iota
)
s_hat = -2.0 * jnp.asarray(s_value, dtype=base_Rcos.dtype) * d_iota_ds / iota_safe
boozer_i = _interp_radial(
jnp.asarray(out["buco_b"], dtype=base_Rcos.dtype), s_half, s_value
)
boozer_g = _interp_radial(
jnp.asarray(out["bvco_b"], dtype=base_Rcos.dtype), s_half, s_value
)
theta_closed = jnp.linspace(-jnp.pi, jnp.pi, ntheta_int + 1, dtype=base_Rcos.dtype)
mod_b, _dmod_b_dtheta = _evaluate_boozer_cosine_series_on_field_line(
theta_closed,
coeffs=bmnc_b,
ixm_b=jnp.asarray(out["ixm_b"]),
ixn_b=jnp.asarray(out["ixn_b"]),
iota=iota_safe,
alpha=float(alpha),
)
mod_b_safe = jnp.maximum(
jnp.abs(mod_b), jnp.asarray(1.0e-30, dtype=base_Rcos.dtype)
)
sqrt_g_booz = (boozer_g + iota_safe * boozer_i) / (mod_b_safe * mod_b_safe)
gradpar_raw = jnp.abs(
jnp.asarray(float(L_reference), dtype=base_Rcos.dtype)
* iota_safe
/ jnp.maximum(
jnp.abs(mod_b_safe * sqrt_g_booz),
jnp.asarray(1.0e-30, dtype=base_Rcos.dtype),
)
)
inv_gradpar_int = _cumulative_trapezoid(1.0 / gradpar_raw, theta_closed)
gradpar_eqarc = (
2.0
* jnp.pi
/ jnp.maximum(inv_gradpar_int[-1], jnp.asarray(1.0e-30, dtype=base_Rcos.dtype))
)
theta_eqarc = gradpar_eqarc * inv_gradpar_int - jnp.pi
theta_uniform_closed = jnp.linspace(
-jnp.pi, jnp.pi, ntheta_int + 1, dtype=base_Rcos.dtype
)
bmag_closed = jnp.asarray(
_interp_equal_arc_profile(
theta_uniform_closed,
theta_eqarc,
mod_b_safe / float(B_reference),
)
)
theta = theta_uniform_closed[:-1]
bmag = bmag_closed[:-1]
gradpar = gradpar_eqarc * jnp.ones_like(theta)
dtheta = 2.0 * jnp.pi / float(ntheta_int)
bmag_safe = jnp.maximum(jnp.abs(bmag), jnp.asarray(1.0e-30, dtype=base_Rcos.dtype))
wave_number = 2.0 * jnp.pi * jnp.fft.fftfreq(ntheta_int, d=float(dtheta))
dbmag_dtheta = jnp.real(jnp.fft.ifft(1j * wave_number * jnp.fft.fft(bmag)))
bgrad = gradpar_eqarc * dbmag_dtheta / bmag_safe
dpsidrho = (
2.0
* jnp.sqrt(jnp.asarray(s_value, dtype=base_Rcos.dtype))
* jnp.asarray(
edge_toroidal_flux_over_2pi,
dtype=base_Rcos.dtype,
)
)
drhodpsi = 1.0 / jnp.maximum(
jnp.abs(dpsidrho), jnp.asarray(1.0e-30, dtype=base_Rcos.dtype)
)
jacobian = 1.0 / jnp.maximum(
jnp.abs(drhodpsi * gradpar_eqarc * bmag_safe),
jnp.asarray(1.0e-30, dtype=base_Rcos.dtype),
)
m = jnp.asarray(out["ixm_b"], dtype=base_Rcos.dtype)
n = jnp.asarray(out["ixn_b"], dtype=base_Rcos.dtype)
phi_b = (
theta_closed - jnp.asarray(float(alpha), dtype=base_Rcos.dtype)
) / iota_safe
phase = m[:, None] * theta_closed[None, :] - n[:, None] * phi_b[None, :]
cos_phase = jnp.cos(phase)
sin_phase = jnp.sin(phase)
m_cos = m[:, None] * cos_phase
m_sin = m[:, None] * sin_phase
n_cos = n[:, None] * cos_phase
n_sin = n[:, None] * sin_phase
r_b = jnp.sum(rmnc_b[:, None] * cos_phase, axis=0)
d_mod_b_d_s = jnp.sum(d_bmnc_b_d_s[:, None] * cos_phase, axis=0)
d_mod_b_d_theta = -jnp.sum(bmnc_b[:, None] * m_sin, axis=0)
d_mod_b_d_phi = jnp.sum(bmnc_b[:, None] * n_sin, axis=0)
d_r_b_d_s = jnp.sum(d_rmnc_b_d_s[:, None] * cos_phase, axis=0)
d_r_b_d_theta = -jnp.sum(rmnc_b[:, None] * m_sin, axis=0)
d_r_b_d_phi = jnp.sum(rmnc_b[:, None] * n_sin, axis=0)
d_z_b_d_s = jnp.sum(d_zmns_b_d_s[:, None] * sin_phase, axis=0)
d_z_b_d_theta = jnp.sum(zmns_b[:, None] * m_cos, axis=0)
d_z_b_d_phi = -jnp.sum(zmns_b[:, None] * n_cos, axis=0)
nu_b = jnp.sum(numns_b[:, None] * sin_phase, axis=0)
d_nu_b_d_s = jnp.sum(d_numns_b_d_s[:, None] * sin_phase, axis=0)
d_nu_b_d_theta = jnp.sum(numns_b[:, None] * m_cos, axis=0)
d_nu_b_d_phi = -jnp.sum(numns_b[:, None] * n_cos, axis=0)
phi_cyl = phi_b - nu_b
sin_phi = jnp.sin(phi_cyl)
cos_phi = jnp.cos(phi_cyl)
d_x_d_theta = d_r_b_d_theta * cos_phi - r_b * sin_phi * (-d_nu_b_d_theta)
d_x_d_phi = d_r_b_d_phi * cos_phi - r_b * sin_phi * (1.0 - d_nu_b_d_phi)
d_x_d_s = d_r_b_d_s * cos_phi - r_b * sin_phi * (-d_nu_b_d_s)
d_y_d_theta = d_r_b_d_theta * sin_phi + r_b * cos_phi * (-d_nu_b_d_theta)
d_y_d_phi = d_r_b_d_phi * sin_phi + r_b * cos_phi * (1.0 - d_nu_b_d_phi)
d_y_d_s = d_r_b_d_s * sin_phi + r_b * cos_phi * (-d_nu_b_d_s)
grad_psi_x = (d_y_d_theta * d_z_b_d_phi - d_z_b_d_theta * d_y_d_phi) / sqrt_g_booz
grad_psi_y = (d_z_b_d_theta * d_x_d_phi - d_x_d_theta * d_z_b_d_phi) / sqrt_g_booz
grad_psi_z = (d_x_d_theta * d_y_d_phi - d_y_d_theta * d_x_d_phi) / sqrt_g_booz
g_sup_psi_psi = grad_psi_x**2 + grad_psi_y**2 + grad_psi_z**2
g_sup_psi_psi_safe = jnp.maximum(
g_sup_psi_psi, jnp.asarray(1.0e-30, dtype=base_Rcos.dtype)
)
etf = jnp.asarray(edge_toroidal_flux_over_2pi, dtype=base_Rcos.dtype)
grad_theta_x = (d_y_d_phi * d_z_b_d_s - d_z_b_d_phi * d_y_d_s) / (sqrt_g_booz * etf)
grad_theta_y = (d_z_b_d_phi * d_x_d_s - d_x_d_phi * d_z_b_d_s) / (sqrt_g_booz * etf)
grad_theta_z = (d_x_d_phi * d_y_d_s - d_y_d_phi * d_x_d_s) / (sqrt_g_booz * etf)
grad_phi_x = (d_y_d_s * d_z_b_d_theta - d_z_b_d_s * d_y_d_theta) / (
sqrt_g_booz * etf
)
grad_phi_y = (d_z_b_d_s * d_x_d_theta - d_x_d_s * d_z_b_d_theta) / (
sqrt_g_booz * etf
)
grad_phi_z = (d_x_d_s * d_y_d_theta - d_y_d_s * d_x_d_theta) / (sqrt_g_booz * etf)
zeta_center = -jnp.asarray(float(alpha), dtype=base_Rcos.dtype) / iota_safe
shear_phase = phi_b - zeta_center
grad_alpha_x = (
-shear_phase * d_iota_ds * grad_psi_x / etf
+ grad_theta_x
- iota_safe * grad_phi_x
)
grad_alpha_y = (
-shear_phase * d_iota_ds * grad_psi_y / etf
+ grad_theta_y
- iota_safe * grad_phi_y
)
grad_alpha_z = (
-shear_phase * d_iota_ds * grad_psi_z / etf
+ grad_theta_z
- iota_safe * grad_phi_z
)
grad_alpha_dot_grad_psi = (
grad_alpha_x * grad_psi_x
+ grad_alpha_y * grad_psi_y
+ grad_alpha_z * grad_psi_z
)
local_shear_l1 = grad_alpha_dot_grad_psi / g_sup_psi_psi_safe
s_arr = jnp.asarray(s_value, dtype=base_Rcos.dtype)
L = jnp.asarray(float(L_reference), dtype=base_Rcos.dtype)
Bref = jnp.asarray(float(B_reference), dtype=base_Rcos.dtype)
shat_metric = s_hat
metric_bmag_sq = mod_b_safe * mod_b_safe
gds2_raw = (
(metric_bmag_sq / g_sup_psi_psi_safe + g_sup_psi_psi_safe * local_shear_l1**2)
* L
* L
* s_arr
)
gds21_raw = g_sup_psi_psi_safe * local_shear_l1 * shat_metric / Bref
gds22_raw = (
g_sup_psi_psi_safe * shat_metric * shat_metric / (L * L * Bref * Bref * s_arr)
)
grho_raw = jnp.sqrt(g_sup_psi_psi_safe / (L * L * Bref * Bref * s_arr))
boozer_current_sum = boozer_g + iota_safe * boozer_i
d_sqrt_g_booz_d_theta = (
-2.0 * boozer_current_sum * d_mod_b_d_theta / (mod_b_safe**3)
)
d_sqrt_g_booz_d_phi = -2.0 * boozer_current_sum * d_mod_b_d_phi / (mod_b_safe**3)
curvature_numerator = (
boozer_g * d_sqrt_g_booz_d_theta - boozer_i * d_sqrt_g_booz_d_phi
)
curvature_denom = 2.0 * sqrt_g_booz * boozer_current_sum
eps = jnp.asarray(1.0e-30, dtype=base_Rcos.dtype)
curvature_denom_safe = jnp.where(
jnp.abs(curvature_denom) < eps,
jnp.sign(curvature_denom + eps) * eps,
curvature_denom,
)
etf_safe = jnp.where(jnp.abs(etf) < eps, jnp.sign(etf + eps) * eps, etf)
kappa_g = curvature_numerator / curvature_denom_safe
local_shear_l0 = -(local_shear_l1 + d_iota_ds / etf_safe * shear_phase)
kappa_n = d_mod_b_d_s / (mod_b_safe * etf_safe) + local_shear_l0 * kappa_g
b_cross_kappa_dot_grad_alpha = (kappa_n + kappa_g * local_shear_l1) * metric_bmag_sq
b_cross_kappa_dot_grad_psi = kappa_g * metric_bmag_sq
toroidal_flux_sign = jnp.sign(etf)
sqrt_s = jnp.sqrt(s_arr)
drift_cvdrift0_raw = (
-b_cross_kappa_dot_grad_psi
* 2.0
* shat_metric
/ jnp.maximum(metric_bmag_sq * sqrt_s, eps)
* toroidal_flux_sign
)
drift_cvdrift_raw = (
-2.0
* Bref
* L
* L
* sqrt_s
* b_cross_kappa_dot_grad_alpha
/ metric_bmag_sq
* toroidal_flux_sign
)
# Root-level VMEC/EIK drift coefficients are stored at the pre-loader (2x)
# level; SPECTRAX-GK compares against the loaded solver convention.
drift_loader_factor = jnp.asarray(0.5, dtype=base_Rcos.dtype)
gds2 = _interp_equal_arc_profile(theta_uniform_closed, theta_eqarc, gds2_raw)[:-1]
gds21 = _interp_equal_arc_profile(theta_uniform_closed, theta_eqarc, gds21_raw)[:-1]
gds22 = _interp_equal_arc_profile(theta_uniform_closed, theta_eqarc, gds22_raw)[:-1]
grho = _interp_equal_arc_profile(theta_uniform_closed, theta_eqarc, grho_raw)[:-1]
cvdrift = (
drift_loader_factor
* _interp_equal_arc_profile(theta_uniform_closed, theta_eqarc, drift_cvdrift_raw)[:-1]
)
gbdrift = cvdrift
cvdrift0 = (
drift_loader_factor
* _interp_equal_arc_profile(theta_uniform_closed, theta_eqarc, drift_cvdrift0_raw)[:-1]
)
gbdrift0 = cvdrift0
return {
"theta": theta,
"theta_equal_arc_closed": theta_eqarc,
"theta_uniform_closed": theta_uniform_closed,
"gradpar": gradpar,
"bmag": bmag,
"bgrad": bgrad,
"jacobian": jacobian,
"gds2": gds2,
"gds21": gds21,
"gds22": gds22,
"cvdrift": cvdrift,
"gbdrift": gbdrift,
"cvdrift0": cvdrift0,
"gbdrift0": gbdrift0,
"grho": grho,
"q": 1.0
/ jnp.maximum(jnp.abs(iota_safe), jnp.asarray(1.0e-30, dtype=base_Rcos.dtype)),
"s_hat": s_hat,
"iota": iota,
"torflux": float(s_value),
"surface_index": int(sidx),
"reference_length": float(L_reference),
"reference_b": float(B_reference),
"mboz": mboz_int,
"nboz": nboz_int,
"surface_stencil_width": None
if surface_stencil_width is None
else int(surface_stencil_width),
"boozer_surface_indices": None
if surface_indices is None
else [int(x) for x in np.asarray(surface_indices)],
"field_line_convention": "Boozer theta, alpha=theta-iota*zeta, equal-arc remap",
"scope": (
"Boozer equal-arc bmag/gradpar/Jacobian plus zero-beta metric/drift parity; "
"finite-beta pressure corrections and broad-equilibrium drift gates remain open"
),
}
[docs]
def flux_tube_geometry_from_vmec_boozer_state( # pragma: no cover
state: Any,
static: Any,
indata: Any,
wout: Any,
*,
surface_index: int | None = None,
torflux: float | None = None,
alpha: float = 0.0,
ntheta: int = 32,
mboz: int = _VMEC_BOOZER_PARITY_MIN_MODE_COUNT,
nboz: int = _VMEC_BOOZER_PARITY_MIN_MODE_COUNT,
jit: bool = False,
surface_stencil_width: int | None = None,
reference_length: float | None = None,
reference_b: float | None = None,
source_model: str = "mode21_vmec_boozer_state",
validate_finite: bool = True,
) -> FluxTubeGeometryData:
"""Build solver-ready geometry directly from a solved ``vmec_jax`` state.
This is the production-facing in-memory bridge for differentiable
optimization workflows. It keeps the path inside JAX-compatible objects:
``VMECState -> BoozXformInputs -> booz_xform_jax -> FluxTubeGeometryData``.
Runtime VMEC file generation can still use the NetCDF/EIK route, but
differentiable stellarator optimization should call this function or a
higher-level objective wrapper around it so gradients never pass through
filesystem artifacts.
"""
mapping = vmec_jax_boozer_equal_arc_core_profiles_from_state(
state,
static,
indata,
wout,
surface_index=surface_index,
torflux=torflux,
alpha=alpha,
ntheta=ntheta,
mboz=mboz,
nboz=nboz,
jit=jit,
surface_stencil_width=surface_stencil_width,
reference_length=reference_length,
reference_b=reference_b,
)
return flux_tube_geometry_from_mapping(
mapping,
source_model=source_model,
validate_finite=validate_finite,
)
[docs]
def vmec_jax_flux_tube_sensitivity_report( # pragma: no cover
*,
params: jnp.ndarray | None = None,
case_name: str = "nfp4_QH_warm_start",
radial_index: int | None = None,
mode_index: int = 1,
surface_index: int | None = None,
alpha: float = 0.0,
ntheta: int = 24,
fd_step: float = 2.0e-6,
) -> dict[str, object]:
"""AD/FD-check VMEC-state derivatives through a solver-ready flux tube.
Unlike the Boozer-only bridge, this report starts from a real
``vmec_jax`` state, evaluates VMEC metric and magnetic-field tensors, emits
the SPECTRAX-GK ``FluxTubeGeometryData`` mapping, and differentiates
geometry observables through that path.
"""
p = jnp.asarray([1.0e-4, 1.0e-4] if params is None else params, dtype=jnp.float64)
if p.ndim != 1 or int(p.shape[0]) != 2:
raise ValueError("params must be a length-2 vector")
info = discover_differentiable_geometry_backends()
if not info.get("vmec_jax_available", False):
return {
"available": False,
"backend_info": info,
"sensitivity": None,
"fd_step": float(fd_step),
"case_name": str(case_name),
"reason": "vmec_jax is not available",
}
try:
driver = importlib.import_module("vmec_jax.driver")
config_mod = importlib.import_module("vmec_jax.config")
static_mod = importlib.import_module("vmec_jax.static")
wout_mod = importlib.import_module("vmec_jax.wout")
input_path, wout_path = driver.example_paths(str(case_name))
if wout_path is None:
raise RuntimeError(
f"vmec_jax example {case_name!r} has no bundled wout reference"
)
cfg, _indata = config_mod.load_config(str(input_path))
static = static_mod.build_static(cfg)
wout = wout_mod.read_wout(wout_path)
state = wout_mod.state_from_wout(wout)
base_Rcos = jnp.asarray(state.Rcos)
base_Zsin = jnp.asarray(state.Zsin)
if base_Rcos.ndim != 2 or base_Zsin.ndim != 2:
raise RuntimeError(
"vmec_jax state Rcos/Zsin arrays must be two-dimensional"
)
ridx = (
int(base_Rcos.shape[0] // 2) if radial_index is None else int(radial_index)
)
midx = int(mode_index)
if not (0 <= ridx < int(base_Rcos.shape[0])):
raise ValueError("radial_index is outside the VMEC state radial grid")
if not (0 <= midx < int(base_Rcos.shape[1])):
raise ValueError("mode_index is outside the VMEC state mode table")
def mapping_fn(x: jnp.ndarray) -> dict[str, Any]:
traced_state = dc_replace(
state,
Rcos=base_Rcos.at[ridx, midx].add(x[0]),
Zsin=base_Zsin.at[ridx, midx].add(x[1]),
)
return vmec_jax_flux_tube_mapping_from_state(
traced_state,
static,
wout,
surface_index=surface_index,
alpha=float(alpha),
ntheta=int(ntheta),
)
sensitivity = geometry_sensitivity_report(
mapping_fn,
p,
fd_step=float(fd_step),
source_model="vmec_jax:state->tensor-flux-tube",
)
mapping = mapping_fn(p)
vmec_meta = mapping["vmec_jax"]
except Exception as exc:
return {
"available": False,
"backend_info": info,
"sensitivity": None,
"fd_step": float(fd_step),
"case_name": str(case_name),
"error": f"{type(exc).__name__}: {exc}",
}
return {
"available": True,
"backend_info": info,
"case_name": str(case_name),
"input_path": str(input_path),
"wout_path": str(wout_path),
"source_model": "vmec_jax:state->tensor-flux-tube",
"param_names": ["delta_Rcos", "delta_Zsin"],
"params": np.asarray(p).tolist(),
"radial_index": int(ridx),
"mode_index": int(midx),
"surface_index": int(vmec_meta["surface_index"]),
"iota": float(np.asarray(vmec_meta["iota"])),
"alpha": float(alpha),
"ntheta": int(ntheta),
"state_shape": [int(base_Rcos.shape[0]), int(base_Rcos.shape[1])],
"sensitivity": sensitivity,
"fd_step": float(fd_step),
"reference_length": float(vmec_meta["reference_length"]),
"reference_b": float(vmec_meta["reference_b"]),
"field_line_convention": str(vmec_meta["field_line_convention"]),
"scope": (
"VMEC-native metric and grad-B flux-tube mapping; full imported-VMEC/EIK "
"drift parity and production transport-gradient gates remain separate."
),
}
[docs]
def vmec_jax_flux_tube_array_parity_report( # pragma: no cover
*,
case_name: str = "nfp4_QH_warm_start",
surface_index: int | None = None,
alpha: float = 0.0,
ntheta: int = 16,
mboz: int = _VMEC_BOOZER_PARITY_MIN_MODE_COUNT,
nboz: int = _VMEC_BOOZER_PARITY_MIN_MODE_COUNT,
boundary: str = "none",
include_shear_variation: bool = True,
include_pressure_variation: bool = True,
core_tolerance: float = 5.0e-2,
scalar_tolerance: float = 5.0e-3,
equal_arc_core_tolerance: float = 1.0e-2,
equal_arc_derivative_tolerance: float = 3.0e-2,
equal_arc_metric_tolerance: float = 8.0e-2,
equal_arc_drift_tolerance: float = 8.0e-2,
) -> dict[str, object]:
"""Compare the direct ``vmec_jax`` flux-tube arrays to imported VMEC/EIK.
This is a diagnostic promotion gate, not a differentiability check. It
starts from the same real ``vmec_jax`` example state used by
:func:`vmec_jax_flux_tube_sensitivity_report`, builds the direct
VMEC-tensor-derived flux-tube mapping, then generates the existing imported
VMEC/EIK geometry on the same surface and compares solver-facing arrays.
The expected current result is that ``q`` and magnetic shear are close
while metric/drift arrays remain open because the direct path still uses a
VMEC-coordinate/equal-theta convention and a local grad-:math:`B` closure
instead of the production Boozer equal-arc/Hegna-Nakajima convention.
"""
ntheta_int = int(ntheta)
if ntheta_int < 4:
raise ValueError("ntheta must be >= 4")
mboz_int = int(mboz)
nboz_int = int(nboz)
if (
mboz_int < _VMEC_BOOZER_PARITY_MIN_MODE_COUNT
or nboz_int < _VMEC_BOOZER_PARITY_MIN_MODE_COUNT
):
raise ValueError(
"mboz and nboz must both be >= "
f"{_VMEC_BOOZER_PARITY_MIN_MODE_COUNT} for VMEC/Boozer parity gates"
)
info = discover_differentiable_geometry_backends()
if not info.get("vmec_jax_available", False):
return {
"available": False,
"backend_info": info,
"case_name": str(case_name),
"reason": "vmec_jax is not available",
}
try:
from spectraxgk.from_gx.vmec import (
generate_vmec_eik_internal,
internal_vmec_backend_available,
)
from spectraxgk.geometry import load_gx_geometry_netcdf
if not internal_vmec_backend_available():
return {
"available": False,
"backend_info": info,
"case_name": str(case_name),
"reason": "internal VMEC/EIK backend is not available",
}
driver = importlib.import_module("vmec_jax.driver")
config_mod = importlib.import_module("vmec_jax.config")
static_mod = importlib.import_module("vmec_jax.static")
wout_mod = importlib.import_module("vmec_jax.wout")
input_path, wout_path = driver.example_paths(str(case_name))
if wout_path is None:
raise RuntimeError(
f"vmec_jax example {case_name!r} has no bundled wout reference"
)
cfg, indata = config_mod.load_config(str(input_path))
static = static_mod.build_static(cfg)
wout = wout_mod.read_wout(wout_path)
state = wout_mod.state_from_wout(wout)
ns = int(jnp.asarray(state.Rcos).shape[0])
sidx = (
max(1, min(ns // 2, ns - 2))
if surface_index is None
else int(surface_index)
)
torflux = float(sidx) / float(max(ns - 1, 1))
direct_mapping = vmec_jax_flux_tube_mapping_from_state(
state,
static,
wout,
surface_index=sidx,
alpha=float(alpha),
ntheta=ntheta_int,
)
direct = flux_tube_geometry_from_mapping(
direct_mapping,
source_model="vmec_jax:state->tensor-flux-tube",
validate_finite=False,
)
request = SimpleNamespace(
vmec_file=str(wout_path),
ntheta=ntheta_int,
boundary=str(boundary),
y0=10.0,
x0=None,
jtwist=None,
beta=0.0,
alpha=float(alpha),
torflux=torflux,
npol=1.0,
npol_min=None,
isaxisym=False,
which_crossing=None,
include_shear_variation=bool(include_shear_variation),
include_pressure_variation=bool(include_pressure_variation),
betaprim=None,
z=(1.0, -1.0),
mass=(1.0, 2.7e-4),
dens=(1.0, 1.0),
temp=(1.0, 1.0),
tprim=(3.0, 0.0),
fprim=(1.0, 0.0),
vnewk=(0.0, 0.0),
species_type=("ion", "electron"),
)
with tempfile.TemporaryDirectory(prefix="spectrax_vmec_eik_parity_") as tmp:
eik_path = Path(tmp) / f"{case_name}.eik.nc"
generate_vmec_eik_internal(output_path=eik_path, request=request)
imported = load_gx_geometry_netcdf(eik_path)
if imported.theta.shape[0] == direct.theta.shape[0] + 1:
imported = imported.trim_terminal_theta_point()
array_pairs = {
"theta": (direct.theta, imported.theta),
"bmag": (direct.bmag_profile, imported.bmag_profile),
"bgrad": (direct.bgrad_profile, imported.bgrad_profile),
"gds2": (direct.gds2_profile, imported.gds2_profile),
"gds21": (direct.gds21_profile, imported.gds21_profile),
"gds22": (direct.gds22_profile, imported.gds22_profile),
"cvdrift": (direct.cv_profile, imported.cv_profile),
"gbdrift": (direct.gb_profile, imported.gb_profile),
"cvdrift0": (direct.cv0_profile, imported.cv0_profile),
"gbdrift0": (direct.gb0_profile, imported.gb0_profile),
"jacobian": (direct.jacobian_profile, imported.jacobian_profile),
"grho": (direct.grho_profile, imported.grho_profile),
}
array_metrics = {
name: _array_parity_metrics(candidate, reference)
for name, (candidate, reference) in array_pairs.items()
}
scalar_metrics = {
"gradpar": _scalar_parity_metrics(
direct.gradpar_value, imported.gradpar_value
),
"q": _scalar_parity_metrics(direct.q, imported.q),
"s_hat": _scalar_parity_metrics(direct.s_hat, imported.s_hat),
}
core_names = (
"bmag",
"gds2",
"gds21",
"gds22",
"cvdrift",
"gbdrift",
"jacobian",
"grho",
)
core_values: list[float] = []
for name in core_names:
metrics = array_metrics[name]
if not bool(metrics.get("shape_match", False)):
continue
raw_value = metrics.get("normalized_max_abs")
core_values.append(
float(raw_value)
if isinstance(raw_value, int | float | np.floating)
else np.inf
)
worst_core = max(core_values) if core_values else np.inf
worst_scalar = max(float(values["rel"]) for values in scalar_metrics.values())
production_parity_passed = bool(
worst_core <= float(core_tolerance)
and worst_scalar <= float(scalar_tolerance)
)
equal_arc_core_error = None
equal_arc_core_array_metrics: dict[str, dict[str, object]] = {}
equal_arc_metric_array_metrics: dict[str, dict[str, object]] = {}
equal_arc_drift_array_metrics: dict[str, dict[str, object]] = {}
equal_arc_core_scalar_metrics: dict[str, dict[str, float]] = {}
equal_arc_core_worst = np.inf
equal_arc_core_worst_scalar = np.inf
equal_arc_derivative_worst = np.inf
equal_arc_metric_worst = np.inf
equal_arc_drift_worst = np.inf
equal_arc_core_passed = False
equal_arc_derivative_passed = False
equal_arc_metric_passed = False
equal_arc_drift_passed = False
if bool(info.get("booz_xform_jax_api_available", False)):
try:
equal_arc_core = vmec_jax_boozer_equal_arc_core_profiles_from_state(
state,
static,
indata,
wout,
surface_index=sidx,
torflux=torflux,
alpha=float(alpha),
ntheta=ntheta_int,
mboz=mboz_int,
nboz=nboz_int,
jit=False,
)
equal_arc_core_pairs = {
"theta": (equal_arc_core["theta"], imported.theta),
"bmag": (equal_arc_core["bmag"], imported.bmag_profile),
"bgrad": (equal_arc_core["bgrad"], imported.bgrad_profile),
"jacobian": (equal_arc_core["jacobian"], imported.jacobian_profile),
}
equal_arc_metric_pairs = {
"gds2": (equal_arc_core["gds2"], imported.gds2_profile),
"gds21": (equal_arc_core["gds21"], imported.gds21_profile),
"gds22": (equal_arc_core["gds22"], imported.gds22_profile),
"grho": (equal_arc_core["grho"], imported.grho_profile),
}
equal_arc_drift_pairs = {
"cvdrift": (equal_arc_core["cvdrift"], imported.cv_profile),
"gbdrift": (equal_arc_core["gbdrift"], imported.gb_profile),
"cvdrift0": (equal_arc_core["cvdrift0"], imported.cv0_profile),
"gbdrift0": (equal_arc_core["gbdrift0"], imported.gb0_profile),
}
equal_arc_core_array_metrics = {
name: _array_parity_metrics(candidate, reference)
for name, (candidate, reference) in equal_arc_core_pairs.items()
}
equal_arc_metric_array_metrics = {
name: _array_parity_metrics(candidate, reference)
for name, (candidate, reference) in equal_arc_metric_pairs.items()
}
equal_arc_drift_array_metrics = {
name: _array_parity_metrics(candidate, reference)
for name, (candidate, reference) in equal_arc_drift_pairs.items()
}
equal_arc_core_scalar_metrics = {
"gradpar": _scalar_parity_metrics(
jnp.asarray(equal_arc_core["gradpar"])[0],
imported.gradpar_value,
),
"q": _scalar_parity_metrics(equal_arc_core["q"], imported.q),
"s_hat": _scalar_parity_metrics(
equal_arc_core["s_hat"], imported.s_hat
),
}
equal_arc_core_names = ("theta", "bmag", "jacobian")
equal_arc_values: list[float] = []
for name in equal_arc_core_names:
metrics = equal_arc_core_array_metrics[name]
if not bool(metrics.get("shape_match", False)):
continue
raw_value = metrics.get("normalized_max_abs")
equal_arc_values.append(
float(raw_value)
if isinstance(raw_value, int | float | np.floating)
else np.inf
)
equal_arc_core_worst = (
max(equal_arc_values) if equal_arc_values else np.inf
)
equal_arc_core_worst_scalar = max(
float(values["rel"])
for values in equal_arc_core_scalar_metrics.values()
)
derivative_metrics = equal_arc_core_array_metrics["bgrad"]
raw_derivative = derivative_metrics.get("normalized_max_abs")
equal_arc_derivative_worst = (
float(raw_derivative)
if isinstance(raw_derivative, int | float | np.floating)
else np.inf
)
equal_arc_core_passed = bool(
equal_arc_core_worst <= float(equal_arc_core_tolerance)
and equal_arc_core_worst_scalar <= float(equal_arc_core_tolerance)
)
equal_arc_derivative_passed = bool(
equal_arc_derivative_worst <= float(equal_arc_derivative_tolerance)
)
metric_values: list[float] = []
for metrics in equal_arc_metric_array_metrics.values():
if not bool(metrics.get("shape_match", False)):
continue
raw_value = metrics.get("normalized_max_abs")
metric_values.append(
float(raw_value)
if isinstance(raw_value, int | float | np.floating)
else np.inf
)
equal_arc_metric_worst = max(metric_values) if metric_values else np.inf
equal_arc_metric_passed = bool(
equal_arc_metric_worst <= float(equal_arc_metric_tolerance)
)
drift_values: list[float] = []
for metrics in equal_arc_drift_array_metrics.values():
if not bool(metrics.get("shape_match", False)):
continue
raw_value = metrics.get("normalized_max_abs")
drift_values.append(
float(raw_value)
if isinstance(raw_value, int | float | np.floating)
else np.inf
)
equal_arc_drift_worst = max(drift_values) if drift_values else np.inf
equal_arc_drift_passed = bool(
equal_arc_drift_worst <= float(equal_arc_drift_tolerance)
)
except (
Exception
) as exc: # pragma: no cover - optional-backend diagnostic detail
equal_arc_core_error = f"{type(exc).__name__}: {exc}"
else:
equal_arc_core_error = "booz_xform_jax functional API is not available"
except Exception as exc:
return {
"available": False,
"backend_info": info,
"case_name": str(case_name),
"error": f"{type(exc).__name__}: {exc}",
}
return {
"available": True,
"backend_info": info,
"case_name": str(case_name),
"input_path": str(input_path),
"wout_path": str(wout_path),
"source_model": "vmec_jax:state->tensor-flux-tube vs imported-vmec-eik",
"surface_index": int(sidx),
"torflux": float(torflux),
"alpha": float(alpha),
"ntheta": int(ntheta_int),
"mboz": mboz_int,
"nboz": nboz_int,
"boundary": str(boundary),
"include_shear_variation": bool(include_shear_variation),
"include_pressure_variation": bool(include_pressure_variation),
"array_metrics": array_metrics,
"scalar_metrics": scalar_metrics,
"equal_arc_core_array_metrics": equal_arc_core_array_metrics,
"equal_arc_metric_array_metrics": equal_arc_metric_array_metrics,
"equal_arc_drift_array_metrics": equal_arc_drift_array_metrics,
"equal_arc_core_scalar_metrics": equal_arc_core_scalar_metrics,
"equal_arc_core_worst_normalized_max_abs": float(equal_arc_core_worst),
"equal_arc_core_worst_scalar_rel": float(equal_arc_core_worst_scalar),
"equal_arc_derivative_worst_normalized_max_abs": float(
equal_arc_derivative_worst
),
"equal_arc_metric_worst_normalized_max_abs": float(equal_arc_metric_worst),
"equal_arc_drift_worst_normalized_max_abs": float(equal_arc_drift_worst),
"equal_arc_core_tolerance": float(equal_arc_core_tolerance),
"equal_arc_derivative_tolerance": float(equal_arc_derivative_tolerance),
"equal_arc_metric_tolerance": float(equal_arc_metric_tolerance),
"equal_arc_drift_tolerance": float(equal_arc_drift_tolerance),
"equal_arc_core_passed": bool(equal_arc_core_passed),
"equal_arc_derivative_passed": bool(equal_arc_derivative_passed),
"equal_arc_metric_passed": bool(equal_arc_metric_passed),
"equal_arc_drift_passed": bool(equal_arc_drift_passed),
"equal_arc_core_error": equal_arc_core_error,
"worst_core_normalized_max_abs": float(worst_core),
"worst_scalar_rel": float(worst_scalar),
"core_tolerance": float(core_tolerance),
"scalar_tolerance": float(scalar_tolerance),
"production_parity_passed": production_parity_passed,
"status": "passed" if production_parity_passed else "diagnostic_open",
"interpretation": (
"The direct VMEC tensor path proves state-level differentiability. "
"The Boozer equal-arc core audit narrows the production gap to the "
"full imported VMEC/EIK metric and drift convention when it passes."
),
}
[docs]
def geometry_sensitivity_report(
mapping_fn: Any,
params: jnp.ndarray,
*,
fd_step: float = 1.0e-4,
rtol: float = 1.0e-4,
atol: float = 1.0e-6,
source_model: str = "vmec_jax:in-memory",
) -> dict[str, object]:
"""Validate geometry-observable sensitivities by AD and finite differences.
``mapping_fn(params)`` must return the solver-ready field-line mapping
accepted by :func:`flux_tube_geometry_from_mapping`. The report is strict
JSON friendly so examples and CI gates can preserve the derivative
contract without depending on large VMEC solves.
"""
p = jnp.asarray(params, dtype=jnp.float64)
if p.ndim != 1:
raise ValueError("params must be one-dimensional")
def observable_fn(x: jnp.ndarray) -> jnp.ndarray:
geom = flux_tube_geometry_from_mapping(
mapping_fn(x),
source_model=source_model,
validate_finite=False,
)
return flux_tube_geometry_observables(geom)
param_names = tuple(f"param_{idx}" for idx in range(int(p.shape[0])))
report = observable_gradient_validation_report(
observable_fn,
p,
fd_step=float(fd_step),
rtol=float(rtol),
atol=float(atol),
observable_names=_GEOMETRY_OBSERVABLE_NAMES,
param_names=param_names,
relative_floor=1.0e-12,
report_kind="geometry_sensitivity_ad_fd_gate",
)
report["source_model"] = str(source_model)
return report
[docs]
def geometry_inverse_design_report(
mapping_fn: Any,
initial_params: jnp.ndarray,
target_observables: jnp.ndarray,
*,
observable_indices: Sequence[int] | None = None,
max_steps: int = 8,
damping: float = 1.0e-8,
fd_step: float = 1.0e-4,
regularization: float = 1.0e-8,
source_model: str = "vmec_jax:in-memory",
) -> dict[str, object]:
"""Run a small Gauss-Newton geometry inverse-design validation.
``mapping_fn(params)`` must be the same solver-ready field-line mapping
accepted by :func:`flux_tube_geometry_from_mapping`. The routine is meant
for differentiable ``vmec_jax`` / ``booz_xform_jax`` workflows: it keeps
the optimization, sensitivity check, and local UQ covariance in one
JSON-friendly report so examples can validate the full AD contract without
depending on a long equilibrium solve in CI.
"""
params = jnp.asarray(initial_params, dtype=_jax_float_dtype())
if params.ndim != 1:
raise ValueError("initial_params must be one-dimensional")
if int(max_steps) < 0:
raise ValueError("max_steps must be non-negative")
if float(damping) < 0.0:
raise ValueError("damping must be non-negative")
if observable_indices is None:
indices_np = np.arange(len(_GEOMETRY_OBSERVABLE_NAMES), dtype=int)
else:
indices_np = np.asarray(list(observable_indices), dtype=int)
if indices_np.ndim != 1 or indices_np.size == 0:
raise ValueError(
"observable_indices must be a non-empty one-dimensional sequence"
)
if np.any(indices_np < 0) or np.any(indices_np >= len(_GEOMETRY_OBSERVABLE_NAMES)):
raise ValueError("observable_indices contains an out-of-range observable index")
target = jnp.asarray(target_observables, dtype=params.dtype)
if target.ndim != 1 or int(target.shape[0]) != int(indices_np.size):
raise ValueError("target_observables length must match observable_indices")
indices = jnp.asarray(indices_np, dtype=jnp.int32)
def observable_fn(x: jnp.ndarray) -> jnp.ndarray:
geom = flux_tube_geometry_from_mapping(
mapping_fn(x),
source_model=source_model,
validate_finite=False,
)
return flux_tube_geometry_observables(geom)[indices]
history: list[dict[str, object]] = []
p = params
residual = observable_fn(p) - target
for step in range(int(max_steps) + 1):
obs = observable_fn(p)
residual = obs - target
objective = 0.5 * jnp.dot(residual, residual)
history.append(
{
"step": int(step),
"params": np.asarray(p).tolist(),
"observables": np.asarray(obs).tolist(),
"objective": float(objective),
"residual_norm": float(jnp.linalg.norm(residual)),
}
)
if step == int(max_steps):
break
jac = jax.jacfwd(observable_fn)(p)
normal = jac.T @ jac + float(damping) * jnp.eye(int(p.shape[0]), dtype=p.dtype)
delta = jnp.linalg.solve(normal, jac.T @ residual)
p = p - delta
jac_ad = jax.jacfwd(observable_fn)(p)
jac_fd = finite_difference_jacobian(observable_fn, p, step=fd_step)
diff = jac_ad - jac_fd
scale = jnp.maximum(jnp.abs(jac_fd), 1.0e-12)
uq = covariance_diagnostics(
np.asarray(jac_ad), np.asarray(residual), regularization=regularization
)
selected_observable_names = [
str(_GEOMETRY_OBSERVABLE_NAMES[int(i)]) for i in indices_np
]
param_names = tuple(f"param_{idx}" for idx in range(int(params.shape[0])))
return {
"observable_names": selected_observable_names,
"initial_params": np.asarray(params).tolist(),
"final_params": np.asarray(p).tolist(),
"target_observables": np.asarray(target).tolist(),
"final_observables": np.asarray(observable_fn(p)).tolist(),
"final_residual": np.asarray(residual).tolist(),
"final_residual_norm": float(jnp.linalg.norm(residual)),
"history": history,
"jacobian_ad": np.asarray(jac_ad).tolist(),
"jacobian_fd": np.asarray(jac_fd).tolist(),
"max_abs_ad_fd_error": float(np.max(np.abs(np.asarray(diff)))),
"max_rel_ad_fd_error": float(
np.max(np.abs(np.asarray(diff) / np.asarray(scale)))
),
"conditioning": _sensitivity_conditioning_metadata(
jac_ad,
jac_fd,
p,
fd_step=float(fd_step),
observable_names=selected_observable_names,
param_names=param_names,
relative_floor=1.0e-12,
),
"uq": uq,
"fd_step": float(fd_step),
"damping": float(damping),
"regularization": float(regularization),
"source_model": str(source_model),
"backend_info": discover_differentiable_geometry_backends(),
}
__all__ = [
"booz_xform_flux_tube_mapping_from_inputs",
"booz_xform_flux_tube_sensitivity_report",
"booz_xform_spectral_sensitivity_report",
"discover_differentiable_geometry_backends",
"evaluate_boozer_bmag_on_field_line",
"finite_difference_jacobian",
"flux_tube_geometry_from_mapping",
"flux_tube_geometry_from_vmec_boozer_state",
"flux_tube_geometry_observables",
"geometry_inverse_design_report",
"geometry_observable_names",
"geometry_sensitivity_report",
"observable_gradient_validation_report",
"vmec_jax_boozer_flux_tube_sensitivity_report",
"vmec_jax_boozer_equal_arc_core_profiles_from_state",
"vmec_jax_field_line_tensor_sensitivity_report",
"vmec_jax_flux_tube_array_parity_report",
"vmec_jax_flux_tube_mapping_from_state",
"vmec_jax_flux_tube_sensitivity_report",
"vmec_jax_metric_tensor_sensitivity_report",
"vmec_boundary_aspect_sensitivity_report",
"vmec_field_line_tensor_observable_names",
"vmec_metric_tensor_observable_names",
]