Source code for spectraxgk.linear_params

"""Linear parameter, term-toggle, and validation policy helpers."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable

import jax
import jax.numpy as jnp
import numpy as np

if TYPE_CHECKING:
    from spectraxgk.terms.config import TermConfig

__all__ = [
    "LinearParams",
    "LinearTerms",
    "Preconditioner",
    "PreconditionerSpec",
    "_as_species_array",
    "_check_nonnegative",
    "_check_positive",
    "_is_tracer",
    "_resolve_implicit_preconditioner",
    "_x64_enabled",
    "linear_terms_to_term_config",
    "term_config_to_linear_terms",
]


[docs] @jax.tree_util.register_pytree_node_class @dataclass(frozen=True) class LinearParams: """Parameters for the linear gyrokinetic operator (supports multi-species arrays).""" charge_sign: float | jnp.ndarray = 1.0 density: float | jnp.ndarray = 1.0 mass: float | jnp.ndarray = 1.0 temp: float | jnp.ndarray = 1.0 tau_e: float = 1.0 vth: float | jnp.ndarray = 1.0 rho: float | jnp.ndarray = 1.0 kpar_scale: float = 1.0 R_over_Ln: float | jnp.ndarray = 2.2 R_over_LTi: float | jnp.ndarray = 6.9 R_over_LTe: float | jnp.ndarray = 0.0 omega_d_scale: float = 1.0 omega_star_scale: float = 1.0 energy_const: float = 0.0 energy_par_coef: float = 0.5 energy_perp_coef: float = 1.0 nu: float | jnp.ndarray = 0.0 nu_hermite: float = 1.0 nu_laguerre: float = 2.0 nu_hyper: float = 0.0 p_hyper: float = 4.0 nu_hyper_l: float = 0.0 nu_hyper_m: float = 1.0 nu_hyper_lm: float = 0.0 p_hyper_l: float = 6.0 p_hyper_m: float = 20.0 p_hyper_lm: float = 6.0 hypercollisions_const: float = 1.0 hypercollisions_kz: float = 0.0 D_hyper: float = 0.0 p_hyper_kperp: float = 2.0 damp_ends_widthfrac: float | jnp.ndarray = 0.125 damp_ends_amp: float | jnp.ndarray = 0.1 tz: float | jnp.ndarray = 1.0 rho_star: float = 1.0 beta: float = 0.0 fapar: float = 0.0 apar_beta_scale: float = 0.5 ampere_g0_scale: float = 0.5 bpar_beta_scale: float = 0.5 def tree_flatten(self): children = ( self.charge_sign, self.density, self.mass, self.temp, self.tau_e, self.vth, self.rho, self.kpar_scale, self.R_over_Ln, self.R_over_LTi, self.R_over_LTe, self.omega_d_scale, self.omega_star_scale, self.energy_const, self.energy_par_coef, self.energy_perp_coef, self.nu, self.nu_hermite, self.nu_laguerre, self.nu_hyper, self.p_hyper, self.nu_hyper_l, self.nu_hyper_m, self.nu_hyper_lm, self.p_hyper_l, self.p_hyper_m, self.p_hyper_lm, self.hypercollisions_const, self.hypercollisions_kz, self.D_hyper, self.p_hyper_kperp, self.damp_ends_widthfrac, self.damp_ends_amp, self.tz, self.rho_star, self.beta, self.fapar, self.apar_beta_scale, self.ampere_g0_scale, self.bpar_beta_scale, ) return children, None @classmethod def tree_unflatten(cls, aux_data, children): return cls(*children)
[docs] @jax.tree_util.register_pytree_node_class @dataclass(frozen=True) class LinearTerms: """Switches for linear-operator components (1.0 = on, 0.0 = off).""" streaming: float = 1.0 mirror: float = 1.0 curvature: float = 1.0 gradb: float = 1.0 diamagnetic: float = 1.0 collisions: float = 1.0 hypercollisions: float = 1.0 hyperdiffusion: float = 0.0 end_damping: float = 1.0 apar: float = 1.0 bpar: float = 1.0 def tree_flatten(self): children = ( self.streaming, self.mirror, self.curvature, self.gradb, self.diamagnetic, self.collisions, self.hypercollisions, self.hyperdiffusion, self.end_damping, self.apar, self.bpar, ) return children, None @classmethod def tree_unflatten(cls, aux_data, children): return cls(*children)
[docs] def linear_terms_to_term_config( terms: LinearTerms | None, *, nonlinear: float = 0.0, ) -> TermConfig: """Convert :class:`LinearTerms` into the modular :class:`TermConfig`.""" from spectraxgk.terms.config import TermConfig term_weights = terms if terms is not None else LinearTerms() return TermConfig( streaming=term_weights.streaming, mirror=term_weights.mirror, curvature=term_weights.curvature, gradb=term_weights.gradb, diamagnetic=term_weights.diamagnetic, collisions=term_weights.collisions, hypercollisions=term_weights.hypercollisions, hyperdiffusion=term_weights.hyperdiffusion, end_damping=term_weights.end_damping, apar=term_weights.apar, bpar=term_weights.bpar, nonlinear=nonlinear, )
[docs] def term_config_to_linear_terms(term_cfg: TermConfig | None) -> LinearTerms: """Convert modular :class:`TermConfig` into linear-only term weights.""" from spectraxgk.terms.config import TermConfig cfg = term_cfg if term_cfg is not None else TermConfig() return LinearTerms( streaming=cfg.streaming, mirror=cfg.mirror, curvature=cfg.curvature, gradb=cfg.gradb, diamagnetic=cfg.diamagnetic, collisions=cfg.collisions, hypercollisions=cfg.hypercollisions, hyperdiffusion=cfg.hyperdiffusion, end_damping=cfg.end_damping, apar=cfg.apar, bpar=cfg.bpar, )
def _is_tracer(x) -> bool: return isinstance(x, jax.core.Tracer) def _x64_enabled() -> bool: return bool(getattr(jax.config, "jax_enable_x64", False)) def _check_positive(x, name: str) -> None: arr = jnp.asarray(x) if _is_tracer(x) or _is_tracer(arr): return if arr.ndim == 0: if float(arr) <= 0.0: raise ValueError(f"{name} must be > 0") return if np.any(np.asarray(arr) <= 0.0): raise ValueError(f"{name} must be > 0") def _check_nonnegative(x, name: str) -> None: arr = jnp.asarray(x) if _is_tracer(x) or _is_tracer(arr): return if arr.ndim == 0: if float(arr) < 0.0: raise ValueError(f"{name} must be >= 0") return if np.any(np.asarray(arr) < 0.0): raise ValueError(f"{name} must be >= 0")
[docs] def _as_species_array(value: float | jnp.ndarray, ns: int, name: str) -> jnp.ndarray: """Ensure a parameter is a 1D array of length ns for multi-species handling.""" arr = jnp.asarray(value) if arr.ndim == 0: arr = arr[None] if arr.size == 1: return jnp.broadcast_to(arr, (ns,)) if int(arr.size) != int(ns): raise ValueError(f"{name} must have length {ns} (got {arr.size})") return arr
Preconditioner = Callable[[jnp.ndarray], jnp.ndarray] PreconditionerSpec = Preconditioner | str | None def _resolve_implicit_preconditioner(preconditioner: PreconditionerSpec) -> PreconditionerSpec: if preconditioner is None: return "auto" if isinstance(preconditioner, str): return preconditioner.strip().lower() return preconditioner