"""Linked-boundary FFT maps and damping profiles for linear operators."""
from __future__ import annotations
import jax.numpy as jnp
import numpy as np
__all__ = [
"_build_linked_end_damping_profile",
"_build_linked_fft_maps",
"_signed_to_index",
]
def _signed_to_index(idx: int, n: int) -> int:
half = (n + 1) // 2
if 0 <= idx < half:
return idx
if half <= idx + n < n:
return idx + n
return -1
[docs]
def _build_linked_fft_maps(
kx: np.ndarray,
ky: np.ndarray,
y0: float,
jtwist: int,
dz: float,
nz: int,
real_dtype: jnp.dtype,
ky_mode: np.ndarray | None = None,
) -> tuple[tuple[jnp.ndarray, ...], tuple[jnp.ndarray, ...]]:
"""Construct GX-style linked FFT index maps for the parallel derivative."""
ny = ky.size
nx = kx.size
if ky_mode is not None:
naky = int(np.asarray(ky_mode, dtype=int).reshape(-1).size)
else:
naky = 1 + (ny - 1) // 3
if nx < 4:
nakx = nx
else:
nakx = 1 + 2 * ((nx - 1) // 3)
if nakx <= 0 or naky <= 0:
return (), ()
nshift = nx - nakx
idx_left = -np.ones((naky, nakx), dtype=int)
idx_right = -np.ones((naky, nakx), dtype=int)
ky_mode_arr: np.ndarray | None = None
if ky_mode is not None:
ky_mode_arr = np.asarray(ky_mode, dtype=int).reshape(-1)
for idx in range(nakx):
idx0 = idx if idx < (nakx + 1) // 2 else idx - nakx
for idy in range(naky):
idy_mode = int(ky_mode_arr[idy]) if ky_mode_arr is not None else idy
if idy_mode == 0:
idx_l = idx0
idx_r = idx0
else:
idx_l = idx0 + idy_mode * jtwist
idx_r = idx0 - idy_mode * jtwist
idx_left[idy, idx] = _signed_to_index(idx_l, nakx)
idx_right[idy, idx] = _signed_to_index(idx_r, nakx)
links_l = np.zeros((naky, nakx), dtype=int)
links_r = np.zeros((naky, nakx), dtype=int)
for idx in range(nakx):
for idy in range(naky):
idx_star = idx
while idx_star != idx_left[idy, idx_star] and idx_left[idy, idx_star] >= 0:
links_l[idy, idx] += 1
idx_star = idx_left[idy, idx_star]
idx_star = idx
while idx_star != idx_right[idy, idx_star] and idx_right[idy, idx_star] >= 0:
links_r[idy, idx] += 1
idx_star = idx_right[idy, idx_star]
n_k = np.zeros(naky * nakx, dtype=int)
k = 0
for idx in range(nakx):
for idy in range(naky):
n_k[k] = 1 + links_l[idy, idx] + links_r[idy, idx]
k += 1
n_k_sorted = np.sort(n_k)
unique_vals = np.unique(n_k_sorted)
n_links = unique_vals.astype(int)
n_chains = np.zeros_like(n_links)
for i, val in enumerate(n_links):
count = int(np.sum(n_k_sorted == val))
n_chains[i] = count // val if val > 0 else 0
linked_indices: list[jnp.ndarray] = []
linked_kz: list[jnp.ndarray] = []
for nlinks_val, nchains_val in zip(n_links, n_chains):
if nlinks_val <= 0 or nchains_val <= 0:
continue
link_kx = np.zeros((nchains_val, nlinks_val), dtype=np.int32)
link_ky = np.zeros((nchains_val, nlinks_val), dtype=np.int32)
n = 0
for idy in range(naky):
for idx in range(nakx):
np_k = 1 + links_l[idy, idx] + links_r[idy, idx]
if np_k != nlinks_val:
continue
p = links_l[idy, idx]
if p != 0:
continue
idx0 = idx if idx < (nakx + 1) // 2 else idx + nshift
link_ky[n, 0] = idy
link_kx[n, 0] = idx0
idx_r = idx
for p in range(1, nlinks_val):
idx_r = idx_right[idy, idx_r]
link_ky[n, p] = idy
if idx_r < (nakx + 1) // 2:
link_kx[n, p] = idx_r
else:
link_kx[n, p] = idx_r + nshift
n += 1
idx_flat = link_ky + ny * link_kx
linked_indices.append(jnp.asarray(idx_flat, dtype=jnp.int32))
nzL = int(nlinks_val) * int(nz)
kz_linked = 2.0 * np.pi * np.fft.fftfreq(nzL, d=float(dz))
linked_kz.append(jnp.asarray(kz_linked, dtype=real_dtype))
return tuple(linked_indices), tuple(linked_kz)
[docs]
def _build_linked_end_damping_profile(
*,
linked_indices: tuple[jnp.ndarray, ...],
ny: int,
nx: int,
nz: int,
widthfrac: float,
ky_mode: np.ndarray | None = None,
) -> np.ndarray:
"""Construct the GX linked-boundary damping profile on the full FFT grid."""
profile = np.zeros((ny, nx, nz), dtype=float)
if not linked_indices or widthfrac <= 0.0 or ny <= 0 or nx <= 0 or nz <= 0:
return profile
ky_mode_arr: np.ndarray | None = None
if ky_mode is not None:
ky_mode_arr = np.asarray(ky_mode, dtype=np.int32).reshape(-1)
if ky_mode_arr.size < ny:
raise ValueError("ky_mode must have at least ny entries for linked end damping")
if nx > 1:
kx_neg = np.concatenate(([0], np.arange(nx - 1, 0, -1, dtype=np.int32)))
else:
kx_neg = np.asarray([0], dtype=np.int32)
for idx_map_j in linked_indices:
idx_map = np.asarray(idx_map_j, dtype=np.int32)
if idx_map.ndim != 2 or idx_map.size == 0:
continue
nlinks = int(idx_map.shape[1])
width = int(nz * nlinks * float(widthfrac))
if width <= 0:
continue
chain_extent = nz * nlinks
for chain in idx_map:
for p, idx_flat in enumerate(chain):
ky_idx = int(idx_flat % ny)
kx_idx = int(idx_flat // ny)
ky_phys = int(ky_mode_arr[ky_idx]) if ky_mode_arr is not None else ky_idx
if ky_phys == 0:
continue
if ky_mode_arr is not None:
mirror_matches = np.flatnonzero(ky_mode_arr == -ky_phys)
mirror_ky = int(mirror_matches[0]) if mirror_matches.size else ky_idx
else:
mirror_ky = (-ky_idx) % ny
mirror_kx = int(kx_neg[kx_idx])
for idz in range(nz):
idzp = idz + nz * p
nu = 0.0
if idzp <= width:
x = float(idzp) / float(width)
nu = 1.0 - 2.0 * x * x / (1.0 + x**4)
elif idzp >= chain_extent - width:
x = float(chain_extent - idzp) / float(width)
nu = 1.0 - 2.0 * x * x / (1.0 + x**4)
profile[ky_idx, kx_idx, idz] = nu
if mirror_ky != ky_idx:
profile[mirror_ky, mirror_kx, idz] = nu
return profile