Source code for spectraxgk.basis

"""Hermite and Laguerre basis utilities."""

from __future__ import annotations

import jax
import jax.numpy as jnp
from jax.scipy.special import gammaln


[docs] def hermite_physicists(x: jnp.ndarray, n_max: int) -> jnp.ndarray: """Physicists' Hermite polynomials H_n(x) for n=0..n_max. Weight: exp(-x**2). Recurrence: H_0 = 1 H_1 = 2x H_{n+1} = 2x H_n - 2n H_{n-1} """ x = jnp.asarray(x) if n_max < 0: raise ValueError("n_max must be >= 0") if n_max == 0: return jnp.expand_dims(jnp.ones_like(x), axis=0) h0 = jnp.ones_like(x) h1 = 2.0 * x def step(carry, n): h_prev, h_curr = carry h_next = 2.0 * x * h_curr - 2.0 * n * h_prev return (h_curr, h_next), h_next _, tail = jax.lax.scan(step, (h0, h1), jnp.arange(1, n_max)) return jnp.concatenate([h0[None, ...], h1[None, ...], tail], axis=0)
[docs] def hermite_normed(x: jnp.ndarray, n_max: int) -> jnp.ndarray: """Normalized Hermite functions with weight exp(-x**2). psi_n = H_n(x) / sqrt(2**n * n! * sqrt(pi)) """ h = hermite_physicists(x, n_max) n = jnp.arange(0, n_max + 1) log_norm = 0.5 * (n * jnp.log(2.0) + gammaln(n + 1) + 0.5 * jnp.log(jnp.pi)) norm = jnp.exp(log_norm) return h / norm[:, None]
[docs] def laguerre(x: jnp.ndarray, l_max: int) -> jnp.ndarray: """Laguerre polynomials L_l(x) for l=0..l_max. Weight: exp(-x). Recurrence: L_0 = 1 L_1 = 1 - x (l+1) L_{l+1} = (2l+1-x) L_l - l L_{l-1} """ x = jnp.asarray(x) if l_max < 0: raise ValueError("l_max must be >= 0") if l_max == 0: return jnp.expand_dims(jnp.ones_like(x), axis=0) l0 = jnp.ones_like(x) l1 = 1.0 - x def step(carry, ell_idx): l_prev, l_curr = carry l_next = ((2.0 * ell_idx + 1.0 - x) * l_curr - ell_idx * l_prev) / (ell_idx + 1.0) return (l_curr, l_next), l_next _, tail = jax.lax.scan(step, (l0, l1), jnp.arange(1, l_max)) return jnp.concatenate([l0[None, ...], l1[None, ...], tail], axis=0)
[docs] def hermite_ladder_coeffs(n_max: int) -> tuple[jnp.ndarray, jnp.ndarray]: """Return sqrt(n+1) and sqrt(n) arrays for Hermite ladder operators.""" if n_max < 0: raise ValueError("n_max must be >= 0") n = jnp.arange(0, n_max + 1) return jnp.sqrt(n + 1.0), jnp.sqrt(n)