Source code for spectraxgk.geometry

"""Analytic flux-tube geometry for the Cyclone base case."""

from __future__ import annotations

from dataclasses import dataclass, replace
import math
from pathlib import Path

import jax
import jax.numpy as jnp
import numpy as np

from spectraxgk.config import GeometryConfig, GridConfig

ZERO_SHAT_THRESHOLD = 1.0e-5


[docs] def zero_shear_enabled( s_hat: float, *, zero_shat: bool = False, threshold: float = ZERO_SHAT_THRESHOLD, ) -> bool: """Return the effective zero-shear state.""" return bool(zero_shat) or abs(float(s_hat)) < float(threshold)
[docs] def effective_boundary( boundary: str, *, s_hat: float, zero_shat: bool = False, threshold: float = ZERO_SHAT_THRESHOLD, ) -> str: """Return the effective boundary after zero-shear promotion.""" if zero_shear_enabled(s_hat, zero_shat=zero_shat, threshold=threshold): return "periodic" return str(boundary)
gx_zero_shat_enabled = zero_shear_enabled gx_effective_boundary = effective_boundary
[docs] @jax.tree_util.register_pytree_node_class @dataclass(frozen=True) class SAlphaGeometry: """Simple s-alpha geometry with circular concentric flux surfaces.""" q: float s_hat: float epsilon: float R0: float = 1.0 B0: float = 1.0 alpha: float = 0.0 drift_scale: float = 1.0 kperp2_bmag: bool = True bessel_bmag_power: float = 0.0 def tree_flatten(self): children = ( self.q, self.s_hat, self.epsilon, self.R0, self.B0, self.alpha, self.drift_scale, self.kperp2_bmag, self.bessel_bmag_power, ) return children, None @classmethod def tree_unflatten(cls, aux_data, children): return cls(*children) @staticmethod def from_config(cfg: GeometryConfig) -> "SAlphaGeometry": zero_shat = zero_shear_enabled(cfg.s_hat, zero_shat=cfg.zero_shat) return SAlphaGeometry( q=cfg.q, s_hat=0.0 if zero_shat else cfg.s_hat, epsilon=cfg.epsilon, R0=cfg.R0, B0=cfg.B0, alpha=cfg.alpha, drift_scale=cfg.drift_scale, kperp2_bmag=cfg.kperp2_bmag, bessel_bmag_power=cfg.bessel_bmag_power, )
[docs] def kx_effective(self, kx0: jnp.ndarray, ky: jnp.ndarray, theta: jnp.ndarray) -> jnp.ndarray: """Field-aligned kx(theta) with s-alpha shear shift.""" shear = self.s_hat * theta - self.alpha * jnp.sin(theta) return kx0 - shear * ky
[docs] def bmag(self, theta: jnp.ndarray) -> jnp.ndarray: """Magnetic field strength for circular s-alpha geometry.""" return 1.0 / (1.0 + self.epsilon * jnp.cos(theta))
[docs] def gradpar(self) -> float: """Parallel gradient factor for s-alpha geometry (constant for equal-arc).""" return float(abs(1.0 / (self.q * self.R0)))
[docs] def bgrad(self, theta: jnp.ndarray) -> jnp.ndarray: """Magnetic field gradient term used in mirror force.""" bmag = self.bmag(theta) return self.gradpar() * self.epsilon * jnp.sin(theta) * bmag
[docs] def metric_coeffs(self, theta: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Metric coefficients (gds2, gds21, gds22) for s-alpha geometry.""" shear = self.s_hat * theta - self.alpha * jnp.sin(theta) gds2 = 1.0 + shear * shear gds21 = -self.s_hat * shear gds22 = jnp.asarray(self.s_hat) * jnp.asarray(self.s_hat) return gds2, gds21, gds22
[docs] def k_perp2(self, kx0: jnp.ndarray, ky: jnp.ndarray, theta: jnp.ndarray) -> jnp.ndarray: """Perpendicular wave-number squared for s-alpha geometry.""" gds2, gds21, gds22 = self.metric_coeffs(theta) s_hat = jnp.asarray(self.s_hat) s_hat_safe = jnp.where(s_hat == 0.0, 1.0, s_hat) kx_hat = kx0 / s_hat_safe kx_hat = jnp.where(s_hat == 0.0, kx0, kx_hat) kperp2 = ky * (ky * gds2 + 2.0 * kx_hat * gds21) + (kx_hat * kx_hat) * gds22 if self.kperp2_bmag: bmag_inv = 1.0 / self.bmag(theta) return kperp2 * (bmag_inv * bmag_inv) return kperp2
[docs] def drift_coeffs( self, theta: jnp.ndarray ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Curvature and grad-B drift coefficients for s-alpha geometry.""" shear = self.s_hat * theta - self.alpha * jnp.sin(theta) base = jnp.cos(theta) + shear * jnp.sin(theta) scale = jnp.asarray(self.drift_scale) cv = scale * base / self.R0 gb = cv cv0 = scale * (-self.s_hat * jnp.sin(theta)) / self.R0 gb0 = cv0 return cv, gb, cv0, gb0
[docs] def drift_components( self, kx: jnp.ndarray, ky: jnp.ndarray, theta: jnp.ndarray ) -> tuple[jnp.ndarray, jnp.ndarray]: """Return cv_d and gb_d drift components in (ky, kx, theta).""" kx0 = kx[None, :, None] ky0 = ky[:, None, None] theta0 = theta[None, None, :] cv, gb, cv0, gb0 = self.drift_coeffs(theta0) s_hat = jnp.asarray(self.s_hat) s_hat_safe = jnp.where(s_hat == 0.0, 1.0, s_hat) kx_hat = kx0 / s_hat_safe kx_hat = jnp.where(s_hat == 0.0, kx0, kx_hat) cv_d = ky0 * cv + kx_hat * cv0 gb_d = ky0 * gb + kx_hat * gb0 return cv_d, gb_d
[docs] def omega_d(self, kx: jnp.ndarray, ky: jnp.ndarray, theta: jnp.ndarray) -> jnp.ndarray: """Magnetic drift frequency for s-alpha geometry.""" cv_d, gb_d = self.drift_components(kx, ky, theta) return cv_d + gb_d
[docs] @jax.tree_util.register_pytree_node_class @dataclass(frozen=True) class SlabGeometry: """GX slab geometry contract.""" s_hat: float = 0.0 z0: float | None = None q: float = 1.0 epsilon: float = 0.0 R0: float = 1.0 B0: float = 1.0 alpha: float = 0.0 drift_scale: float = 0.0 kperp2_bmag: bool = True bessel_bmag_power: float = 0.0 zero_shat: bool = False def tree_flatten(self): children = ( self.s_hat, self.q, self.epsilon, self.R0, self.B0, self.alpha, self.drift_scale, self.kperp2_bmag, self.bessel_bmag_power, ) return children, {"z0": self.z0, "zero_shat": self.zero_shat} @classmethod def tree_unflatten(cls, aux_data, children): ( s_hat, q, epsilon, R0, B0, alpha, drift_scale, kperp2_bmag, bessel_bmag_power, ) = children return cls( s_hat=s_hat, z0=aux_data["z0"], q=q, epsilon=epsilon, R0=R0, B0=B0, alpha=alpha, drift_scale=drift_scale, kperp2_bmag=kperp2_bmag, bessel_bmag_power=bessel_bmag_power, zero_shat=aux_data["zero_shat"], ) @staticmethod def from_config(cfg: GeometryConfig) -> "SlabGeometry": zero_shat = zero_shear_enabled(cfg.s_hat, zero_shat=cfg.zero_shat) shat = 0.0 if zero_shat else float(cfg.s_hat) return SlabGeometry( s_hat=shat, z0=cfg.z0, q=1.0, epsilon=0.0, R0=cfg.R0, B0=cfg.B0, alpha=0.0, drift_scale=0.0, kperp2_bmag=cfg.kperp2_bmag, bessel_bmag_power=cfg.bessel_bmag_power, zero_shat=zero_shat, ) def kx_effective(self, kx0: jnp.ndarray, ky: jnp.ndarray, theta: jnp.ndarray) -> jnp.ndarray: shear = jnp.asarray(self.s_hat) * theta return kx0 - shear * ky def bmag(self, theta: jnp.ndarray) -> jnp.ndarray: return jnp.ones_like(jnp.asarray(theta)) def gradpar(self) -> float: if self.z0 is not None and float(self.z0) > 0.0: return float(1.0 / float(self.z0)) return 1.0 def bgrad(self, theta: jnp.ndarray) -> jnp.ndarray: return jnp.zeros_like(jnp.asarray(theta)) def metric_coeffs(self, theta: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: theta_arr = jnp.asarray(theta) shear = jnp.asarray(self.s_hat) * theta_arr gds2 = 1.0 + shear * shear gds21 = -jnp.asarray(self.s_hat) * shear if float(self.s_hat) == 0.0: gds22 = jnp.ones_like(theta_arr) else: gds22 = jnp.full_like(theta_arr, float(self.s_hat) * float(self.s_hat)) return gds2, gds21, gds22 def k_perp2(self, kx0: jnp.ndarray, ky: jnp.ndarray, theta: jnp.ndarray) -> jnp.ndarray: gds2, gds21, gds22 = self.metric_coeffs(theta) s_hat = jnp.asarray(self.s_hat) s_hat_safe = jnp.where(s_hat == 0.0, 1.0, s_hat) kx_hat = kx0 / s_hat_safe kx_hat = jnp.where(s_hat == 0.0, kx0, kx_hat) kperp2 = ky * (ky * gds2 + 2.0 * kx_hat * gds21) + (kx_hat * kx_hat) * gds22 return kperp2 def drift_coeffs( self, theta: jnp.ndarray ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: theta_arr = jnp.asarray(theta) zeros = jnp.zeros_like(theta_arr) return zeros, zeros, zeros, zeros def drift_components( self, kx: jnp.ndarray, ky: jnp.ndarray, theta: jnp.ndarray ) -> tuple[jnp.ndarray, jnp.ndarray]: kx0 = jnp.asarray(kx) ky0 = jnp.asarray(ky) theta0 = jnp.asarray(theta) zeros = jnp.zeros((ky0.shape[0], kx0.shape[0], theta0.shape[0]), dtype=theta0.dtype) return zeros, zeros def omega_d(self, kx: jnp.ndarray, ky: jnp.ndarray, theta: jnp.ndarray) -> jnp.ndarray: kx0 = jnp.asarray(kx) ky0 = jnp.asarray(ky) theta0 = jnp.asarray(theta) return jnp.zeros((ky0.shape[0], kx0.shape[0], theta0.shape[0]), dtype=theta0.dtype)
[docs] @jax.tree_util.register_pytree_node_class @dataclass(frozen=True) class FluxTubeGeometryData: """Sampled flux-tube geometry contract for solver-ready metric profiles.""" theta: jnp.ndarray gradpar_value: float bmag_profile: jnp.ndarray bgrad_profile: jnp.ndarray gds2_profile: jnp.ndarray gds21_profile: jnp.ndarray gds22_profile: jnp.ndarray cv_profile: jnp.ndarray gb_profile: jnp.ndarray cv0_profile: jnp.ndarray gb0_profile: jnp.ndarray jacobian_profile: jnp.ndarray grho_profile: jnp.ndarray q: float s_hat: float epsilon: float R0: float B0: float = 1.0 alpha: float = 0.0 drift_scale: float = 1.0 kxfac: float = 1.0 theta_scale: float = 1.0 nfp: int = 1 kperp2_bmag: bool = True bessel_bmag_power: float = 0.0 source_model: str = "sampled" theta_closed_interval: bool = False def tree_flatten(self): children = ( self.theta, self.gradpar_value, self.bmag_profile, self.bgrad_profile, self.gds2_profile, self.gds21_profile, self.gds22_profile, self.cv_profile, self.gb_profile, self.cv0_profile, self.gb0_profile, self.jacobian_profile, self.grho_profile, self.q, self.s_hat, self.epsilon, self.R0, self.B0, self.alpha, self.drift_scale, self.kxfac, self.theta_scale, self.nfp, self.kperp2_bmag, self.bessel_bmag_power, ) return children, { "source_model": self.source_model, "theta_closed_interval": self.theta_closed_interval, } @classmethod def tree_unflatten(cls, aux_data, children): return cls( *children, source_model=aux_data["source_model"], theta_closed_interval=aux_data["theta_closed_interval"], ) def _theta_matches(self, theta: jnp.ndarray) -> jnp.ndarray: theta_arr = jnp.asarray(theta) if theta_arr.shape[-1] != self.theta.shape[0]: raise ValueError("theta must have the same last dimension as the sampled geometry grid") if isinstance(theta_arr, jax.core.Tracer): return theta_arr theta_line = theta_arr if theta_arr.ndim == 1 else theta_arr.reshape(-1, theta_arr.shape[-1])[0] if isinstance(theta_line, jax.core.Tracer): return theta_arr if not np.allclose(np.asarray(theta_line), np.asarray(self.theta), rtol=1.0e-6, atol=1.0e-6): raise ValueError("theta does not match the sampled geometry grid") return theta_arr def _broadcast_profile(self, theta: jnp.ndarray, profile: jnp.ndarray) -> jnp.ndarray: theta_arr = self._theta_matches(theta) if theta_arr.ndim == 1: return profile shape = (1,) * (theta_arr.ndim - 1) + (profile.shape[0],) return jnp.broadcast_to(profile.reshape(shape), theta_arr.shape)
[docs] def trim_terminal_theta_point(self) -> FluxTubeGeometryData: """Return a copy without the terminal theta sample. GX `*.eik.nc` files commonly store a closed theta interval, while the spectral solver uses the matching open interval with the terminal point excluded. Trimming keeps the imported coefficients aligned with the runtime grid without changing the physical extent. """ if self.theta.shape[0] < 2: raise ValueError("Cannot trim the terminal point from a geometry grid with fewer than two samples") return replace( self, theta=self.theta[:-1], bmag_profile=self.bmag_profile[:-1], bgrad_profile=self.bgrad_profile[:-1], gds2_profile=self.gds2_profile[:-1], gds21_profile=self.gds21_profile[:-1], gds22_profile=self.gds22_profile[:-1], cv_profile=self.cv_profile[:-1], gb_profile=self.gb_profile[:-1], cv0_profile=self.cv0_profile[:-1], gb0_profile=self.gb0_profile[:-1], jacobian_profile=self.jacobian_profile[:-1], grho_profile=self.grho_profile[:-1], theta_closed_interval=False, )
def gradpar(self) -> float: return float(self.gradpar_value) def bmag(self, theta: jnp.ndarray) -> jnp.ndarray: return self._broadcast_profile(theta, self.bmag_profile) def bgrad(self, theta: jnp.ndarray) -> jnp.ndarray: return self._broadcast_profile(theta, self.bgrad_profile) def metric_coeffs(self, theta: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: return ( self._broadcast_profile(theta, self.gds2_profile), self._broadcast_profile(theta, self.gds21_profile), self._broadcast_profile(theta, self.gds22_profile), ) def drift_coeffs( self, theta: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: return ( self._broadcast_profile(theta, self.cv_profile), self._broadcast_profile(theta, self.gb_profile), self._broadcast_profile(theta, self.cv0_profile), self._broadcast_profile(theta, self.gb0_profile), ) def jacobian(self, theta: jnp.ndarray) -> jnp.ndarray: return self._broadcast_profile(theta, self.jacobian_profile) def grho(self, theta: jnp.ndarray) -> jnp.ndarray: return self._broadcast_profile(theta, self.grho_profile) def k_perp2(self, kx0: jnp.ndarray, ky: jnp.ndarray, theta: jnp.ndarray) -> jnp.ndarray: gds2, gds21, gds22 = self.metric_coeffs(theta) s_hat = jnp.asarray(self.s_hat) s_hat_safe = jnp.where(s_hat == 0.0, 1.0, s_hat) kx_hat = kx0 / s_hat_safe kx_hat = jnp.where(s_hat == 0.0, kx0, kx_hat) kperp2 = ky * (ky * gds2 + 2.0 * kx_hat * gds21) + (kx_hat * kx_hat) * gds22 if self.kperp2_bmag: bmag_inv = 1.0 / self.bmag(theta) return kperp2 * (bmag_inv * bmag_inv) return kperp2 def drift_components( self, kx: jnp.ndarray, ky: jnp.ndarray, theta: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray]: kx0 = kx[None, :, None] ky0 = ky[:, None, None] theta0 = theta[None, None, :] cv, gb, cv0, gb0 = self.drift_coeffs(theta0) s_hat = jnp.asarray(self.s_hat) s_hat_safe = jnp.where(s_hat == 0.0, 1.0, s_hat) kx_hat = kx0 / s_hat_safe kx_hat = jnp.where(s_hat == 0.0, kx0, kx_hat) cv_d = ky0 * cv + kx_hat * cv0 gb_d = ky0 * gb + kx_hat * gb0 return cv_d, gb_d def omega_d(self, kx: jnp.ndarray, ky: jnp.ndarray, theta: jnp.ndarray) -> jnp.ndarray: cv_d, gb_d = self.drift_components(kx, ky, theta) return cv_d + gb_d
[docs] def sample_flux_tube_geometry(geom: SAlphaGeometry | SlabGeometry, theta: jnp.ndarray) -> FluxTubeGeometryData: """Sample an analytic geometry model onto a flux-tube theta grid.""" theta_arr = jnp.asarray(theta) gds2, gds21, gds22 = geom.metric_coeffs(theta_arr) gds22_arr = gds22 if gds22.ndim else jnp.full_like(theta_arr, gds22) cv, gb, cv0, gb0 = geom.drift_coeffs(theta_arr) bmag = geom.bmag(theta_arr) gradpar = float(geom.gradpar()) jacobian = 1.0 / (jnp.abs(jnp.asarray(gradpar)) * bmag) return FluxTubeGeometryData( theta=theta_arr, gradpar_value=gradpar, bmag_profile=bmag, bgrad_profile=geom.bgrad(theta_arr), gds2_profile=gds2, gds21_profile=gds21, gds22_profile=gds22_arr, cv_profile=cv, gb_profile=gb, cv0_profile=cv0, gb0_profile=gb0, jacobian_profile=jacobian, grho_profile=jnp.ones_like(theta_arr), q=float(geom.q), s_hat=float(geom.s_hat), epsilon=float(geom.epsilon), R0=float(geom.R0), B0=float(geom.B0), alpha=float(geom.alpha), drift_scale=float(geom.drift_scale), kxfac=1.0, theta_scale=1.0, nfp=1, kperp2_bmag=bool(geom.kperp2_bmag), bessel_bmag_power=float(geom.bessel_bmag_power), source_model="slab" if isinstance(geom, SlabGeometry) else "s-alpha", theta_closed_interval=False, )
def _periodic_spectral_derivative(values: np.ndarray, spacing: float) -> np.ndarray: """Return the periodic spectral derivative of a uniform 1D profile.""" if values.ndim != 1: raise ValueError("Periodic spectral derivatives require a one-dimensional profile") if values.size < 2: return np.zeros_like(values) k = 2.0 * np.pi * np.fft.fftfreq(values.size, d=spacing) deriv_hat = 1j * k * np.fft.fft(values) return np.fft.ifft(deriv_hat).real.astype(values.dtype, copy=False) def _gx_bgrad_from_bmag(theta: np.ndarray, bmag: np.ndarray, gradpar_val: float, *, closed: bool) -> np.ndarray: """Reconstruct GX's mirror term from ``bmag`` on the solver theta grid.""" if theta.ndim != 1 or bmag.ndim != 1: raise ValueError("GX bgrad reconstruction expects one-dimensional theta and bmag profiles") if theta.shape != bmag.shape: raise ValueError("theta and bmag must have the same shape for GX bgrad reconstruction") if theta.size < 2: return np.zeros_like(bmag) if closed: work_theta = theta[:-1] work_bmag = bmag[:-1] else: work_theta = theta work_bmag = bmag spacing = float(work_theta[1] - work_theta[0]) d_bmag = _periodic_spectral_derivative(work_bmag, spacing) bgrad = float(gradpar_val) * d_bmag / np.clip(work_bmag, 1.0e-30, None) if closed: return np.concatenate([bgrad, bgrad[:1]]) return bgrad
[docs] def load_gx_geometry_netcdf(path: str | Path) -> FluxTubeGeometryData: """Load sampled geometry from a GX-style NetCDF file.""" try: from netCDF4 import Dataset except ImportError as exc: # pragma: no cover - optional import raise ImportError("netCDF4 is required to load GX geometry NetCDF files") from exc def _read_scalar(variables, *names: str, default: float | None = None) -> float: for name in names: if name in variables: arr = np.asarray(variables[name][:], dtype=float) if arr.ndim == 0: return float(arr) if arr.ndim == 1 and np.allclose(arr, arr[0]): return float(arr[0]) raise ValueError(f"GX geometry variable '{name}' must be scalar or constant on theta") if default is None: raise KeyError(names[0]) return float(default) def _read_profile(variables, *names: str) -> np.ndarray: for name in names: if name in variables: arr = np.asarray(variables[name][:], dtype=float) if arr.ndim != 1: raise ValueError(f"GX geometry variable '{name}' must be one-dimensional on theta") return arr raise KeyError(names[0]) def _infer_root_theta_closed_interval(theta: np.ndarray, variables) -> bool: """Infer whether a root-level GX ``*.eik.nc`` file includes a terminal theta endpoint. VMEC-style ``*.eik.nc`` files often include the periodic terminal point, while GX's Miller helper writes an already-open theta grid. Root-level files therefore cannot be treated as closed intervals unconditionally. """ if theta.ndim != 1 or theta.size < 2: return False profile_names = ( "bmag", "gds2", "gds21", "gds22", "cvdrift", "gbdrift", "grho", ) matches = 0 checked = 0 for name in profile_names: if name not in variables: continue arr = np.asarray(variables[name][:], dtype=float) if arr.ndim != 1 or arr.size != theta.size: continue checked += 1 scale = max(float(np.nanmax(np.abs(arr))), 1.0) if abs(float(arr[-1] - arr[0])) <= max(1.0e-10, 1.0e-6 * scale): matches += 1 if checked == 0: return False return matches >= max(1, checked // 2 + checked % 2) root = Dataset(Path(path), "r") try: is_grouped_gx_output = "Geometry" in root.groups and "Grids" in root.groups if is_grouped_gx_output: geom_vars = root.groups["Geometry"].variables grid_vars = root.groups["Grids"].variables theta = _read_profile(grid_vars, "theta") theta_closed_interval = False else: geom_vars = root.variables grid_vars = root.variables theta = _read_profile(root.variables, "theta") theta_closed_interval = _infer_root_theta_closed_interval(np.asarray(theta, dtype=float), root.variables) gradpar_val = _read_scalar(geom_vars, "gradpar") bmag = _read_profile(geom_vars, "bmag") drhodpsi = _read_scalar(geom_vars, "drhodpsi", default=1.0) if is_grouped_gx_output and "bgrad" in geom_vars: bgrad = _read_profile(geom_vars, "bgrad") else: bgrad = _gx_bgrad_from_bmag( np.asarray(theta, dtype=float), np.asarray(bmag, dtype=float), gradpar_val, closed=theta_closed_interval, ) if is_grouped_gx_output: cvdrift = _read_profile(geom_vars, "cvdrift") gbdrift = _read_profile(geom_vars, "gbdrift") cvdrift0 = _read_profile(geom_vars, "cvdrift0") gbdrift0 = _read_profile(geom_vars, "gbdrift0") jacobian = _read_profile(geom_vars, "jacobian", "jacob") else: # Root-level VMEC ``*.eik.nc`` files carry the pre-GX drift # normalization and a Jacobian that GX replaces at load time. cvdrift = 0.5 * _read_profile(geom_vars, "cvdrift") gbdrift = 0.5 * _read_profile(geom_vars, "gbdrift") cvdrift0 = 0.5 * _read_profile(geom_vars, "cvdrift0") gbdrift0 = 0.5 * _read_profile(geom_vars, "gbdrift0") jacobian = 1.0 / np.abs(float(drhodpsi) * float(gradpar_val) * np.asarray(bmag, dtype=float)) rmaj = _read_scalar(geom_vars, "rmaj", "Rmaj", default=1.0) aminor = _read_scalar(geom_vars, "aminor", default=0.0) epsilon = aminor / rmaj if abs(rmaj) > 0.0 else 0.0 return FluxTubeGeometryData( theta=jnp.asarray(theta), gradpar_value=gradpar_val, bmag_profile=jnp.asarray(bmag), bgrad_profile=jnp.asarray(bgrad), gds2_profile=jnp.asarray(_read_profile(geom_vars, "gds2")), gds21_profile=jnp.asarray(_read_profile(geom_vars, "gds21")), gds22_profile=jnp.asarray(_read_profile(geom_vars, "gds22")), cv_profile=jnp.asarray(cvdrift), gb_profile=jnp.asarray(gbdrift), cv0_profile=jnp.asarray(cvdrift0), gb0_profile=jnp.asarray(gbdrift0), jacobian_profile=jnp.asarray(jacobian), grho_profile=jnp.asarray(_read_profile(geom_vars, "grho")), q=_read_scalar(geom_vars, "q", default=0.0), s_hat=_read_scalar(geom_vars, "shat", default=0.0), epsilon=float(epsilon), R0=float(rmaj), B0=1.0, alpha=_read_scalar(geom_vars, "alpha", default=0.0), drift_scale=1.0, kxfac=_read_scalar(geom_vars, "kxfac", default=1.0), theta_scale=_read_scalar(geom_vars, "theta_scale", "scale", default=1.0), nfp=int(round(_read_scalar(geom_vars, "nfp", default=1.0))), kperp2_bmag=True, bessel_bmag_power=0.0, source_model="gx-netcdf", theta_closed_interval=theta_closed_interval, ) finally: root.close()
FluxTubeGeometryLike = SAlphaGeometry | SlabGeometry | FluxTubeGeometryData
[docs] def build_flux_tube_geometry(cfg: GeometryConfig) -> FluxTubeGeometryLike: """Build an analytic or imported flux-tube geometry from config.""" model = str(cfg.model).strip().lower().replace("_", "-") if model in {"s-alpha", "salpha", "analytic"}: return SAlphaGeometry.from_config(cfg) if model in {"slab"}: return SlabGeometry.from_config(cfg) if model in {"gx-netcdf", "gx-nc", "netcdf", "nc", "gx-eik", "eik", "vmec-eik", "desc-eik"}: if cfg.geometry_file is None: raise ValueError("geometry.geometry_file must be set for imported NetCDF/eik geometry") return load_gx_geometry_netcdf(cfg.geometry_file) raise ValueError( "geometry.model must be one of " "{'s-alpha', 'slab', 'gx-netcdf', 'gx-eik', 'vmec-eik', 'desc-eik'}" )
[docs] def ensure_flux_tube_geometry_data( geom: FluxTubeGeometryLike, theta: jnp.ndarray, ) -> FluxTubeGeometryData: """Return sampled geometry data for analytic or pre-sampled inputs.""" if isinstance(geom, FluxTubeGeometryData): try: geom._theta_matches(theta) return geom except ValueError as exc: theta_arr = jnp.asarray(theta) if geom.theta.shape[0] == theta_arr.shape[-1] + 1: trimmed = geom.trim_terminal_theta_point() trimmed._theta_matches(theta) return trimmed raise exc return sample_flux_tube_geometry(geom, theta)
[docs] def twist_shift_params( geom: FluxTubeGeometryLike, grid: GridConfig, ) -> tuple[int, float]: """Return `(jtwist, x0)` for twist-and-shift grid defaults.""" y0 = float(grid.y0) if grid.y0 is not None else float(grid.Ly) / (2.0 * jnp.pi) if isinstance(geom, FluxTubeGeometryData): gds21_val = float(np.asarray(geom.gds21_profile[0])) gds22_val = float(np.asarray(geom.gds22_profile[0])) shat = float(geom.s_hat) else: if grid.ntheta is not None: if grid.zp is not None: zp = int(grid.zp) elif grid.nperiod is not None: zp = 2 * int(grid.nperiod) - 1 else: zp = 1 theta_min = -jnp.pi * float(zp) else: theta_min = float(grid.z_min) theta_min_f = float(theta_min) shat = float(geom.s_hat) if isinstance(geom, SAlphaGeometry): shear = shat * theta_min_f - float(geom.alpha) * math.sin(theta_min_f) gds21_val = -shat * shear gds22_val = shat * shat elif isinstance(geom, SlabGeometry): shear = shat * theta_min_f gds21_val = -shat * shear gds22_val = 1.0 if shat == 0.0 else shat * shat else: _gds2, gds21, gds22 = geom.metric_coeffs(np.asarray([theta_min_f], dtype=float)) gds21_val = float(np.asarray(gds21, dtype=float).reshape(-1)[0]) gds22_arr = np.asarray(gds22, dtype=float) gds22_val = float(gds22_arr.reshape(-1)[0]) if gds22_arr.ndim > 0 else float(gds22_arr) twist_shift_geo_fac = 2.0 * shat * gds21_val / gds22_val if gds22_val != 0.0 else 0.0 if grid.jtwist is None: jtwist = int(round(twist_shift_geo_fac)) if jtwist == 0: jtwist = 1 else: jtwist = int(grid.jtwist) if jtwist == 0: jtwist = 1 if twist_shift_geo_fac == 0.0: x0 = y0 else: x0 = y0 * abs(jtwist) / abs(twist_shift_geo_fac) return jtwist, x0
gx_twist_shift_params = twist_shift_params
[docs] def apply_geometry_grid_defaults( geom: FluxTubeGeometryLike, grid: GridConfig, ) -> GridConfig: """Apply GX-aligned grid defaults implied by the selected geometry.""" grid_out = grid if isinstance(geom, FluxTubeGeometryData): theta = np.asarray(geom.theta, dtype=float) if theta.ndim != 1 or theta.size < 2: raise ValueError("Imported GX geometry theta grid must be one-dimensional with at least two points") if geom.theta_closed_interval: nz = int(theta.size - 1) z_min = float(theta[0]) z_max = float(theta[-1]) else: spacing = float(theta[1] - theta[0]) nz = int(theta.size) z_min = float(theta[0]) z_max = float(theta[-1] + spacing) grid_out = replace( grid_out, Nz=nz, z_min=z_min, z_max=z_max, ntheta=None, nperiod=None, zp=None, ) if float(grid_out.kxfac) == 1.0: grid_out = replace(grid_out, kxfac=float(geom.kxfac)) boundary = effective_boundary( str(grid_out.boundary).lower(), s_hat=float(getattr(geom, "s_hat", 0.0)), zero_shat=bool(getattr(geom, "zero_shat", False)), ) if boundary != str(grid_out.boundary).lower(): grid_out = replace(grid_out, boundary=boundary, jtwist=None) if boundary in {"linked", "fix aspect"} and not bool(grid_out.non_twist): jtwist, x0 = twist_shift_params(geom, grid_out) grid_out = replace(grid_out, Lx=2.0 * np.pi * x0, jtwist=jtwist) elif boundary == "periodic" and zero_shear_enabled( float(getattr(geom, "s_hat", 0.0)), zero_shat=bool(getattr(geom, "zero_shat", False)), ): # GX zero-shear promotion switches the lane onto the periodic # grad-parallel operator, so any linked-FFT metadata must be cleared. grid_out = replace(grid_out, jtwist=None) return grid_out
apply_gx_geometry_grid_defaults = apply_geometry_grid_defaults