Source code for spectraxgk.nonlinear_helpers

"""Helper policies and operators for nonlinear gyrokinetic drivers."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Callable

import jax.numpy as jnp
import jax
import numpy as np

from spectraxgk.grids import SpectralGrid, real_fft_mesh
from spectraxgk.linear import (
    LinearCache,
    LinearParams,
    _build_implicit_operator,
    collision_damping as _base_collision_damping,
    hypercollision_damping,
    term_config_to_linear_terms,
)
from spectraxgk.terms.config import FieldState, TermConfig
from spectraxgk.terms.nonlinear import _broadcast_grid, _ifft2_xy

__all__ = [
    "IMEXLinearOperator",
    "_apply_collision_split",
    "_collision_damping",
    "_gx_nonlinear_omega_components",
    "_gx_omega_mode_mask",
    "_make_fixed_mode_projector",
    "_make_hermitian_projector",
    "build_nonlinear_imex_operator",
]


[docs] @dataclass(frozen=True) class IMEXLinearOperator: """Reusable matrix-free linear operator for nonlinear IMEX solves.""" state_dtype: jnp.dtype shape: tuple[int, ...] dt_val: jnp.ndarray precond_op: Callable[[jnp.ndarray], jnp.ndarray] | None matvec: Callable[[jnp.ndarray], jnp.ndarray] squeeze_species: bool
[docs] def _make_hermitian_projector( ky_vals: np.ndarray, nx: int ) -> Callable[[jnp.ndarray], jnp.ndarray]: """Project full-ky states onto the GX real-FFT Hermitian manifold.""" ny_full = int(ky_vals.size) nyc = ny_full // 2 + 1 use_hermitian = nyc > 2 and bool(np.any(np.asarray(ky_vals) < 0.0)) if not use_hermitian: return lambda G_state: G_state neg_hi = nyc - 1 if (ny_full % 2 == 0) else nyc if nx > 1: kx_neg = jnp.asarray( np.concatenate(([0], np.arange(nx - 1, 0, -1))), dtype=jnp.int32 ) else: kx_neg = None def project(G_state: jnp.ndarray) -> jnp.ndarray: pos = G_state[..., :nyc, :, :] neg = jnp.conj(pos[..., 1:neg_hi, :, :])[..., ::-1, :, :] if kx_neg is not None: neg = neg[..., kx_neg, :] return jnp.concatenate([pos, neg], axis=-3) return project
[docs] def _gx_nonlinear_omega_components( fields: FieldState, grid: SpectralGrid, cache: LinearCache, *, gx_real_fft: bool, kx_max: float, ky_max: float, kxfac: float, vpar_max: float, muB_max: float, ) -> tuple[jnp.ndarray, jnp.ndarray]: """GX-style nonlinear x/y CFL frequency components from grad(phi,apar,bpar).""" phi = fields.phi apar = fields.apar bpar = fields.bpar ny = int(grid.ky.size) nyc = 1 + ny // 2 real_dtype = jnp.real(jnp.empty((), dtype=phi.dtype)).dtype kxfac_val = jnp.asarray(kxfac, dtype=real_dtype) imag = jnp.asarray(1j, dtype=phi.dtype) fft_norm = float(grid.ky.size * grid.kx.size) ifft_scale = jnp.asarray(fft_norm, dtype=real_dtype) use_batched_fft = jax.default_backend() != "cpu" if gx_real_fft: _, ky_vals, kx_nyc, ky_nyc = real_fft_mesh(cache.kx_grid, cache.ky_grid) nyc = int(ky_vals.shape[0]) phi_nyc = phi[:nyc, :, :] kx_b = _broadcast_grid(kx_nyc, phi_nyc.ndim) ky_b = _broadcast_grid(ky_nyc, phi_nyc.ndim) if use_batched_fft: grad_phi = jnp.stack( [imag * kx_b * phi_nyc, imag * ky_b * phi_nyc], axis=0 ) grad_phi = ( jnp.fft.irfft2( grad_phi, s=(grid.kx.size, grid.ky.size), axes=(-2, -3) ) * ifft_scale ) dphi_dx = grad_phi[0] dphi_dy = grad_phi[1] else: dphi_dx = jnp.fft.irfft2( imag * kx_b * phi_nyc, s=(grid.kx.size, grid.ky.size), axes=(-2, -3), ) dphi_dy = jnp.fft.irfft2( imag * ky_b * phi_nyc, s=(grid.kx.size, grid.ky.size), axes=(-2, -3), ) dphi_dx = dphi_dx * ifft_scale dphi_dy = dphi_dy * ifft_scale def _grad_real(field: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: field_nyc = field[:nyc, :, :] if use_batched_fft: grad = jnp.stack( [imag * kx_b * field_nyc, imag * ky_b * field_nyc], axis=0 ) grad = ( jnp.fft.irfft2( grad, s=(grid.kx.size, grid.ky.size), axes=(-2, -3) ) * ifft_scale ) return grad[0], grad[1] dfx = jnp.fft.irfft2( imag * kx_b * field_nyc, s=(grid.kx.size, grid.ky.size), axes=(-2, -3), ) dfy = jnp.fft.irfft2( imag * ky_b * field_nyc, s=(grid.kx.size, grid.ky.size), axes=(-2, -3), ) return dfx * ifft_scale, dfy * ifft_scale else: kx_b = _broadcast_grid(cache.kx_grid, phi.ndim) ky_b = _broadcast_grid(cache.ky_grid, phi.ndim) if use_batched_fft: grad_phi = ( _ifft2_xy(jnp.stack([imag * kx_b * phi, imag * ky_b * phi], axis=0)) * ifft_scale ) dphi_dx = grad_phi[0] dphi_dy = grad_phi[1] else: dphi_dx = _ifft2_xy(imag * kx_b * phi) * ifft_scale dphi_dy = _ifft2_xy(imag * ky_b * phi) * ifft_scale def _grad_real(field: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: if use_batched_fft: grad = ( _ifft2_xy( jnp.stack([imag * kx_b * field, imag * ky_b * field], axis=0) ) * ifft_scale ) return grad[0], grad[1] dfx = _ifft2_xy(imag * kx_b * field) * ifft_scale dfy = _ifft2_xy(imag * ky_b * field) * ifft_scale return dfx, dfy dphi_dx = jnp.abs(dphi_dx) dphi_dy = jnp.abs(dphi_dy) if apar is not None: dap_dx, dap_dy = _grad_real(apar) dphi_dx = dphi_dx + vpar_max * jnp.abs(dap_dx) dphi_dy = dphi_dy + vpar_max * jnp.abs(dap_dy) if bpar is not None: dbp_dx, dbp_dy = _grad_real(bpar) dphi_dx = dphi_dx + muB_max * jnp.abs(dbp_dx) dphi_dy = dphi_dy + muB_max * jnp.abs(dbp_dy) vmax_x = jnp.max(dphi_dy) vmax_y = jnp.max(dphi_dx) scale = jnp.asarray(0.5, dtype=real_dtype) omega_x = ( jnp.abs(kxfac_val) * jnp.asarray(kx_max, dtype=real_dtype) * vmax_x * scale ) omega_y = ( jnp.abs(kxfac_val) * jnp.asarray(ky_max, dtype=real_dtype) * vmax_y * scale ) return jnp.asarray(omega_x, dtype=real_dtype), jnp.asarray( omega_y, dtype=real_dtype )
[docs] def _gx_omega_mode_mask( grid: SpectralGrid, cache: LinearCache, *, gx_real_fft: bool, ) -> jnp.ndarray: """Mask used to reduce mode-wise GX omega/gamma diagnostics.""" ny = int(grid.ky.size) nx = int(grid.kx.size) if gx_real_fft and bool(np.any(np.asarray(grid.ky) < 0.0)): # Full-ky SPECTRAX layout stores the rFFT-unique modes in the first # Ny//2+1 entries, including the Nyquist row when Ny is even. ky_unique = jnp.arange(ny, dtype=jnp.int32)[:, None] < (ny // 2 + 1) else: ky_unique = jnp.asarray(cache.ky)[:, None] >= 0.0 return jnp.asarray(grid.dealias_mask, dtype=bool) & jnp.broadcast_to( ky_unique, (ny, nx) )
[docs] def _collision_damping( cache: LinearCache, params: LinearParams, term_cfg: TermConfig, real_dtype: jnp.dtype, *, squeeze_species: bool, ) -> jnp.ndarray: """Assemble collision + hypercollision damping for operator splitting.""" damping = _base_collision_damping( cache, params, real_dtype, squeeze_species=squeeze_species ) hyper_damp = hypercollision_damping(cache, params, real_dtype) coll_w = jnp.asarray(term_cfg.collisions, dtype=real_dtype) hyper_w = jnp.asarray(term_cfg.hypercollisions, dtype=real_dtype) if squeeze_species and hyper_damp.ndim == 6: hyper_damp = hyper_damp[0] damping = coll_w * damping + hyper_w * hyper_damp return damping.astype(real_dtype)
[docs] def _apply_collision_split( G: jnp.ndarray, damping: jnp.ndarray, dt_local: jnp.ndarray, scheme: str, ) -> jnp.ndarray: """Apply a diagonal collision/hypercollision split update.""" scheme_key = scheme.strip().lower() if scheme_key in {"implicit", "imex"}: return G / (1.0 + dt_local * damping) if scheme_key in {"exp", "sts", "rkc", "rkc2"}: # For diagonal collision operators the exponential update is exact and # behaves like a stabilized explicit (STS/RKC) limit. return G * jnp.exp(-dt_local * damping) raise ValueError("collision_scheme must be one of {'implicit', 'exp', 'sts', 'rkc'}")
[docs] def _make_fixed_mode_projector( fixed_state: jnp.ndarray | None, *, ky_index: int | None, kx_index: int | None, ) -> Callable[[jnp.ndarray], jnp.ndarray] | None: """Return a projector that keeps one Fourier mode equal to ``fixed_state``.""" if fixed_state is None or ky_index is None or kx_index is None: return None ky_i = int(ky_index) kx_i = int(kx_index) fixed_block = jnp.asarray(fixed_state)[..., ky_i : ky_i + 1, kx_i : kx_i + 1, :] def project(G_state: jnp.ndarray) -> jnp.ndarray: return G_state.at[..., ky_i : ky_i + 1, kx_i : kx_i + 1, :].set( fixed_block ) return project
[docs] def build_nonlinear_imex_operator( G0: jnp.ndarray, cache: LinearCache, params: LinearParams, dt: float, *, terms: TermConfig | None = None, implicit_preconditioner: str | None = None, gx_real_fft: bool = True, ) -> IMEXLinearOperator: """Build and cache the matrix-free linear operator used by nonlinear IMEX.""" del gx_real_fft term_cfg = terms or TermConfig() linear_terms = term_config_to_linear_terms(term_cfg) G, shape, _size, dt_val, precond_op, matvec, squeeze_species = ( _build_implicit_operator( G0, cache, params, dt, linear_terms, implicit_preconditioner, ) ) return IMEXLinearOperator( state_dtype=G.dtype, shape=shape, dt_val=dt_val, precond_op=precond_op, matvec=matvec, squeeze_species=squeeze_species, )