"""Shared helpers for benchmark entry points.
This module keeps pure policies, reference-data loaders, small result containers,
and initializer builders out of the large benchmark runner module while preserving
``spectraxgk.benchmarks`` as the compatibility import surface.
"""
from __future__ import annotations
from dataclasses import dataclass, replace
from importlib import resources
from typing import Sequence
import warnings
import jax.numpy as jnp
import numpy as np
from spectraxgk.analysis import (
ModeSelection,
extract_mode_time_series,
fit_growth_rate_auto_with_stats,
)
from spectraxgk.config import InitializationConfig, KineticElectronBaseCase
from spectraxgk.geometry import FluxTubeGeometryLike
from spectraxgk.grids import SpectralGrid
from spectraxgk.linear import LinearParams
from spectraxgk.linear_krylov import KrylovConfig
from spectraxgk.normalization import apply_diagnostic_normalization
from spectraxgk.species import Species, build_linear_params
REFERENCE_NU_HYPER_L = 0.0
REFERENCE_NU_HYPER_M = 1.0
REFERENCE_P_HYPER_L = 6.0
REFERENCE_P_HYPER_M = 20.0
REFERENCE_DAMP_ENDS_AMP = 0.1
REFERENCE_DAMP_ENDS_WIDTHFRAC = 1.0 / 8.0
KBM_GX_SOLVER_LOCK: tuple[tuple[float, str], ...] = (
(0.10, "gx_time"),
(0.30, "gx_time"),
(0.40, "gx_time"),
)
KBM_GX_SOLVER_LOCK_TOL = 0.03
def _gx_p_hyper_m(nhermite: int | None) -> float:
if nhermite is None:
return REFERENCE_P_HYPER_M
return float(min(REFERENCE_P_HYPER_M, max(int(nhermite) // 2, 1)))
def _apply_gx_hypercollisions(
params: LinearParams, *, nhermite: int | None = None
) -> LinearParams:
return replace(
params,
nu_hyper=0.0,
nu_hyper_l=REFERENCE_NU_HYPER_L,
nu_hyper_m=REFERENCE_NU_HYPER_M,
p_hyper_l=REFERENCE_P_HYPER_L,
p_hyper_m=_gx_p_hyper_m(nhermite),
hypercollisions_const=0.0,
hypercollisions_kz=1.0,
)
def _gx_linked_end_damping(gx_reference: bool) -> tuple[float, float]:
if gx_reference:
return REFERENCE_DAMP_ENDS_AMP, REFERENCE_DAMP_ENDS_WIDTHFRAC
return 0.0, 0.0
def _midplane_index(grid: SpectralGrid) -> int:
"""Return GX-style midplane index for growth-rate diagnostics."""
if grid.z.size <= 1:
return 0
idx = int(grid.z.size // 2 + 1)
return min(idx, int(grid.z.size) - 1)
[docs]
def select_kbm_solver_auto(solver: str, *, ky_target: float, gx_reference: bool) -> str:
"""Return deterministic KBM solver choice for auto mode."""
solver_key = solver.strip().lower()
if solver_key != "auto":
return solver_key
if not gx_reference:
return "time"
ky_abs = abs(float(ky_target))
for ky_ref, solver_ref in KBM_GX_SOLVER_LOCK:
if abs(ky_abs - ky_ref) <= KBM_GX_SOLVER_LOCK_TOL:
return solver_ref
return "gx_time"
def _select_fit_signal(
phi_t: np.ndarray,
density_t: np.ndarray | None,
sel: ModeSelection,
*,
fit_signal: str,
mode_method: str,
fallback: bool = True,
) -> np.ndarray:
def _extract(arr: np.ndarray) -> np.ndarray:
return extract_mode_time_series(arr, sel, method=mode_method)
def _is_valid(arr: np.ndarray) -> bool:
finite = np.isfinite(arr)
return int(np.count_nonzero(finite)) >= 2
if fit_signal == "phi":
signal = _extract(phi_t)
if fallback and not _is_valid(signal) and density_t is not None:
alt = _extract(density_t)
if _is_valid(alt):
return alt
if not _is_valid(signal):
warnings.warn(
"Fit signal has insufficient finite samples; falling back to zeros.",
RuntimeWarning,
)
return np.zeros(phi_t.shape[0], dtype=np.complex128)
return signal
if fit_signal == "density":
if density_t is None:
raise ValueError("density_t must be provided when fit_signal='density'")
signal = _extract(density_t)
if fallback and not _is_valid(signal):
alt = _extract(phi_t)
if _is_valid(alt):
return alt
if not _is_valid(signal):
warnings.warn(
"Fit signal has insufficient finite samples; falling back to zeros.",
RuntimeWarning,
)
return np.zeros(phi_t.shape[0], dtype=np.complex128)
return signal
raise ValueError("fit_signal must be 'phi' or 'density'")
def _score_fit_signal_auto(
t: np.ndarray,
signal: np.ndarray,
*,
tmin: float | None,
tmax: float | None,
window_fraction: float,
min_points: int,
start_fraction: float,
growth_weight: float,
require_positive: bool,
min_amp_fraction: float,
max_amp_fraction: float,
window_method: str,
max_fraction: float,
end_fraction: float,
num_windows: int,
phase_weight: float,
length_weight: float,
min_r2: float,
late_penalty: float,
min_slope: float | None,
min_slope_frac: float,
slope_var_weight: float,
) -> tuple[float, float, float]:
"""Score a candidate fit signal using auto-window stats."""
try:
gamma, omega, _tmin, _tmax, r2_log, r2_phase = fit_growth_rate_auto_with_stats(
t,
signal,
tmin=tmin,
tmax=tmax,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
max_amp_fraction=max_amp_fraction,
window_method=window_method,
max_fraction=max_fraction,
end_fraction=end_fraction,
num_windows=num_windows,
phase_weight=phase_weight,
length_weight=length_weight,
min_r2=min_r2,
late_penalty=late_penalty,
min_slope=min_slope,
min_slope_frac=min_slope_frac,
slope_var_weight=slope_var_weight,
)
except ValueError:
return 0.0, 0.0, -np.inf
if not np.isfinite(gamma) or not np.isfinite(omega):
return gamma, omega, -np.inf
if require_positive and gamma <= 0.0:
return gamma, omega, -np.inf
if r2_log < min_r2:
return gamma, omega, -np.inf
score = float(r2_log + phase_weight * r2_phase + growth_weight * gamma)
return gamma, omega, score
def _select_fit_signal_auto(
t: np.ndarray,
phi_t: np.ndarray,
density_t: np.ndarray | None,
sel: ModeSelection,
*,
mode_method: str,
tmin: float | None,
tmax: float | None,
window_fraction: float,
min_points: int,
start_fraction: float,
growth_weight: float,
require_positive: bool,
min_amp_fraction: float,
max_amp_fraction: float,
window_method: str,
max_fraction: float,
end_fraction: float,
num_windows: int,
phase_weight: float,
length_weight: float,
min_r2: float,
late_penalty: float,
min_slope: float | None,
min_slope_frac: float,
slope_var_weight: float,
) -> tuple[np.ndarray, str, float, float]:
"""Choose between phi/density signals based on fit quality."""
phi_signal = extract_mode_time_series(phi_t, sel, method=mode_method)
gamma_phi, omega_phi, score_phi = _score_fit_signal_auto(
t,
phi_signal,
tmin=tmin,
tmax=tmax,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
max_amp_fraction=max_amp_fraction,
window_method=window_method,
max_fraction=max_fraction,
end_fraction=end_fraction,
num_windows=num_windows,
phase_weight=phase_weight,
length_weight=length_weight,
min_r2=min_r2,
late_penalty=late_penalty,
min_slope=min_slope,
min_slope_frac=min_slope_frac,
slope_var_weight=slope_var_weight,
)
best_signal = phi_signal
best_name = "phi"
best_gamma = gamma_phi
best_omega = omega_phi
best_score = score_phi
if density_t is not None:
density_signal = extract_mode_time_series(density_t, sel, method=mode_method)
gamma_den, omega_den, score_den = _score_fit_signal_auto(
t,
density_signal,
tmin=tmin,
tmax=tmax,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
max_amp_fraction=max_amp_fraction,
window_method=window_method,
max_fraction=max_fraction,
end_fraction=end_fraction,
num_windows=num_windows,
phase_weight=phase_weight,
length_weight=length_weight,
min_r2=min_r2,
late_penalty=late_penalty,
min_slope=min_slope,
min_slope_frac=min_slope_frac,
slope_var_weight=slope_var_weight,
)
if score_den > best_score:
best_signal = density_signal
best_name = "density"
best_gamma = gamma_den
best_omega = omega_den
best_score = score_den
return best_signal, best_name, float(best_gamma), float(best_omega)
def _extract_mode_only_signal(
source: np.ndarray,
*,
local_idx: int,
species_index: int | None = None,
) -> np.ndarray:
"""Extract a 1D time trace from reduced mode-only outputs."""
arr = np.asarray(source)
if arr.ndim == 0:
return np.asarray([arr], dtype=np.complex128)
if arr.ndim == 1:
return arr
# Some save modes return (t, species, ky). Select requested species first.
if species_index is not None and arr.ndim >= 3 and arr.shape[1] > 0:
idx = min(max(int(species_index), 0), arr.shape[1] - 1)
arr = arr[:, idx, ...]
if arr.ndim == 2:
idx = min(max(int(local_idx), 0), arr.shape[1] - 1)
return arr[:, idx]
# Final fallback: flatten non-time axes and select one column.
arr2 = arr.reshape(arr.shape[0], -1)
idx = min(max(int(local_idx), 0), arr2.shape[1] - 1)
return arr2[:, idx]
def _is_array_like(value) -> bool:
return isinstance(value, (list, tuple, np.ndarray))
def _iter_ky_batches(
ky_values: np.ndarray,
*,
ky_batch: int,
fixed_batch_shape: bool,
):
"""Yield ky batches with optional edge padding for fixed-shape compilation."""
n = int(len(ky_values))
if ky_batch <= 1:
for idx in range(n):
ky = float(ky_values[idx])
yield idx, np.asarray([ky], dtype=float), 1
return
for start in range(0, n, ky_batch):
raw = np.asarray(ky_values[start : start + ky_batch], dtype=float)
valid = int(raw.size)
if valid == 0:
continue
if fixed_batch_shape and valid < ky_batch:
pad = np.full((ky_batch - valid,), raw[-1], dtype=float)
batch = np.concatenate([raw, pad], axis=0)
else:
batch = raw
yield start, batch, valid
def _resolve_streaming_window(
t_total: float,
tmin: float | None,
tmax: float | None,
start_fraction: float,
window_fraction: float,
end_fraction: float,
) -> tuple[float, float]:
if tmin is not None and tmax is not None:
return float(tmin), float(tmax)
t_start = float(start_fraction) * t_total
t_end = float(end_fraction) * t_total
t_end = min(t_end, t_start + float(window_fraction) * t_total)
if t_end <= t_start:
t_end = t_total
return t_start, t_end
def _normalize_growth_rate(
gamma: float,
omega: float,
params: LinearParams,
diagnostic_norm: str,
) -> tuple[float, float]:
return apply_diagnostic_normalization(
gamma,
omega,
rho_star=float(np.asarray(params.rho_star)),
diagnostic_norm=diagnostic_norm,
)
def _build_gaussian_profile(
z: np.ndarray,
*,
kx: float,
ky: float,
s_hat: float,
init_cfg: InitializationConfig,
) -> np.ndarray:
if ky == 0.0:
return np.zeros_like(z)
theta0 = kx / (s_hat * ky)
envelope = (
init_cfg.gaussian_envelope_constant
+ init_cfg.gaussian_envelope_sine * np.sin(z - theta0)
)
width = init_cfg.gaussian_width
if width <= 0.0:
raise ValueError("gaussian_width must be > 0")
return envelope * np.exp(-(((z - theta0) / width) ** 2))
def _build_initial_condition(
grid: SpectralGrid,
geom: FluxTubeGeometryLike,
*,
ky_index: int | Sequence[int] | np.ndarray,
kx_index: int,
Nl: int,
Nm: int,
init_cfg: InitializationConfig,
) -> jnp.ndarray:
init_field = init_cfg.init_field.lower()
field_map = {
"density": (0, 0),
"upar": (0, 1),
"tpar": (0, 2),
"tperp": (1, 0),
"qpar": (0, 3),
"qperp": (1, 1),
}
# GX scales some moments when init_field="all" (see moments.cu).
all_scales = {
"density": 1.0,
"upar": 1.0,
"tpar": 1.0 / np.sqrt(2.0),
"tperp": 1.0,
"qpar": 1.0 / np.sqrt(6.0),
"qperp": 1.0,
}
if init_field != "all" and init_field not in field_map:
raise ValueError(
"init_field must be one of {'density','upar','tpar','tperp','qpar','qperp','all'}"
)
G0 = np.zeros((Nl, Nm, grid.ky.size, grid.kx.size, grid.z.size), dtype=np.complex64)
amp = float(init_cfg.init_amp)
ky_idx = np.atleast_1d(np.asarray(ky_index, dtype=int))
for ky_i in ky_idx:
if init_cfg.gaussian_init:
profile = _build_gaussian_profile(
np.asarray(grid.z),
kx=float(grid.kx[kx_index]),
ky=float(grid.ky[ky_i]),
s_hat=geom.s_hat,
init_cfg=init_cfg,
)
init_vals = amp * profile * (1.0 + 1.0j)
else:
init_vals = amp * (1.0 + 1.0j) * np.ones_like(grid.z)
if grid.ky[ky_i] != 0.0:
if init_field == "all":
for field_name, (l_idx, m_idx) in field_map.items():
if l_idx < Nl and m_idx < Nm:
scale = all_scales.get(field_name, 1.0)
G0[l_idx, m_idx, ky_i, kx_index, :] = init_vals * scale
else:
l_idx, m_idx = field_map[init_field]
if l_idx >= Nl or m_idx >= Nm:
raise ValueError("init_field moment exceeds (Nl, Nm) resolution")
G0[l_idx, m_idx, ky_i, kx_index, :] = init_vals
return jnp.asarray(G0)
def _kinetic_reference_init_cfg(
init_cfg: InitializationConfig, *, gx_reference: bool
) -> InitializationConfig:
"""Restore the historical kinetic benchmark seed on the GX-reference path.
Older kinetic parity runs seeded a constant electron-density moment rather than
the newer tiny Gaussian default. Preserve explicit user overrides by only
replacing the exact current kinetic default init.
"""
if not gx_reference:
return init_cfg
kinetic_default_init = KineticElectronBaseCase().init
if init_cfg != kinetic_default_init:
return init_cfg
return InitializationConfig(
init_field="density",
init_amp=1.0e-3,
init_single=True,
random_seed=kinetic_default_init.random_seed,
gaussian_init=False,
gaussian_width=kinetic_default_init.gaussian_width,
gaussian_envelope_constant=kinetic_default_init.gaussian_envelope_constant,
gaussian_envelope_sine=kinetic_default_init.gaussian_envelope_sine,
kpar_init=kinetic_default_init.kpar_init,
init_file=kinetic_default_init.init_file,
init_file_scale=kinetic_default_init.init_file_scale,
init_file_mode=kinetic_default_init.init_file_mode,
init_electrons_only=kinetic_default_init.init_electrons_only,
)
def _kbm_use_multi_target_krylov(
kcfg: KrylovConfig,
targets: Sequence[float] | None,
*,
shift: complex | None,
) -> bool:
"""Return whether KBM benchmark helpers should sweep target factors."""
if targets is None:
return False
if kcfg.mode_family.strip().lower() != "kbm":
return False
if kcfg.method.strip().lower() != "shift_invert":
return False
if shift is not None:
return False
if kcfg.shift_selection.strip().lower() == "shift":
return False
return True
[docs]
@dataclass(frozen=True)
class CycloneReference:
ky: np.ndarray
omega: np.ndarray
gamma: np.ndarray
[docs]
@dataclass(frozen=True)
class CycloneRunResult:
t: np.ndarray
phi_t: np.ndarray
gamma: float
omega: float
ky: float
selection: ModeSelection
[docs]
@dataclass(frozen=True)
class CycloneScanResult:
ky: np.ndarray
gamma: np.ndarray
omega: np.ndarray
[docs]
@dataclass(frozen=True)
class CycloneComparison:
ky: float
gamma: float
omega: float
gamma_ref: float
omega_ref: float
rel_gamma: float
rel_omega: float
[docs]
@dataclass(frozen=True)
class LinearRunResult:
t: np.ndarray
phi_t: np.ndarray
gamma: float
omega: float
ky: float
selection: ModeSelection
gamma_t: np.ndarray | None = None
omega_t: np.ndarray | None = None
[docs]
@dataclass(frozen=True)
class LinearScanResult:
ky: np.ndarray
gamma: np.ndarray
omega: np.ndarray
[docs]
def load_cyclone_reference() -> CycloneReference:
"""Load Cyclone base case reference data (adiabatic electrons)."""
data_path = resources.files("spectraxgk").joinpath(
"data", "cyclone_reference_adiabatic.csv"
)
arr = np.loadtxt(str(data_path), delimiter=",", skiprows=1)
ky = arr[:, 0]
omega = arr[:, 1]
gamma = arr[:, 2]
return CycloneReference(ky=ky, omega=omega, gamma=gamma)
def _load_reference_with_header(filename: str) -> CycloneReference:
"""Load reference CSVs with columns ky,gamma,omega."""
data_path = resources.files("spectraxgk").joinpath("data", filename)
arr = np.genfromtxt(str(data_path), delimiter=",", names=True, dtype=float)
ky = np.atleast_1d(np.asarray(arr["ky"], dtype=float))
gamma = np.atleast_1d(np.asarray(arr["gamma"], dtype=float))
omega = np.atleast_1d(np.asarray(arr["omega"], dtype=float))
return CycloneReference(ky=ky, omega=omega, gamma=gamma)
[docs]
def load_cyclone_reference_kinetic() -> CycloneReference:
"""Load Cyclone base case reference data (kinetic electrons)."""
data_path = resources.files("spectraxgk").joinpath(
"data", "cyclone_reference_kinetic.csv"
)
arr = np.loadtxt(str(data_path), delimiter=",", skiprows=1)
ky = arr[:, 0]
omega = arr[:, 1]
gamma = arr[:, 2]
return CycloneReference(ky=ky, omega=omega, gamma=gamma)
[docs]
def load_kbm_reference() -> CycloneReference:
"""Load KBM reference data (finite beta, kinetic electrons)."""
data_path = resources.files("spectraxgk").joinpath("data", "kbm_reference.csv")
arr = np.loadtxt(str(data_path), delimiter=",", skiprows=1)
ky = arr[:, 0]
omega = arr[:, 1]
gamma = arr[:, 2]
return CycloneReference(ky=ky, omega=omega, gamma=gamma)
[docs]
def load_etg_reference() -> CycloneReference:
"""Load GX-backed ETG reference data for the tracked two-species ETG lane."""
data_path = resources.files("spectraxgk").joinpath("data", "etg_reference.csv")
arr = np.loadtxt(str(data_path), delimiter=",", skiprows=1)
ky = arr[:, 0]
omega = arr[:, 1]
gamma = arr[:, 2]
return CycloneReference(ky=ky, omega=omega, gamma=gamma)
[docs]
def load_tem_reference() -> CycloneReference:
"""Load the provisional TEM reference digitized from the literature.
This lane is not backed by a GX reference dump. It remains an extended
stress case while the literature case definition is being reconstructed.
"""
data_path = resources.files("spectraxgk").joinpath("data", "tem_reference.csv")
arr = np.loadtxt(str(data_path), delimiter=",", skiprows=1)
ky = arr[:, 0]
omega = arr[:, 1]
gamma = arr[:, 2]
return CycloneReference(ky=ky, omega=omega, gamma=gamma)
[docs]
def compare_cyclone_to_reference(
result: CycloneRunResult, reference: CycloneReference
) -> CycloneComparison:
"""Compare a Cyclone run result against the reference data set."""
idx = int(np.argmin(np.abs(reference.ky - result.ky)))
gamma_ref = float(reference.gamma[idx])
omega_ref = float(reference.omega[idx])
rel_gamma = (result.gamma - gamma_ref) / gamma_ref if gamma_ref != 0.0 else np.nan
rel_omega = (result.omega - omega_ref) / omega_ref if omega_ref != 0.0 else np.nan
return CycloneComparison(
ky=float(reference.ky[idx]),
gamma=result.gamma,
omega=result.omega,
gamma_ref=gamma_ref,
omega_ref=omega_ref,
rel_gamma=rel_gamma,
rel_omega=rel_omega,
)
def _two_species_params(
model,
*,
kpar_scale: float,
omega_d_scale: float,
omega_star_scale: float,
rho_star: float,
beta_override: float | None = None,
fapar_override: float | None = None,
apar_beta_scale: float | None = None,
ampere_g0_scale: float | None = None,
bpar_beta_scale: float | None = None,
damp_ends_amp: float | None = None,
damp_ends_widthfrac: float | None = None,
nhermite: int | None = None,
) -> LinearParams:
"""Build LinearParams for a two-species kinetic model (ions + electrons)."""
mass_ratio = float(model.mass_ratio)
if mass_ratio <= 0.0:
raise ValueError("mass_ratio must be > 0")
Te_over_Ti = float(model.Te_over_Ti)
if Te_over_Ti <= 0.0:
raise ValueError("Te_over_Ti must be > 0")
ion_fprim_raw = getattr(model, "R_over_Lni", None)
ele_fprim_raw = getattr(model, "R_over_Lne", None)
ion_fprim = (
float(model.R_over_Ln) if ion_fprim_raw is None else float(ion_fprim_raw)
)
ele_fprim = (
float(model.R_over_Ln) if ele_fprim_raw is None else float(ele_fprim_raw)
)
nu_i = float(getattr(model, "nu_i", 0.0))
nu_e = float(getattr(model, "nu_e", 0.0))
beta = float(getattr(model, "beta", 1.0e-5))
if beta_override is not None:
beta = float(beta_override)
ion = Species(
charge=1.0,
mass=1.0,
density=1.0,
temperature=1.0,
tprim=float(getattr(model, "R_over_LTi", model.R_over_LTe)),
fprim=ion_fprim,
nu=nu_i,
)
electron = Species(
charge=-1.0,
mass=1.0 / mass_ratio,
density=1.0,
temperature=Te_over_Ti,
tprim=float(model.R_over_LTe),
fprim=ele_fprim,
nu=nu_e,
)
params = build_linear_params(
[ion, electron],
tau_e=0.0,
kpar_scale=kpar_scale,
omega_d_scale=omega_d_scale,
omega_star_scale=omega_star_scale,
rho_star=rho_star,
beta=beta,
fapar=1.0 if beta > 0.0 else 0.0,
apar_beta_scale=0.5 if apar_beta_scale is None else float(apar_beta_scale),
ampere_g0_scale=0.5 if ampere_g0_scale is None else float(ampere_g0_scale),
bpar_beta_scale=0.5 if bpar_beta_scale is None else float(bpar_beta_scale),
)
params = _apply_gx_hypercollisions(params, nhermite=nhermite)
if fapar_override is not None:
params = replace(params, fapar=float(fapar_override))
if damp_ends_amp is not None:
params = replace(params, damp_ends_amp=float(damp_ends_amp))
if damp_ends_widthfrac is not None:
params = replace(params, damp_ends_widthfrac=float(damp_ends_widthfrac))
return params
def _electron_only_params(
model,
*,
kpar_scale: float,
omega_d_scale: float,
omega_star_scale: float,
rho_star: float,
beta_override: float | None = None,
fapar_override: float | None = None,
apar_beta_scale: float | None = None,
ampere_g0_scale: float | None = None,
bpar_beta_scale: float | None = None,
damp_ends_amp: float | None = None,
damp_ends_widthfrac: float | None = None,
nhermite: int | None = None,
) -> LinearParams:
"""Build LinearParams for a single kinetic electron species + Boltzmann ions."""
mass_ratio = float(model.mass_ratio)
if mass_ratio <= 0.0:
raise ValueError("mass_ratio must be > 0")
Te_over_Ti = float(model.Te_over_Ti)
if Te_over_Ti <= 0.0:
raise ValueError("Te_over_Ti must be > 0")
nu_e = float(getattr(model, "nu_e", 0.0))
beta = float(getattr(model, "beta", 1.0e-5))
if beta_override is not None:
beta = float(beta_override)
electron = Species(
charge=-1.0,
mass=1.0 / mass_ratio,
density=1.0,
temperature=Te_over_Ti,
tprim=float(model.R_over_LTe),
fprim=float(model.R_over_Ln),
nu=nu_e,
)
params = build_linear_params(
[electron],
tau_e=Te_over_Ti,
kpar_scale=kpar_scale,
omega_d_scale=omega_d_scale,
omega_star_scale=omega_star_scale,
rho_star=rho_star,
beta=beta,
fapar=1.0 if beta > 0.0 else 0.0,
apar_beta_scale=0.5 if apar_beta_scale is None else float(apar_beta_scale),
ampere_g0_scale=0.5 if ampere_g0_scale is None else float(ampere_g0_scale),
bpar_beta_scale=0.5 if bpar_beta_scale is None else float(bpar_beta_scale),
)
params = _apply_gx_hypercollisions(params, nhermite=nhermite)
if fapar_override is not None:
params = replace(params, fapar=float(fapar_override))
if damp_ends_amp is not None:
params = replace(params, damp_ends_amp=float(damp_ends_amp))
if damp_ends_widthfrac is not None:
params = replace(params, damp_ends_widthfrac=float(damp_ends_widthfrac))
return params