Source code for spectraxgk.sharded_integrators

"""Sharded fixed-step integrators for multi-device scaling experiments."""

from __future__ import annotations

from typing import Any

import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.pjit import pjit

from spectraxgk.linear import LinearCache, LinearParams, LinearTerms, linear_rhs_cached
from spectraxgk.nonlinear import _make_hermitian_projector, nonlinear_rhs_cached
from spectraxgk.terms.config import FieldState, TermConfig


_EXPLICIT_METHODS = {"euler", "rk2", "rk3", "rk3_gx", "rk3_classic", "rk4", "sspx3"}


def _dt_array(dt: float, state_dtype: jnp.dtype) -> jnp.ndarray:
    return jnp.asarray(dt, dtype=jnp.real(jnp.empty((), dtype=state_dtype)).dtype)


def _validate_steps(steps: int) -> None:
    if steps < 1:
        raise ValueError("steps must be >= 1")


def _validate_explicit_method(method: str) -> str:
    method_key = str(method).strip().lower()
    if method_key not in _EXPLICIT_METHODS:
        raise ValueError(
            "method must be one of {'euler', 'rk2', 'rk3', 'rk3_gx', 'rk3_classic', 'rk4', 'sspx3'}"
        )
    return method_key


[docs] def integrate_linear_sharded( G0: jnp.ndarray, cache: LinearCache, params: LinearParams, *, dt: float, steps: int, terms: LinearTerms | None = None, state_sharding: Any | None = None, ) -> jnp.ndarray: """Integrate the linear system with a pjit-sharded RK2 loop. This is intentionally minimal: it returns the final state only and avoids saving time histories to focus on strong scaling of the RHS. """ if terms is None: terms = LinearTerms() _validate_steps(steps) state_dtype = jnp.result_type(G0, jnp.complex64) G0 = jnp.asarray(G0, dtype=state_dtype) dt_val = _dt_array(dt, state_dtype) def _maybe_shard(state: jnp.ndarray) -> jnp.ndarray: if state_sharding is None: return state return jax.lax.with_sharding_constraint(state, state_sharding) def step(G, _): G = _maybe_shard(G) dG, _ = linear_rhs_cached(G, cache, params, terms=terms, dt=dt_val) G_half = G + 0.5 * dt_val * dG dG_half, _ = linear_rhs_cached(G_half, cache, params, terms=terms, dt=dt_val) G_next = G + dt_val * dG_half return _maybe_shard(G_next), None def run(G_init): G_final, _ = jax.lax.scan(step, G_init, xs=None, length=steps) return G_final run_pjit = pjit( run, in_shardings=state_sharding, out_shardings=state_sharding, ) if state_sharding is not None: G0 = jax.device_put(G0, state_sharding) G0 = _maybe_shard(G0) return run_pjit(G0)
[docs] def integrate_nonlinear_sharded( G0: jnp.ndarray, cache: LinearCache, params: LinearParams, *, dt: float, steps: int, method: str = "rk2", terms: TermConfig | None = None, state_sharding: Any | None = None, gx_real_fft: bool = True, laguerre_mode: str = "grid", return_fields: bool = True, ) -> tuple[jnp.ndarray, FieldState] | jnp.ndarray: """Integrate the nonlinear system with an explicit pjit-sharded scan. The state array can be partitioned along a ``resolve_state_sharding`` axis such as ``ky`` or ``kx``. This is a diagnostic whole-state sharding primitive for identity gates and profiler localization. It is not a production nonlinear domain decomposition or speedup claim until the exact workload has communication-complete identity, conservation, transport, and profiler gates. Domain-sharding identity reports are metadata gates only; they do not authorize routing through this whole-state integrator. """ _validate_steps(steps) method_key = _validate_explicit_method(method) term_cfg = terms or TermConfig() state_dtype = jnp.result_type(G0, jnp.complex64) G0 = jnp.asarray(G0, dtype=state_dtype) dt_val = _dt_array(dt, state_dtype) projector = _make_hermitian_projector(np.asarray(cache.ky), int(np.asarray(cache.kx).size)) if gx_real_fft else None def _maybe_shard(state: jnp.ndarray) -> jnp.ndarray: if state_sharding is None: return state return jax.lax.with_sharding_constraint(state, state_sharding) def _project(state: jnp.ndarray) -> jnp.ndarray: if projector is None: return state return jnp.asarray(projector(state), dtype=state_dtype) def _rhs(state: jnp.ndarray) -> tuple[jnp.ndarray, FieldState]: dG, fields = nonlinear_rhs_cached( state, cache, params, term_cfg, gx_real_fft=gx_real_fft, laguerre_mode=laguerre_mode, ) return jnp.asarray(dG, dtype=state_dtype), fields def _stage(state: jnp.ndarray, increment: jnp.ndarray, scale: float) -> jnp.ndarray: return _maybe_shard(_project(state + jnp.asarray(scale, dtype=dt_val.dtype) * dt_val * increment)) def step(G: jnp.ndarray, _): G = _maybe_shard(_project(G)) k1, _ = _rhs(G) if method_key == "euler": G_next = G + dt_val * k1 elif method_key == "rk2": k2, _ = _rhs(_stage(G, k1, 0.5)) G_next = G + dt_val * k2 elif method_key == "rk3_classic": G1 = _stage(G, k1, 1.0) k2, _ = _rhs(G1) G2 = _maybe_shard(_project(0.75 * G + 0.25 * (G1 + dt_val * k2))) k3, _ = _rhs(G2) G_next = (1.0 / 3.0) * G + (2.0 / 3.0) * (G2 + dt_val * k3) elif method_key in {"rk3", "rk3_gx"}: k2, _ = _rhs(_stage(G, k1, 1.0 / 3.0)) k3, _ = _rhs(_stage(G, k2, 2.0 / 3.0)) G3 = _stage(G, k3, 0.75) G_next = G3 + 0.25 * dt_val * k1 elif method_key == "rk4": k2, _ = _rhs(_stage(G, k1, 0.5)) k3, _ = _rhs(_stage(G, k2, 0.5)) k4, _ = _rhs(_stage(G, k3, 1.0)) G_next = G + (dt_val / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4) else: G1 = _stage(G, k1, 1.0) k2, _ = _rhs(G1) G2 = _maybe_shard(_project(0.75 * G + 0.25 * (G1 + dt_val * k2))) k3, _ = _rhs(G2) G_next = (1.0 / 3.0) * G + (2.0 / 3.0) * (G2 + dt_val * k3) G_next = _maybe_shard(_project(jnp.asarray(G_next, dtype=state_dtype))) _dG_next, fields_next = _rhs(G_next) return G_next, fields_next def run_with_fields(G_init): G_final, fields_t = jax.lax.scan(step, G_init, xs=None, length=steps) return G_final, fields_t def run_final_only(G_init): G_final, _fields_t = jax.lax.scan(step, G_init, xs=None, length=steps) return G_final if state_sharding is not None: G0 = jax.device_put(G0, state_sharding) G0 = _maybe_shard(G0) if return_fields: return pjit(run_with_fields, in_shardings=state_sharding, out_shardings=None)(G0) return pjit(run_final_only, in_shardings=state_sharding, out_shardings=state_sharding)(G0)