Source code for spectraxgk.linear

"""Linear electrostatic gyrokinetic building blocks (Hermite-Laguerre)."""

from __future__ import annotations

from functools import partial
from typing import Any, Callable

import jax
import jax.numpy as jnp
import numpy as np
from jax.scipy.sparse.linalg import gmres

from spectraxgk.geometry import FluxTubeGeometryLike
from spectraxgk.grids import SpectralGrid
from spectraxgk.linear_linked import (
    _build_linked_end_damping_profile,  # noqa: F401 - legacy private helper re-export
    _build_linked_fft_maps,  # noqa: F401 - legacy private helper re-export
    _signed_to_index,  # noqa: F401 - legacy private helper re-export
)
from spectraxgk.linear_moments import (
    apply_hermite_v,  # noqa: F401 - legacy public helper re-export
    apply_hermite_v2,  # noqa: F401 - legacy public helper re-export
    apply_laguerre_x,  # noqa: F401 - legacy public helper re-export
    build_H,  # noqa: F401 - legacy public helper re-export
    compute_b,  # noqa: F401 - legacy public helper re-export
    diamagnetic_drive_coeffs,  # noqa: F401 - legacy public helper re-export
    energy_operator,  # noqa: F401 - legacy public helper re-export
    grad_z_periodic,  # noqa: F401 - legacy public helper re-export
    lenard_bernstein_eigenvalues,  # noqa: F401 - legacy public helper re-export
    quasineutrality_phi,  # noqa: F401 - legacy public helper re-export
    shift_axis,  # noqa: F401 - legacy public helper re-export
    streaming_term,  # noqa: F401 - legacy public helper re-export
)
from spectraxgk.linear_cache import (
    LinearCache,
    _build_end_damping_profile_array,  # noqa: F401 - legacy private helper re-export
    _build_gyroaverage_cache_arrays,  # noqa: F401 - legacy private helper re-export
    _build_low_rank_moment_cache_arrays,  # noqa: F401 - legacy private helper re-export
    _numpy_dtype_for_jax,  # noqa: F401 - legacy private helper re-export
    build_linear_cache,
    collision_damping,
    hypercollision_damping,
)
from spectraxgk.linear_params import (
    LinearParams,
    LinearTerms,
    Preconditioner,  # noqa: F401 - legacy public type alias re-export
    PreconditionerSpec,
    _as_species_array,
    _check_nonnegative,  # noqa: F401 - legacy private helper re-export
    _check_positive,  # noqa: F401 - legacy private helper re-export
    _is_tracer,  # noqa: F401 - legacy private helper re-export
    _resolve_implicit_preconditioner,
    _x64_enabled,
    linear_terms_to_term_config,
    term_config_to_linear_terms,  # noqa: F401 - legacy public helper re-export
)
from spectraxgk.linear_parallel import (
    _FUSED_ELECTROSTATIC_SLICE_KERNEL_CACHE,  # noqa: F401 - legacy private helper re-export
    _electrostatic_streaming_field_rhs,  # noqa: F401 - legacy private helper re-export
    _is_electrostatic_field_terms,
    _is_electrostatic_slice_terms,  # noqa: F401 - legacy private helper re-export
    _is_streaming_only_terms,  # noqa: F401 - legacy private helper re-export
    _linear_rhs_electrostatic_slices_velocity_sharded_fused,  # noqa: F401 - legacy private helper re-export
    _resolve_parallel_devices,  # noqa: F401 - legacy private helper re-export
    _streaming_electrostatic_from_phi_velocity_sharded,  # noqa: F401 - legacy private helper re-export
    linear_rhs_electrostatic_slices_velocity_sharded,  # noqa: F401 - legacy public helper re-export
    linear_rhs_parallel_cached,
    linear_rhs_streaming_electrostatic_velocity_sharded,  # noqa: F401 - legacy public helper re-export
    linear_rhs_streaming_velocity_sharded,  # noqa: F401 - legacy public helper re-export
)


_SSPX3_ADT = float((1.0 / 6.0) ** (1.0 / 3.0))
_SSPX3_WGTFAC = float((9.0 - 2.0 * (6.0 ** (2.0 / 3.0))) ** 0.5)
_SSPX3_W1 = 0.5 * (_SSPX3_WGTFAC - 1.0)
_SSPX3_W2 = 0.5 * ((6.0 ** (2.0 / 3.0)) - 1.0 - _SSPX3_WGTFAC)
_SSPX3_W3 = (1.0 / _SSPX3_ADT) - 1.0 - _SSPX3_W2 * (_SSPX3_W1 + 1.0)


[docs] def linear_rhs( G: jnp.ndarray, grid: SpectralGrid, geom: FluxTubeGeometryLike, params: LinearParams, terms: LinearTerms | None = None, *, dt: jnp.ndarray | float | None = None, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Compute the linear RHS and electrostatic potential. Parameters ---------- G : jnp.ndarray Laguerre-Hermite moments with shape (Nl, Nm, Ny, Nx, Nz). grid : SpectralGrid Flux-tube spectral grid. geom : SAlphaGeometry Analytic s-alpha geometry. params : LinearParams Physical and normalization parameters. """ if G.ndim == 5: Nl, Nm = G.shape[0], G.shape[1] elif G.ndim == 6: Nl, Nm = G.shape[1], G.shape[2] else: raise ValueError( "G must have shape (Nl, Nm, Ny, Nx, Nz) or (Ns, Nl, Nm, Ny, Nx, Nz)" ) cache = build_linear_cache(grid, geom, params, Nl, Nm) return linear_rhs_cached(G, cache, params, terms=terms, dt=dt)
[docs] def linear_rhs_cached( G: jnp.ndarray, cache: LinearCache, params: LinearParams, terms: LinearTerms | None = None, *, use_jit: bool = True, use_custom_vjp: bool = True, dt: jnp.ndarray | float | None = None, force_electrostatic_fields: bool = False, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Compute the linear RHS using precomputed geometry arrays.""" from spectraxgk.terms.assembly import ( assemble_rhs_cached, assemble_rhs_cached_electrostatic_jit, assemble_rhs_cached_jit, ) term_cfg = linear_terms_to_term_config(terms) if use_jit: rhs_fn = ( assemble_rhs_cached_electrostatic_jit if force_electrostatic_fields else assemble_rhs_cached_jit ) dG, fields = rhs_fn(G, cache, params, term_cfg, dt) else: dG, fields = assemble_rhs_cached( G, cache, params, terms=term_cfg, use_custom_vjp=use_custom_vjp, dt=dt, force_electrostatic_fields=force_electrostatic_fields, ) return dG, fields.phi
def _integrate_linear_cached_impl( G0: jnp.ndarray, cache: LinearCache, params: LinearParams, dt: float, steps: int, method: str = "rk4", checkpoint: bool = False, terms: LinearTerms | None = None, sample_stride: int = 1, show_progress: bool = False, parallel: Any | None = None, force_electrostatic_fields: bool = False, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Time integrate the linear system using cached geometry arrays.""" if method not in {"euler", "rk2", "rk4", "imex", "imex2", "sspx3"}: raise ValueError( "method must be one of {'euler', 'rk2', 'rk4', 'imex', 'imex2', 'sspx3'}" ) if terms is None: terms = LinearTerms() base_dtype = jnp.complex128 if _x64_enabled() else jnp.complex64 state_dtype = jnp.result_type(G0, base_dtype) G0 = jnp.asarray(G0, dtype=state_dtype) real_dtype = jnp.real(jnp.empty((), dtype=state_dtype)).dtype dt_val = jnp.asarray(dt, dtype=real_dtype) hyper_damp = hypercollision_damping(cache, params, real_dtype) if G0.ndim == 5 and hyper_damp.ndim == 6: hyper_damp = hyper_damp[0] damping = ( collision_damping(cache, params, real_dtype, squeeze_species=(G0.ndim == 5)) + hyper_damp ) damping = damping.astype(real_dtype) parallel_strategy = ( "serial" if parallel is None else str(getattr(parallel, "strategy", "serial")).lower().replace("-", "_") ) def rhs(G_in: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: if parallel_strategy == "serial": return linear_rhs_cached( G_in, cache, params, terms=terms, dt=dt_val, force_electrostatic_fields=force_electrostatic_fields, ) return linear_rhs_parallel_cached( G_in, cache, params, terms=terms, parallel=parallel, dt=dt_val ) def advance(G): dG, _phi = rhs(G) if method == "imex": dG_explicit = dG + damping * G return (G + dt_val * dG_explicit) / (1.0 + dt_val * damping) if method == "imex2": dG_explicit = dG + damping * G G_half = (G + 0.5 * dt_val * dG_explicit) / (1.0 + 0.5 * dt_val * damping) dG_half, _phi = rhs(G_half) dG_half_exp = dG_half + damping * G_half return (G + dt_val * dG_half_exp) / (1.0 + dt_val * damping) if method == "euler": return G + dt_val * dG if method == "rk2": k1 = dG k2, _ = rhs(G + 0.5 * dt_val * k1) return G + dt_val * k2 if method == "sspx3": def _euler_step(G_state: jnp.ndarray) -> jnp.ndarray: dG_state, _ = rhs(G_state) return G_state + (_SSPX3_ADT * dt_val) * dG_state G1 = _euler_step(G) G2_euler = _euler_step(G1) G2 = (1.0 - _SSPX3_W1) * G + (_SSPX3_W1 - 1.0) * G1 + G2_euler G3 = _euler_step(G2) return ( (1.0 - _SSPX3_W2 - _SSPX3_W3) * G + _SSPX3_W3 * G1 + (_SSPX3_W2 - 1.0) * G2 + G3 ) k1 = dG k2, _ = rhs(G + 0.5 * dt_val * k1) k3, _ = rhs(G + 0.5 * dt_val * k2) k4, _ = rhs(G + dt_val * k3) return G + (dt_val / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4) def step(G, idx): G_new = advance(G) _dG_new, phi_new = rhs(G_new) if show_progress: from spectraxgk.utils.callbacks import print_callback, should_emit_progress sim_time = (idx + 1) * dt_val sim_total = jnp.asarray(steps, dtype=dt_val.dtype) * dt_val phi_max = jnp.max(jnp.abs(phi_new)) G_new = jax.lax.cond( should_emit_progress(idx, steps), lambda state: print_callback( state, idx, steps, 0.0, 0.0, phi_max, 0.0, sim_time, sim_total, metric_labels=("|phi|_max", "|n|_max"), ), lambda state: state, G_new, ) return G_new, phi_new step_fn = jax.checkpoint(step) if checkpoint else step indices = jnp.arange(steps) if sample_stride <= 1: return jax.lax.scan(step_fn, G0, indices) def sample_step(G, idx): def inner_step(i, state): return advance(state) G_out = jax.lax.fori_loop(0, sample_stride, inner_step, G) _dG_out, phi_out = rhs(G_out) if show_progress: from spectraxgk.utils.callbacks import print_callback, should_emit_progress completed_idx = jnp.minimum((idx + 1) * sample_stride, steps) - 1 sim_time = jnp.minimum((idx + 1) * sample_stride, steps) * dt_val sim_total = jnp.asarray(steps, dtype=dt_val.dtype) * dt_val phi_max = jnp.max(jnp.abs(phi_out)) G_out = jax.lax.cond( should_emit_progress(completed_idx, steps), lambda state: print_callback( state, completed_idx, steps, 0.0, 0.0, phi_max, 0.0, sim_time, sim_total, metric_labels=("|phi|_max", "|n|_max"), ), lambda state: state, G_out, ) return G_out, phi_out num_samples = steps // sample_stride sample_indices = jnp.arange(num_samples) return jax.lax.scan(sample_step, G0, sample_indices) @partial( jax.jit, static_argnames=( "steps", "method", "checkpoint", "sample_stride", "show_progress", "force_electrostatic_fields", ), ) def _integrate_linear_cached( G0: jnp.ndarray, cache: LinearCache, params: LinearParams, dt: float, steps: int, method: str = "rk4", checkpoint: bool = False, terms: LinearTerms | None = None, sample_stride: int = 1, show_progress: bool = False, force_electrostatic_fields: bool = False, ) -> tuple[jnp.ndarray, jnp.ndarray]: return _integrate_linear_cached_impl( G0, cache, params, dt, steps, method=method, checkpoint=checkpoint, terms=terms, sample_stride=sample_stride, show_progress=show_progress, force_electrostatic_fields=force_electrostatic_fields, ) @partial( jax.jit, static_argnames=( "steps", "method", "checkpoint", "sample_stride", "show_progress", "force_electrostatic_fields", ), donate_argnums=(0,), ) def _integrate_linear_cached_donate( G0: jnp.ndarray, cache: LinearCache, params: LinearParams, dt: float, steps: int, method: str = "rk4", checkpoint: bool = False, terms: LinearTerms | None = None, sample_stride: int = 1, show_progress: bool = False, force_electrostatic_fields: bool = False, ) -> tuple[jnp.ndarray, jnp.ndarray]: return _integrate_linear_cached_impl( G0, cache, params, dt, steps, method=method, checkpoint=checkpoint, terms=terms, sample_stride=sample_stride, show_progress=show_progress, force_electrostatic_fields=force_electrostatic_fields, ) def _build_implicit_operator( G0: jnp.ndarray, cache: LinearCache, params: LinearParams, dt: float, terms: LinearTerms | None, implicit_preconditioner: PreconditionerSpec, ) -> tuple[ jnp.ndarray, tuple[int, ...], int, jnp.ndarray, Callable[[jnp.ndarray], jnp.ndarray], Callable[[jnp.ndarray], jnp.ndarray], bool, ]: if terms is None: terms = LinearTerms() base_dtype = jnp.complex128 if _x64_enabled() else jnp.complex64 state_dtype = jnp.result_type(G0, base_dtype) G = jnp.asarray(G0, dtype=state_dtype) real_dtype = jnp.real(jnp.empty((), dtype=state_dtype)).dtype dt_val = jnp.asarray(dt, dtype=real_dtype) squeeze_species = False if G.ndim == 5: G = G[None, ...] squeeze_species = True shape = G.shape size = int(np.prod(np.asarray(shape))) ns = shape[0] hyper_damp = hypercollision_damping(cache, params, real_dtype) damping = ( collision_damping(cache, params, real_dtype, squeeze_species=False) + hyper_damp ) damping = damping.astype(real_dtype) ell = cache.l.astype(real_dtype) m = cache.m.astype(real_dtype) cv_d = cache.cv_d.astype(real_dtype) gb_d = cache.gb_d.astype(real_dtype) bgrad = cache.bgrad.astype(real_dtype) w_mirror = jnp.asarray(terms.mirror, dtype=real_dtype) w_curv = jnp.asarray(terms.curvature, dtype=real_dtype) w_gradb = jnp.asarray(terms.gradb, dtype=real_dtype) diag = jnp.zeros_like(damping, dtype=state_dtype) imag = jnp.asarray(1j, dtype=state_dtype) tz = _as_species_array(params.tz, ns, "tz").astype(real_dtype) vth = _as_species_array(params.vth, ns, "vth").astype(real_dtype) tz_b = tz[:, None, None, None, None, None] vth_b = vth[:, None, None, None, None, None] omega_d_scale = jnp.asarray(params.omega_d_scale, dtype=real_dtype) diag = diag - imag * tz_b * omega_d_scale * ( w_curv * cv_d[None, None, None, ...] * (2.0 * m + 1.0) + w_gradb * gb_d[None, None, None, ...] * (2.0 * ell + 1.0) ) bgrad = bgrad[None, None, None, None, None, :] mirror_diag = vth_b * (2.0 * ell + 1.0) * (2.0 * m + 1.0) mirror_weight = 0.2 diag = diag - w_mirror * mirror_weight * bgrad * mirror_diag precond_full = 1.0 / (1.0 + dt_val * damping - dt_val * diag) precond_full = precond_full.astype(G.dtype) precond_damp = (1.0 / (1.0 + dt_val * damping)).astype(G.dtype) kpar = params.kpar_scale * cache.kz.astype(real_dtype) w_stream = jnp.asarray(terms.streaming, dtype=real_dtype) kpar_b = kpar[None, None, None, None, None, :] precond_pas = 1.0 / ( 1.0 + dt_val * damping - dt_val * diag + imag * dt_val * w_stream * vth_b * kpar_b ) precond_pas = precond_pas.astype(G.dtype) resolved_precond = _resolve_implicit_preconditioner(implicit_preconditioner) sqrt_m_line = cache.sqrt_m_ladder.reshape(-1).astype(real_dtype) sqrt_p_line = cache.sqrt_p.reshape(-1).astype(real_dtype) def _solve_hermite_lines_fft( x: jnp.ndarray, *, kz: jnp.ndarray, ) -> jnp.ndarray: """Invert (I - dt*L_stream) approximately via FFT(z) + tridiagonal(m).""" x_hat = jnp.fft.fft(x, axis=-1) x_hat_mlast = jnp.moveaxis(x_hat, 2, -1) # (..., Nz, Nm) coeff = ( (dt_val * w_stream * jnp.asarray(params.kpar_scale, dtype=real_dtype)) * vth[:, None, None, None, None] * (imag * kz)[None, None, None, None, :] ) coeff = coeff[..., None] # (Ns, 1, 1, 1, Nz, 1) dl = coeff * sqrt_m_line du = coeff * sqrt_p_line du = du.at[..., -1].set(jnp.asarray(0.0, dtype=du.dtype)) d = jnp.ones_like(du) batch_shape = x_hat_mlast.shape dl = jnp.broadcast_to(dl, batch_shape) d = jnp.broadcast_to(d, batch_shape) du = jnp.broadcast_to(du, batch_shape) y_hat_mlast = jax.lax.linalg.tridiagonal_solve( dl, d, du, x_hat_mlast[..., None] )[..., 0] y_hat = jnp.moveaxis(y_hat_mlast, -1, 2) return jnp.fft.ifft(y_hat, axis=-1) def _solve_hermite_lines_linked(x: jnp.ndarray) -> jnp.ndarray: """Linked-FFT variant of the Hermite-line streaming preconditioner.""" if not cache.linked_indices: return _solve_hermite_lines_fft(x, kz=cache.kz) Ny = x.shape[-3] Nx = x.shape[-2] Nz = x.shape[-1] lead_shape = x.shape[:-3] x_flat = x.reshape(*lead_shape, Ny * Nx, Nz) y_flat = jnp.zeros_like(x_flat) def _scatter_unique( target: jnp.ndarray, idx_flat: jnp.ndarray, updates: jnp.ndarray ) -> jnp.ndarray: idx = jnp.asarray(idx_flat, dtype=jnp.int32) target_t = jnp.moveaxis(target, -2, 0) updates_t = jnp.moveaxis(updates, -2, 0) idx = idx[:, None] dnums = jax.lax.ScatterDimensionNumbers( update_window_dims=tuple(range(1, updates_t.ndim)), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), ) out_t = jax.lax.scatter( target_t, idx, updates_t, dnums, unique_indices=True, ) return jnp.moveaxis(out_t, 0, -2) for idx_map, kz_link in zip(cache.linked_indices, cache.linked_kz): nChains, nLinks = idx_map.shape idx_flat = idx_map.reshape(-1) x_link = jnp.take(x_flat, idx_flat, axis=-2) x_link = x_link.reshape(*lead_shape, nChains, nLinks * Nz) x_hat = jnp.fft.fft(x_link, axis=-1) x_hat_mlast = jnp.moveaxis(x_hat, 2, -1) # (Ns, Nl, nChains, nfreq, Nm) coeff = ( (dt_val * w_stream * jnp.asarray(params.kpar_scale, dtype=real_dtype)) * vth[:, None, None, None] * (imag * kz_link)[None, None, None, :] ) coeff = coeff[..., None] # (Ns, 1, 1, nfreq, 1) dl = coeff * sqrt_m_line du = coeff * sqrt_p_line du = du.at[..., -1].set(jnp.asarray(0.0, dtype=du.dtype)) d = jnp.ones_like(du) batch_shape = x_hat_mlast.shape dl = jnp.broadcast_to(dl, batch_shape) d = jnp.broadcast_to(d, batch_shape) du = jnp.broadcast_to(du, batch_shape) y_hat_mlast = jax.lax.linalg.tridiagonal_solve( dl, d, du, x_hat_mlast[..., None] )[..., 0] y_hat = jnp.moveaxis(y_hat_mlast, -1, 2) y_link = jnp.fft.ifft(y_hat, axis=-1) y_link = y_link.reshape(*lead_shape, nChains * nLinks, Nz) y_flat = _scatter_unique(y_flat, idx_flat, y_link) return y_flat.reshape(*lead_shape, Ny, Nx, Nz) def apply_precond_full(x_flat: jnp.ndarray) -> jnp.ndarray: x = x_flat.reshape(shape) return (x * precond_full).reshape(size) def apply_precond_damp(x_flat: jnp.ndarray) -> jnp.ndarray: x = x_flat.reshape(shape) return (x * precond_damp).reshape(size) def apply_precond_pas(x_flat: jnp.ndarray) -> jnp.ndarray: x = x_flat.reshape(shape) return (x * precond_pas).reshape(size) def _project_kx_coarse(x: jnp.ndarray) -> jnp.ndarray: """Coarse-space projection/prolongation for twist/shift coupling. For periodic grids this reduces to the mean over kx. For linked grids we average within each linked (ky, kx) chain so the coarse correction does not destroy the linked coupling structure. """ if not cache.use_twist_shift or not cache.linked_indices: x_mean = jnp.mean(x, axis=4, keepdims=True) return jnp.broadcast_to(x_mean, x.shape) Ny = x.shape[-3] Nx = x.shape[-2] Nz = x.shape[-1] lead_shape = x.shape[:-3] x_flat = x.reshape(*lead_shape, Ny * Nx, Nz) y_flat = jnp.zeros_like(x_flat) def _scatter_unique( target: jnp.ndarray, idx_flat: jnp.ndarray, updates: jnp.ndarray ) -> jnp.ndarray: idx = jnp.asarray(idx_flat, dtype=jnp.int32) target_t = jnp.moveaxis(target, -2, 0) updates_t = jnp.moveaxis(updates, -2, 0) idx = idx[:, None] dnums = jax.lax.ScatterDimensionNumbers( update_window_dims=tuple(range(1, updates_t.ndim)), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), ) out_t = jax.lax.scatter( target_t, idx, updates_t, dnums, unique_indices=True, ) return jnp.moveaxis(out_t, 0, -2) for idx_map in cache.linked_indices: nChains, nLinks = idx_map.shape idx_flat = idx_map.reshape(-1) x_link = jnp.take(x_flat, idx_flat, axis=-2) x_link = x_link.reshape(*lead_shape, nChains, nLinks, Nz) x_mean = jnp.mean(x_link, axis=-2, keepdims=True) x_mean = jnp.broadcast_to(x_mean, x_link.shape) x_updates = x_mean.reshape(*lead_shape, nChains * nLinks, Nz) y_flat = _scatter_unique(y_flat, idx_flat, x_updates) return y_flat.reshape(*lead_shape, Ny, Nx, Nz) def apply_precond_pas_coarse(x_flat: jnp.ndarray) -> jnp.ndarray: """PAS line + kx-coarse correction (additive Schur-style).""" x = x_flat.reshape(shape) x_line = x * precond_pas x_coarse = _project_kx_coarse(x) * precond_pas x_line_coarse = _project_kx_coarse(x_line) x_out = x_line + (x_coarse - x_line_coarse) return x_out.reshape(size) def apply_precond_hermite_line(x_flat: jnp.ndarray) -> jnp.ndarray: x = x_flat.reshape(shape) x = x * precond_full x = ( _solve_hermite_lines_linked(x) if cache.use_twist_shift else _solve_hermite_lines_fft(x, kz=cache.kz) ) return x.reshape(size) def apply_precond_hermite_line_coarse(x_flat: jnp.ndarray) -> jnp.ndarray: x = x_flat.reshape(shape) x_line = apply_precond_hermite_line(x.reshape(size)).reshape(shape) x_coarse_in = _project_kx_coarse(x) x_coarse_full = apply_precond_hermite_line(x_coarse_in.reshape(size)).reshape( shape ) x_line_coarse_full = _project_kx_coarse(x_line) return (x_line + (x_coarse_full - x_line_coarse_full)).reshape(size) def apply_identity(x_flat: jnp.ndarray) -> jnp.ndarray: return x_flat precond_op: Callable[[jnp.ndarray], jnp.ndarray] if callable(resolved_precond): precond_op = resolved_precond else: key = resolved_precond or "auto" if key in {"auto", "diag", "diagonal", "physics", "block"}: precond_op = apply_precond_full elif key in {"damping", "collisional", "hyper"}: precond_op = apply_precond_damp elif key in {"pas", "pas-line", "pas_line"}: precond_op = apply_precond_pas elif key in {"pas-coarse", "pas_schur", "block-schur", "schur", "pas-hybrid"}: precond_op = apply_precond_pas_coarse elif key in { "hermite-line", "hermite_line", "hermite", "streaming-line", "streaming_line", }: precond_op = apply_precond_hermite_line elif key in { "hermite-line-coarse", "hermite_line_coarse", "hermite_coarse", "streaming-line-coarse", }: precond_op = apply_precond_hermite_line_coarse elif key in {"identity", "none", "off"}: precond_op = apply_identity else: raise ValueError(f"Unknown implicit_preconditioner '{resolved_precond}'") def matvec(x_flat: jnp.ndarray) -> jnp.ndarray: x = x_flat.reshape(shape) dG, _phi = linear_rhs_cached( x, cache, params, terms=terms, use_jit=False, use_custom_vjp=False, dt=dt_val, ) return (x - dt_val * dG).reshape(size) return G, shape, size, dt_val, precond_op, matvec, squeeze_species def _integrate_linear_implicit_cached( G0: jnp.ndarray, cache: LinearCache, params: LinearParams, dt: float, steps: int, *, terms: LinearTerms | None = None, implicit_tol: float = 1.0e-6, implicit_maxiter: int = 200, implicit_iters: int = 3, implicit_relax: float = 0.7, implicit_restart: int = 20, implicit_solve_method: str = "batched", implicit_preconditioner: PreconditionerSpec = None, checkpoint: bool = False, sample_stride: int = 1, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Implicit linear integrator using GMRES with a diagonal preconditioner.""" if terms is None: terms = LinearTerms() if sample_stride < 1: raise ValueError("sample_stride must be >= 1") if steps % sample_stride != 0: raise ValueError("steps must be divisible by sample_stride") G, shape, size, dt_val, precond_op, matvec, squeeze_species = ( _build_implicit_operator(G0, cache, params, dt, terms, implicit_preconditioner) ) def fixed_point(G_in: jnp.ndarray, G_rhs: jnp.ndarray) -> jnp.ndarray: def body(_i, g): dG, _phi = linear_rhs_cached( g, cache, params, terms=terms, use_jit=False, use_custom_vjp=False, dt=dt_val, ) g_next = G_rhs + dt_val * dG return (1.0 - implicit_relax) * g + implicit_relax * g_next return jax.lax.fori_loop(0, max(int(implicit_iters), 0), body, G_in) def solve_step(G_in: jnp.ndarray) -> jnp.ndarray: G_guess = fixed_point(G_in, G_in) sol, _info = gmres( matvec, G_in.reshape(size), x0=G_guess.reshape(size), tol=implicit_tol, maxiter=implicit_maxiter, restart=implicit_restart, M=precond_op, solve_method=implicit_solve_method, ) return sol.reshape(shape) def step(G_in, _): G_new = solve_step(G_in) _dG_new, phi_new = linear_rhs_cached( G_new, cache, params, terms=terms, use_jit=False, use_custom_vjp=False, dt=dt_val, ) return G_new, phi_new step_fn = jax.checkpoint(step) if checkpoint else step if sample_stride <= 1: G_out, phi_t = jax.lax.scan(step_fn, G, None, length=steps) else: def sample_step(G_in, _): def inner_step(_i, g): return solve_step(g) G_out_local = jax.lax.fori_loop(0, sample_stride, inner_step, G_in) _dG_out, phi_out = linear_rhs_cached( G_out_local, cache, params, terms=terms, use_jit=False, use_custom_vjp=False, dt=dt_val, ) return G_out_local, phi_out num_samples = steps // sample_stride G_out, phi_t = jax.lax.scan(sample_step, G, None, length=num_samples) G_out = G_out[0] if squeeze_species else G_out return G_out, phi_t
[docs] def integrate_linear( G0: jnp.ndarray, grid: SpectralGrid, geom: FluxTubeGeometryLike, params: LinearParams, dt: float, steps: int, method: str = "rk4", cache: LinearCache | None = None, implicit_tol: float = 1.0e-6, implicit_maxiter: int = 200, implicit_iters: int = 3, implicit_relax: float = 0.7, implicit_restart: int = 20, implicit_solve_method: str = "batched", implicit_preconditioner: PreconditionerSpec = None, terms: LinearTerms | None = None, checkpoint: bool = False, sample_stride: int = 1, donate: bool = False, show_progress: bool = False, parallel: Any | None = None, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Time integrate the linear system using a fixed-step scheme.""" if terms is None: terms = LinearTerms() if sample_stride < 1: raise ValueError("sample_stride must be >= 1") if steps % sample_stride != 0: raise ValueError("steps must be divisible by sample_stride") if cache is None: if G0.ndim == 5: Nl, Nm = G0.shape[0], G0.shape[1] elif G0.ndim == 6: Nl, Nm = G0.shape[1], G0.shape[2] else: raise ValueError( "G0 must have shape (Nl, Nm, Ny, Nx, Nz) or (Ns, Nl, Nm, Ny, Nx, Nz)" ) cache = build_linear_cache(grid, geom, params, Nl, Nm) if method == "semi-implicit": method = "imex" parallel_strategy = ( "serial" if parallel is None else str(getattr(parallel, "strategy", "serial")).lower().replace("-", "_") ) force_electrostatic_fields = _is_electrostatic_field_terms(terms) if method == "implicit": if parallel_strategy != "serial": raise NotImplementedError( "parallel linear integration currently supports only explicit fixed-step methods" ) return _integrate_linear_implicit_cached( G0, cache, params, dt=dt, steps=steps, terms=terms, implicit_tol=implicit_tol, implicit_maxiter=implicit_maxiter, implicit_iters=implicit_iters, implicit_relax=implicit_relax, implicit_restart=implicit_restart, implicit_solve_method=implicit_solve_method, implicit_preconditioner=implicit_preconditioner, checkpoint=checkpoint, sample_stride=sample_stride, ) if parallel_strategy != "serial": if donate: raise NotImplementedError( "parallel linear integration does not currently support donated input buffers" ) return _integrate_linear_cached_impl( G0, cache, params, dt, steps, method=method, checkpoint=checkpoint, terms=terms, sample_stride=sample_stride, show_progress=show_progress, parallel=parallel, force_electrostatic_fields=force_electrostatic_fields, ) integrator = _integrate_linear_cached_donate if donate else _integrate_linear_cached return integrator( G0, cache, params, dt, steps, method=method, checkpoint=checkpoint, terms=terms, sample_stride=sample_stride, show_progress=show_progress, force_electrostatic_fields=force_electrostatic_fields, )
[docs] def integrate_linear_diagnostics( G0: jnp.ndarray, grid: SpectralGrid, geom: FluxTubeGeometryLike, params: LinearParams, dt: float, steps: int, *, method: str = "rk4", cache: LinearCache | None = None, terms: LinearTerms | None = None, sample_stride: int = 1, species_index: int | None = 0, record_hl_energy: bool = False, show_progress: bool = False, ) -> ( tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] | tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray] ): """Integrate and return (G_out, phi_t, density_t) for diagnostics.""" if terms is None: terms = LinearTerms() if sample_stride < 1: raise ValueError("sample_stride must be >= 1") if steps % sample_stride != 0: raise ValueError("steps must be divisible by sample_stride") if cache is None: if G0.ndim == 5: Nl, Nm = G0.shape[0], G0.shape[1] elif G0.ndim == 6: Nl, Nm = G0.shape[1], G0.shape[2] else: raise ValueError( "G0 must have shape (Nl, Nm, Ny, Nx, Nz) or (Ns, Nl, Nm, Ny, Nx, Nz)" ) cache = build_linear_cache(grid, geom, params, Nl, Nm) base_dtype = jnp.complex128 if _x64_enabled() else jnp.complex64 state_dtype = jnp.result_type(G0, base_dtype) G0 = jnp.asarray(G0, dtype=state_dtype) real_dtype = jnp.real(jnp.empty((), dtype=state_dtype)).dtype dt_val = jnp.asarray(dt, dtype=real_dtype) hyper_damp = hypercollision_damping(cache, params, real_dtype) if G0.ndim == 5 and hyper_damp.ndim == 6: hyper_damp = hyper_damp[0] damping = ( collision_damping(cache, params, real_dtype, squeeze_species=(G0.ndim == 5)) + hyper_damp ) damping = damping.astype(real_dtype) def advance(G_in: jnp.ndarray) -> jnp.ndarray: dG, _phi = linear_rhs_cached( G_in, cache, params, terms=terms, use_jit=False, dt=dt_val ) if method == "imex": dG_explicit = dG + damping * G_in return (G_in + dt_val * dG_explicit) / (1.0 + dt_val * damping) if method == "imex2": dG_explicit = dG + damping * G_in G_half = (G_in + 0.5 * dt_val * dG_explicit) / ( 1.0 + 0.5 * dt_val * damping ) dG_half, _phi = linear_rhs_cached( G_half, cache, params, terms=terms, use_jit=False, dt=dt_val ) dG_half_exp = dG_half + damping * G_half return (G_in + dt_val * dG_half_exp) / (1.0 + dt_val * damping) if method == "euler": return G_in + dt_val * dG if method == "rk2": k1 = dG k2, _ = linear_rhs_cached( G_in + 0.5 * dt_val * k1, cache, params, terms=terms, use_jit=False, dt=dt_val, ) return G_in + dt_val * k2 if method == "sspx3": def _euler_step(G_state: jnp.ndarray) -> jnp.ndarray: dG_state, _phi_state = linear_rhs_cached( G_state, cache, params, terms=terms, use_jit=False, dt=dt_val, ) return G_state + (_SSPX3_ADT * dt_val) * dG_state G1 = _euler_step(G_in) G2_euler = _euler_step(G1) G2 = (1.0 - _SSPX3_W1) * G_in + (_SSPX3_W1 - 1.0) * G1 + G2_euler G3 = _euler_step(G2) return ( (1.0 - _SSPX3_W2 - _SSPX3_W3) * G_in + _SSPX3_W3 * G1 + (_SSPX3_W2 - 1.0) * G2 + G3 ) if method == "rk4": k1 = dG k2, _ = linear_rhs_cached( G_in + 0.5 * dt_val * k1, cache, params, terms=terms, use_jit=False, dt=dt_val, ) k3, _ = linear_rhs_cached( G_in + 0.5 * dt_val * k2, cache, params, terms=terms, use_jit=False, dt=dt_val, ) k4, _ = linear_rhs_cached( G_in + dt_val * k3, cache, params, terms=terms, use_jit=False, dt=dt_val ) return G_in + (dt_val / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4) raise ValueError(f"Unsupported method '{method}'") def density_from_G(G_in: jnp.ndarray) -> jnp.ndarray: Jl = cache.Jl if G_in.ndim == 5: if Jl.ndim == 5: Jl_s = Jl[0] else: Jl_s = Jl return jnp.sum(Jl_s * G_in[:, 0, ...], axis=0) if Jl.ndim == 5: if species_index is None: return jnp.sum(jnp.sum(Jl * G_in[:, :, 0, ...], axis=1), axis=0) Jl_s = Jl[int(species_index)] return jnp.sum(Jl_s * G_in[int(species_index), :, 0, ...], axis=0) if species_index is None: return jnp.sum(jnp.sum(Jl[None, ...] * G_in[:, :, 0, ...], axis=1), axis=0) return jnp.sum(Jl * G_in[int(species_index), :, 0, ...], axis=0) def hl_energy_from_G(G_in: jnp.ndarray) -> jnp.ndarray: if G_in.ndim == 5: return jnp.sum(jnp.abs(G_in) ** 2, axis=(2, 3, 4)) return jnp.sum(jnp.abs(G_in) ** 2, axis=(0, 3, 4, 5)) def step(G_in, idx): G_out = advance(G_in) _dG, phi = linear_rhs_cached( G_out, cache, params, terms=terms, use_jit=False, dt=dt_val ) density = density_from_G(G_out) if show_progress: from spectraxgk.utils.callbacks import print_callback, should_emit_progress sim_time = (idx + 1) * dt_val sim_total = jnp.asarray(steps, dtype=dt_val.dtype) * dt_val phi_max = jnp.max(jnp.abs(phi)) density_max = jnp.max(jnp.abs(density)) G_out = jax.lax.cond( should_emit_progress(idx, steps), lambda state: print_callback( state, idx, steps, 0.0, 0.0, phi_max, density_max, sim_time, sim_total, metric_labels=("|phi|_max", "|n|_max"), ), lambda state: state, G_out, ) if record_hl_energy: hl_energy = hl_energy_from_G(G_out) return G_out, (phi, density, hl_energy) return G_out, (phi, density) if sample_stride <= 1: indices = jnp.arange(steps) G_out, outputs = jax.lax.scan(step, G0, indices) else: def sample_step(G_in, idx): def inner_step(_i, g): return advance(g) G_out_local = jax.lax.fori_loop(0, sample_stride, inner_step, G_in) _dG, phi_out = linear_rhs_cached( G_out_local, cache, params, terms=terms, use_jit=False, dt=dt_val ) density_out = density_from_G(G_out_local) if show_progress: from spectraxgk.utils.callbacks import ( print_callback, should_emit_progress, ) completed_idx = jnp.minimum((idx + 1) * sample_stride, steps) - 1 sim_time = jnp.minimum((idx + 1) * sample_stride, steps) * dt_val sim_total = jnp.asarray(steps, dtype=dt_val.dtype) * dt_val phi_max = jnp.max(jnp.abs(phi_out)) density_max = jnp.max(jnp.abs(density_out)) G_out_local = jax.lax.cond( should_emit_progress(completed_idx, steps), lambda state: print_callback( state, completed_idx, steps, 0.0, 0.0, phi_max, density_max, sim_time, sim_total, metric_labels=("|phi|_max", "|n|_max"), ), lambda state: state, G_out_local, ) if record_hl_energy: hl_out = hl_energy_from_G(G_out_local) return G_out_local, (phi_out, density_out, hl_out) return G_out_local, (phi_out, density_out) num_samples = steps // sample_stride sample_indices = jnp.arange(num_samples) G_out, outputs = jax.lax.scan(sample_step, G0, sample_indices) if record_hl_energy: phi_t, density_t, hl_t = outputs return G_out, phi_t, density_t, hl_t phi_t, density_t = outputs return G_out, phi_t, density_t