Source code for spectraxgk.linear_moments

"""Hermite-Laguerre moment and field-coupling primitives for linear operators."""

from __future__ import annotations

import jax
import jax.numpy as jnp

from spectraxgk.basis import hermite_ladder_coeffs
from spectraxgk.geometry import FluxTubeGeometryLike
from spectraxgk.grids import SpectralGrid
from spectraxgk.linear_params import _check_nonnegative, _check_positive

__all__ = [
    "apply_hermite_v",
    "apply_hermite_v2",
    "apply_laguerre_x",
    "build_H",
    "compute_b",
    "diamagnetic_drive_coeffs",
    "energy_operator",
    "grad_z_periodic",
    "lenard_bernstein_eigenvalues",
    "quasineutrality_phi",
    "shift_axis",
    "streaming_term",
]


[docs] def grad_z_periodic(f: jnp.ndarray, dz: float | jnp.ndarray) -> jnp.ndarray: """Spectral periodic derivative along the last axis.""" _check_positive(dz, "dz") n = f.shape[-1] dz_val = jnp.asarray(dz, dtype=jnp.real(f).dtype) kz = 2.0 * jnp.pi * jnp.fft.fftfreq(n, d=dz_val) f_hat = jnp.fft.fft(f, axis=-1) df_hat = (1j * kz) * f_hat return jnp.fft.ifft(df_hat, axis=-1)
[docs] def compute_b( grid: SpectralGrid, geom: FluxTubeGeometryLike, rho: float ) -> jnp.ndarray: """Compute b = rho^2 * k_perp^2(kx, ky, theta) for s-alpha geometry.""" _check_positive(rho, "rho") kx0 = grid.kx[None, :, None] ky = grid.ky[:, None, None] theta = grid.z[None, None, :] kperp2 = geom.k_perp2(kx0, ky, theta) return (rho * rho) * kperp2
[docs] def lenard_bernstein_eigenvalues( Nl: int, Nm: int, nu_hermite: float, nu_laguerre: float ) -> jnp.ndarray: """Diagonal Lenard-Bernstein rates in Hermite-Laguerre space.""" ell = jnp.arange(Nl) m = jnp.arange(Nm) return nu_laguerre * ell[:, None] + nu_hermite * m[None, :]
[docs] def apply_hermite_v(G: jnp.ndarray) -> jnp.ndarray: """Multiply Hermite coefficients by v_parallel (ladder form).""" axis_m = -4 Nm = G.shape[axis_m] sqrt_p, sqrt_m = hermite_ladder_coeffs(Nm - 1) sqrt_p = sqrt_p[:Nm] sqrt_m = sqrt_m[:Nm] G_plus = shift_axis(G, 1, axis_m) G_minus = shift_axis(G, -1, axis_m) shape = [1] * G.ndim shape[axis_m] = Nm sqrt_p = sqrt_p.reshape(shape) sqrt_m = sqrt_m.reshape(shape) return sqrt_p * G_plus + sqrt_m * G_minus
[docs] def apply_hermite_v2(G: jnp.ndarray) -> jnp.ndarray: """Multiply Hermite coefficients by v_parallel^2.""" return apply_hermite_v(apply_hermite_v(G))
[docs] def apply_laguerre_x(G: jnp.ndarray) -> jnp.ndarray: """Multiply Laguerre coefficients by the perpendicular energy variable.""" axis_l = -5 Nl = G.shape[axis_l] ell = jnp.arange(Nl) G_plus = shift_axis(G, 1, axis_l) G_minus = shift_axis(G, -1, axis_l) ell_shape = [1] * G.ndim ell_shape[axis_l] = Nl ell_col = ell.reshape(ell_shape) return (2.0 * ell_col + 1.0) * G - (ell_col + 1.0) * G_plus - ell_col * G_minus
[docs] def shift_axis(arr: jnp.ndarray, offset: int, axis: int) -> jnp.ndarray: """Shift an array along an axis with zero padding (non-periodic).""" axis = axis % arr.ndim if offset == 0: return arr axis_len = arr.shape[axis] if abs(offset) >= axis_len: return jnp.zeros_like(arr) out = jnp.zeros_like(arr) if offset > 0: body = jax.lax.slice_in_dim(arr, offset, axis_len, axis=axis) starts = [0] * arr.ndim starts[axis] = 0 return jax.lax.dynamic_update_slice(out, body, starts) body = jax.lax.slice_in_dim(arr, 0, axis_len + offset, axis=axis) starts = [0] * arr.ndim starts[axis] = -offset return jax.lax.dynamic_update_slice(out, body, starts)
[docs] def energy_operator( G: jnp.ndarray, coeff_const: float, coeff_par: float, coeff_perp: float ) -> jnp.ndarray: """Apply the energy operator (1 + v_par^2 + mu) in Hermite-Laguerre space.""" return ( coeff_const * G + coeff_par * apply_hermite_v2(G) + coeff_perp * apply_laguerre_x(G) )
[docs] def diamagnetic_drive_coeffs( Nl: int, Nm: int, eta_i: jnp.ndarray, coeff_const: float, coeff_par: float, coeff_perp: float, ) -> jnp.ndarray: """Return velocity-space coefficients for (1 + eta_i(E - 3/2)).""" e00 = jnp.zeros((Nl, Nm, 1, 1, 1)) e00 = e00.at[0, 0, 0, 0, 0].set(1.0) energy_e00 = energy_operator(e00, coeff_const, coeff_par, coeff_perp) coeffs = e00 + eta_i * (energy_e00 - 1.5 * e00) return coeffs[:, :, 0, 0, 0]
[docs] def quasineutrality_phi( G: jnp.ndarray, Jl: jnp.ndarray, tau_e: float | jnp.ndarray, charge: jnp.ndarray, density: jnp.ndarray, tz: jnp.ndarray, ) -> jnp.ndarray: """Solve electrostatic quasineutrality for phi with optional adiabatic closure.""" _check_nonnegative(tau_e, "tau_e") Gm0 = G[:, :, 0, ...] num = jnp.sum( density[:, None, None, None] * charge[:, None, None, None] * jnp.sum(Jl * Gm0, axis=1), axis=0, ) g0 = jnp.sum(Jl * Jl, axis=1) zt = jnp.where(tz == 0.0, 0.0, 1.0 / tz) den = tau_e + jnp.sum( density[:, None, None, None] * charge[:, None, None, None] * zt[:, None, None, None] * (1.0 - g0), axis=0, ) den_safe = jnp.where(den == 0.0, jnp.inf, den) return num / den_safe
[docs] def build_H( G: jnp.ndarray, Jl: jnp.ndarray, phi: jnp.ndarray, tz: jnp.ndarray, apar: jnp.ndarray | None = None, vth: jnp.ndarray | None = None, bpar: jnp.ndarray | None = None, JlB: jnp.ndarray | None = None, ) -> jnp.ndarray: """Map G -> H for mirror/curvature/grad-B/collision terms. GX builds H by adding the field terms for m=0 (phi, Bpar) and the A_parallel term for m=1, while the streaming term applies its own pre-derivative field contributions. We mirror that behavior here. """ squeeze_species = False if G.ndim == 5: G = G[None, ...] squeeze_species = True if Jl.ndim == 4: Jl = Jl[None, ...] tz_arr = jnp.asarray(tz) if tz_arr.ndim == 0: tz_arr = tz_arr[None] zt_arr = jnp.where(tz_arr == 0.0, 0.0, 1.0 / tz_arr) Nm = G.shape[-4] m0_mask = (jnp.arange(Nm, dtype=jnp.int32) == 0).astype(G.dtype) m0_mask = m0_mask.reshape((1, 1, Nm, 1, 1, 1)) phi_term = (zt_arr[:, None, None, None, None] * Jl * phi)[:, :, None, ...] H = G + m0_mask * phi_term if apar is not None: if vth is None: raise ValueError("vth must be provided when apar is supplied") m1_mask = (jnp.arange(Nm, dtype=jnp.int32) == 1).astype(G.dtype) m1_mask = m1_mask.reshape((1, 1, Nm, 1, 1, 1)) vth_arr = jnp.asarray(vth) if vth_arr.ndim == 0: vth_arr = vth_arr[None] apar_term = ( zt_arr[:, None, None, None, None] * vth_arr[:, None, None, None, None] * Jl * apar )[:, :, None, ...] H = H - m1_mask * apar_term if bpar is not None: if JlB is None: raise ValueError("JlB must be provided when bpar is supplied") bpar_term = (JlB * bpar)[:, :, None, ...] H = H + m0_mask * bpar_term return H[0] if squeeze_species else H
[docs] def streaming_term( H: jnp.ndarray, dz: float | jnp.ndarray, vth: float | jnp.ndarray ) -> jnp.ndarray: """Streaming term using Hermite ladder and real-space z derivative.""" _check_positive(vth, "vth") dH_dz = grad_z_periodic(H, dz) axis_m = -4 Nm = H.shape[axis_m] sqrt_p, sqrt_m = hermite_ladder_coeffs(Nm - 1) sqrt_p = sqrt_p[:Nm] sqrt_m = sqrt_m[:Nm] pad = [(0, 0)] * H.ndim pad[axis_m] = (1, 1) dH_pad = jnp.pad(dH_dz, pad) slc_plus = [slice(None)] * H.ndim slc_minus = [slice(None)] * H.ndim slc_plus[axis_m] = slice(2, None) slc_minus[axis_m] = slice(0, -2) dH_plus = dH_pad[tuple(slc_plus)] dH_minus = dH_pad[tuple(slc_minus)] shape = [1] * H.ndim shape[axis_m] = Nm sqrt_p = sqrt_p.reshape(shape) sqrt_m = sqrt_m.reshape(shape) ladder = sqrt_p * dH_plus + sqrt_m * dH_minus vth_arr = jnp.asarray(vth) if vth_arr.ndim == 0: vth_arr = vth_arr[None] v_shape = [1] * H.ndim v_shape[0] = vth_arr.shape[0] vth_arr = vth_arr.reshape(v_shape) return vth_arr * ladder