Source code for spectraxgk.zonal_objective

"""Reduced zonal-flow objectives for differentiable stellarator optimization.

This module is intentionally backend-free.  Production callers should build
zonal-response metrics from VMEC/Boozer/SPECTRAX-GK rows, then pass the metric
tensors here for reduction, finite-difference checks, and UQ diagnostics.
"""

from __future__ import annotations

from collections.abc import Callable, Iterable, Mapping, Sequence
from dataclasses import asdict, dataclass
from typing import Any, Literal

import jax.numpy as jnp
import numpy as np

from spectraxgk.stellarator_objective_portfolio import (
    PortfolioReduction,
    aggregate_objective_portfolio,
    objective_portfolio_sensitivity_report,
)


ZONAL_FLOW_OBJECTIVE_NAMES = (
    "inverse_residual",
    "damping_rate",
    "growth_over_residual",
    "recurrence_amplitude",
)

MissingDampingPolicy = Literal["fail", "zero"]


[docs] @dataclass(frozen=True) class ZonalFlowObjectiveConfig: """Weights and floors for a minimizable zonal-flow objective. The objective rewards large residual zonal response by minimizing ``1 / residual`` and penalizes collisionless damping, linear growth not screened by the residual, and late-time recurrence/envelope amplitude. Nonlinear heat-flux suppression remains a separate holdout gate. """ residual_weight: float = 1.0 damping_weight: float = 1.0 growth_over_residual_weight: float = 0.0 recurrence_weight: float = 0.0 residual_floor: float = 1.0e-6 def __post_init__(self) -> None: weights = ( float(self.residual_weight), float(self.damping_weight), float(self.growth_over_residual_weight), float(self.recurrence_weight), ) if any((not np.isfinite(weight)) or weight < 0.0 for weight in weights): raise ValueError("zonal-flow objective weights must be finite and non-negative") if float(sum(weights)) <= 0.0: raise ValueError("at least one zonal-flow objective weight must be positive") if (not np.isfinite(float(self.residual_floor))) or float(self.residual_floor) <= 0.0: raise ValueError("residual_floor must be finite and positive")
[docs] def objective_weights(self) -> jnp.ndarray: """Return the normalized objective-column weights used by the reducer.""" return jnp.asarray( [ self.residual_weight, self.damping_weight, self.growth_over_residual_weight, self.recurrence_weight, ], dtype=jnp.float32, )
[docs] def to_dict(self) -> dict[str, object]: """Return a JSON-friendly representation.""" payload = asdict(self) payload["objective_names"] = list(ZONAL_FLOW_OBJECTIVE_NAMES) return payload
def _metric_tensor(value: Any, *, name: str, strictly_positive: bool = False) -> jnp.ndarray: array = jnp.asarray(value) if int(array.ndim) != 3: raise ValueError(f"{name} must have shape (n_surface, n_alpha, n_kx)") if any(int(size) < 1 for size in array.shape): raise ValueError(f"{name} dimensions must all be positive") if jnp.issubdtype(array.dtype, jnp.complexfloating) or not jnp.issubdtype(array.dtype, jnp.number): raise TypeError(f"{name} must be a real numeric array") try: concrete = np.asarray(array, dtype=float) except Exception as exc: # pragma: no cover - triggered only under JAX tracing. class_name = type(exc).__name__ if "Tracer" in class_name or "Concretization" in class_name: return array raise if not np.all(np.isfinite(concrete)): raise ValueError(f"{name} must be finite") if strictly_positive and np.any(concrete <= 0.0): raise ValueError(f"{name} must be strictly positive") return array def _first_present(record: Mapping[str, Any], keys: Sequence[str]) -> tuple[str | None, Any]: for key in keys: if key in record: return key, record[key] return None, None def _optional_float(value: Any, *, field: str) -> float | None: if value is None: return None if isinstance(value, str): stripped = value.strip() if stripped.lower() in {"", "nan", "none", "null"}: return None value = stripped try: scalar = float(value) except (TypeError, ValueError) as exc: raise ValueError(f"{field} must be numeric when present") from exc if not np.isfinite(scalar): return None return scalar def _required_float( record: Mapping[str, Any], keys: Sequence[str], *, field: str, default: float | None = None, ) -> float: key, raw = _first_present(record, keys) value = _optional_float(default if key is None else raw, field=field) if value is None: raise ValueError(f"record is missing finite {field}; tried keys {list(keys)}") return value def _optional_metric( record: Mapping[str, Any], keys: Sequence[str], *, field: str, default: float | None = None, ) -> float | None: key, raw = _first_present(record, keys) if key is None: return default return _optional_float(raw, field=field) def _axis_index(values: list[float]) -> dict[float, int]: return {value: index for index, value in enumerate(values)} def _finite_metric_tensor_from_records( records: Iterable[Mapping[str, Any]], *, surface_keys: Sequence[str], alpha_keys: Sequence[str], kx_keys: Sequence[str], residual_keys: Sequence[str], damping_keys: Sequence[str], linear_growth_keys: Sequence[str], recurrence_keys: Sequence[str], missing_damping_policy: MissingDampingPolicy, ) -> tuple[ list[float], list[float], list[float], np.ndarray, np.ndarray, np.ndarray, np.ndarray, list[dict[str, float]], int, int, ]: if missing_damping_policy not in {"fail", "zero"}: raise ValueError("missing_damping_policy must be 'fail' or 'zero'") normalized: list[dict[str, float]] = [] missing_damping_count = 0 missing_recurrence_count = 0 for record in records: surface = _optional_metric(record, surface_keys, field="surface", default=0.0) alpha = _optional_metric(record, alpha_keys, field="alpha", default=0.0) kx = _required_float(record, kx_keys, field="kx") residual_value = _required_float(record, residual_keys, field="residual_level") if residual_value <= 0.0: raise ValueError("residual_level must be strictly positive in every record") damping_value = _optional_metric(record, damping_keys, field="damping_rate") if damping_value is None: missing_damping_count += 1 if missing_damping_policy == "fail": raise ValueError("record is missing finite damping_rate") damping_value = 0.0 growth_value = _optional_metric(record, linear_growth_keys, field="linear_growth_rate", default=0.0) recurrence_value = _optional_metric(record, recurrence_keys, field="recurrence_amplitude") if recurrence_value is None: missing_recurrence_count += 1 recurrence_value = 0.0 normalized.append( { "surface": float(surface if surface is not None else 0.0), "alpha": float(alpha if alpha is not None else 0.0), "kx": float(kx), "residual_level": float(residual_value), "damping_rate": float(damping_value), "linear_growth_rate": float(growth_value if growth_value is not None else 0.0), "recurrence_amplitude": float(recurrence_value), } ) if not normalized: raise ValueError("at least one zonal-flow objective record is required") surfaces = sorted({row["surface"] for row in normalized}) alphas = sorted({row["alpha"] for row in normalized}) kx_values = sorted({row["kx"] for row in normalized}) surface_index = _axis_index(surfaces) alpha_index = _axis_index(alphas) kx_index = _axis_index(kx_values) shape = (len(surfaces), len(alphas), len(kx_values)) residual_tensor = np.full(shape, np.nan, dtype=float) damping_tensor = np.full(shape, np.nan, dtype=float) growth_tensor = np.full(shape, np.nan, dtype=float) recurrence_tensor = np.full(shape, np.nan, dtype=float) seen: set[tuple[int, int, int]] = set() for row in normalized: index = (surface_index[row["surface"]], alpha_index[row["alpha"]], kx_index[row["kx"]]) if index in seen: raise ValueError( "duplicate zonal-flow objective record for " f"surface={row['surface']}, alpha={row['alpha']}, kx={row['kx']}" ) seen.add(index) residual_tensor[index] = row["residual_level"] damping_tensor[index] = row["damping_rate"] growth_tensor[index] = row["linear_growth_rate"] recurrence_tensor[index] = row["recurrence_amplitude"] for name, tensor in ( ("residual_level", residual_tensor), ("damping_rate", damping_tensor), ("linear_growth_rate", growth_tensor), ("recurrence_amplitude", recurrence_tensor), ): if not np.all(np.isfinite(tensor)): raise ValueError(f"records do not form a complete finite tensor for {name}") return ( surfaces, alphas, kx_values, residual_tensor, damping_tensor, growth_tensor, recurrence_tensor, normalized, missing_damping_count, missing_recurrence_count, )
[docs] def zonal_flow_objective_rows( *, residual_level: Any, damping_rate: Any, linear_growth_rate: Any | None = None, recurrence_amplitude: Any | None = None, config: ZonalFlowObjectiveConfig | None = None, ) -> jnp.ndarray: """Return objective rows with shape ``(surface, alpha, kx, objective)``. ``residual_level`` is the late-time residual normalized to the initial zonal potential. Larger residuals reduce the first objective column. ``damping_rate`` should be positive for decaying GAM/zonal envelopes. ``linear_growth_rate`` is optional and encodes a suppression-relevance metric: high ITG growth with weak residuals is penalized. The recurrence column should be a non-negative late-envelope or moment-tail amplitude. """ cfg = config or ZonalFlowObjectiveConfig() residual = _metric_tensor(residual_level, name="residual_level", strictly_positive=True) damping = _metric_tensor(damping_rate, name="damping_rate") growth = ( jnp.zeros_like(residual) if linear_growth_rate is None else _metric_tensor(linear_growth_rate, name="linear_growth_rate") ) recurrence = ( jnp.zeros_like(residual) if recurrence_amplitude is None else _metric_tensor(recurrence_amplitude, name="recurrence_amplitude") ) try: residual, damping, growth, recurrence = jnp.broadcast_arrays(residual, damping, growth, recurrence) except ValueError as exc: raise ValueError("all zonal-flow metric tensors must be broadcast-compatible") from exc safe_residual = jnp.maximum(residual, jnp.asarray(float(cfg.residual_floor), dtype=residual.dtype)) rows = jnp.stack( ( 1.0 / safe_residual, jnp.maximum(damping, 0.0), jnp.maximum(growth, 0.0) / safe_residual, jnp.maximum(recurrence, 0.0), ), axis=-1, ) return rows.astype(jnp.result_type(rows, jnp.asarray(1.0)))
[docs] def zonal_flow_reduced_objective( *, residual_level: Any, damping_rate: Any, linear_growth_rate: Any | None = None, recurrence_amplitude: Any | None = None, config: ZonalFlowObjectiveConfig | None = None, sample_weights: Any | None = None, surface_weights: Any | None = None, alpha_weights: Any | None = None, ky_weights: Any | None = None, reduction: PortfolioReduction = "weighted_mean", ) -> jnp.ndarray: """Reduce zonal-flow metric tensors to one differentiable scalar objective.""" cfg = config or ZonalFlowObjectiveConfig() rows = zonal_flow_objective_rows( residual_level=residual_level, damping_rate=damping_rate, linear_growth_rate=linear_growth_rate, recurrence_amplitude=recurrence_amplitude, config=cfg, ) return aggregate_objective_portfolio( rows, sample_weights=sample_weights, surface_weights=surface_weights, alpha_weights=alpha_weights, ky_weights=ky_weights, objective_weights=cfg.objective_weights(), reduction=reduction, )
[docs] def zonal_flow_objective_artifact_from_records( records: Iterable[Mapping[str, Any]], *, config: ZonalFlowObjectiveConfig | None = None, surface_keys: Sequence[str] = ("surface", "surface_index", "torflux"), alpha_keys: Sequence[str] = ("alpha", "field_line_label"), kx_keys: Sequence[str] = ("kx", "kx_target", "kx_rhoi"), residual_keys: Sequence[str] = ("residual_level", "spectrax_residual"), damping_keys: Sequence[str] = ("damping_rate", "gam_damping_rate"), linear_growth_keys: Sequence[str] = ("linear_growth_rate", "growth_rate", "gamma"), recurrence_keys: Sequence[str] = ( "recurrence_amplitude", "tail_std_ratio", "residual_std", "tail_std", ), missing_damping_policy: MissingDampingPolicy = "fail", claim_level: str | None = None, source_paths: Sequence[str] | None = None, reduction: PortfolioReduction = "weighted_mean", ) -> dict[str, object]: """Build a strict JSON-friendly zonal-flow objective artifact. The input is a table of validated zonal-response metrics. Rows are mapped onto the shared ``(surface, alpha, kx)`` portfolio tensor used by the stellarator objective stack. Missing damping rates fail by default because a promoted zonal-flow optimization claim must know the damping convention. Diagnostic artifacts can set ``missing_damping_policy='zero'`` to produce rows while carrying an explicit ``promotion_ready=False`` flag. """ cfg = config or ZonalFlowObjectiveConfig() ( surfaces, alphas, kx_values, residual, damping, growth, recurrence, normalized, missing_damping_count, missing_recurrence_count, ) = _finite_metric_tensor_from_records( records, surface_keys=surface_keys, alpha_keys=alpha_keys, kx_keys=kx_keys, residual_keys=residual_keys, damping_keys=damping_keys, linear_growth_keys=linear_growth_keys, recurrence_keys=recurrence_keys, missing_damping_policy=missing_damping_policy, ) rows = zonal_flow_objective_rows( residual_level=residual, damping_rate=damping, linear_growth_rate=growth, recurrence_amplitude=recurrence, config=cfg, ) reduced = zonal_flow_reduced_objective( residual_level=residual, damping_rate=damping, linear_growth_rate=growth, recurrence_amplitude=recurrence, config=cfg, reduction=reduction, ) rows_np = np.asarray(rows, dtype=float) weights = np.asarray(cfg.objective_weights(), dtype=float) normalized_weights = weights / float(np.sum(weights)) row_table: list[dict[str, float | str]] = [] for item in normalized: i = surfaces.index(item["surface"]) j = alphas.index(item["alpha"]) k = kx_values.index(item["kx"]) objective_row = rows_np[i, j, k, :] row_table.append( { "surface": item["surface"], "alpha": item["alpha"], "kx": item["kx"], "residual_level": item["residual_level"], "damping_rate": item["damping_rate"], "linear_growth_rate": item["linear_growth_rate"], "recurrence_amplitude": item["recurrence_amplitude"], "inverse_residual": float(objective_row[0]), "growth_over_residual": float(objective_row[2]), "sample_objective": float(np.dot(objective_row, normalized_weights)), } ) promotion_ready = missing_damping_count == 0 and missing_recurrence_count == 0 payload_claim = claim_level or ( "promotable_zonal_flow_objective_rows" if promotion_ready else "diagnostic_zonal_flow_objective_rows_not_promoted_to_optimization_claim" ) return { "kind": "zonal_flow_objective_artifact", "claim_level": payload_claim, "promotion_ready": bool(promotion_ready), "objective_names": list(ZONAL_FLOW_OBJECTIVE_NAMES), "objective_config": cfg.to_dict(), "missing_damping_policy": missing_damping_policy, "missing_damping_count": int(missing_damping_count), "missing_recurrence_count": int(missing_recurrence_count), "axes": { "surface": [float(value) for value in surfaces], "alpha": [float(value) for value in alphas], "kx": [float(value) for value in kx_values], }, "sample_count": int(len(normalized)), "metrics": { "residual_level": residual.tolist(), "damping_rate": damping.tolist(), "linear_growth_rate": growth.tolist(), "recurrence_amplitude": recurrence.tolist(), }, "objective_rows": rows_np.tolist(), "reduced_objective": float(np.asarray(reduced)), "row_table": row_table, "source_paths": list(source_paths or []), "reduction": reduction, }
def _metric_mapping_rows( metrics: Mapping[str, Any], *, config: ZonalFlowObjectiveConfig, ) -> jnp.ndarray: if "residual_level" not in metrics or "damping_rate" not in metrics: raise ValueError("metric_fn must return residual_level and damping_rate") return zonal_flow_objective_rows( residual_level=metrics["residual_level"], damping_rate=metrics["damping_rate"], linear_growth_rate=metrics.get("linear_growth_rate"), recurrence_amplitude=metrics.get("recurrence_amplitude"), config=config, )
[docs] def zonal_flow_objective_sensitivity_report( metric_fn: Callable[[jnp.ndarray], Mapping[str, Any]], params: Any, *, config: ZonalFlowObjectiveConfig | None = None, sample_weights: Any | None = None, surface_weights: Any | None = None, alpha_weights: Any | None = None, ky_weights: Any | None = None, reduction: PortfolioReduction = "weighted_mean", step: float = 1.0e-4, rtol: float = 1.0e-4, atol: float = 1.0e-6, min_rank: int | None = None, condition_number_limit: float = 1.0e8, covariance_regularization: float = 1.0e-9, workers: int = 1, parallel_executor: str = "thread", ) -> dict[str, object]: """AD/FD, row-Jacobian, and UQ gate for a zonal-flow optimization map.""" cfg = config or ZonalFlowObjectiveConfig() def row_fn(x: jnp.ndarray) -> jnp.ndarray: return _metric_mapping_rows(metric_fn(x), config=cfg) report = objective_portfolio_sensitivity_report( row_fn, params, sample_weights=sample_weights, surface_weights=surface_weights, alpha_weights=alpha_weights, ky_weights=ky_weights, objective_weights=cfg.objective_weights(), reduction=reduction, step=step, rtol=rtol, atol=atol, min_rank=min_rank, condition_number_limit=condition_number_limit, covariance_regularization=covariance_regularization, workers=workers, parallel_executor=parallel_executor, ) report["kind"] = "zonal_flow_objective_sensitivity_report" report["objective_config"] = cfg.to_dict() report["objective_names"] = list(ZONAL_FLOW_OBJECTIVE_NAMES) report["claim_level"] = ( "reduced_zonal_flow_objective_gradient_gate_not_nonlinear_turbulence_suppression_claim" ) return report