Source code for spectraxgk.linear_parallel

"""Velocity-parallel linear RHS helpers."""

from __future__ import annotations

from typing import Any

import jax
import jax.numpy as jnp
import numpy as np

from spectraxgk.linear_cache import LinearCache
from spectraxgk.linear_moments import build_H
from spectraxgk.linear_params import LinearParams, LinearTerms, _as_species_array

__all__ = [
    "_FUSED_ELECTROSTATIC_SLICE_KERNEL_CACHE",
    "_electrostatic_streaming_field_rhs",
    "_is_electrostatic_field_terms",
    "_is_electrostatic_slice_terms",
    "_is_streaming_only_terms",
    "_linear_rhs_electrostatic_slices_velocity_sharded_fused",
    "_resolve_parallel_devices",
    "_streaming_electrostatic_from_phi_velocity_sharded",
    "linear_rhs_electrostatic_slices_velocity_sharded",
    "linear_rhs_parallel_cached",
    "linear_rhs_streaming_electrostatic_velocity_sharded",
    "linear_rhs_streaming_velocity_sharded",
]

_FUSED_ELECTROSTATIC_SLICE_KERNEL_CACHE: dict[tuple[Any, ...], tuple[Any, Any]] = {}


def _is_streaming_only_terms(terms: LinearTerms | None) -> bool:
    term_weights = terms if terms is not None else LinearTerms()
    return (
        float(term_weights.streaming) == 1.0
        and float(term_weights.mirror) == 0.0
        and float(term_weights.curvature) == 0.0
        and float(term_weights.gradb) == 0.0
        and float(term_weights.diamagnetic) == 0.0
        and float(term_weights.collisions) == 0.0
        and float(term_weights.hypercollisions) == 0.0
        and float(term_weights.hyperdiffusion) == 0.0
        and float(term_weights.end_damping) == 0.0
        and float(term_weights.apar) == 0.0
        and float(term_weights.bpar) == 0.0
    )


def _is_electrostatic_slice_terms(terms: LinearTerms | None) -> bool:
    term_weights = terms if terms is not None else LinearTerms()
    return (
        float(term_weights.collisions) == 0.0
        and float(term_weights.hypercollisions) == 0.0
        and float(term_weights.hyperdiffusion) == 0.0
        and float(term_weights.end_damping) == 0.0
        and float(term_weights.apar) == 0.0
        and float(term_weights.bpar) == 0.0
    )


def _is_electrostatic_field_terms(terms: LinearTerms | None) -> bool:
    term_weights = terms if terms is not None else LinearTerms()
    return float(term_weights.apar) == 0.0 and float(term_weights.bpar) == 0.0


[docs] def _resolve_parallel_devices( *, num_devices: int | None = None, devices: Any | None = None ) -> list[Any]: """Return an explicit device list for opt-in parallel diagnostics.""" if devices is None: device_list = list(jax.devices()) if num_devices is not None: device_count = int(num_devices) if device_count < 1: raise ValueError("num_devices must be >= 1") if len(device_list) < device_count: raise ValueError( f"requested {device_count} devices, but only {len(device_list)} are available" ) device_list = device_list[:device_count] else: device_list = list(devices) if num_devices is not None and int(num_devices) != len(device_list): raise ValueError("num_devices must match the explicit devices list length") if not device_list: raise ValueError("at least one device is required") return device_list
[docs] def linear_rhs_streaming_velocity_sharded( G: jnp.ndarray, cache: LinearCache, params: LinearParams, *, num_devices: int | None = None, devices: Any | None = None, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Compute the streaming-only linear RHS with the Hermite shard-map path. This diagnostic route is intentionally narrower than :func:`linear_rhs_cached`: it covers the velocity-space streaming operator only and returns a zero electrostatic potential. It is used to gate the future production velocity decomposition before field solves, drifts, collisions, and nonlinear terms are exposed through the runtime path. """ from spectraxgk.velocity_sharding import ( build_velocity_sharding_plan, periodic_streaming_shard_map, ) arr = jnp.asarray(G) if arr.ndim not in (5, 6): raise ValueError( "G must have shape (Nl, Nm, Ny, Nx, Nz) or (Ns, Nl, Nm, Ny, Nx, Nz)" ) device_list = _resolve_parallel_devices(num_devices=num_devices, devices=devices) plan = build_velocity_sharding_plan( arr.shape, num_devices=len(device_list), axes=("hermite",) ) dG = -periodic_streaming_shard_map( arr, plan, kz=cache.kz, vth=params.vth, devices=device_list ) phi = jnp.zeros(arr.shape[-3:], dtype=arr.dtype) return dG, phi
[docs] def _streaming_electrostatic_from_phi_velocity_sharded( arr: jnp.ndarray, cache: LinearCache, params: LinearParams, *, phi: jnp.ndarray, plan: Any, devices: Any, ) -> jnp.ndarray: """Apply electrostatic streaming with a precomputed electrostatic field.""" from spectraxgk.terms.operators import grad_z_periodic as operator_grad_z_periodic from spectraxgk.velocity_sharding import periodic_streaming_shard_map particle_streaming = -periodic_streaming_shard_map( arr, plan, kz=cache.kz, vth=params.vth, devices=devices ) real_dtype = jnp.real(arr).dtype G6 = arr[None, ...] tz = _as_species_array(params.tz, 1, "tz").astype(real_dtype) vth = _as_species_array(params.vth, 1, "vth").astype(real_dtype) field_rhs = _electrostatic_streaming_field_rhs( G6, phi=phi, Jl=cache.Jl, tz=tz, vth=vth ) field_streaming = jnp.asarray( params.kpar_scale, dtype=real_dtype ) * operator_grad_z_periodic(field_rhs, kz=cache.kz) return particle_streaming + field_streaming[0]
[docs] def _electrostatic_streaming_field_rhs( G6: jnp.ndarray, *, phi: jnp.ndarray, Jl: jnp.ndarray, tz: jnp.ndarray, vth: jnp.ndarray, ) -> jnp.ndarray: """Build the pre-derivative electrostatic streaming field term.""" Nm = G6.shape[2] m_idx = jnp.arange(Nm, dtype=jnp.int32)[None, None, :, None, None, None] zt = jnp.where(tz == 0.0, 0.0, 1.0 / tz) zt5 = zt[:, None, None, None, None] vth5 = vth[:, None, None, None, None] phi_s = phi[None, None, ...] drive_m1 = -zt5 * vth5 * Jl * phi_s return (m_idx == 1).astype(G6.dtype) * drive_m1[:, :, None, ...]
[docs] def linear_rhs_streaming_electrostatic_velocity_sharded( G: jnp.ndarray, cache: LinearCache, params: LinearParams, *, num_devices: int | None = None, devices: Any | None = None, use_custom_vjp: bool = True, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Compute electrostatic streaming RHS with Hermite-sharded particle streaming. This route solves ``phi`` with the production electrostatic field solve, applies the Hermite velocity-sharded particle-streaming operator, and adds the GX-style electrostatic streaming field term. It is limited to periodic field-line grids and excludes electromagnetic fields by construction. """ from spectraxgk.velocity_sharding import ( build_velocity_sharding_plan, electrostatic_phi_shard_map, ) arr = jnp.asarray(G) if arr.ndim != 5: raise NotImplementedError( "velocity-sharded electrostatic streaming currently supports single-species 5D states" ) if bool(getattr(cache, "use_twist_shift", False)): raise NotImplementedError( "velocity-sharded electrostatic streaming currently requires a periodic z grid" ) device_list = _resolve_parallel_devices(num_devices=num_devices, devices=devices) plan = build_velocity_sharding_plan( arr.shape, num_devices=len(device_list), axes=("hermite",) ) phi = electrostatic_phi_shard_map( arr, plan, Jl=cache.Jl, tau_e=params.tau_e, charge=params.charge_sign, density=params.density, tz=params.tz, mask0=cache.mask0, devices=device_list, ) return _streaming_electrostatic_from_phi_velocity_sharded( arr, cache, params, phi=phi, plan=plan, devices=device_list ), phi
[docs] def _linear_rhs_electrostatic_slices_velocity_sharded_fused( arr: jnp.ndarray, cache: LinearCache, params: LinearParams, term_weights: LinearTerms, *, plan: Any, devices: Any, axis_name: str = "m", ) -> tuple[jnp.ndarray, jnp.ndarray]: """Fuse the current single-species periodic electrostatic shard-map route.""" from jax.sharding import Mesh, NamedSharding, PartitionSpec from spectraxgk.terms.operators import ( grad_z_periodic as operator_grad_z_periodic, shift_axis as operator_shift_axis, ) dims = ("l", "m", "ky", "kx", "z") m_axis = dims.index("m") m_chunks = int(plan.chunks.get("m", 1)) if m_chunks <= 1: raise ValueError("fused Hermite route requires more than one Hermite chunk") if int(arr.shape[m_axis]) % m_chunks != 0: raise ValueError("Hermite dimension must divide evenly across Hermite chunks") active_non_hermite = tuple( active_axis for active_axis in plan.active_axes if active_axis != "m" ) if active_non_hermite: raise NotImplementedError( "fused electrostatic slice route currently supports only an active 'm' axis" ) device_list = list(devices) if len(device_list) < m_chunks: raise ValueError("not enough devices for the requested Hermite decomposition") mesh = Mesh(np.asarray(device_list[:m_chunks]), (axis_name,)) spec_list: list[str | None] = [None] * arr.ndim spec_list[m_axis] = axis_name state_spec = PartitionSpec(*spec_list) phi_spec = PartitionSpec(None, None, None) sharding = NamedSharding(mesh, state_spec) local_m = int(arr.shape[m_axis]) // m_chunks local_m_index = jnp.arange(local_m, dtype=jnp.int32).reshape((1, local_m, 1, 1, 1)) prev_pairs = tuple((idx, idx + 1) for idx in range(m_chunks - 1)) next_pairs = tuple((idx, idx - 1) for idx in range(1, m_chunks)) real_dtype = jnp.real(arr).dtype jl = jnp.asarray(cache.Jl) if jl.ndim == 5: jl = jl[0] charge_s = jnp.asarray(params.charge_sign, dtype=real_dtype).reshape(-1)[0] density_s = jnp.asarray(params.density, dtype=real_dtype).reshape(-1)[0] tau = jnp.asarray(params.tau_e, dtype=real_dtype) tz_s = jnp.asarray(params.tz, dtype=real_dtype).reshape(-1)[0] zt = jnp.where(tz_s == 0.0, 0.0, 1.0 / tz_s) vth_s = jnp.asarray(params.vth, dtype=real_dtype).reshape(-1)[0] g0 = jnp.sum(jl * jl, axis=0) den_safe = jnp.where( tau + density_s * charge_s * zt * (1.0 - g0) == 0.0, jnp.inf, tau + density_s * charge_s * zt * (1.0 - g0), ) mask0 = None if cache.mask0 is None else jnp.asarray(cache.mask0) ell = jnp.arange(arr.shape[0], dtype=real_dtype).reshape((arr.shape[0], 1, 1, 1, 1)) ell_p1 = ell + 1.0 bgrad = jnp.asarray(cache.bgrad, dtype=real_dtype).reshape( (1, 1, 1, 1, int(jnp.asarray(cache.bgrad).shape[-1])) ) cv = jnp.asarray(cache.cv_d, dtype=real_dtype).reshape( (1, 1) + tuple(jnp.asarray(cache.cv_d).shape) ) gb = jnp.asarray(cache.gb_d, dtype=real_dtype).reshape( (1, 1) + tuple(jnp.asarray(cache.gb_d).shape) ) omega_d_scale = jnp.asarray(params.omega_d_scale, dtype=real_dtype) kpar_scale = jnp.asarray(params.kpar_scale, dtype=real_dtype) imag = jnp.asarray(1j, dtype=arr.dtype) omega_star = ( imag * jnp.asarray(params.omega_star_scale, dtype=real_dtype) * jnp.asarray(cache.ky, dtype=real_dtype) ) omega_star_s = omega_star.reshape((1, omega_star.shape[0], 1, 1)) tprim_s = jnp.asarray(params.R_over_LTi, dtype=real_dtype).reshape(-1)[0] fprim_s = jnp.asarray(params.R_over_Ln, dtype=real_dtype).reshape(-1)[0] jl_m1 = operator_shift_axis(jl, -1, axis=0) jl_p1 = operator_shift_axis(jl, 1, axis=0) l4 = jnp.asarray(cache.l4, dtype=real_dtype).reshape((arr.shape[0], 1, 1, 1)) w_streaming = jnp.asarray(term_weights.streaming, dtype=real_dtype) w_mirror = jnp.asarray(term_weights.mirror, dtype=real_dtype) w_curv = jnp.asarray(term_weights.curvature, dtype=real_dtype) w_gradb = jnp.asarray(term_weights.gradb, dtype=real_dtype) w_diamag = jnp.asarray(term_weights.diamagnetic, dtype=real_dtype) def shift_m(local, *, offset: int): depth = abs(int(offset)) if depth == 0: return local if offset < 0: boundary = local[:, -depth:, ...] received = jax.lax.ppermute(boundary, axis_name, prev_pairs) return jnp.concatenate([received, local[:, :-depth, ...]], axis=1) boundary = local[:, :depth, ...] received = jax.lax.ppermute(boundary, axis_name, next_pairs) return jnp.concatenate([local[:, depth:, ...], received], axis=1) def fused(local): global_m = jax.lax.axis_index(axis_name) * local_m + local_m_index global_m_real = global_m.astype(real_dtype) m0 = (global_m == 0).astype(local.dtype) local_gm0 = jnp.sum(local * m0, axis=1) local_nbar = density_s * charge_s * jnp.sum(jl * local_gm0, axis=0) phi = jax.lax.psum(local_nbar, axis_name) / den_safe if mask0 is not None: phi = jnp.where(mask0, 0.0, phi) dlocal_dz = operator_grad_z_periodic(local, kz=cache.kz) lower = shift_m(dlocal_dz, offset=-1) upper = shift_m(dlocal_dz, offset=1) streaming = -vth_s * ( jnp.sqrt(global_m_real + 1.0) * upper + jnp.sqrt(global_m_real) * lower ) field_drive_m1 = (global_m == 1).astype(local.dtype) * (-zt * vth_s * jl * phi)[ :, None, ... ] streaming = streaming + kpar_scale * operator_grad_z_periodic( field_drive_m1, kz=cache.kz ) h = local + (global_m == 0).astype(local.dtype) * (zt * jl * phi)[:, None, ...] h_m_p1 = shift_m(h, offset=1) h_m_m1 = shift_m(h, offset=-1) mirror_term = ( -jnp.sqrt(global_m_real + 1.0) * ell_p1 * h_m_p1 - jnp.sqrt(global_m_real + 1.0) * ell * operator_shift_axis(h_m_p1, -1, axis=0) + jnp.sqrt(global_m_real) * ell * h_m_m1 + jnp.sqrt(global_m_real) * ell_p1 * operator_shift_axis(h_m_m1, 1, axis=0) ) mirror = -vth_s * bgrad * mirror_term h_m_p2 = shift_m(h, offset=2) h_m_m2 = shift_m(h, offset=-2) curv_term = ( jnp.sqrt((global_m_real + 1.0) * (global_m_real + 2.0)) * h_m_p2 + (2.0 * global_m_real + 1.0) * h + jnp.sqrt(global_m_real * (global_m_real - 1.0)) * h_m_m2 ) gradb_term = ( (ell + 1.0) * operator_shift_axis(h, 1, axis=0) + (2.0 * ell + 1.0) * h + ell * operator_shift_axis(h, -1, axis=0) ) curvature = -(imag * tz_s * omega_d_scale * cv) * curv_term gradb = -(imag * tz_s * omega_d_scale * gb) * gradb_term drive_m0 = ( omega_star_s * phi * ( jl_m1 * (l4 * tprim_s) + jl * (fprim_s + 2.0 * l4 * tprim_s) + jl_p1 * ((l4 + 1.0) * tprim_s) ) ) drive_m2 = ( omega_star_s * phi * jl * (tprim_s / jnp.sqrt(jnp.asarray(2.0, dtype=real_dtype))) ) diamagnetic = (global_m == 0).astype(local.dtype) * drive_m0[:, None, ...] diamagnetic = ( diamagnetic + (global_m == 2).astype(local.dtype) * drive_m2[:, None, ...] ) rhs = ( w_streaming * streaming + w_mirror * mirror + w_curv * curvature + w_gradb * gradb ) rhs = rhs + w_diamag * diamagnetic return rhs, phi cache_key = ( "electrostatic_linear_slices_fused", tuple(int(x) for x in arr.shape), str(arr.dtype), id(cache), id(params), float(term_weights.streaming), float(term_weights.mirror), float(term_weights.curvature), float(term_weights.gradb), float(term_weights.diamagnetic), tuple(str(device) for device in device_list[:m_chunks]), axis_name, ) cached = _FUSED_ELECTROSTATIC_SLICE_KERNEL_CACHE.get(cache_key) if cached is None: mapped = jax.jit( jax.shard_map( fused, mesh=mesh, in_specs=state_spec, out_specs=(state_spec, phi_spec), axis_names={axis_name}, ) ) cached = (mapped, sharding) _FUSED_ELECTROSTATIC_SLICE_KERNEL_CACHE[cache_key] = cached else: mapped, sharding = cached return mapped(jax.device_put(arr, sharding))
[docs] def linear_rhs_electrostatic_slices_velocity_sharded( G: jnp.ndarray, cache: LinearCache, params: LinearParams, terms: LinearTerms | None = None, *, num_devices: int | None = None, devices: Any | None = None, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Compute gated electrostatic streaming, drift, and diamagnetic slices.""" from spectraxgk.velocity_sharding import ( build_velocity_sharding_plan, curvature_gradb_drift_shard_map, diamagnetic_drive_shard_map, electrostatic_phi_shard_map, mirror_drift_shard_map, ) term_weights = terms if terms is not None else LinearTerms() if not _is_electrostatic_slice_terms(term_weights): raise NotImplementedError( "electrostatic slice route allows only electrostatic linear terms" ) arr = jnp.asarray(G) if arr.ndim != 5: raise NotImplementedError( "velocity-sharded electrostatic slice route currently supports single-species 5D states" ) if bool(getattr(cache, "use_twist_shift", False)): raise NotImplementedError( "velocity-sharded electrostatic slice route currently requires a periodic z grid" ) device_list = _resolve_parallel_devices(num_devices=num_devices, devices=devices) plan = build_velocity_sharding_plan( arr.shape, num_devices=len(device_list), axes=("hermite",) ) if len(device_list) > 1: return _linear_rhs_electrostatic_slices_velocity_sharded_fused( arr, cache, params, term_weights, plan=plan, devices=device_list, ) real_dtype = jnp.real(arr).dtype phi = electrostatic_phi_shard_map( arr, plan, Jl=cache.Jl, tau_e=params.tau_e, charge=params.charge_sign, density=params.density, tz=params.tz, mask0=cache.mask0, devices=device_list, ) dG = jnp.zeros_like(arr) if float(term_weights.streaming) != 0.0: streaming = _streaming_electrostatic_from_phi_velocity_sharded( arr, cache, params, phi=phi, plan=plan, devices=device_list, ) dG = dG + jnp.asarray(term_weights.streaming, dtype=real_dtype) * streaming H = build_H(arr, cache.Jl, phi, jnp.asarray([params.tz], dtype=real_dtype)) if float(term_weights.mirror) != 0.0: dG = dG + mirror_drift_shard_map( H, plan, vth=jnp.asarray([params.vth], dtype=real_dtype), bgrad=cache.bgrad, ell=cache.l, sqrt_m=cache.sqrt_m, sqrt_m_p1=cache.sqrt_m_p1, weight=jnp.asarray(term_weights.mirror, dtype=real_dtype), devices=device_list, ) if float(term_weights.curvature) != 0.0 or float(term_weights.gradb) != 0.0: dG = dG + curvature_gradb_drift_shard_map( H, plan, tz=jnp.asarray([params.tz], dtype=real_dtype), omega_d_scale=params.omega_d_scale, cv_d=cache.cv_d, gb_d=cache.gb_d, ell=cache.l, m=cache.m, weight_curv=jnp.asarray(term_weights.curvature, dtype=real_dtype), weight_gradb=jnp.asarray(term_weights.gradb, dtype=real_dtype), devices=device_list, ) if float(term_weights.diamagnetic) != 0.0: dG = dG + diamagnetic_drive_shard_map( arr, plan, phi=phi, Jl=cache.Jl, l4=cache.l4, tprim=params.R_over_LTi, fprim=params.R_over_Ln, omega_star_scale=params.omega_star_scale, ky=cache.ky, weight=jnp.asarray(term_weights.diamagnetic, dtype=real_dtype), devices=device_list, ) return dG, phi
[docs] def linear_rhs_parallel_cached( G: jnp.ndarray, cache: LinearCache, params: LinearParams, terms: LinearTerms | None = None, *, parallel: Any | None = None, use_jit: bool = True, use_custom_vjp: bool = True, dt: jnp.ndarray | float | None = None, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Compute linear RHS with an explicit, disabled-by-default parallel route. ``parallel=None`` and ``parallel.strategy="serial"`` are exact aliases for :func:`linear_rhs_cached`. The non-serial velocity routes are opt-in, Hermite-axis-only identity gates. ``backend="auto"`` selects the most complete currently gated electrostatic route when the term set is eligible; otherwise callers must request a narrower explicit backend. """ import spectraxgk.linear as linear_compat if ( parallel is None or str(getattr(parallel, "strategy", "serial")).lower() == "serial" ): return linear_compat.linear_rhs_cached( G, cache, params, terms=terms, use_jit=use_jit, use_custom_vjp=use_custom_vjp, dt=dt, ) strategy = str(getattr(parallel, "strategy", "serial")).lower().replace("-", "_") backend = str(getattr(parallel, "backend", "auto")).lower().replace("-", "_") axis = str(getattr(parallel, "axis", "hermite")).lower().replace("-", "_") if strategy == "velocity" and backend == "auto": if axis not in {"m", "hermite"}: raise NotImplementedError( "velocity sharding currently supports only the Hermite axis" ) if _is_electrostatic_slice_terms(terms): backend = "electrostatic_linear_slices" else: raise NotImplementedError( "backend='auto' can only select gated electrostatic velocity routes; " "disable collision/EM/end-damping terms or request an explicit backend" ) if strategy == "velocity" and backend in { "streaming_only", "linear_streaming_only", }: if axis not in {"m", "hermite"}: raise NotImplementedError( "streaming-only velocity sharding currently supports only the Hermite axis" ) if not _is_streaming_only_terms(terms): raise NotImplementedError( "velocity streaming route requires streaming-only LinearTerms" ) return linear_compat.linear_rhs_streaming_velocity_sharded( G, cache, params, num_devices=getattr(parallel, "num_devices", None), ) if strategy == "velocity" and backend in { "streaming_electrostatic", "linear_streaming_electrostatic", }: if axis not in {"m", "hermite"}: raise NotImplementedError( "electrostatic streaming velocity sharding currently supports only the Hermite axis" ) if not _is_streaming_only_terms(terms): raise NotImplementedError( "electrostatic velocity streaming route requires streaming-only LinearTerms" ) return linear_compat.linear_rhs_streaming_electrostatic_velocity_sharded( G, cache, params, num_devices=getattr(parallel, "num_devices", None), use_custom_vjp=use_custom_vjp, ) if strategy == "velocity" and backend in { "electrostatic_linear_slices", "linear_electrostatic_slices", }: if axis not in {"m", "hermite"}: raise NotImplementedError( "electrostatic slice velocity sharding currently supports only the Hermite axis" ) if not _is_electrostatic_slice_terms(terms): raise NotImplementedError( "electrostatic slice route requires collision/EM terms to be disabled" ) return linear_compat.linear_rhs_electrostatic_slices_velocity_sharded( G, cache, params, terms=terms, num_devices=getattr(parallel, "num_devices", None), ) raise NotImplementedError( "parallel linear RHS currently supports only strategy='velocity' with gated electrostatic backends" )