Source code for spectraxgk.vmec_jax_transport_gradient

"""Boundary-gradient diagnostics for VMEC-JAX transport objectives.

The helpers here intentionally avoid importing ``vmec_jax`` at module import
time.  They operate on the small optimizer protocol exposed by VMEC-JAX
(``residual_fun``, ``objective_and_gradient_fun``, and ``_specs``), which keeps
SPECTRAX-GK tests fast while letting examples evaluate the real full-chain
transport sensitivity on machines with VMEC-JAX installed.
"""

from __future__ import annotations

from collections.abc import Sequence
from pathlib import Path
from typing import Any
import json

import numpy as np


def _finite_float(value: Any) -> float | None:
    try:
        out = float(value)
    except Exception:
        return None
    return out if np.isfinite(out) else None


[docs] def boundary_spec_record(spec: Any, *, fallback_index: int) -> dict[str, Any]: """Return a JSON-safe summary for a VMEC-JAX boundary parameter spec.""" name = str(getattr(spec, "name", f"p{fallback_index}")) return { "name": name, "kind": None if not hasattr(spec, "kind") else str(getattr(spec, "kind")), "mode_index": None if not hasattr(spec, "index") else int(getattr(spec, "index")), "m": None if not hasattr(spec, "m") else int(getattr(spec, "m")), "n": None if not hasattr(spec, "n") else int(getattr(spec, "n")), }
def _top_gradient_components( gradient: np.ndarray, specs: Sequence[Any], *, top_n: int, ) -> list[dict[str, Any]]: """Return the largest active boundary-gradient components.""" grad = np.asarray(gradient, dtype=float).reshape(-1) order = np.argsort(-np.abs(grad)) scale = max(float(np.linalg.norm(grad)), 1.0e-300) rows: list[dict[str, Any]] = [] for rank, idx in enumerate(order[: max(0, int(top_n))], start=1): spec = specs[int(idx)] if int(idx) < len(specs) else object() row = boundary_spec_record(spec, fallback_index=int(idx)) value = float(grad[int(idx)]) row.update( { "rank": rank, "parameter_index": int(idx), "gradient": value, "abs_gradient": abs(value), "fraction_of_l2_norm": abs(value) / scale, } ) rows.append(row) return rows
[docs] def build_boundary_transport_gradient_report( optimizer: Any, *, params: Sequence[float] | np.ndarray | None = None, label: str = "vmec_jax_transport_gradient", top_n: int = 12, sensitivity_atol: float = 1.0e-12, include_jacobian: bool = False, ) -> dict[str, Any]: """Evaluate transport residual and boundary-gradient diagnostics. Parameters ---------- optimizer: VMEC-JAX-like optimizer exposing ``residual_fun(params)`` and ``objective_and_gradient_fun(params)``. The optional ``_specs`` member is used only for readable boundary-coefficient labels. params: Active boundary-parameter vector. ``None`` means the zero-increment vector aligned with ``optimizer._specs``. label: Human-readable artifact label. top_n: Number of largest gradient components to keep. sensitivity_atol: Absolute L2 threshold below which the boundary transport response is classified as locally flat for optimization purposes. include_jacobian: If true and the optimizer exposes ``jacobian_fun``, include dense residual-Jacobian norms. This can be substantially more expensive than the reverse scalar-gradient path. """ specs = tuple(getattr(optimizer, "_specs", ())) if params is None: if not specs: raise ValueError("params must be provided when optimizer._specs is unavailable") params_array = np.zeros(len(specs), dtype=float) else: params_array = np.asarray(params, dtype=float).reshape(-1) if not np.all(np.isfinite(params_array)): raise ValueError("params must be finite") residual = np.asarray(optimizer.residual_fun(params_array), dtype=float).reshape(-1) objective_value_raw, gradient_raw = optimizer.objective_and_gradient_fun(params_array) objective_value = _finite_float(objective_value_raw) gradient = np.asarray(gradient_raw, dtype=float).reshape(-1) if gradient.size != params_array.size: raise ValueError( f"gradient size {gradient.size} does not match parameter size {params_array.size}" ) residual_finite = bool(np.all(np.isfinite(residual))) gradient_finite = bool(np.all(np.isfinite(gradient))) gradient_l2 = float(np.linalg.norm(gradient)) if gradient_finite else float("nan") gradient_linf = float(np.max(np.abs(gradient))) if gradient.size and gradient_finite else 0.0 residual_l2 = float(np.linalg.norm(residual)) if residual_finite else float("nan") residual_linf = float(np.max(np.abs(residual))) if residual.size and residual_finite else 0.0 finite = bool(objective_value is not None and residual_finite and gradient_finite) sensitive = bool(finite and gradient_l2 > float(sensitivity_atol)) report: dict[str, Any] = { "kind": "vmec_jax_transport_gradient_diagnostic", "label": str(label), "parameter_count": int(params_array.size), "residual_count": int(residual.size), "objective_value": objective_value, "residual_norm_l2": residual_l2, "residual_norm_linf": residual_linf, "gradient_norm_l2": gradient_l2, "gradient_norm_linf": gradient_linf, "sensitivity_atol": float(sensitivity_atol), "finite": finite, "transport_sensitivity_detected": sensitive, "classification": ( "sensitive_boundary_transport_objective" if sensitive else ( "invalid_nonfinite_transport_gradient" if not finite else "locally_flat_or_underconditioned_transport_objective" ) ), "top_gradient_components": _top_gradient_components( gradient, specs, top_n=int(top_n), ), "next_action": ( "use a constraint-preserving projected update or constrained line search along the leading " "transport-gradient components" if sensitive else ( "do not launch another blind scalar-weight ladder; change the transport observable, " "sample set, or finite-difference scale until the local boundary sensitivity is measurable" ) ), } if bool(include_jacobian): jacobian_fun = getattr(optimizer, "jacobian_fun", None) if callable(jacobian_fun): jac = np.asarray(jacobian_fun(params_array), dtype=float) if jac.ndim == 1: jac = jac.reshape(1, -1) jac_finite = bool(np.all(np.isfinite(jac))) report["jacobian"] = { "available": True, "shape": [int(item) for item in jac.shape], "finite": jac_finite, "frobenius_norm": float(np.linalg.norm(jac)) if jac_finite else None, "row_norms": ( [float(item) for item in np.linalg.norm(jac, axis=1)] if jac_finite else None ), } else: report["jacobian"] = { "available": False, "reason": "optimizer_has_no_jacobian_fun", } return report
[docs] def write_boundary_transport_gradient_report(report: dict[str, Any], path: str | Path) -> Path: """Write a boundary-gradient diagnostic JSON artifact.""" out = Path(path) out.parent.mkdir(parents=True, exist_ok=True) out.write_text(json.dumps(report, indent=2, allow_nan=False) + "\n", encoding="utf-8") return out
__all__ = [ "boundary_spec_record", "build_boundary_transport_gradient_report", "write_boundary_transport_gradient_report", ]