"""Gyroaveraging coefficients for Laguerre velocity space."""
from __future__ import annotations
import jax.numpy as jnp
from jax.scipy.special import gammaln, i0e
import math
import numpy as np
[docs]
def gamma0(b: jnp.ndarray) -> jnp.ndarray:
"""Compute Gamma_0(b) = exp(-b) I_0(b) using i0e for stability."""
b = jnp.asarray(b)
return i0e(b)
[docs]
def bessel_j0(x: jnp.ndarray) -> jnp.ndarray:
"""Return J0(x) using a Cephes-style approximation (GX-compatible)."""
x = jnp.asarray(x)
ax = jnp.abs(x)
y = x * x
r = (
57568490574.0
+ y
* (
-13362590354.0
+ y * (651619640.7 + y * (-11214424.18 + y * (77392.33017 + y * -184.9052456)))
)
)
s = (
57568490411.0
+ y
* (
1029532985.0
+ y * (9494680.718 + y * (59272.64853 + y * (267.8532712 + y)))
)
)
res_small = r / s
z = 8.0 / jnp.maximum(ax, 1.0e-30)
y2 = z * z
xx = ax - 0.785398164
p = 1.0 + y2 * (
-0.1098628627e-2
+ y2 * (0.2734510407e-4 + y2 * (-0.2073370639e-5 + y2 * 0.2093887211e-6))
)
q = -0.1562499995e-1 + y2 * (
0.1430488765e-3
+ y2 * (-0.6911147651e-5 + y2 * (0.7621095161e-6 + y2 * -0.934945152e-7))
)
res_large = jnp.sqrt(0.636619772 / jnp.maximum(ax, 1.0e-30)) * (
jnp.cos(xx) * p - z * jnp.sin(xx) * q
)
out = jnp.where(ax < 8.0, res_small, res_large)
return jnp.where(jnp.isfinite(out), out, res_small)
[docs]
def bessel_j1(x: jnp.ndarray) -> jnp.ndarray:
"""Return J1(x) using a Cephes-style approximation (GX-compatible)."""
x = jnp.asarray(x)
ax = jnp.abs(x)
y = x * x
r = (
72362614232.0
+ y
* (
-7895059235.0
+ y * (242396853.1 + y * (-2972611.439 + y * (15704.48260 + y * -30.16036606)))
)
)
s = (
144725228442.0
+ y
* (
2300535178.0
+ y * (18583304.74 + y * (99447.43394 + y * (376.9991397 + y)))
)
)
res_small = x * (r / s)
z = 8.0 / jnp.maximum(ax, 1.0e-30)
y2 = z * z
xx = ax - 2.356194491
p = 1.0 + y2 * (
0.183105e-2
+ y2 * (-0.3516396496e-4 + y2 * (0.2457520174e-5 + y2 * -0.240337019e-6))
)
q = 0.04687499995 + y2 * (
-0.2002690873e-3
+ y2 * (0.8449199096e-5 + y2 * (-0.88228987e-6 + y2 * 0.105787412e-6))
)
res_large = jnp.sqrt(0.636619772 / jnp.maximum(ax, 1.0e-30)) * (
jnp.cos(xx) * p - z * jnp.sin(xx) * q
)
res_large = jnp.where(x < 0.0, -res_large, res_large)
out = jnp.where(ax < 8.0, res_small, res_large)
return jnp.where(jnp.isfinite(out), out, res_small)
[docs]
def gx_factorial(m: jnp.ndarray) -> jnp.ndarray:
"""Return GX's single-precision factorial approximation."""
m_arr = jnp.asarray(m)
dtype = m_arr.dtype
exact = jnp.asarray([1.0, 1.0, 2.0, 6.0, 24.0, 120.0, 720.0], dtype=dtype)
m_int = m_arr.astype(jnp.int32)
m_clamped = jnp.clip(m_int, 0, exact.shape[0] - 1)
m_safe = jnp.where(m_arr > 0, m_arr, jnp.asarray(1.0, dtype=dtype))
stirling = jnp.sqrt(2.0 * jnp.asarray(jnp.pi, dtype=dtype) * m_safe) * (m_safe**m_safe) * jnp.exp(
-m_safe
) * (
1.0
+ 1.0 / (12.0 * m_safe)
+ 1.0 / (288.0 * m_safe * m_safe)
)
return jnp.where(m_int <= 6, exact[m_clamped], stirling)
[docs]
def J_l_all(b: jnp.ndarray, l_max: int) -> jnp.ndarray:
"""Gyroaveraging coefficients matching the GX Laguerre-Hermite convention."""
if l_max < 0:
raise ValueError("l_max must be >= 0")
b = jnp.asarray(b)
ell = jnp.arange(l_max + 1, dtype=b.dtype)
l_shape = (l_max + 1,) + (1,) * b.ndim
ell = ell.reshape(l_shape)
sign = jnp.where((ell % 2) == 0, 1.0, -1.0)
half_b = 0.5 * b
half_b_safe = jnp.where(half_b > 0.0, half_b, 1.0)
log_abs = ell * jnp.log(half_b_safe[None, ...]) - gammaln(ell + 1.0) - half_b[None, ...]
Jl = sign * jnp.exp(log_abs)
zero_mask = (b == 0.0)[None, ...]
Jl = jnp.where(zero_mask & (ell == 0), 1.0, Jl)
Jl = jnp.where(zero_mask & (ell > 0), 0.0, Jl)
return Jl
[docs]
def sum_Jl2(b: jnp.ndarray, l_max: int) -> jnp.ndarray:
"""Truncated sum of J_l(b)^2, useful for Gamma_0 convergence checks."""
Jl = J_l_all(b, l_max)
return jnp.sum(Jl * Jl, axis=0)
[docs]
def gx_laguerre_nj(nl: int) -> int:
"""GX default for number of Laguerre quadrature points."""
return max(1, 3 * nl // 2 - 1)