"""Spectral grid utilities for flux-tube geometry."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence
import jax
import jax.numpy as jnp
import numpy as np
from spectraxgk.config import GridConfig
[docs]
@jax.tree_util.register_pytree_node_class
@dataclass(frozen=True)
class SpectralGrid:
kx: jnp.ndarray
ky: jnp.ndarray
z: jnp.ndarray
kx_grid: jnp.ndarray
ky_grid: jnp.ndarray
dealias_mask: jnp.ndarray
y0: float
x0: float
boundary: str
jtwist: int | None
non_twist: bool
kxfac: float
ky_mode: jnp.ndarray | None = None
def tree_flatten(self):
children = (
self.kx,
self.ky,
self.z,
self.kx_grid,
self.ky_grid,
self.dealias_mask,
)
aux_data = (
self.y0,
self.x0,
self.boundary,
self.jtwist,
self.non_twist,
self.kxfac,
self.ky_mode,
)
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
y0, x0, boundary, jtwist, non_twist, kxfac, ky_mode = aux_data
return cls(
*children,
y0=y0,
x0=x0,
boundary=boundary,
jtwist=jtwist,
non_twist=non_twist,
kxfac=kxfac,
ky_mode=ky_mode,
)
def _fftfreq_phys(n: int, L: float) -> jnp.ndarray:
"""Physical wave numbers for an FFT grid of length L."""
return 2.0 * jnp.pi * jnp.fft.fftfreq(n, d=L / n)
[docs]
def twothirds_mask(Ny: int, Nx: int) -> jnp.ndarray:
"""2/3 dealiasing mask for 2D Fourier grids."""
ky = jnp.fft.fftfreq(Ny)
kx = jnp.fft.fftfreq(Nx)
# GX excludes the boundary shell |k| = 1/3 and keeps only the strict 2/3 interior.
ky_ok = jnp.abs(ky) < (1.0 / 3.0)
kx_ok = jnp.abs(kx) < (1.0 / 3.0)
return ky_ok[:, None] & kx_ok[None, :]
[docs]
def real_fft_unique_ky(ky: jnp.ndarray) -> jnp.ndarray:
"""Return the compressed non-negative `ky` block for a real FFT."""
ky_arr = jnp.asarray(ky)
if ky_arr.ndim == 0:
raise ValueError("ky must be at least 1D")
ky_1d = ky_arr if ky_arr.ndim == 1 else ky_arr[:, 0]
nyc = 1 + int(ky_1d.shape[0]) // 2
return jnp.abs(ky_1d[:nyc])
[docs]
def real_fft_ordered_kx(kx: jnp.ndarray) -> jnp.ndarray:
"""Return the `kx` ordering used with real-FFT nonlinear kernels."""
kx_arr = jnp.asarray(kx)
if kx_arr.ndim == 0:
raise ValueError("kx must be at least 1D")
kx_1d = kx_arr if kx_arr.ndim == 1 else kx_arr[0, :]
nx = int(kx_1d.shape[0])
if nx == 0 or (nx % 2) != 0:
return kx_1d
return kx_1d.at[nx // 2].set(jnp.abs(kx_1d[nx // 2]))
[docs]
def real_fft_mesh(
kx_grid: jnp.ndarray,
ky_grid: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Return compressed real-FFT `(kx, ky)` multipliers and meshgrids."""
kx = real_fft_ordered_kx(kx_grid)
ky = real_fft_unique_ky(ky_grid)
ky_mesh, kx_mesh = jnp.meshgrid(ky, kx, indexing="ij")
return kx, ky, kx_mesh, ky_mesh
gx_real_fft_ky = real_fft_unique_ky
gx_real_fft_kx = real_fft_ordered_kx
gx_real_fft_mesh = real_fft_mesh
def build_spectral_grid(cfg: GridConfig) -> SpectralGrid:
Lx = cfg.Lx
Ly = 2.0 * jnp.pi * cfg.y0 if cfg.y0 is not None else cfg.Ly
y0 = float(cfg.y0) if cfg.y0 is not None else float(Ly) / (2.0 * jnp.pi)
x0 = float(Lx) / (2.0 * jnp.pi)
zp = cfg.zp
if zp is None:
if cfg.nperiod is not None:
zp = 2 * cfg.nperiod - 1
elif cfg.ntheta is not None:
zp = 1
Nz = cfg.Nz
if cfg.ntheta is not None:
Nz = int(cfg.ntheta) * int(zp if zp is not None else 1)
z_min = -jnp.pi * float(zp if zp is not None else 1)
z_max = jnp.pi * float(zp if zp is not None else 1)
else:
z_min = cfg.z_min
z_max = cfg.z_max
kx = _fftfreq_phys(cfg.Nx, Lx)
ky = _fftfreq_phys(cfg.Ny, Ly)
z = jnp.linspace(z_min, z_max, Nz, endpoint=False)
ky_grid, kx_grid = jnp.meshgrid(ky, kx, indexing="ij")
mask = twothirds_mask(cfg.Ny, cfg.Nx)
return SpectralGrid(
kx=kx,
ky=ky,
z=z,
kx_grid=kx_grid,
ky_grid=ky_grid,
dealias_mask=mask,
y0=y0,
x0=x0,
boundary=str(cfg.boundary),
jtwist=cfg.jtwist,
non_twist=bool(cfg.non_twist),
kxfac=float(cfg.kxfac),
ky_mode=None,
)
[docs]
def select_ky_grid(
grid: SpectralGrid,
ky_index: int | jnp.ndarray | np.ndarray | Sequence[int],
) -> SpectralGrid:
"""Return a linear-solver grid sliced down to one or more ky indices.
The parent grid's two-thirds mask belongs to nonlinear FFT products. A
linear ky scan must not zero a selected high-ky mode just because that row
would be dealiased in a nonlinear convolution, so sliced linear grids carry
an all-true mask.
"""
ky_idx = jnp.asarray(ky_index, dtype=jnp.int32)
if ky_idx.ndim == 0:
ky_idx = ky_idx[None]
ky = jnp.take(grid.ky, ky_idx, axis=0)
ky_grid = jnp.take(grid.ky_grid, ky_idx, axis=0)
kx_grid = jnp.take(grid.kx_grid, ky_idx, axis=0)
mask = jnp.ones_like(jnp.take(grid.dealias_mask, ky_idx, axis=0), dtype=bool)
ky_mode = jnp.rint(ky * grid.y0).astype(jnp.int32)
return SpectralGrid(
kx=grid.kx,
ky=ky,
z=grid.z,
kx_grid=kx_grid,
ky_grid=ky_grid,
dealias_mask=mask,
y0=grid.y0,
x0=grid.x0,
boundary=grid.boundary,
jtwist=grid.jtwist,
non_twist=grid.non_twist,
kxfac=grid.kxfac,
ky_mode=ky_mode,
)
[docs]
def select_real_fft_ky_grid(
grid: SpectralGrid,
ky_values: jnp.ndarray | np.ndarray | Sequence[float],
) -> SpectralGrid:
"""Return a positive-`ky` real-FFT view of `grid`."""
ky_vals = jnp.asarray(ky_values, dtype=grid.ky.dtype)
if ky_vals.ndim != 1 or ky_vals.size == 0:
raise ValueError("ky_values must be a non-empty 1D array")
nky = int(ky_vals.shape[0])
if nky > int(grid.ky.shape[0]):
raise ValueError("ky_values length cannot exceed the full grid ky length")
kx_vals = real_fft_ordered_kx(grid.kx)
mask = jnp.take(grid.dealias_mask, jnp.arange(nky, dtype=jnp.int32), axis=0)
kx_grid = jnp.broadcast_to(kx_vals[None, :], (nky, kx_vals.shape[0]))
ky_grid = jnp.broadcast_to(ky_vals[:, None], (nky, kx_vals.shape[0]))
ky_mode = jnp.rint(ky_vals * grid.y0).astype(jnp.int32)
return SpectralGrid(
kx=kx_vals,
ky=ky_vals,
z=grid.z,
kx_grid=kx_grid,
ky_grid=ky_grid,
dealias_mask=mask,
y0=grid.y0,
x0=grid.x0,
boundary=grid.boundary,
jtwist=grid.jtwist,
non_twist=grid.non_twist,
kxfac=grid.kxfac,
ky_mode=ky_mode,
)
select_gx_real_fft_ky_grid = select_real_fft_ky_grid