"""Reduced QA low-turbulence stellarator optimization comparison tools.
This module adds a deliberately scoped, fully JAX-differentiable comparison
between two low-order quasi-axisymmetric (QA) stellarator designs:
* a control design constrained by quasisymmetry, aspect ratio, and an iota
floor; and
* a transport-aware design with the same constraints plus a reduced nonlinear
ITG heat-flux envelope in the objective.
The utilities are intended for optimization plumbing, sensitivity validation,
and manuscript figure generation. The heat-flux trace is a smooth reduced
envelope, not a turbulent nonlinear gyrokinetic time trace, and the visualized
surfaces are reduced boundary models rather than solved VMEC-JAX equilibria.
They are not a substitute for a production
VMEC/Boozer/full-nonlinear-GK optimization loop.
"""
from __future__ import annotations
from dataclasses import asdict, dataclass
from typing import Any, Sequence
import jax
import jax.numpy as jnp
import numpy as np
from spectraxgk.autodiff_validation import (
autodiff_finite_difference_report,
covariance_diagnostics,
)
from spectraxgk.stellarator_optimization import (
PARAMETER_NAMES,
_validate_params,
smooth_positive,
)
QA_LOW_TURBULENCE_DESIGN_NAMES = (
"qa_constraints",
"qa_plus_nonlinear_heat_flux",
)
QA_LOW_TURBULENCE_OBSERVABLE_NAMES = (
"aspect",
"mean_iota",
"iota_floor_violation",
"iota_operating_floor_violation",
"qa_residual",
"growth_rate",
"kperp_eff2",
"linear_heat_flux_weight",
"quasilinear_heat_flux",
"nonlinear_heat_flux_mean",
"nonlinear_heat_flux_cv",
"nonlinear_heat_flux_trend",
)
[docs]
@dataclass(frozen=True)
class QALowTurbulenceConfig:
"""Configuration for the reduced QA low-turbulence comparison."""
target_aspect: float = 6.0
min_iota: float = 0.41
iota_operating_floor: float = 0.70
max_mode: int = 1
aspect_weight: float = 8.0
iota_floor_weight: float = 160.0
iota_operating_weight: float = 70.0
qa_weight: float = 8.0
target_helical_amplitude: float = 0.16
helical_shaping_weight: float = 24.0
regularization: float = 2.0e-3
nonlinear_weight: float = 8.0
learning_rate: float = 0.030
steps: int = 60
nonlinear_dt: float = 0.20
nonlinear_steps: int = 2000
nonlinear_tail_fraction: float = 0.50
long_window_min_time: float = 300.0
long_window_max_cv: float = 0.03
long_window_max_trend: float = 0.02
long_window_max_half_mean_rel_change: float = 0.02
fixed_density_gradient: float = 2.2
fixed_temperature_gradient: float = 6.0
scan_density_gradients: tuple[float, ...] = (
0.6,
1.0,
1.4,
1.8,
2.2,
2.8,
3.4,
4.0,
4.8,
)
fd_step: float = 1.0e-4
surface_ntheta: int = 72
surface_nzeta: int = 72
n_field_periods: int = 2
[docs]
@dataclass(frozen=True)
class QALowTurbulenceResult:
"""JSON-ready result for one reduced QA optimization."""
design_name: str
includes_nonlinear_heat_flux: bool
parameter_names: tuple[str, ...]
observable_names: tuple[str, ...]
initial_params: tuple[float, ...]
final_params: tuple[float, ...]
initial_objective: float
final_objective: float
initial_observables: tuple[float, ...]
final_observables: tuple[float, ...]
history: tuple[dict[str, Any], ...]
residual_gradient_gate: dict[str, Any]
scalar_gradient_gate: dict[str, Any]
observable_gradient_gate: dict[str, Any]
covariance: dict[str, Any]
config: dict[str, Any]
[docs]
def to_dict(self) -> dict[str, Any]:
"""Return a stable JSON-friendly payload."""
return {
"design_name": self.design_name,
"includes_nonlinear_heat_flux": self.includes_nonlinear_heat_flux,
"parameter_names": list(self.parameter_names),
"observable_names": list(self.observable_names),
"initial_params": list(self.initial_params),
"final_params": list(self.final_params),
"initial_objective": self.initial_objective,
"final_objective": self.final_objective,
"initial_observables": list(self.initial_observables),
"final_observables": list(self.final_observables),
"history": list(self.history),
"residual_gradient_gate": self.residual_gradient_gate,
"scalar_gradient_gate": self.scalar_gradient_gate,
"observable_gradient_gate": self.observable_gradient_gate,
"covariance": self.covariance,
"config": self.config,
"claim_level": (
"reduced_differentiable_qa_low_turbulence_comparison_"
"not_full_vmec_nonlinear_transport_optimization"
),
}
[docs]
def default_qa_low_turbulence_initial_params() -> jnp.ndarray:
"""Return the shared off-optimum QA seed for the comparison."""
return jnp.asarray([0.24, 0.34, 0.30, -0.22])
def _fd_gate_tolerances(fd_step: float) -> tuple[float, float, float]:
if bool(jax.config.read("jax_enable_x64")):
return float(fd_step), 5.0e-3, 7.0e-4
return max(float(fd_step), 1.0e-3), 8.0e-2, 8.0e-3
def _qa_low_turbulence_core(
params: jnp.ndarray | Sequence[float],
config: QALowTurbulenceConfig | None = None,
) -> dict[str, jnp.ndarray]:
"""Return smooth reduced QA/ITG features for the aspect-6 comparison."""
cfg = config or QALowTurbulenceConfig()
p = _validate_params(params)
minor_shift, elong_shift, ripple, shear_shift = p
dtype = p.dtype
target_aspect = jnp.asarray(cfg.target_aspect, dtype=dtype)
min_iota = jnp.asarray(cfg.min_iota, dtype=dtype)
operating_iota = jnp.asarray(cfg.iota_operating_floor, dtype=dtype)
aspect = target_aspect * jnp.exp(
-0.42 * minor_shift + 0.050 * elong_shift**2 + 0.035 * ripple**2
)
target_helical = jnp.asarray(cfg.target_helical_amplitude, dtype=dtype)
helical_mismatch = ripple - target_helical
mean_iota = (
min_iota
+ 0.235
+ 0.155 * shear_shift
+ 0.115 * ripple
+ 0.018 * elong_shift
)
floor_violation = smooth_positive(min_iota - mean_iota, beta=80.0)
operating_floor_violation = smooth_positive(operating_iota - mean_iota, beta=45.0)
qa_residual = jnp.sqrt(
(0.040 * helical_mismatch) ** 2
+ (0.012 * elong_shift * helical_mismatch) ** 2
+ (0.010 * minor_shift * helical_mismatch) ** 2
+ (2.0e-4) ** 2
)
shaping_stabilizer = -0.085 * jnp.tanh(1.15 * elong_shift + 0.28 * shear_shift)
shaping_cost = (
0.020 * elong_shift**2
+ 0.030 * shear_shift**2
+ 0.018 * minor_shift**2
+ 0.006 * helical_mismatch**2
)
bad_curvature = (
0.078
+ 0.16 * qa_residual
+ 0.040 * (aspect / target_aspect - 1.0) ** 2
+ shaping_cost
+ shaping_stabilizer
)
shear_metric = jnp.sqrt((shear_shift - 0.18) ** 2 + 4.0e-4)
kperp_eff2 = (
0.32
+ 0.17 / aspect
+ 0.30 * qa_residual
+ 0.045 * (elong_shift - 0.55) ** 2
+ 0.050 * shear_metric**2
)
raw_drive = 1.75 * bad_curvature + 0.070 * kperp_eff2 - 0.185
growth_rate = 0.018 + smooth_positive(raw_drive, beta=22.0)
flux_weight = (
0.34
+ 1.80 * qa_residual
+ 0.085 * (elong_shift - 0.68) ** 2
+ 0.060 * (shear_shift - 0.24) ** 2
+ 0.025 * (aspect / target_aspect - 1.0) ** 2
)
quasilinear_heat_flux = 0.72 * flux_weight * growth_rate**2 / jnp.maximum(kperp_eff2, 1.0e-10)
return {
"aspect": aspect,
"mean_iota": mean_iota,
"iota_floor_violation": floor_violation,
"iota_operating_floor_violation": operating_floor_violation,
"qa_residual": qa_residual,
"helical_mismatch": helical_mismatch,
"bad_curvature": bad_curvature,
"kperp_eff2": kperp_eff2,
"growth_rate": growth_rate,
"linear_heat_flux_weight": flux_weight,
"quasilinear_heat_flux": quasilinear_heat_flux,
"shear_metric": shear_metric,
}
[docs]
def qa_low_turbulence_heat_flux_trace(
params: jnp.ndarray | Sequence[float],
config: QALowTurbulenceConfig | None = None,
*,
density_gradient: float | None = None,
temperature_gradient: float | None = None,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Return a reduced nonlinear ITG heat-flux envelope for one gradient point.
``density_gradient`` and ``temperature_gradient`` are normalized as
``a/L_n`` and ``a/L_T``. The envelope is a fixed-step differentiable RK2
integration of ``dE/dt = 2 gamma E - alpha E^2`` and ``Q_i = W_i E``.
"""
cfg = config or QALowTurbulenceConfig()
p = _validate_params(params)
core = _qa_low_turbulence_core(p, cfg)
dtype = p.dtype
aln = jnp.asarray(
cfg.fixed_density_gradient if density_gradient is None else density_gradient,
dtype=dtype,
)
alt = jnp.asarray(
cfg.fixed_temperature_gradient if temperature_gradient is None else temperature_gradient,
dtype=dtype,
)
eta_i = alt / jnp.maximum(aln, jnp.asarray(0.25, dtype=dtype))
pressure_drive = 1.0 + 0.060 * (alt - 6.0) + 0.055 * (aln - 2.2) + 0.018 * (eta_i - 2.7)
pressure_drive = smooth_positive(pressure_drive, beta=10.0)
minor_shift, elong_shift, _ripple, shear_shift = p
transport_shaping = (
jax.nn.sigmoid(8.0 * (elong_shift - 0.82))
+ 0.45 * jax.nn.sigmoid(8.0 * (minor_shift - 0.10))
+ 0.30 * jax.nn.sigmoid(8.0 * (shear_shift - 0.42))
)
shaping_suppression = 1.0 / (1.0 + 0.45 * transport_shaping)
growth = smooth_positive(core["growth_rate"] * pressure_drive * shaping_suppression, beta=18.0)
saturation = (
1.15
+ 2.45 * core["kperp_eff2"]
+ 0.40 * core["qa_residual"]
+ 0.055 * aln
+ 0.030 * alt
)
drive_weight = core["linear_heat_flux_weight"] * (
1.0 + 0.070 * aln + 0.040 * alt + 0.025 * smooth_positive(eta_i - 1.0, beta=6.0)
) / (1.0 + 0.30 * transport_shaping)
dt = jnp.asarray(cfg.nonlinear_dt, dtype=dtype)
steps = int(cfg.nonlinear_steps)
times = dt * jnp.arange(steps + 1, dtype=dtype)
equilibrium_energy = 2.0 * growth / jnp.maximum(saturation, jnp.asarray(1.0e-12, dtype=dtype))
seed = jnp.asarray(1.0e-3, dtype=dtype) * (1.0 + 0.30 * p[2] ** 2 + 0.15 * p[1] ** 2)
energy0 = jnp.maximum(seed, 0.35 * equilibrium_energy)
def rhs(energy: jnp.ndarray) -> jnp.ndarray:
return 2.0 * growth * energy - saturation * energy**2
def step_fn(energy: jnp.ndarray, _idx: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
k1 = rhs(energy)
predictor = jnp.maximum(energy + dt * k1, jnp.asarray(0.0, dtype=dtype))
k2 = rhs(predictor)
next_energy = jnp.maximum(energy + 0.5 * dt * (k1 + k2), jnp.asarray(0.0, dtype=dtype))
return next_energy, next_energy
_, tail = jax.lax.scan(step_fn, energy0, jnp.arange(steps, dtype=jnp.int32))
energy = jnp.concatenate([jnp.asarray([energy0], dtype=dtype), tail])
return times, drive_weight * energy
[docs]
def qa_low_turbulence_window_metrics(
times: jnp.ndarray,
heat_flux: jnp.ndarray,
*,
tail_fraction: float = 0.50,
eps: float = 1.0e-12,
) -> dict[str, jnp.ndarray]:
"""Return differentiable late-window heat-flux statistics.
The standard deviation uses ``sqrt(var + eps)`` so the Jacobian remains
finite when a long reduced trace has fully saturated and the late-window
variance is numerically zero.
"""
t = jnp.asarray(times)
q = jnp.asarray(heat_flux)
if int(t.ndim) != 1 or int(q.ndim) != 1 or int(t.shape[0]) != int(q.shape[0]):
raise ValueError("times and heat_flux must be one-dimensional arrays with matching length")
n = int(q.shape[0])
start = max(0, min(n - 2, int(round((1.0 - float(tail_fraction)) * n))))
tw = t[start:]
qw = q[start:]
dtype = qw.dtype
eps_arr = jnp.asarray(eps, dtype=dtype)
mean = jnp.mean(qw)
variance = jnp.mean((qw - mean) ** 2)
std = jnp.sqrt(variance + eps_arr)
centered_t = tw - jnp.mean(tw)
denom = jnp.maximum(jnp.sum(centered_t**2), eps_arr)
slope = jnp.sum(centered_t * (qw - mean)) / denom
span = jnp.maximum(tw[-1] - tw[0], eps_arr)
trend = jnp.abs(slope) * span / jnp.maximum(jnp.abs(mean), eps_arr)
cv = std / jnp.maximum(jnp.abs(mean), eps_arr)
return {
"mean": mean,
"std": std,
"cv": cv,
"trend": trend,
"slope": slope,
"start_index": jnp.asarray(start),
}
[docs]
def qa_low_turbulence_observables(
params: jnp.ndarray | Sequence[float],
config: QALowTurbulenceConfig | None = None,
*,
density_gradient: float | None = None,
temperature_gradient: float | None = None,
) -> dict[str, jnp.ndarray]:
"""Return reduced QA constraints and ITG observables."""
cfg = config or QALowTurbulenceConfig()
core = _qa_low_turbulence_core(params, cfg)
times, heat_flux = qa_low_turbulence_heat_flux_trace(
params,
cfg,
density_gradient=density_gradient,
temperature_gradient=temperature_gradient,
)
window = qa_low_turbulence_window_metrics(
times,
heat_flux,
tail_fraction=cfg.nonlinear_tail_fraction,
)
return {
"aspect": core["aspect"],
"mean_iota": core["mean_iota"],
"iota_floor_violation": core["iota_floor_violation"],
"iota_operating_floor_violation": core["iota_operating_floor_violation"],
"qa_residual": core["qa_residual"],
"growth_rate": core["growth_rate"],
"kperp_eff2": core["kperp_eff2"],
"linear_heat_flux_weight": core["linear_heat_flux_weight"],
"quasilinear_heat_flux": core["quasilinear_heat_flux"],
"nonlinear_heat_flux_mean": window["mean"],
"nonlinear_heat_flux_cv": window["cv"],
"nonlinear_heat_flux_trend": window["trend"],
}
[docs]
def qa_low_turbulence_observable_vector(
params: jnp.ndarray | Sequence[float],
config: QALowTurbulenceConfig | None = None,
) -> jnp.ndarray:
"""Return QA low-turbulence observables in stable order."""
obs = qa_low_turbulence_observables(params, config)
return jnp.asarray([obs[name] for name in QA_LOW_TURBULENCE_OBSERVABLE_NAMES])
[docs]
def qa_low_turbulence_residual_names(
*, includes_nonlinear_heat_flux: bool,
) -> tuple[str, ...]:
"""Return the stable residual names for the comparison objective."""
names = (
"aspect_constraint",
"minimum_iota_floor",
"operating_iota_floor",
"quasisymmetry_residual",
"qa_helical_shaping_amplitude",
*(f"regularization_{name}" for name in PARAMETER_NAMES),
)
if includes_nonlinear_heat_flux:
return (*names, "reduced_nonlinear_heat_flux")
return names
[docs]
def qa_low_turbulence_residual_vector(
params: jnp.ndarray | Sequence[float],
config: QALowTurbulenceConfig | None = None,
*,
includes_nonlinear_heat_flux: bool,
) -> jnp.ndarray:
"""Return weighted residuals for the aspect-6 QA low-turbulence objective."""
cfg = config or QALowTurbulenceConfig()
p = _validate_params(params)
obs = qa_low_turbulence_observables(p, cfg)
dtype = p.dtype
aspect_res = jnp.sqrt(jnp.asarray(cfg.aspect_weight, dtype=dtype)) * (
(obs["aspect"] - cfg.target_aspect) / cfg.target_aspect
)
iota_res = jnp.sqrt(jnp.asarray(cfg.iota_floor_weight, dtype=dtype)) * obs[
"iota_floor_violation"
]
operating_iota_res = jnp.sqrt(jnp.asarray(cfg.iota_operating_weight, dtype=dtype)) * obs[
"iota_operating_floor_violation"
]
qa_res = jnp.sqrt(jnp.asarray(cfg.qa_weight, dtype=dtype)) * obs["qa_residual"]
helical_res = jnp.sqrt(jnp.asarray(cfg.helical_shaping_weight, dtype=dtype)) * (
_qa_low_turbulence_core(p, cfg)["helical_mismatch"]
)
reg_res = jnp.sqrt(jnp.asarray(cfg.regularization, dtype=dtype)) * p
parts = [
jnp.asarray([aspect_res, iota_res, operating_iota_res, qa_res, helical_res], dtype=dtype),
reg_res,
]
if includes_nonlinear_heat_flux:
q_res = jnp.sqrt(
jnp.maximum(
jnp.asarray(cfg.nonlinear_weight, dtype=dtype) * obs["nonlinear_heat_flux_mean"],
jnp.asarray(0.0, dtype=dtype),
)
)
parts.append(jnp.asarray([q_res], dtype=dtype))
return jnp.concatenate(parts)
[docs]
def qa_low_turbulence_objective(
params: jnp.ndarray | Sequence[float],
config: QALowTurbulenceConfig | None = None,
*,
includes_nonlinear_heat_flux: bool,
) -> jnp.ndarray:
"""Return the scalar reduced QA comparison objective."""
residual = qa_low_turbulence_residual_vector(
params,
config,
includes_nonlinear_heat_flux=includes_nonlinear_heat_flux,
)
return jnp.dot(residual, residual)
def _history_row(
step: int,
params: jnp.ndarray,
objective: jnp.ndarray,
grad: jnp.ndarray,
config: QALowTurbulenceConfig,
) -> dict[str, Any]:
obs = np.asarray(qa_low_turbulence_observable_vector(params, config), dtype=float)
return {
"step": int(step),
"objective": float(objective),
"gradient_norm": float(jnp.linalg.norm(grad)),
"params": [float(x) for x in np.asarray(params, dtype=float)],
"observables": [float(x) for x in obs],
}
[docs]
def qa_low_turbulence_observable_sensitivity_report(
params: jnp.ndarray | Sequence[float],
config: QALowTurbulenceConfig | None = None,
*,
finite_difference_workers: int = 1,
) -> dict[str, Any]:
"""Check the full controls-to-observables differentiable plumbing.
This gate is stricter than the scalar objective check: it differentiates
the full reduced observable vector, including the long-window nonlinear
heat-flux mean, CV, and trend, and compares the JAX Jacobian against
central finite differences.
"""
cfg = config or QALowTurbulenceConfig()
p = _validate_params(params)
fd_step, rtol, atol = _fd_gate_tolerances(cfg.fd_step)
report = autodiff_finite_difference_report(
lambda x: qa_low_turbulence_observable_vector(x, cfg),
p,
step=fd_step,
rtol=rtol,
atol=atol,
workers=finite_difference_workers,
)
report["observable_names"] = list(QA_LOW_TURBULENCE_OBSERVABLE_NAMES)
report["parameter_names"] = list(PARAMETER_NAMES)
report["kind"] = "qa_low_turbulence_observable_sensitivity_report"
report["claim_level"] = (
"full_reduced_controls_to_linear_quasilinear_nonlinear_observable_"
"differentiability_gate"
)
return report
def _sensitivity_reports(
params: jnp.ndarray,
config: QALowTurbulenceConfig,
*,
includes_nonlinear_heat_flux: bool,
finite_difference_workers: int,
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any], dict[str, Any]]:
fd_step, rtol, atol = _fd_gate_tolerances(config.fd_step)
scalar_gate = autodiff_finite_difference_report(
lambda x: qa_low_turbulence_objective(
x,
config,
includes_nonlinear_heat_flux=includes_nonlinear_heat_flux,
),
params,
step=fd_step,
rtol=rtol,
atol=atol,
workers=finite_difference_workers,
)
residual_gate = autodiff_finite_difference_report(
lambda x: qa_low_turbulence_residual_vector(
x,
config,
includes_nonlinear_heat_flux=includes_nonlinear_heat_flux,
),
params,
step=fd_step,
rtol=rtol,
atol=atol,
workers=finite_difference_workers,
)
jac = np.asarray(residual_gate["jacobian_ad"], dtype=float)
residual = np.asarray(
qa_low_turbulence_residual_vector(
params,
config,
includes_nonlinear_heat_flux=includes_nonlinear_heat_flux,
),
dtype=float,
)
covariance = covariance_diagnostics(jac, residual, regularization=1.0e-8)
covariance["residual_names"] = list(
qa_low_turbulence_residual_names(
includes_nonlinear_heat_flux=includes_nonlinear_heat_flux,
)
)
covariance["source"] = "qa_low_turbulence_weighted_residuals"
observable_gate = qa_low_turbulence_observable_sensitivity_report(
params,
config,
finite_difference_workers=finite_difference_workers,
)
return scalar_gate, residual_gate, observable_gate, covariance
[docs]
def optimize_qa_low_turbulence(
*,
includes_nonlinear_heat_flux: bool,
config: QALowTurbulenceConfig | None = None,
initial_params: jnp.ndarray | Sequence[float] | None = None,
finite_difference_workers: int = 1,
) -> QALowTurbulenceResult:
"""Optimize one reduced QA low-turbulence design with Adam."""
cfg = config or QALowTurbulenceConfig()
initial_p = default_qa_low_turbulence_initial_params() if initial_params is None else _validate_params(initial_params)
p = jnp.asarray(initial_p)
m = jnp.zeros_like(p)
v = jnp.zeros_like(p)
beta1 = 0.9
beta2 = 0.999
eps = 1.0e-8
grad_fn = jax.value_and_grad(
lambda x: qa_low_turbulence_objective(
x,
cfg,
includes_nonlinear_heat_flux=includes_nonlinear_heat_flux,
)
)
history: list[dict[str, Any]] = []
objective0, grad0 = grad_fn(p)
history.append(_history_row(0, p, objective0, grad0, cfg))
for step in range(1, int(cfg.steps) + 1):
objective, grad = grad_fn(p)
m = beta1 * m + (1.0 - beta1) * grad
v = beta2 * v + (1.0 - beta2) * (grad * grad)
mhat = m / (1.0 - beta1**step)
vhat = v / (1.0 - beta2**step)
p = p - cfg.learning_rate * mhat / (jnp.sqrt(vhat) + eps)
if step % 5 == 0 or step == int(cfg.steps):
new_objective, new_grad = grad_fn(p)
history.append(_history_row(step, p, new_objective, new_grad, cfg))
final_objective, final_grad = grad_fn(p)
if history[-1]["step"] != int(cfg.steps):
history.append(_history_row(int(cfg.steps), p, final_objective, final_grad, cfg))
scalar_gate, residual_gate, observable_gate, covariance = _sensitivity_reports(
p,
cfg,
includes_nonlinear_heat_flux=includes_nonlinear_heat_flux,
finite_difference_workers=finite_difference_workers,
)
design_name = (
"qa_plus_nonlinear_heat_flux"
if includes_nonlinear_heat_flux
else "qa_constraints"
)
return QALowTurbulenceResult(
design_name=design_name,
includes_nonlinear_heat_flux=includes_nonlinear_heat_flux,
parameter_names=tuple(PARAMETER_NAMES),
observable_names=tuple(QA_LOW_TURBULENCE_OBSERVABLE_NAMES),
initial_params=tuple(float(x) for x in np.asarray(initial_p, dtype=float)),
final_params=tuple(float(x) for x in np.asarray(p, dtype=float)),
initial_objective=float(objective0),
final_objective=float(final_objective),
initial_observables=tuple(
float(x) for x in np.asarray(qa_low_turbulence_observable_vector(initial_p, cfg), dtype=float)
),
final_observables=tuple(float(x) for x in np.asarray(qa_low_turbulence_observable_vector(p, cfg))),
history=tuple(history),
residual_gradient_gate=residual_gate,
scalar_gradient_gate=scalar_gate,
observable_gradient_gate=observable_gate,
covariance=covariance,
config=asdict(cfg),
)
[docs]
def reduced_boundary_surface(
params: jnp.ndarray | Sequence[float],
config: QALowTurbulenceConfig | None = None,
) -> dict[str, Any]:
"""Return a reduced max-mode-1 boundary surface for visualization."""
cfg = config or QALowTurbulenceConfig()
p = np.asarray(_validate_params(params), dtype=float)
minor_shift, elong_shift, ripple, shear_shift = p
aspect = float(_qa_low_turbulence_core(p.tolist(), cfg)["aspect"])
theta = np.linspace(0.0, 2.0 * np.pi, int(cfg.surface_ntheta), endpoint=False)
zeta = np.linspace(0.0, 2.0 * np.pi, int(cfg.surface_nzeta), endpoint=False)
tt, zz = np.meshgrid(theta, zeta, indexing="ij")
major_radius = 1.0
minor_radius = major_radius / max(aspect, 1.0e-6)
elongation = 1.0 + 0.26 * float(elong_shift)
helical = float(ripple)
nfp = int(cfg.n_field_periods)
visual_helical_radial = 1.35 * helical
visual_helical_vertical = 1.05 * helical
axis_radius = major_radius * (1.0 + 0.22 * helical * np.cos(nfp * zz))
axis_height = minor_radius * 1.10 * helical * np.sin(nfp * zz)
radius = axis_radius + minor_radius * (
np.cos(tt)
+ visual_helical_radial * np.cos(tt - nfp * zz)
+ 0.030 * float(shear_shift) * np.cos(2.0 * tt)
+ 0.060 * float(minor_shift) * np.cos(tt + nfp * zz)
)
height = axis_height + minor_radius * (
elongation * np.sin(tt)
+ visual_helical_vertical * np.sin(tt - nfp * zz)
+ 0.040 * float(shear_shift) * np.sin(2.0 * tt)
)
x = radius * np.cos(zz)
y = radius * np.sin(zz)
return {
"theta": theta.tolist(),
"zeta": zeta.tolist(),
"x": x.tolist(),
"y": y.tolist(),
"z": height.tolist(),
"visual_helical_radial_amplitude": float(visual_helical_radial),
"visual_helical_vertical_amplitude": float(visual_helical_vertical),
"reduced_boundary_scope": "max-mode-1 visualization, not a solved VMEC equilibrium",
}
[docs]
def reduced_lcfs_bmag(
params: jnp.ndarray | Sequence[float],
config: QALowTurbulenceConfig | None = None,
) -> dict[str, Any]:
"""Return a reduced LCFS ``|B|`` map for QA visualization."""
cfg = config or QALowTurbulenceConfig()
p = np.asarray(_validate_params(params), dtype=float)
core = _qa_low_turbulence_core(p.tolist(), cfg)
theta = np.linspace(0.0, 2.0 * np.pi, int(cfg.surface_ntheta), endpoint=False)
zeta = np.linspace(0.0, 2.0 * np.pi, int(cfg.surface_nzeta), endpoint=False)
tt, zz = np.meshgrid(theta, zeta, indexing="ij")
nfp = int(cfg.n_field_periods)
ripple = float(p[2])
elong = float(p[1])
qa_amp = float(core["qa_residual"])
bmag = (
1.0
+ 0.055 * np.cos(tt)
+ 0.018 * elong * np.cos(2.0 * tt)
+ 0.220 * ripple * np.cos(tt - nfp * zz)
+ 0.030 * qa_amp * np.cos(2.0 * tt - nfp * zz)
)
return {
"theta": theta.tolist(),
"zeta": zeta.tolist(),
"bmag": bmag.tolist(),
"reduced_bmag_scope": "synthetic LCFS |B| map from reduced QA controls",
}
def _window_mean_for_gradient(
params: Sequence[float],
config: QALowTurbulenceConfig,
*,
density_gradient: float,
temperature_gradient: float,
) -> tuple[float, float, float, float]:
times, heat_flux = qa_low_turbulence_heat_flux_trace(
params,
config,
density_gradient=density_gradient,
temperature_gradient=temperature_gradient,
)
window = qa_low_turbulence_window_metrics(
times,
heat_flux,
tail_fraction=config.nonlinear_tail_fraction,
)
core = _qa_low_turbulence_core(params, config)
return (
float(window["mean"]),
float(window["cv"]),
float(window["trend"]),
float(core["growth_rate"]),
)
def _scan_density_gradient(
params: Sequence[float],
config: QALowTurbulenceConfig,
) -> dict[str, Any]:
gradients = np.asarray(config.scan_density_gradients, dtype=float)
means = []
cvs = []
trends = []
gammas = []
for aln in gradients:
mean, cv, trend, gamma = _window_mean_for_gradient(
params,
config,
density_gradient=float(aln),
temperature_gradient=config.fixed_temperature_gradient,
)
means.append(mean)
cvs.append(cv)
trends.append(trend)
gammas.append(gamma)
slope = float(np.polyfit(gradients, np.asarray(means, dtype=float), deg=1)[0])
return {
"density_gradient_axis": gradients.tolist(),
"fixed_temperature_gradient": float(config.fixed_temperature_gradient),
"heat_flux_mean": means,
"heat_flux_cv": cvs,
"heat_flux_trend": trends,
"growth_rate": gammas,
"linear_slope_dQ_d_a_over_Ln": slope,
}
def _long_window_convergence_gate(
times: np.ndarray,
heat_flux: np.ndarray,
window: dict[str, Any],
config: QALowTurbulenceConfig,
) -> dict[str, Any]:
"""Return explicit convergence checks for the fixed-gradient trace."""
t = np.asarray(times, dtype=float)
q = np.asarray(heat_flux, dtype=float)
start = int(window["start_index"])
tail = q[start:]
if tail.size < 4:
raise ValueError("long-window convergence gate requires at least four late-window samples")
split = max(1, tail.size // 2)
first_mean = float(np.mean(tail[:split]))
second_mean = float(np.mean(tail[split:]))
full_mean = float(np.mean(tail))
denom = max(abs(full_mean), 1.0e-14)
half_mean_rel_change = abs(second_mean - first_mean) / denom
running_mean = np.cumsum(tail) / np.arange(1, tail.size + 1, dtype=float)
checkpoint = max(0, int(round(0.75 * (tail.size - 1))))
running_mean_rel_change = abs(float(running_mean[-1]) - float(running_mean[checkpoint])) / denom
tmax = float(t[-1])
passed = bool(
tmax >= float(config.long_window_min_time)
and float(window["cv"]) <= float(config.long_window_max_cv)
and float(window["trend"]) <= float(config.long_window_max_trend)
and half_mean_rel_change <= float(config.long_window_max_half_mean_rel_change)
)
return {
"passed": passed,
"tmax": tmax,
"minimum_tmax": float(config.long_window_min_time),
"first_half_mean": first_mean,
"second_half_mean": second_mean,
"half_window_relative_mean_change": float(half_mean_rel_change),
"running_mean_checkpoint_fraction": 0.75,
"running_mean_relative_change_since_checkpoint": float(running_mean_rel_change),
"max_cv": float(config.long_window_max_cv),
"max_trend": float(config.long_window_max_trend),
"max_half_window_relative_mean_change": float(config.long_window_max_half_mean_rel_change),
}
def _fixed_trace_payload(params: Sequence[float], config: QALowTurbulenceConfig) -> dict[str, Any]:
times, heat_flux = qa_low_turbulence_heat_flux_trace(
params,
config,
density_gradient=config.fixed_density_gradient,
temperature_gradient=config.fixed_temperature_gradient,
)
window = qa_low_turbulence_window_metrics(
times,
heat_flux,
tail_fraction=config.nonlinear_tail_fraction,
)
times_np = np.asarray(times, dtype=float)
heat_flux_np = np.asarray(heat_flux, dtype=float)
window_payload = {key: float(value) for key, value in window.items()}
window_payload["start_index"] = int(window["start_index"])
convergence_gate = _long_window_convergence_gate(times_np, heat_flux_np, window_payload, config)
return {
"density_gradient": float(config.fixed_density_gradient),
"temperature_gradient": float(config.fixed_temperature_gradient),
"trace_kind": "smooth_reduced_nonlinear_envelope_not_full_turbulent_gk",
"trace_equation": "dE/dt = 2 gamma E - alpha E^2; Q_env = W E",
"times": [float(x) for x in times_np],
"heat_flux": [float(x) for x in heat_flux_np],
"window": window_payload,
"long_window_convergence": convergence_gate,
}
[docs]
def qa_low_turbulence_comparison_payload(
config: QALowTurbulenceConfig | None = None,
*,
finite_difference_workers: int = 1,
) -> dict[str, Any]:
"""Build the full JSON-ready aspect-6 QA low-turbulence comparison."""
cfg = config or QALowTurbulenceConfig()
control = optimize_qa_low_turbulence(
includes_nonlinear_heat_flux=False,
config=cfg,
finite_difference_workers=finite_difference_workers,
)
transport = optimize_qa_low_turbulence(
includes_nonlinear_heat_flux=True,
config=cfg,
finite_difference_workers=finite_difference_workers,
)
results = [control.to_dict(), transport.to_dict()]
obs_index = {name: i for i, name in enumerate(QA_LOW_TURBULENCE_OBSERVABLE_NAMES)}
control_q = float(control.final_observables[obs_index["nonlinear_heat_flux_mean"]])
transport_q = float(transport.final_observables[obs_index["nonlinear_heat_flux_mean"]])
reduction = 1.0 - transport_q / max(control_q, 1.0e-14)
design_payloads = []
for result in results:
params = result["final_params"]
design_payloads.append(
{
"design_name": result["design_name"],
"final_params": params,
"final_observables": result["final_observables"],
"density_gradient_scan": _scan_density_gradient(params, cfg),
"fixed_gradient_trace": _fixed_trace_payload(params, cfg),
"surface": reduced_boundary_surface(params, cfg),
"lcfs_bmag": reduced_lcfs_bmag(params, cfg),
}
)
all_gates_passed = all(
bool(result["scalar_gradient_gate"]["passed"])
and bool(result["residual_gradient_gate"]["passed"])
and bool(result["observable_gradient_gate"]["passed"])
for result in results
)
long_window_gates_passed = all(
bool(design["fixed_gradient_trace"]["long_window_convergence"]["passed"])
for design in design_payloads
)
constraints_passed = all(
abs(result["final_observables"][obs_index["aspect"]] - cfg.target_aspect) / cfg.target_aspect < 2.5e-2
and result["final_observables"][obs_index["mean_iota"]] >= cfg.iota_operating_floor - 2.0e-3
and result["final_observables"][obs_index["qa_residual"]] < 2.5e-2
for result in results
)
transport_passed = bool(transport_q <= 0.95 * control_q)
passed = bool(all_gates_passed and constraints_passed and transport_passed and long_window_gates_passed)
return {
"kind": "qa_low_turbulence_comparison",
"claim_level": (
"reduced_differentiable_qa_low_turbulence_comparison_"
"not_full_vmec_nonlinear_transport_optimization"
),
"target_aspect": float(cfg.target_aspect),
"minimum_iota": float(cfg.min_iota),
"operating_iota_floor": float(cfg.iota_operating_floor),
"fixed_density_gradient": float(cfg.fixed_density_gradient),
"fixed_temperature_gradient": float(cfg.fixed_temperature_gradient),
"parameter_names": list(PARAMETER_NAMES),
"observable_names": list(QA_LOW_TURBULENCE_OBSERVABLE_NAMES),
"results": results,
"designs": design_payloads,
"comparison_metrics": {
"control_design_heat_flux_mean": control_q,
"transport_design_heat_flux_mean": transport_q,
"relative_heat_flux_reduction_at_fixed_gradients": float(reduction),
"constraints_passed": bool(constraints_passed),
"transport_reduction_gate_passed": transport_passed,
"long_window_gates_passed": bool(long_window_gates_passed),
"ad_fd_gates_passed": bool(all_gates_passed),
"passed": passed,
"reduced_differentiable_plumbing_passed": bool(all_gates_passed),
"full_vmec_nonlinear_differentiable_plumbing_passed": False,
},
"differentiable_plumbing": {
"stages": [
"reduced QA controls",
"geometry constraints and reduced LCFS visualization",
"linear ITG feature map",
"quasilinear mixing-length diagnostic",
"long-window differentiable nonlinear heat-flux envelope",
"weighted optimization residuals",
"scalar, residual, and observable AD-vs-FD gates",
],
"all_scalar_objective_gates_passed": all(
bool(result["scalar_gradient_gate"]["passed"]) for result in results
),
"all_residual_jacobian_gates_passed": all(
bool(result["residual_gradient_gate"]["passed"]) for result in results
),
"all_observable_jacobian_gates_passed": all(
bool(result["observable_gradient_gate"]["passed"]) for result in results
),
"passed": bool(all_gates_passed),
},
"model_equations": {
"objective": (
"||r||^2 with aspect, minimum-iota, operating-iota, QA, QA-compatible "
"helical-shaping, regularization, and optional sqrt(weight * late-window reduced "
"nonlinear heat flux) residuals"
),
"nonlinear_envelope": "dE/dt = 2 gamma E - alpha E^2; Q_env = W_i E; fixed-step RK2",
"gradient_scan": "fixed a/L_T while scanning a/L_n and refitting late-window Q_env means",
},
"config": asdict(cfg),
"scope_notes": [
"The surface and |B| maps are reduced max-mode-1 visualizations, not solved VMEC equilibria.",
"The nonlinear heat-flux objective is a differentiable envelope used for optimization plumbing.",
"Production nonlinear claims still require long post-transient replicated SPECTRAX-GK windows.",
],
}