Source code for spectraxgk.quasilinear

"""Quasilinear transport diagnostics from linear gyrokinetic states.

The routines in this module compute linear heat and particle flux weights
from an eigenstate or late-time linear state.  Saturation rules are kept
explicitly separate from the linear weights so calibration and uncertainty
metadata can be audited case by case.
"""

from __future__ import annotations

from dataclasses import asdict, dataclass
from typing import Any, Iterable, Sequence

import jax.numpy as jnp
import numpy as np

from spectraxgk.diagnostics import (
    gx_Wphi,
    gx_heat_flux_species,
    gx_particle_flux_species,
    gx_volume_factors,
)
from spectraxgk.geometry import FluxTubeGeometryLike
from spectraxgk.grids import SpectralGrid
from spectraxgk.linear import LinearCache, LinearParams
from spectraxgk.terms.assembly import compute_fields_cached
from spectraxgk.terms.config import TermConfig


_SUPPORTED_NORMALIZATIONS = {"phi_rms", "phi_midplane", "field_energy"}
_SUPPORTED_RULES = {
    "none",
    "mixing_length",
    "lapillonne_2011",
    "linear_weight",
    "absolute_growth_mixing_length",
    "abs_growth_mixing_length",
}
_SUPPORTED_MODES = {"weights", "saturated"}


[docs] @dataclass(frozen=True) class QuasilinearTransportResult: """JSON-friendly quasilinear diagnostic payload for one linear mode.""" ky: float gamma: float omega: float mode: str saturation_rule: str amplitude_normalization: str channels: tuple[str, ...] kperp_average: str kperp_eff2: float phi_norm2: float amplitude2: float | None heat_flux_weight_species: tuple[float, ...] particle_flux_weight_species: tuple[float, ...] saturated_heat_flux_species: tuple[float, ...] | None saturated_particle_flux_species: tuple[float, ...] | None species: tuple[str, ...] metadata: dict[str, Any]
[docs] def to_dict(self) -> dict[str, Any]: """Return a stable JSON-serializable representation.""" payload = asdict(self) payload["channels"] = list(self.channels) payload["species"] = list(self.species) payload["heat_flux_weight_species"] = list(self.heat_flux_weight_species) payload["particle_flux_weight_species"] = list(self.particle_flux_weight_species) if self.saturated_heat_flux_species is not None: payload["saturated_heat_flux_species"] = list(self.saturated_heat_flux_species) if self.saturated_particle_flux_species is not None: payload["saturated_particle_flux_species"] = list(self.saturated_particle_flux_species) payload["heat_flux_weight_total"] = float(sum(self.heat_flux_weight_species)) payload["particle_flux_weight_total"] = float(sum(self.particle_flux_weight_species)) if self.saturated_heat_flux_species is not None: payload["saturated_heat_flux_total"] = float(sum(self.saturated_heat_flux_species)) else: payload["saturated_heat_flux_total"] = None if self.saturated_particle_flux_species is not None: payload["saturated_particle_flux_total"] = float(sum(self.saturated_particle_flux_species)) else: payload["saturated_particle_flux_total"] = None return payload
[docs] def normalize_quasilinear_channels(channels: Iterable[str] | str) -> tuple[str, ...]: """Normalize and validate quasilinear field channels.""" values: tuple[str, ...] if isinstance(channels, str): values = (channels,) else: values = tuple(str(ch).strip().lower() for ch in channels) values = tuple(dict.fromkeys(ch for ch in values if ch)) if not values: values = ("es",) unsupported = [ch for ch in values if ch != "es"] if unsupported: raise NotImplementedError( "Only electrostatic quasilinear flux channels are validated so far; " f"unsupported channels: {unsupported}" ) return values
[docs] def spectral_phi_weights( phi: jnp.ndarray, cache: LinearCache, vol_fac: jnp.ndarray, *, use_dealias: bool = True, ) -> jnp.ndarray: """Return ``(ky, kx, z)`` weights used for ``|phi|^2`` averages.""" ky = jnp.asarray(cache.ky) has_negative = jnp.any(ky < 0.0) fac = jnp.where(has_negative, 1.0, jnp.where(ky == 0.0, 1.0, 2.0)) fac = fac[:, None] * jnp.ones((1, cache.kx.size), dtype=fac.dtype) if use_dealias: fac = fac * cache.dealias_mask.astype(fac.dtype) return (jnp.abs(phi) ** 2) * fac[:, :, None] * vol_fac[None, None, :]
[docs] def effective_kperp2( phi: jnp.ndarray, cache: LinearCache, vol_fac: jnp.ndarray, *, use_dealias: bool = True, eps: float = 1.0e-30, ) -> jnp.ndarray: """Compute ``<k_perp^2 |phi|^2>/<|phi|^2>`` for a linear mode.""" weights = spectral_phi_weights(phi, cache, vol_fac, use_dealias=use_dealias) denom = jnp.sum(weights) return jnp.sum(cache.kperp2 * weights) / jnp.maximum(denom, jnp.asarray(eps, dtype=denom.dtype))
[docs] def phi_norm2( phi: jnp.ndarray, cache: LinearCache, params: LinearParams, vol_fac: jnp.ndarray, *, normalization: str = "phi_rms", use_dealias: bool = True, eps: float = 1.0e-30, ) -> jnp.ndarray: """Return the amplitude normalization used for quasilinear weights.""" norm_key = normalization.strip().lower() if norm_key not in _SUPPORTED_NORMALIZATIONS: raise ValueError(f"Unknown quasilinear amplitude normalization '{normalization}'") if norm_key == "phi_rms": return jnp.maximum( jnp.sum(spectral_phi_weights(phi, cache, vol_fac, use_dealias=use_dealias)), jnp.asarray(eps, dtype=jnp.real(phi).dtype), ) if norm_key == "phi_midplane": z_index = int(phi.shape[-1] // 2) return jnp.maximum( jnp.max(jnp.abs(phi[:, :, z_index]) ** 2), jnp.asarray(eps, dtype=jnp.real(phi).dtype), ) return jnp.maximum( gx_Wphi(phi, cache, params, vol_fac, use_dealias=use_dealias), jnp.asarray(eps, dtype=jnp.real(phi).dtype), )
[docs] def saturation_amplitude2( *, gamma: float, kperp_eff2_value: float, rule: str, csat: float = 1.0, gamma_floor: float = 0.0, include_stable_modes: bool = False, ) -> float | None: """Return the squared amplitude implied by a named saturation rule.""" rule_key = rule.strip().lower() if rule_key not in _SUPPORTED_RULES: raise NotImplementedError(f"Quasilinear saturation rule '{rule}' is not implemented") if rule_key == "none": return None if kperp_eff2_value <= 0.0 or not np.isfinite(kperp_eff2_value): return 0.0 drive = float(gamma) - float(gamma_floor) if not include_stable_modes: drive = max(drive, 0.0) if rule_key == "linear_weight": return float(csat) if rule_key in {"absolute_growth_mixing_length", "abs_growth_mixing_length"}: return float(csat) * abs(float(gamma)) / float(kperp_eff2_value) if rule_key in {"mixing_length", "lapillonne_2011"}: return float(csat) * drive / float(kperp_eff2_value) return None
[docs] def mixing_length_amplitude2_jax( gamma: jnp.ndarray | float, kperp_eff2_value: jnp.ndarray | float, *, csat: float = 1.0, gamma_floor: float = 0.0, include_stable_modes: bool = False, eps: float = 1.0e-30, ) -> jnp.ndarray: """JAX-differentiable mixing-length squared-amplitude rule.""" gamma_arr = jnp.asarray(gamma) kperp_arr = jnp.asarray(kperp_eff2_value) drive = gamma_arr - jnp.asarray(gamma_floor, dtype=gamma_arr.dtype) if not include_stable_modes: drive = jnp.maximum(drive, jnp.asarray(0.0, dtype=gamma_arr.dtype)) denom = jnp.maximum(kperp_arr, jnp.asarray(eps, dtype=kperp_arr.dtype)) return jnp.asarray(csat, dtype=gamma_arr.dtype) * drive / denom
[docs] def saturated_flux_from_linear_weight( linear_flux_weight: jnp.ndarray | float, gamma: jnp.ndarray | float, kperp_eff2_value: jnp.ndarray | float, *, csat: float = 1.0, gamma_floor: float = 0.0, include_stable_modes: bool = False, ) -> jnp.ndarray: """Return a differentiable mixing-length saturated flux estimate.""" amp2 = mixing_length_amplitude2_jax( gamma, kperp_eff2_value, csat=csat, gamma_floor=gamma_floor, include_stable_modes=include_stable_modes, ) return jnp.asarray(linear_flux_weight) * amp2
[docs] def quasilinear_feature_objective( features: jnp.ndarray | Sequence[float], *, rule: str = "mixing_length", csat: float = 1.0, gamma_floor: float = 0.0, include_stable_modes: bool = False, ) -> jnp.ndarray: """Differentiable objective from ``[gamma, kperp_eff2, flux_weight]``. This helper is intentionally small: it is the reduced objective used by derivative validation tests and optimization examples once a linear scan has produced quasilinear weights. """ x = jnp.asarray(features) if x.shape[-1] != 3: raise ValueError("features must end with [gamma, kperp_eff2, flux_weight]") rule_key = rule.strip().lower() if rule_key == "linear_weight": return jnp.asarray(csat, dtype=x.dtype) * x[..., 2] if rule_key in {"absolute_growth_mixing_length", "abs_growth_mixing_length"}: denom = jnp.maximum(x[..., 1], jnp.asarray(1.0e-30, dtype=x.dtype)) return jnp.asarray(csat, dtype=x.dtype) * jnp.abs(x[..., 0]) * x[..., 2] / denom if rule_key not in {"mixing_length", "lapillonne_2011"}: raise NotImplementedError(f"Quasilinear feature rule '{rule}' is not implemented") return saturated_flux_from_linear_weight( x[..., 2], x[..., 0], x[..., 1], csat=csat, gamma_floor=gamma_floor, include_stable_modes=include_stable_modes, )
[docs] def shape_aware_power_law_objective( features: jnp.ndarray | Sequence[float], ky: jnp.ndarray | Sequence[float] | float, *, exponent: jnp.ndarray | float, csat: float = 1.0, ky_ref: float | None = None, eps: float = 1.0e-30, ) -> jnp.ndarray: """Differentiable shape-aware linear-weight objective. ``features`` must end with ``[gamma, kperp_eff2, flux_weight]``. The current low-dimensional shape model intentionally uses only the linear heat-flux weight and a power-law envelope in ``ky``: ``Q = C_sat * flux_weight * (ky / ky_ref)**exponent``. Growth-rate dependence is left to separately validated rules. This helper exists so the shape-aware saturation diagnostics and future optimization examples use one differentiable objective rather than plotting-only formulas. """ dtype = jnp.result_type(features, ky, exponent, jnp.float32) x = jnp.asarray(features, dtype=dtype) if x.shape[-1] != 3: raise ValueError("features must end with [gamma, kperp_eff2, flux_weight]") ky_arr = jnp.asarray(ky, dtype=dtype) eps_arr = jnp.asarray(eps, dtype=dtype) positive_ky = jnp.maximum(ky_arr, eps_arr) if ky_ref is None: ref = jnp.exp(jnp.mean(jnp.log(positive_ky))) else: ref = jnp.maximum(jnp.asarray(ky_ref, dtype=dtype), eps_arr) envelope = (positive_ky / ref) ** jnp.asarray(exponent, dtype=dtype) return jnp.asarray(csat, dtype=dtype) * x[..., 2] * envelope
[docs] def compute_quasilinear_from_linear_state( state: jnp.ndarray | np.ndarray, *, cache: LinearCache, grid: SpectralGrid, geom: FluxTubeGeometryLike, params: LinearParams, ky: float, gamma: float, omega: float, terms: TermConfig | None = None, mode: str = "weights", saturation_rule: str = "none", amplitude_normalization: str = "phi_rms", kperp_average: str = "phi_weighted", csat: float = 1.0, gamma_floor: float = 0.0, include_stable_modes: bool = False, channels: Sequence[str] | str = ("es",), species_names: Sequence[str] | None = None, use_dealias: bool = True, flux_scale: float = 1.0, metadata: dict[str, Any] | None = None, ) -> QuasilinearTransportResult: """Compute quasilinear transport weights from a linear state. The returned heat and particle flux weights are divided by the selected mode-amplitude normalization, so they are invariant under complex phase rotations and real amplitude rescalings of the eigenstate. """ mode_key = mode.strip().lower() if mode_key not in _SUPPORTED_MODES: raise ValueError(f"Unknown quasilinear mode '{mode}'") channels_use = normalize_quasilinear_channels(channels) kperp_key = kperp_average.strip().lower() if kperp_key != "phi_weighted": raise NotImplementedError("Only phi_weighted kperp averaging is validated so far") G = jnp.asarray(state) fields = compute_fields_cached(G, cache, params, terms=terms) phi = fields.phi zero_field = jnp.zeros_like(phi) apar = zero_field bpar = zero_field vol_fac, flux_fac = gx_volume_factors(geom, grid) norm2 = phi_norm2( phi, cache, params, vol_fac, normalization=amplitude_normalization, use_dealias=use_dealias, ) kperp_eff = effective_kperp2(phi, cache, vol_fac, use_dealias=use_dealias) heat = gx_heat_flux_species( G, phi, apar, bpar, cache, grid, params, flux_fac, use_dealias=use_dealias, flux_scale=flux_scale, ) particle = gx_particle_flux_species( G, phi, apar, bpar, cache, grid, params, flux_fac, use_dealias=use_dealias, flux_scale=flux_scale, ) heat_weights = jnp.real(heat / norm2) particle_weights = jnp.real(particle / norm2) heat_np = np.asarray(heat_weights, dtype=float).reshape(-1) particle_np = np.asarray(particle_weights, dtype=float).reshape(-1) species = tuple(species_names or tuple(f"s{i}" for i in range(heat_np.size))) if len(species) != heat_np.size: species = tuple(f"s{i}" for i in range(heat_np.size)) amp2 = saturation_amplitude2( gamma=gamma, kperp_eff2_value=float(np.asarray(kperp_eff)), rule=saturation_rule, csat=csat, gamma_floor=gamma_floor, include_stable_modes=include_stable_modes, ) saturated_heat = None saturated_particle = None if mode_key == "saturated" and amp2 is not None: saturated_heat = tuple(float(x) for x in heat_np * amp2) saturated_particle = tuple(float(x) for x in particle_np * amp2) meta = dict(metadata or {}) meta.setdefault("claim_level", "linear_weights" if amp2 is None else "uncalibrated_saturation_rule") meta.setdefault("field_channels_validated", list(channels_use)) meta.setdefault("electromagnetic_channels", "disabled_until_validated") return QuasilinearTransportResult( ky=float(ky), gamma=float(gamma), omega=float(omega), mode=mode_key, saturation_rule=saturation_rule.strip().lower(), amplitude_normalization=amplitude_normalization.strip().lower(), channels=channels_use, kperp_average=kperp_key, kperp_eff2=float(np.asarray(kperp_eff)), phi_norm2=float(np.asarray(norm2)), amplitude2=None if amp2 is None else float(amp2), heat_flux_weight_species=tuple(float(x) for x in heat_np), particle_flux_weight_species=tuple(float(x) for x in particle_np), saturated_heat_flux_species=saturated_heat, saturated_particle_flux_species=saturated_particle, species=species, metadata=meta, )