"""Benchmark utilities for linear Cyclone base case comparisons."""
from __future__ import annotations
from dataclasses import replace
from typing import Callable, Sequence
import numpy as np
import jax.numpy as jnp
from spectraxgk.analysis import (
ModeSelection,
ModeSelectionBatch,
extract_mode_time_series,
fit_growth_rate,
fit_growth_rate_auto,
gx_growth_rate_from_omega_series,
gx_growth_rate_from_phi,
select_ky_index,
)
from spectraxgk.config import (
CycloneBaseCase,
ETGBaseCase,
InitializationConfig,
KineticElectronBaseCase,
KBMBaseCase,
TEMBaseCase,
TimeConfig,
resolve_cfl_fac,
)
from spectraxgk.geometry import (
SAlphaGeometry,
apply_geometry_grid_defaults,
build_flux_tube_geometry,
)
from spectraxgk.grids import build_spectral_grid, select_ky_grid
from spectraxgk.diffrax_integrators import (
integrate_linear_diffrax,
integrate_linear_diffrax_streaming,
)
from spectraxgk.gx_integrators import (
ExplicitTimeConfig,
integrate_linear_gx,
integrate_linear_gx_diagnostics,
)
from spectraxgk.benchmark_defaults import (
CYCLONE_KRYLOV_DEFAULT,
CYCLONE_OMEGA_D_SCALE,
CYCLONE_OMEGA_STAR_SCALE,
CYCLONE_RHO_STAR,
ETG_KRYLOV_DEFAULT,
ETG_OMEGA_D_SCALE,
ETG_OMEGA_STAR_SCALE,
ETG_RHO_STAR,
KBM_KRYLOV_DEFAULT,
KBM_OMEGA_D_SCALE,
KBM_OMEGA_STAR_SCALE,
KBM_RHO_STAR,
KINETIC_KRYLOV_DEFAULT,
KINETIC_KRYLOV_GX_REFERENCE,
Kinetic_OMEGA_D_SCALE,
Kinetic_OMEGA_STAR_SCALE,
Kinetic_RHO_STAR,
TEM_KRYLOV_DEFAULT,
TEM_OMEGA_D_SCALE,
TEM_OMEGA_STAR_SCALE,
TEM_RHO_STAR,
)
from spectraxgk.linear import (
LinearParams,
LinearTerms,
build_linear_cache,
integrate_linear,
integrate_linear_diagnostics,
linear_terms_to_term_config,
)
from spectraxgk.linear_krylov import KrylovConfig, dominant_eigenpair
from spectraxgk.runners import integrate_linear_from_config
from spectraxgk.terms.assembly import compute_fields_cached
from spectraxgk.benchmark_helpers import (
KBM_GX_SOLVER_LOCK,
KBM_GX_SOLVER_LOCK_TOL,
REFERENCE_DAMP_ENDS_AMP,
REFERENCE_DAMP_ENDS_WIDTHFRAC,
REFERENCE_NU_HYPER_L,
REFERENCE_NU_HYPER_M,
REFERENCE_P_HYPER_L,
REFERENCE_P_HYPER_M,
CycloneComparison,
CycloneReference,
CycloneRunResult,
CycloneScanResult,
LinearRunResult,
LinearScanResult,
_apply_gx_hypercollisions,
_build_gaussian_profile,
_build_initial_condition,
_electron_only_params,
_extract_mode_only_signal,
_gx_linked_end_damping,
_gx_p_hyper_m,
_is_array_like,
_iter_ky_batches,
_kbm_use_multi_target_krylov,
_kinetic_reference_init_cfg,
_load_reference_with_header,
_midplane_index,
_normalize_growth_rate,
_resolve_streaming_window,
_score_fit_signal_auto,
_select_fit_signal,
_select_fit_signal_auto,
_two_species_params,
compare_cyclone_to_reference,
load_cyclone_reference,
load_cyclone_reference_kinetic,
load_etg_reference,
load_kbm_reference,
load_tem_reference,
select_kbm_solver_auto,
)
from spectraxgk.benchmark_scan import (
ScanFitWindowPolicy,
apply_auto_fit_scan_policy,
indexed_float_value,
indexed_scan_value,
normalize_fit_signal,
normalize_solver_key,
resolve_scan_mode_method,
should_use_ky_batch,
scan_window_valid,
)
__all__ = [
"CYCLONE_KRYLOV_DEFAULT",
"CYCLONE_OMEGA_D_SCALE",
"CYCLONE_OMEGA_STAR_SCALE",
"CYCLONE_RHO_STAR",
"ETG_KRYLOV_DEFAULT",
"ETG_OMEGA_D_SCALE",
"ETG_OMEGA_STAR_SCALE",
"ETG_RHO_STAR",
"KBM_KRYLOV_DEFAULT",
"KBM_GX_SOLVER_LOCK",
"KBM_GX_SOLVER_LOCK_TOL",
"KBM_OMEGA_D_SCALE",
"KBM_OMEGA_STAR_SCALE",
"KBM_RHO_STAR",
"KINETIC_KRYLOV_DEFAULT",
"KINETIC_KRYLOV_GX_REFERENCE",
"Kinetic_OMEGA_D_SCALE",
"Kinetic_OMEGA_STAR_SCALE",
"Kinetic_RHO_STAR",
"REFERENCE_DAMP_ENDS_AMP",
"REFERENCE_DAMP_ENDS_WIDTHFRAC",
"REFERENCE_NU_HYPER_L",
"REFERENCE_NU_HYPER_M",
"REFERENCE_P_HYPER_L",
"REFERENCE_P_HYPER_M",
"TEM_KRYLOV_DEFAULT",
"TEM_OMEGA_D_SCALE",
"TEM_OMEGA_STAR_SCALE",
"TEM_RHO_STAR",
"CycloneComparison",
"CycloneReference",
"CycloneRunResult",
"CycloneScanResult",
"LinearRunResult",
"LinearScanResult",
"_apply_gx_hypercollisions",
"_build_gaussian_profile",
"_build_initial_condition",
"_electron_only_params",
"_extract_mode_only_signal",
"_gx_linked_end_damping",
"_gx_p_hyper_m",
"_is_array_like",
"_iter_ky_batches",
"_kbm_use_multi_target_krylov",
"_kinetic_reference_init_cfg",
"_load_reference_with_header",
"_midplane_index",
"_normalize_growth_rate",
"_resolve_streaming_window",
"_score_fit_signal_auto",
"_select_fit_signal",
"_select_fit_signal_auto",
"_two_species_params",
"compare_cyclone_to_reference",
"load_cyclone_reference",
"load_cyclone_reference_kinetic",
"load_etg_reference",
"load_kbm_reference",
"load_tem_reference",
"run_cyclone_linear",
"run_cyclone_scan",
"run_etg_linear",
"run_etg_scan",
"run_kbm_beta_scan",
"run_kbm_linear",
"run_kbm_scan",
"run_kinetic_linear",
"run_kinetic_scan",
"run_tem_linear",
"run_tem_scan",
"select_kbm_solver_auto",
]
[docs]
def run_cyclone_linear(
ky_target: float = 0.3,
Nl: int = 6,
Nm: int = 12,
dt: float = 0.01,
steps: int = 800,
method: str = "rk4",
params: LinearParams | None = None,
cfg: CycloneBaseCase | None = None,
time_cfg: TimeConfig | None = None,
solver: str = "auto",
krylov_cfg: KrylovConfig | None = None,
tmin: float | None = None,
tmax: float | None = None,
auto_window: bool = True,
window_fraction: float = 0.4,
min_points: int = 40,
start_fraction: float = 0.2,
growth_weight: float = 1.0,
require_positive: bool = True,
min_amp_fraction: float = 0.0,
max_fraction: float = 0.8,
end_fraction: float = 1.0,
max_amp_fraction: float = 1.0,
phase_weight: float = 0.2,
length_weight: float = 0.05,
min_r2: float = 0.0,
late_penalty: float = 0.0,
min_slope: float | None = None,
min_slope_frac: float = 0.0,
slope_var_weight: float = 0.0,
window_method: str = "loglinear",
mode_method: str = "project",
terms: LinearTerms | None = None,
sample_stride: int | None = None,
fit_signal: str = "auto",
init_cfg: InitializationConfig | None = None,
diagnostic_norm: str = "none",
use_jit: bool = True,
gx_reference: bool | None = None,
show_progress: bool = False,
status_callback: Callable[[str], None] | None = None,
) -> CycloneRunResult:
"""Run the linear Cyclone benchmark and extract growth rate."""
def _status(message: str) -> None:
if status_callback is not None:
status_callback(message)
cfg = cfg or CycloneBaseCase()
init_cfg = init_cfg or getattr(cfg, "init", None) or InitializationConfig()
_status("building spectral grid")
grid_full = build_spectral_grid(cfg.grid)
gx_reference_use = (
bool(cfg.gx_reference) if gx_reference is None else bool(gx_reference)
)
geom_cfg = cfg.geometry
if gx_reference_use:
geom_cfg = replace(geom_cfg, drift_scale=1.0)
if diagnostic_norm == "none":
diagnostic_norm = "gx"
if mode_method not in {"z_index", "max"}:
mode_method = "z_index"
_status("building s-alpha geometry")
geom = SAlphaGeometry.from_config(geom_cfg)
if params is None:
_status("building Cyclone linear parameters")
params = LinearParams(
R_over_Ln=cfg.model.R_over_Ln,
R_over_LTi=cfg.model.R_over_LTi,
R_over_LTe=cfg.model.R_over_LTe,
omega_d_scale=CYCLONE_OMEGA_D_SCALE,
omega_star_scale=CYCLONE_OMEGA_STAR_SCALE,
rho_star=CYCLONE_RHO_STAR,
kpar_scale=float(geom.gradpar()),
nu=cfg.model.nu_i,
damp_ends_amp=REFERENCE_DAMP_ENDS_AMP,
damp_ends_widthfrac=REFERENCE_DAMP_ENDS_WIDTHFRAC,
)
params = _apply_gx_hypercollisions(params, nhermite=Nm)
if terms is None:
if getattr(cfg.model, "adiabatic_ions", False):
terms = LinearTerms(bpar=0.0)
else:
terms = LinearTerms()
solver_key = solver.strip().lower()
fit_key = fit_signal.strip().lower()
if fit_key not in {"phi", "density", "auto"}:
raise ValueError("fit_signal must be 'phi', 'density', or 'auto'")
need_density = fit_key in {"density", "auto"}
ky_index = select_ky_index(np.asarray(grid_full.ky), ky_target)
grid = select_ky_grid(grid_full, ky_index)
sel = ModeSelection(ky_index=0, kx_index=0, z_index=_midplane_index(grid))
_status(f"selected ky index {ky_index} at ky={float(grid.ky[sel.ky_index]):.4f}")
_status("building initial condition")
G0_base = np.asarray(
_build_initial_condition(
grid,
geom,
ky_index=sel.ky_index,
kx_index=sel.kx_index,
Nl=Nl,
Nm=Nm,
init_cfg=init_cfg,
)
)
def _fresh_G0() -> jnp.ndarray:
return jnp.asarray(G0_base)
_status("building linear cache")
cache = build_linear_cache(grid, geom, params, Nl, Nm)
def _is_valid_growth(gamma_val: float, omega_val: float) -> bool:
if not np.isfinite(gamma_val) or not np.isfinite(omega_val):
return False
if require_positive and gamma_val <= 0.0:
return False
return True
def _run_krylov() -> tuple[float, float, np.ndarray, np.ndarray]:
_status("starting Krylov solve")
kcfg = krylov_cfg or CYCLONE_KRYLOV_DEFAULT
# GX-style time seed to stabilize branch selection. If the caller
# supplied an explicit shift, respect it directly and avoid the seed
# march; this keeps explicit-shift scans bounded and deterministic.
gamma_seed = 0.0
omega_seed = 0.0
seed_ok = False
omega_ok = False
if kcfg.shift is None:
try:
_status("estimating frequency seed with short GX time march")
t_seed = min(150.0, float(kcfg.power_dt) * 15000.0)
time_cfg = ExplicitTimeConfig(
dt=float(kcfg.power_dt),
t_max=t_seed,
sample_stride=1,
fixed_dt=True,
)
G0_seed = _fresh_G0()
t_short, phi_t, _g_t, _o_t = integrate_linear_gx(
G0_seed,
grid,
cache,
params,
geom,
time_cfg,
terms=terms,
mode_method="z_index",
show_progress=show_progress,
)
sel = ModeSelection(
ky_index=0, kx_index=0, z_index=_midplane_index(grid)
)
gamma_seed, omega_seed, _g, _o, _t_mid = gx_growth_rate_from_phi(
phi_t,
t_short,
sel,
navg_fraction=0.5,
mode_method="z_index",
)
omega_ok = np.isfinite(omega_seed) and abs(omega_seed) > 1.0e-8
seed_ok = omega_ok and np.isfinite(gamma_seed) and gamma_seed > 0.0
except Exception:
seed_ok = False
omega_ok = False
if not seed_ok:
try:
_status(
"primary seed failed; retrying reduced Hermite-Laguerre seed"
)
Nl_seed = min(Nl, 16)
Nm_seed = min(Nm, 12)
cache_seed = build_linear_cache(
grid, geom, params, Nl_seed, Nm_seed
)
G0_seed = _build_initial_condition(
grid,
geom,
ky_index=sel.ky_index,
kx_index=sel.kx_index,
Nl=Nl_seed,
Nm=Nm_seed,
init_cfg=init_cfg,
)
t_seed = min(150.0, float(kcfg.power_dt) * 15000.0)
time_cfg = ExplicitTimeConfig(
dt=float(kcfg.power_dt),
t_max=t_seed,
sample_stride=1,
fixed_dt=True,
)
G0_seed = jnp.asarray(np.asarray(G0_seed))
t_short, phi_t, _g_t, _o_t = integrate_linear_gx(
G0_seed,
grid,
cache_seed,
params,
geom,
time_cfg,
terms=terms,
mode_method="z_index",
show_progress=show_progress,
)
sel_seed = ModeSelection(
ky_index=0, kx_index=0, z_index=_midplane_index(grid)
)
gamma_seed, omega_seed, _g, _o, _t_mid = gx_growth_rate_from_phi(
phi_t,
t_short,
sel_seed,
navg_fraction=0.5,
mode_method="z_index",
)
omega_ok = np.isfinite(omega_seed) and abs(omega_seed) > 1.0e-8
seed_ok = omega_ok and np.isfinite(gamma_seed) and gamma_seed > 0.0
except Exception:
seed_ok = False
omega_ok = False
shift = None
if omega_ok:
shift = complex(float(gamma_seed) if seed_ok else 0.0, float(-omega_seed))
G0_krylov = _fresh_G0()
_status("running dominant eigenpair solve")
eig, vec = dominant_eigenpair(
G0_krylov,
cache,
params,
terms=terms,
krylov_dim=kcfg.krylov_dim,
restarts=kcfg.restarts,
omega_min_factor=kcfg.omega_min_factor,
omega_target_factor=kcfg.omega_target_factor,
omega_cap_factor=kcfg.omega_cap_factor,
omega_sign=kcfg.omega_sign,
method=kcfg.method,
power_iters=kcfg.power_iters,
power_dt=kcfg.power_dt,
shift=shift if shift is not None else kcfg.shift,
shift_source=kcfg.shift_source,
shift_tol=kcfg.shift_tol,
shift_maxiter=kcfg.shift_maxiter,
shift_restart=kcfg.shift_restart,
shift_solve_method=kcfg.shift_solve_method,
shift_preconditioner=kcfg.shift_preconditioner,
shift_selection=kcfg.shift_selection,
mode_family=kcfg.mode_family,
fallback_method=kcfg.fallback_method,
fallback_real_floor=kcfg.fallback_real_floor,
status_callback=_status,
)
term_cfg = linear_terms_to_term_config(terms)
phi = compute_fields_cached(vec, cache, params, terms=term_cfg).phi
phi_t_out = np.asarray(phi)[None, ...]
t_out = np.array([0.0])
gamma_out = float(np.real(eig))
omega_out = float(-np.imag(eig))
if seed_ok:
seed_strong = (gamma_seed > 0.0) and (abs(omega_seed) > 1.0e-6)
if seed_strong:
omega_tol = 0.15 * max(abs(omega_seed), 1.0e-6)
gamma_tol = 0.15 * max(abs(gamma_seed), 1.0e-6)
use_seed = (
not np.isfinite(gamma_out)
or not np.isfinite(omega_out)
or (gamma_seed > 0.0 and gamma_out < 0.0)
or abs(omega_out - omega_seed) > omega_tol
or abs(gamma_out - gamma_seed) > gamma_tol
)
if use_seed:
gamma_out = float(gamma_seed)
omega_out = float(omega_seed)
if kcfg.omega_sign != 0:
omega_out = float(np.sign(kcfg.omega_sign)) * abs(omega_out)
gamma_out, omega_out = _normalize_growth_rate(
gamma_out, omega_out, params, diagnostic_norm
)
_status(f"Krylov solve complete: gamma={gamma_out:.6f} omega={omega_out:.6f}")
return gamma_out, omega_out, phi_t_out, t_out
def _run_time() -> tuple[float, float, np.ndarray, np.ndarray]:
_status(f"starting time integration path with fit_signal={fit_key}")
method_key = method.lower()
phi_t: jnp.ndarray | np.ndarray
density_t: jnp.ndarray | np.ndarray | None
time_cfg_use = None
if time_cfg is not None:
time_cfg_use = replace(time_cfg, dt=float(dt), t_max=float(dt) * int(steps))
if sample_stride is not None:
time_cfg_use = replace(time_cfg_use, sample_stride=sample_stride)
elif cfg.time.use_diffrax and not (
method_key.startswith("imex") or method_key.startswith("implicit")
):
time_cfg_use = replace(cfg.time, dt=float(dt), t_max=float(dt) * int(steps))
if sample_stride is not None:
time_cfg_use = replace(time_cfg_use, sample_stride=sample_stride)
if gx_reference_use:
# GX integrator applies damping with per-time scaling internally.
params_use = params
_status("running GX-reference time integrator")
t_max_val = (
float(dt) * int(steps)
if time_cfg_use is None
else float(time_cfg_use.t_max)
)
stride = (
int(sample_stride)
if sample_stride is not None
else (1 if time_cfg_use is None else int(time_cfg_use.sample_stride))
)
gx_time_cfg = ExplicitTimeConfig(
dt=float(dt),
t_max=t_max_val,
sample_stride=stride,
fixed_dt=True,
)
t, phi_t, _g_t, _o_t = integrate_linear_gx(
_fresh_G0(),
grid,
cache,
params_use,
geom,
gx_time_cfg,
terms=terms,
mode_method="z_index",
show_progress=show_progress,
)
sel_local = ModeSelection(
ky_index=0, kx_index=0, z_index=_midplane_index(grid)
)
gamma, omega, _g, _o, _t_mid = gx_growth_rate_from_phi(
phi_t, t, sel_local, navg_fraction=0.5, mode_method="z_index"
)
gamma, omega = _normalize_growth_rate(
gamma, omega, params_use, diagnostic_norm
)
return gamma, omega, np.asarray(phi_t), np.asarray(t)
params_use = params
if time_cfg_use is not None:
_status(
f"running runtime-configured integrator over {int(steps)} steps with sample_stride={int(time_cfg_use.sample_stride)}"
)
if need_density:
_status(
"saving phi and density diagnostics for automatic fit selection"
)
_, saved = integrate_linear_from_config(
_fresh_G0(),
grid,
geom,
params_use,
time_cfg_use,
terms=terms,
save_field="phi+density",
density_species_index=0,
show_progress=show_progress,
)
phi_t, density_t = saved
else:
_, phi_t = integrate_linear_from_config(
_fresh_G0(),
grid,
geom,
params_use,
time_cfg_use,
terms=terms,
show_progress=show_progress,
)
density_t = None
stride = time_cfg_use.sample_stride
else:
stride = 1 if sample_stride is None else int(sample_stride)
if need_density or not use_jit:
_status(
f"running explicit diagnostics integrator over {int(steps)} steps with sample_stride={stride}"
)
_diag = integrate_linear_diagnostics(
_fresh_G0(),
grid,
geom,
params_use,
dt=dt,
steps=steps,
method=method,
terms=terms,
sample_stride=stride,
species_index=0,
record_hl_energy=False,
show_progress=show_progress,
)
phi_t = _diag[1]
density_t = _diag[2] if len(_diag) > 2 else None
else:
_status(
f"running cached linear integrator over {int(steps)} steps with sample_stride={stride}"
)
_, phi_out_time = integrate_linear(
_fresh_G0(),
grid,
geom,
params_use,
dt=dt,
steps=steps,
method=method,
terms=terms,
sample_stride=stride,
show_progress=show_progress,
)
phi_t = phi_out_time
density_t = None
phi_t_np = np.asarray(phi_t)
t_arr = np.arange(phi_t_np.shape[0]) * dt * stride
density_np = None if density_t is None else np.asarray(density_t)
_status(
f"integration complete; fitting growth rate from {phi_t_np.shape[0]} saved samples"
)
if fit_key == "auto":
signal, _name, gamma_out, omega_out = _select_fit_signal_auto(
t_arr,
phi_t_np,
density_np,
sel,
mode_method=mode_method,
tmin=tmin,
tmax=tmax,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
max_amp_fraction=max_amp_fraction,
window_method=window_method,
max_fraction=max_fraction,
end_fraction=end_fraction,
num_windows=8,
phase_weight=phase_weight,
length_weight=length_weight,
min_r2=min_r2,
late_penalty=late_penalty,
min_slope=min_slope,
min_slope_frac=min_slope_frac,
slope_var_weight=slope_var_weight,
)
_status(f"automatic fit selected signal '{_name}'")
if not np.isfinite(gamma_out) or not np.isfinite(omega_out):
gamma_out, omega_out = 0.0, 0.0
else:
signal = _select_fit_signal(
phi_t_np,
density_np,
sel,
fit_signal=fit_key,
mode_method=mode_method,
)
if auto_window and tmin is None and tmax is None:
gamma_out, omega_out, _tmin, _tmax = fit_growth_rate_auto(
t_arr,
signal,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
max_fraction=max_fraction,
end_fraction=end_fraction,
max_amp_fraction=max_amp_fraction,
phase_weight=phase_weight,
length_weight=length_weight,
min_r2=min_r2,
late_penalty=late_penalty,
min_slope=min_slope,
min_slope_frac=min_slope_frac,
slope_var_weight=slope_var_weight,
window_method=window_method,
)
else:
gamma_out, omega_out = fit_growth_rate(
t_arr, signal, tmin=tmin, tmax=tmax
)
gamma_out, omega_out = _normalize_growth_rate(
gamma_out, omega_out, params_use, diagnostic_norm
)
_status(
f"time integration fit complete: gamma={gamma_out:.6f} omega={omega_out:.6f}"
)
return float(gamma_out), float(omega_out), phi_t_np, t_arr
if solver_key == "krylov":
gamma, omega, phi_t_np, t = _run_krylov()
elif solver_key == "auto":
try:
gamma, omega, phi_t_np, t = _run_time()
except ValueError as exc:
_status(f"time-path failed ({exc}); falling back to Krylov solve")
gamma, omega, phi_t_np, t = _run_krylov()
if not _is_valid_growth(gamma, omega):
_status("time-path result rejected; falling back to Krylov solve")
gamma, omega, phi_t_np, t = _run_krylov()
else:
gamma, omega, phi_t_np, t = _run_time()
_status(f"completed Cyclone linear run at ky={float(grid.ky[sel.ky_index]):.4f}")
return CycloneRunResult(
t=t,
phi_t=phi_t_np,
gamma=gamma,
omega=omega,
ky=float(grid.ky[sel.ky_index]),
selection=sel,
)
[docs]
def run_cyclone_scan(
ky_values: np.ndarray,
Nl: int = 6,
Nm: int = 12,
dt: float | np.ndarray = 0.01,
steps: int | np.ndarray = 800,
method: str = "imex2",
params: LinearParams | None = None,
cfg: CycloneBaseCase | None = None,
time_cfg: TimeConfig | None = None,
solver: str = "auto",
krylov_cfg: KrylovConfig | None = None,
tmin: float | None = None,
tmax: float | None = None,
auto_window: bool = True,
window_fraction: float = 0.4,
min_points: int = 40,
start_fraction: float = 0.2,
growth_weight: float = 1.0,
require_positive: bool = True,
min_amp_fraction: float = 0.0,
max_fraction: float = 0.8,
end_fraction: float = 1.0,
max_amp_fraction: float = 1.0,
phase_weight: float = 0.2,
length_weight: float = 0.05,
min_r2: float = 0.0,
late_penalty: float = 0.0,
min_slope: float | None = None,
min_slope_frac: float = 0.0,
slope_var_weight: float = 0.0,
window_method: str = "loglinear",
mode_method: str = "project",
mode_only: bool = True,
terms: LinearTerms | None = None,
sample_stride: int | None = None,
fit_signal: str = "auto",
diagnostic_norm: str = "none",
use_jit: bool = True,
ky_batch: int = 4,
fixed_batch_shape: bool = True,
streaming_fit: bool = True,
streaming_amp_floor: float = 1.0e-30,
mode_follow: bool = True,
gx_reference: bool | None = None,
show_progress: bool = False,
) -> CycloneScanResult:
"""Run the linear Cyclone benchmark for a list of ky values.
If ``time_cfg`` is provided, its ``dt`` and ``t_max`` override ``dt``/``steps``.
"""
cfg = cfg or CycloneBaseCase()
init_cfg = getattr(cfg, "init", None) or InitializationConfig()
grid_full = build_spectral_grid(cfg.grid)
gx_reference_use = (
bool(cfg.gx_reference) if gx_reference is None else bool(gx_reference)
)
geom_cfg = cfg.geometry
if gx_reference_use:
geom_cfg = replace(geom_cfg, drift_scale=1.0)
if diagnostic_norm == "none":
diagnostic_norm = "gx"
if mode_method not in {"z_index", "max"}:
mode_method = "z_index"
geom = SAlphaGeometry.from_config(geom_cfg)
if params is None:
params = LinearParams(
R_over_Ln=cfg.model.R_over_Ln,
R_over_LTi=cfg.model.R_over_LTi,
R_over_LTe=cfg.model.R_over_LTe,
omega_d_scale=CYCLONE_OMEGA_D_SCALE,
omega_star_scale=CYCLONE_OMEGA_STAR_SCALE,
rho_star=CYCLONE_RHO_STAR,
kpar_scale=float(geom.gradpar()),
nu=cfg.model.nu_i,
damp_ends_amp=REFERENCE_DAMP_ENDS_AMP,
damp_ends_widthfrac=REFERENCE_DAMP_ENDS_WIDTHFRAC,
)
params = _apply_gx_hypercollisions(params, nhermite=Nm)
if terms is None:
if getattr(cfg.model, "adiabatic_ions", False):
terms = LinearTerms(bpar=0.0)
else:
terms = LinearTerms()
solver_key = normalize_solver_key(solver)
fit_key = normalize_fit_signal(fit_signal)
auto_solver = solver_key == "auto"
if auto_solver:
solver_key = "gx_time" if gx_reference_use else "time"
streaming_fit, mode_only = apply_auto_fit_scan_policy(
fit_key, streaming_fit=streaming_fit, mode_only=mode_only
)
need_density = fit_key in {"density", "auto"}
gammas = []
omegas = []
ky_out = []
mode_method = resolve_scan_mode_method(mode_method, mode_only=mode_only)
use_batch = should_use_ky_batch(
ky_batch=ky_batch,
solver_key=solver_key,
dt=dt,
steps=steps,
tmin=tmin,
tmax=tmax,
)
fit_policy = ScanFitWindowPolicy(
tmin=tmin,
tmax=tmax,
auto_window=auto_window,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
max_fraction=max_fraction,
end_fraction=end_fraction,
max_amp_fraction=max_amp_fraction,
phase_weight=phase_weight,
length_weight=length_weight,
min_r2=min_r2,
late_penalty=late_penalty,
min_slope=min_slope,
min_slope_frac=min_slope_frac,
slope_var_weight=slope_var_weight,
window_method=window_method,
fit_growth_rate_fn=fit_growth_rate,
fit_growth_rate_auto_fn=fit_growth_rate_auto,
normalize_growth_rate_fn=_normalize_growth_rate,
)
def _fit_signal(
signal: np.ndarray, idx: int, dt_i: float, stride: int
) -> tuple[float, float]:
return fit_policy.fit_signal(
signal,
idx=idx,
dt=dt_i,
stride=stride,
params=params,
diagnostic_norm=diagnostic_norm,
)
ky_values_arr = np.asarray(ky_values, dtype=float)
phi_t: jnp.ndarray | np.ndarray
density_t: jnp.ndarray | np.ndarray | None
if solver_key == "krylov":
if ky_values_arr.size == 0:
return CycloneScanResult(
ky=ky_values_arr, gamma=np.array([]), omega=np.array([])
)
order = (
np.argsort(ky_values_arr) if mode_follow else np.arange(ky_values_arr.size)
)
gamma_out = np.zeros_like(ky_values_arr, dtype=float)
omega_out = np.zeros_like(ky_values_arr, dtype=float)
v_ref: jnp.ndarray | None = None
prev_eig: complex | None = None
cfg_use = krylov_cfg or CYCLONE_KRYLOV_DEFAULT
for idx in order:
ky_val = float(ky_values_arr[idx])
ky_index = select_ky_index(np.asarray(grid_full.ky), ky_val)
grid = select_ky_grid(grid_full, ky_index)
G0_jax = _build_initial_condition(
grid,
geom,
ky_index=0,
kx_index=0,
Nl=Nl,
Nm=Nm,
init_cfg=init_cfg,
)
cache = build_linear_cache(grid, geom, params, Nl, Nm)
# Use a short GX-style time integration to seed the branch.
gamma_seed = 0.0
omega_seed = 0.0
seed_ok = False
omega_ok = False
if prev_eig is None:
try:
t_seed = min(150.0, float(cfg_use.power_dt) * 15000.0)
gx_time_cfg = ExplicitTimeConfig(
dt=float(cfg_use.power_dt),
t_max=t_seed,
sample_stride=1,
fixed_dt=True,
)
G0_seed = jnp.array(G0_jax)
t_short, phi_seed, _g_t, _o_t = integrate_linear_gx(
G0_seed,
grid,
cache,
params,
geom,
gx_time_cfg,
terms=terms,
mode_method="z_index",
show_progress=show_progress,
)
sel = ModeSelection(
ky_index=0, kx_index=0, z_index=_midplane_index(grid)
)
gamma_seed, omega_seed, _g, _o, _t_mid = gx_growth_rate_from_phi(
phi_seed,
t_short,
sel,
navg_fraction=0.5,
mode_method="z_index",
)
omega_ok = np.isfinite(omega_seed) and abs(omega_seed) > 1.0e-8
seed_ok = omega_ok and np.isfinite(gamma_seed) and gamma_seed > 0.0
except Exception:
seed_ok = False
omega_ok = False
if not seed_ok:
try:
Nl_seed = min(Nl, 16)
Nm_seed = min(Nm, 12)
cache_seed = build_linear_cache(
grid, geom, params, Nl_seed, Nm_seed
)
G0_seed = _build_initial_condition(
grid,
geom,
ky_index=0,
kx_index=0,
Nl=Nl_seed,
Nm=Nm_seed,
init_cfg=init_cfg,
)
t_seed = min(150.0, float(cfg_use.power_dt) * 15000.0)
gx_time_cfg = ExplicitTimeConfig(
dt=float(cfg_use.power_dt),
t_max=t_seed,
sample_stride=1,
fixed_dt=True,
)
t_short, phi_seed, _g_t, _o_t = integrate_linear_gx(
G0_seed,
grid,
cache_seed,
params,
geom,
gx_time_cfg,
terms=terms,
mode_method="z_index",
show_progress=show_progress,
)
sel_seed = ModeSelection(
ky_index=0, kx_index=0, z_index=_midplane_index(grid)
)
gamma_seed, omega_seed, _g, _o, _t_mid = gx_growth_rate_from_phi(
phi_seed,
t_short,
sel_seed,
navg_fraction=0.5,
mode_method="z_index",
)
omega_ok = np.isfinite(omega_seed) and abs(omega_seed) > 1.0e-8
seed_ok = omega_ok and np.isfinite(gamma_seed) and gamma_seed > 0.0
except Exception:
seed_ok = False
omega_ok = False
shift: complex | None
if prev_eig is not None and np.isfinite(prev_eig):
shift = prev_eig
elif omega_ok:
shift = complex(
float(gamma_seed) if seed_ok else 0.0, float(-omega_seed)
)
else:
shift = None
eig, vec = dominant_eigenpair(
G0_jax,
cache,
params,
terms=terms,
v_ref=v_ref,
select_overlap=v_ref is not None,
krylov_dim=cfg_use.krylov_dim,
restarts=cfg_use.restarts,
omega_min_factor=cfg_use.omega_min_factor,
omega_target_factor=cfg_use.omega_target_factor,
omega_cap_factor=cfg_use.omega_cap_factor,
omega_sign=cfg_use.omega_sign,
method=cfg_use.method,
power_iters=cfg_use.power_iters,
power_dt=cfg_use.power_dt,
shift=shift if shift is not None else cfg_use.shift,
shift_source=cfg_use.shift_source,
shift_tol=cfg_use.shift_tol,
shift_maxiter=cfg_use.shift_maxiter,
shift_restart=cfg_use.shift_restart,
shift_solve_method=cfg_use.shift_solve_method,
shift_preconditioner=cfg_use.shift_preconditioner,
shift_selection=cfg_use.shift_selection,
mode_family=cfg_use.mode_family,
fallback_method=cfg_use.fallback_method,
fallback_real_floor=cfg_use.fallback_real_floor,
)
gamma = float(np.real(eig))
omega = float(-np.imag(eig))
# If Krylov lands on the wrong branch, fall back to GX-style seed.
use_seed = False
if seed_ok:
seed_strong = (gamma_seed > 0.0) and (abs(omega_seed) > 1.0e-6)
if seed_strong:
omega_tol = 0.15 * max(abs(omega_seed), 1.0e-6)
gamma_tol = 0.15 * max(abs(gamma_seed), 1.0e-6)
use_seed = (
not np.isfinite(gamma)
or not np.isfinite(omega)
or (gamma_seed > 0.0 and gamma < 0.0)
or abs(omega - omega_seed) > omega_tol
or abs(gamma - gamma_seed) > gamma_tol
)
if use_seed and seed_ok:
gamma = float(gamma_seed)
omega = float(omega_seed)
else:
v_ref = vec
prev_eig = complex(float(gamma), float(-omega))
gamma, omega = _normalize_growth_rate(gamma, omega, params, diagnostic_norm)
gamma_out[idx] = gamma
omega_out[idx] = omega
return CycloneScanResult(ky=ky_values_arr, gamma=gamma_out, omega=omega_out)
if solver_key == "gx_time":
if ky_values_arr.size == 0:
return CycloneScanResult(
ky=ky_values_arr, gamma=np.array([]), omega=np.array([])
)
gamma_out = np.zeros_like(ky_values_arr, dtype=float)
omega_out = np.zeros_like(ky_values_arr, dtype=float)
prev_omega: float | None = None
prev_prev_omega: float | None = None
kcfg = krylov_cfg or CYCLONE_KRYLOV_DEFAULT
time_base = time_cfg or cfg.time
for idx, ky_val in enumerate(ky_values_arr):
ky_index = select_ky_index(np.asarray(grid_full.ky), float(ky_val))
grid = select_ky_grid(grid_full, ky_index)
G0_jax = _build_initial_condition(
grid,
geom,
ky_index=0,
kx_index=0,
Nl=Nl,
Nm=Nm,
init_cfg=init_cfg,
)
cache = build_linear_cache(grid, geom, params, Nl, Nm)
dt_i = float(dt[idx]) if isinstance(dt, np.ndarray) else float(dt)
steps_i = int(steps[idx]) if isinstance(steps, np.ndarray) else int(steps)
t_max_val = dt_i * float(steps_i)
if gx_reference_use and time_cfg is None:
fixed_dt_i = True
dt_min_i = dt_i
dt_max_i: float | None = dt_i
cfl_i = 1.0
cfl_fac_i = 1.0
else:
fixed_dt_i = bool(time_base.fixed_dt)
dt_min_i = float(time_base.dt_min)
dt_max_i = None if time_base.dt_max is None else float(time_base.dt_max)
cfl_i = float(time_base.cfl)
cfl_fac_i = resolve_cfl_fac(str(time_base.method), time_base.cfl_fac)
gx_time_cfg = ExplicitTimeConfig(
dt=dt_i,
t_max=t_max_val,
sample_stride=1,
fixed_dt=fixed_dt_i,
dt_min=dt_min_i,
dt_max=dt_max_i,
cfl=cfl_i,
cfl_fac=cfl_fac_i,
)
G0_time = jnp.array(G0_jax)
t, phi_gx, _g_t, _o_t = integrate_linear_gx(
G0_time,
grid,
cache,
params,
geom,
gx_time_cfg,
terms=terms,
mode_method="z_index",
show_progress=show_progress,
)
sel_local = ModeSelection(
ky_index=0, kx_index=0, z_index=_midplane_index(grid)
)
gx_growth_ok = True
try:
gamma, omega, _g, _o, _t_mid = gx_growth_rate_from_phi(
phi_gx, t, sel_local, navg_fraction=0.5, mode_method="z_index"
)
gamma, omega = _normalize_growth_rate(
gamma, omega, params, diagnostic_norm
)
except ValueError:
gx_growth_ok = False
gamma = float("nan")
omega = float("nan")
if gx_reference_use and prev_omega is None and omega < 0.0:
omega = abs(omega)
need_reselect = (
(gx_reference_use and gx_growth_ok)
and prev_omega is not None
and prev_omega > 0.0
and (omega <= 0.0 or ((idx >= 2) and (omega < 0.85 * prev_omega)))
)
if need_reselect or not gx_growth_ok:
target_omega: float | None = (
prev_omega if (gx_growth_ok and prev_omega is not None) else None
)
if (
target_omega is not None
and prev_prev_omega is not None
and prev_omega is not None
and prev_omega > prev_prev_omega
):
target_omega = prev_omega + (prev_omega - prev_prev_omega)
G0_krylov = jnp.array(G0_jax)
eig, _vec = dominant_eigenpair(
G0_krylov,
cache,
params,
terms=terms,
krylov_dim=kcfg.krylov_dim,
restarts=kcfg.restarts,
omega_min_factor=kcfg.omega_min_factor,
omega_target_factor=kcfg.omega_target_factor,
omega_cap_factor=kcfg.omega_cap_factor,
omega_sign=kcfg.omega_sign,
method=kcfg.method,
power_iters=kcfg.power_iters,
power_dt=kcfg.power_dt,
shift=kcfg.shift,
shift_source=kcfg.shift_source,
shift_tol=kcfg.shift_tol,
shift_maxiter=kcfg.shift_maxiter,
shift_restart=kcfg.shift_restart,
shift_solve_method=kcfg.shift_solve_method,
shift_preconditioner=kcfg.shift_preconditioner,
shift_selection=kcfg.shift_selection,
mode_family=kcfg.mode_family,
fallback_method=kcfg.fallback_method,
fallback_real_floor=kcfg.fallback_real_floor,
)
gamma_k = float(np.real(eig))
omega_k = float(abs(-np.imag(eig)))
gamma_k, omega_k = _normalize_growth_rate(
gamma_k, omega_k, params, diagnostic_norm
)
if not gx_growth_ok:
gamma, omega = gamma_k, omega_k
else:
assert target_omega is not None
candidates: list[tuple[float, float]] = [
(float(gamma), float(abs(omega)))
]
gamma_base = abs(float(gamma))
gamma_delta_limit = max(3.0 * gamma_base, gamma_base + 0.05, 1.0e-3)
if (
np.isfinite(gamma_k)
and np.isfinite(omega_k)
and gamma_k > 0.0
and abs(gamma_k - float(gamma)) <= gamma_delta_limit
):
candidates.append((gamma_k, omega_k))
def _score(candidate: tuple[float, float]) -> float:
g_val, o_val = candidate
penalty = 0.0 if g_val > 0.0 else 1.0e3
return penalty + abs(o_val - target_omega)
gamma, omega = min(candidates, key=_score)
gamma_out[idx] = gamma
omega_out[idx] = omega
prev_prev_omega = prev_omega
prev_omega = float(omega)
return CycloneScanResult(ky=ky_values_arr, gamma=gamma_out, omega=omega_out)
if use_batch:
ky_iter = _iter_ky_batches(
ky_values_arr,
ky_batch=ky_batch,
fixed_batch_shape=fixed_batch_shape,
)
else:
ky_iter = _iter_ky_batches(ky_values_arr, ky_batch=1, fixed_batch_shape=False)
ky_slice: np.ndarray
ky_indices: list[int]
sel_scan: ModeSelection | ModeSelectionBatch
for batch_start, ky_slice, valid_count in ky_iter:
if use_batch:
ky_indices = [
select_ky_index(np.asarray(grid_full.ky), float(ky)) for ky in ky_slice
]
grid = select_ky_grid(grid_full, ky_indices)
sel_indices = np.arange(len(ky_indices), dtype=int)
sel_scan = ModeSelectionBatch(sel_indices, 0, _midplane_index(grid))
dt_i = float(dt)
steps_i = int(steps)
else:
ky_indices = [select_ky_index(np.asarray(grid_full.ky), float(ky_slice[0]))]
grid = select_ky_grid(grid_full, ky_indices[0])
sel_scan = ModeSelection(
ky_index=0, kx_index=0, z_index=_midplane_index(grid)
)
dt_i = float(dt[batch_start]) if isinstance(dt, np.ndarray) else float(dt)
steps_i = (
int(steps[batch_start]) if isinstance(steps, np.ndarray) else int(steps)
)
ky_local = np.arange(len(ky_indices))
G0_jax = _build_initial_condition(
grid,
geom,
ky_index=ky_local,
kx_index=0,
Nl=Nl,
Nm=Nm,
init_cfg=init_cfg,
)
cache = build_linear_cache(grid, geom, params, Nl, Nm)
time_cfg_i = None
if time_cfg is not None:
time_cfg_i = replace(time_cfg, dt=dt_i, t_max=dt_i * steps_i)
if sample_stride is not None:
time_cfg_i = replace(time_cfg_i, sample_stride=sample_stride)
params_use = params
if time_cfg_i is not None and time_cfg_i.use_diffrax and streaming_fit:
t_total = float(time_cfg_i.t_max)
tmin_i, tmax_i = _resolve_streaming_window(
t_total,
indexed_float_value(tmin, batch_start),
indexed_float_value(tmax, batch_start),
start_fraction,
window_fraction,
1.0,
)
_, gamma_vals, omega_vals = integrate_linear_diffrax_streaming(
G0_jax,
grid,
geom,
params_use,
dt=dt_i,
steps=steps_i,
method=time_cfg_i.diffrax_solver,
cache=cache,
terms=terms,
adaptive=False,
rtol=time_cfg_i.diffrax_rtol,
atol=time_cfg_i.diffrax_atol,
max_steps=time_cfg_i.diffrax_max_steps,
progress_bar=time_cfg_i.progress_bar,
checkpoint=time_cfg_i.checkpoint,
tmin=tmin_i,
tmax=tmax_i,
fit_signal="phi",
show_progress=show_progress,
mode_ky_indices=ky_local[:valid_count],
mode_kx_index=0,
mode_z_index=_midplane_index(grid),
mode_method=mode_method,
amp_floor=streaming_amp_floor,
return_state=False,
)
gamma_arr = np.asarray(gamma_vals)
omega_arr = np.asarray(omega_vals)
for local_idx in range(valid_count):
ky_val = ky_slice[local_idx]
gamma_i, omega_i = _normalize_growth_rate(
float(gamma_arr[local_idx]),
float(omega_arr[local_idx]),
params_use,
diagnostic_norm,
)
gammas.append(gamma_i)
omegas.append(omega_i)
ky_out.append(float(ky_val))
continue
if time_cfg_i is not None:
save_field = (
"phi+density"
if fit_key == "auto"
else ("density" if fit_key == "density" else "phi")
)
save_mode = None if fit_key == "auto" else (sel_scan if mode_only else None)
_, saved = integrate_linear_from_config(
G0_jax,
grid,
geom,
params_use,
time_cfg_i,
cache=cache,
terms=terms,
save_mode=save_mode,
mode_method=mode_method,
save_field=save_field,
density_species_index=0 if need_density else None,
)
if fit_key == "auto":
phi_t, density_t = saved
phi_t = np.asarray(phi_t)
density_t = np.asarray(density_t)
else:
phi_t = np.asarray(saved)
density_t = None
stride = time_cfg_i.sample_stride
else:
stride = 1 if sample_stride is None else int(sample_stride)
if use_jit and not need_density:
_, phi_out_time = integrate_linear(
G0_jax,
grid,
geom,
params_use,
dt=dt_i,
steps=steps_i,
method=method,
cache=cache,
terms=terms,
sample_stride=stride,
show_progress=show_progress,
)
phi_t = phi_out_time
density_t = None
else:
_diag = integrate_linear_diagnostics(
G0_jax,
grid,
geom,
params_use,
dt=dt_i,
steps=steps_i,
method=method,
cache=cache,
terms=terms,
sample_stride=stride,
species_index=None,
record_hl_energy=False,
)
phi_t = np.asarray(_diag[1])
density_t = np.asarray(_diag[2]) if len(_diag) > 2 else None
phi_t_np = np.asarray(phi_t)
signal_t = None
if mode_only and phi_t_np.ndim == 2:
signal_t = phi_t_np
density_np = None if density_t is None else np.asarray(density_t)
t = np.arange(phi_t_np.shape[0]) * dt_i * stride
def _is_valid_growth(gamma_val: float, omega_val: float) -> bool:
if not np.isfinite(gamma_val) or not np.isfinite(omega_val):
return False
if require_positive and gamma_val <= 0.0:
return False
return True
for local_idx in range(valid_count):
ky_val = ky_slice[local_idx]
if signal_t is None:
sel_local = ModeSelection(
ky_index=local_idx, kx_index=0, z_index=_midplane_index(grid)
)
if fit_key == "auto":
signal, _name, gamma, omega = _select_fit_signal_auto(
t,
phi_t_np,
density_np,
sel_local,
mode_method=mode_method,
tmin=indexed_float_value(tmin, batch_start + local_idx),
tmax=indexed_float_value(tmax, batch_start + local_idx),
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
max_amp_fraction=max_amp_fraction,
window_method=window_method,
max_fraction=max_fraction,
end_fraction=end_fraction,
num_windows=8,
phase_weight=phase_weight,
length_weight=length_weight,
min_r2=min_r2,
late_penalty=late_penalty,
min_slope=min_slope,
min_slope_frac=min_slope_frac,
slope_var_weight=slope_var_weight,
)
gamma, omega = _normalize_growth_rate(
gamma, omega, params_use, diagnostic_norm
)
if auto_solver and not _is_valid_growth(gamma, omega):
res = run_cyclone_linear(
ky_target=float(ky_val),
Nl=Nl,
Nm=Nm,
dt=dt_i,
steps=steps_i,
method=method,
params=params,
cfg=cfg,
time_cfg=time_cfg,
solver="krylov",
krylov_cfg=krylov_cfg,
diagnostic_norm=diagnostic_norm,
fit_signal="phi",
show_progress=show_progress,
)
gamma = float(res.gamma)
omega = float(res.omega)
gammas.append(gamma)
omegas.append(omega)
ky_out.append(float(ky_val))
continue
signal = extract_mode_time_series(
phi_t_np, sel_local, method=mode_method
)
else:
signal = signal_t[:, local_idx] if signal_t.ndim > 1 else signal_t
gamma, omega = _fit_signal(signal, batch_start + local_idx, dt_i, stride)
if auto_solver and not _is_valid_growth(gamma, omega):
res = run_cyclone_linear(
ky_target=float(ky_val),
Nl=Nl,
Nm=Nm,
dt=dt_i,
steps=steps_i,
method=method,
params=params,
cfg=cfg,
time_cfg=time_cfg,
solver="krylov",
krylov_cfg=krylov_cfg,
diagnostic_norm=diagnostic_norm,
fit_signal="phi",
show_progress=show_progress,
)
gamma = float(res.gamma)
omega = float(res.omega)
gammas.append(gamma)
omegas.append(omega)
ky_out.append(float(ky_val))
return CycloneScanResult(
ky=np.array(ky_out), gamma=np.array(gammas), omega=np.array(omegas)
)
[docs]
def run_etg_linear(
ky_target: float = 3.0,
Nl: int = 6,
Nm: int = 12,
dt: float = 0.01,
steps: int = 800,
method: str = "rk4",
params: LinearParams | None = None,
cfg: ETGBaseCase | None = None,
time_cfg: TimeConfig | None = None,
solver: str = "auto",
krylov_cfg: KrylovConfig | None = None,
tmin: float | None = None,
tmax: float | None = None,
auto_window: bool = True,
window_fraction: float = 0.4,
min_points: int = 40,
start_fraction: float = 0.2,
growth_weight: float = 1.0,
require_positive: bool = True,
min_amp_fraction: float = 0.0,
mode_method: str = "project",
terms: LinearTerms | None = None,
sample_stride: int | None = None,
fit_signal: str = "auto",
streaming_fit: bool = False,
streaming_amp_floor: float = 1.0e-30,
gx_growth: bool = False,
gx_navg_fraction: float = 0.5,
diagnostic_norm: str = "none",
show_progress: bool = False,
) -> LinearRunResult:
"""Run an ETG linear benchmark and extract growth rate."""
cfg = cfg or ETGBaseCase()
grid_full = build_spectral_grid(cfg.grid)
geom = SAlphaGeometry.from_config(cfg.geometry)
if params is None:
if getattr(cfg.model, "adiabatic_ions", False):
params = _electron_only_params(
cfg.model,
kpar_scale=float(geom.gradpar()),
omega_d_scale=ETG_OMEGA_D_SCALE,
omega_star_scale=ETG_OMEGA_STAR_SCALE,
rho_star=ETG_RHO_STAR,
damp_ends_amp=0.0,
damp_ends_widthfrac=0.0,
nhermite=Nm,
)
else:
params = _two_species_params(
cfg.model,
kpar_scale=float(geom.gradpar()),
omega_d_scale=ETG_OMEGA_D_SCALE,
omega_star_scale=ETG_OMEGA_STAR_SCALE,
rho_star=ETG_RHO_STAR,
damp_ends_amp=0.0,
damp_ends_widthfrac=0.0,
nhermite=Nm,
)
if terms is None:
# The ETG benchmark contract is electrostatic for both the adiabatic-ion
# and two-species variants. Keep the default ETG wrappers aligned with
# the tracked ETG asset-generation tools.
terms = LinearTerms(apar=0.0, bpar=0.0, hypercollisions=1.0)
ky_index = select_ky_index(np.asarray(grid_full.ky), ky_target)
grid = select_ky_grid(grid_full, ky_index)
sel = ModeSelection(ky_index=0, kx_index=0, z_index=_midplane_index(grid))
charge = np.atleast_1d(np.asarray(params.charge_sign))
ns = int(charge.size)
electron_index = int(np.argmin(charge))
G0 = np.zeros(
(ns, Nl, Nm, grid.ky.size, grid.kx.size, grid.z.size), dtype=np.complex64
)
G0_single = _build_initial_condition(
grid,
geom,
ky_index=sel.ky_index,
kx_index=sel.kx_index,
Nl=Nl,
Nm=Nm,
init_cfg=cfg.init,
)
G0[electron_index] = np.asarray(G0_single, dtype=np.complex64)
G0_jax = jnp.asarray(G0)
solver_key = solver.strip().lower()
fit_key = fit_signal.strip().lower()
if fit_key not in {"phi", "density", "auto"}:
raise ValueError("fit_signal must be 'phi', 'density', or 'auto'")
if fit_key == "auto" and streaming_fit:
streaming_fit = False
auto_solver = solver_key == "auto"
if auto_solver:
solver_key = "krylov"
def _is_valid_growth(gamma_val: float, omega_val: float) -> bool:
if not np.isfinite(gamma_val) or not np.isfinite(omega_val):
return False
if require_positive and gamma_val <= 0.0:
return False
return True
if solver_key == "krylov":
krylov_cfg = krylov_cfg or ETG_KRYLOV_DEFAULT
cache = build_linear_cache(grid, geom, params, Nl, Nm)
eig, vec = dominant_eigenpair(
G0_jax,
cache,
params,
terms=terms,
krylov_dim=krylov_cfg.krylov_dim,
restarts=krylov_cfg.restarts,
omega_min_factor=krylov_cfg.omega_min_factor,
omega_target_factor=krylov_cfg.omega_target_factor,
omega_cap_factor=krylov_cfg.omega_cap_factor,
omega_sign=krylov_cfg.omega_sign,
method=krylov_cfg.method,
power_iters=krylov_cfg.power_iters,
power_dt=krylov_cfg.power_dt,
shift=krylov_cfg.shift,
shift_source=krylov_cfg.shift_source,
shift_tol=krylov_cfg.shift_tol,
shift_maxiter=krylov_cfg.shift_maxiter,
shift_restart=krylov_cfg.shift_restart,
shift_solve_method=krylov_cfg.shift_solve_method,
shift_preconditioner=krylov_cfg.shift_preconditioner,
shift_selection=krylov_cfg.shift_selection,
mode_family=krylov_cfg.mode_family,
fallback_method=krylov_cfg.fallback_method,
fallback_real_floor=krylov_cfg.fallback_real_floor,
)
term_cfg = linear_terms_to_term_config(terms)
phi = compute_fields_cached(vec, cache, params, terms=term_cfg).phi
phi_t_np = np.asarray(phi)[None, ...]
t = np.array([0.0])
gamma = float(np.real(eig))
omega = float(-np.imag(eig))
if krylov_cfg.omega_sign != 0:
omega = float(np.sign(krylov_cfg.omega_sign)) * abs(omega)
gamma, omega = _normalize_growth_rate(gamma, omega, params, diagnostic_norm)
if auto_solver and not _is_valid_growth(gamma, omega):
solver_key = "time"
if solver_key != "krylov":
time_cfg_use = time_cfg
if time_cfg_use is None and streaming_fit and cfg.time.use_diffrax:
max_steps = max(int(cfg.time.diffrax_max_steps), int(steps))
time_cfg_use = replace(
cfg.time,
dt=dt,
t_max=dt * steps,
diffrax_max_steps=max_steps,
)
if sample_stride is not None:
time_cfg_use = replace(time_cfg_use, sample_stride=sample_stride)
if time_cfg_use is not None:
if sample_stride is not None:
time_cfg_use = replace(time_cfg_use, sample_stride=sample_stride)
if time_cfg is not None:
dt = float(time_cfg_use.dt)
steps = int(round(time_cfg_use.t_max / time_cfg_use.dt))
cache = build_linear_cache(grid, geom, params, Nl, Nm)
if fit_key in {"density", "auto"}:
if streaming_fit and time_cfg_use.use_diffrax:
t_total = float(dt * steps)
tmin_i, tmax_i = _resolve_streaming_window(
t_total, tmin, tmax, start_fraction, window_fraction, 1.0
)
G_last, gamma_vals, omega_vals = integrate_linear_diffrax_streaming(
G0_jax,
grid,
geom,
params,
dt=dt,
steps=steps,
method=time_cfg_use.diffrax_solver,
cache=cache,
terms=terms,
adaptive=False,
rtol=time_cfg_use.diffrax_rtol,
atol=time_cfg_use.diffrax_atol,
max_steps=time_cfg_use.diffrax_max_steps,
show_progress=show_progress,
progress_bar=time_cfg_use.progress_bar,
checkpoint=time_cfg_use.checkpoint,
tmin=tmin_i,
tmax=tmax_i,
fit_signal="density",
mode_ky_indices=np.array([0], dtype=int),
mode_kx_index=0,
mode_z_index=_midplane_index(grid),
mode_method=mode_method,
amp_floor=streaming_amp_floor,
density_species_index=electron_index,
return_state=True,
)
gamma = float(np.asarray(gamma_vals)[0])
omega = float(np.asarray(omega_vals)[0])
gamma, omega = _normalize_growth_rate(
gamma, omega, params, diagnostic_norm
)
if G_last is not None and G_last.ndim == 7:
G_last = G_last[0]
term_cfg = linear_terms_to_term_config(terms)
if G_last is None:
raise ValueError(
"Expected final state from streaming fit; got None."
)
phi_last = compute_fields_cached(
G_last, cache, params, terms=term_cfg
).phi
phi_t = jnp.asarray(phi_last)[None, ...]
density_t = None
stride = time_cfg_use.sample_stride
phi_t_np = np.asarray(phi_t)
t = np.array([tmax_i])
return LinearRunResult(
t=t,
phi_t=phi_t_np,
gamma=gamma,
omega=omega,
ky=float(grid.ky[sel.ky_index]),
selection=sel,
)
if time_cfg_use.use_diffrax:
_, saved = integrate_linear_diffrax(
G0_jax,
grid,
geom,
params,
dt=dt,
steps=steps,
method=time_cfg_use.diffrax_solver,
cache=cache,
terms=terms,
adaptive=time_cfg_use.diffrax_adaptive,
rtol=time_cfg_use.diffrax_rtol,
atol=time_cfg_use.diffrax_atol,
max_steps=time_cfg_use.diffrax_max_steps,
show_progress=show_progress,
progress_bar=time_cfg_use.progress_bar,
checkpoint=time_cfg_use.checkpoint,
sample_stride=time_cfg_use.sample_stride,
return_state=time_cfg_use.save_state,
save_field="phi+density",
density_species_index=electron_index,
)
phi_t, density_t = saved
else:
_diag = integrate_linear_diagnostics(
G0_jax,
grid,
geom,
params,
dt=dt,
steps=steps,
method=time_cfg_use.method,
cache=cache,
terms=terms,
sample_stride=time_cfg_use.sample_stride,
species_index=electron_index,
)
phi_t = _diag[1]
density_t = _diag[2] if len(_diag) > 2 else None
else:
_, phi_t = integrate_linear_from_config(
G0_jax,
grid,
geom,
params,
time_cfg_use,
cache=cache,
terms=terms,
show_progress=show_progress,
)
density_t = None
stride = time_cfg_use.sample_stride
else:
stride = 1 if sample_stride is None else int(sample_stride)
if fit_key in {"density", "auto"}:
_diag = integrate_linear_diagnostics(
G0_jax,
grid,
geom,
params,
dt=dt,
steps=steps,
method=method,
terms=terms,
sample_stride=stride,
species_index=electron_index,
)
phi_t = _diag[1]
density_t = _diag[2] if len(_diag) > 2 else None
else:
_, phi_t = integrate_linear(
G0_jax,
grid,
geom,
params,
dt=dt,
steps=steps,
method=method,
terms=terms,
sample_stride=stride,
show_progress=show_progress,
)
density_t = None
phi_t_np = np.asarray(phi_t)
t = np.arange(phi_t_np.shape[0]) * dt * stride
density_np = None if density_t is None else np.asarray(density_t)
if gx_growth and fit_key == "phi":
gamma, omega, _gamma_t, _omega_t, _t_mid = gx_growth_rate_from_phi(
phi_t_np,
t,
sel,
navg_fraction=gx_navg_fraction,
mode_method=mode_method,
)
gamma, omega = _normalize_growth_rate(gamma, omega, params, diagnostic_norm)
return LinearRunResult(
t=t,
phi_t=phi_t_np,
gamma=gamma,
omega=omega,
ky=float(grid.ky[sel.ky_index]),
selection=sel,
)
if fit_key == "auto":
signal, _name, gamma, omega = _select_fit_signal_auto(
t,
phi_t_np,
density_np,
sel,
mode_method=mode_method,
tmin=tmin,
tmax=tmax,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
max_amp_fraction=0.9,
window_method="loglinear",
max_fraction=0.8,
end_fraction=0.9,
num_windows=8,
phase_weight=0.2,
length_weight=0.05,
min_r2=0.0,
late_penalty=0.1,
min_slope=None,
min_slope_frac=0.0,
slope_var_weight=0.0,
)
gamma, omega = _normalize_growth_rate(gamma, omega, params, diagnostic_norm)
return LinearRunResult(
t=t,
phi_t=phi_t_np,
gamma=gamma,
omega=omega,
ky=float(grid.ky[sel.ky_index]),
selection=sel,
)
signal = _select_fit_signal(
phi_t_np,
density_np,
sel,
fit_signal=fit_key,
mode_method=mode_method,
)
use_auto = auto_window and tmin is None and tmax is None
if not use_auto and not scan_window_valid(t, tmin, tmax):
use_auto = True
if use_auto:
gamma, omega, _tmin, _tmax = fit_growth_rate_auto(
t,
signal,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
)
else:
try:
gamma, omega = fit_growth_rate(t, signal, tmin=tmin, tmax=tmax)
except ValueError:
gamma, omega, _tmin, _tmax = fit_growth_rate_auto(
t,
signal,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
)
gamma, omega = _normalize_growth_rate(gamma, omega, params, diagnostic_norm)
return LinearRunResult(
t=t,
phi_t=phi_t_np,
gamma=gamma,
omega=omega,
ky=float(grid.ky[sel.ky_index]),
selection=sel,
)
[docs]
def run_etg_scan(
ky_values: np.ndarray,
Nl: int = 6,
Nm: int = 12,
dt: float | np.ndarray = 0.01,
steps: int | np.ndarray = 800,
method: str = "imex2",
params: LinearParams | None = None,
cfg: ETGBaseCase | None = None,
time_cfg: TimeConfig | None = None,
solver: str = "auto",
krylov_cfg: KrylovConfig | None = None,
tmin: float | None = None,
tmax: float | None = None,
auto_window: bool = True,
window_fraction: float = 0.4,
min_points: int = 40,
start_fraction: float = 0.2,
growth_weight: float = 1.0,
require_positive: bool = True,
min_amp_fraction: float = 0.0,
max_fraction: float = 0.8,
end_fraction: float = 1.0,
max_amp_fraction: float = 1.0,
phase_weight: float = 0.2,
length_weight: float = 0.05,
min_r2: float = 0.0,
late_penalty: float = 0.0,
min_slope: float | None = None,
min_slope_frac: float = 0.0,
slope_var_weight: float = 0.0,
window_method: str = "loglinear",
mode_method: str = "project",
mode_only: bool = True,
terms: LinearTerms | None = None,
sample_stride: int | None = None,
fit_signal: str = "auto",
ky_batch: int = 4,
fixed_batch_shape: bool = True,
streaming_fit: bool = True,
streaming_amp_floor: float = 1.0e-30,
gx_growth: bool = False,
gx_navg_fraction: float = 0.5,
diagnostic_norm: str = "none",
show_progress: bool = False,
) -> LinearScanResult:
"""Run an ETG linear benchmark for a list of ky values.
If ``time_cfg`` is provided, its ``dt`` and ``t_max`` override ``dt``/``steps``.
"""
cfg = cfg or ETGBaseCase()
grid_full = build_spectral_grid(cfg.grid)
geom = SAlphaGeometry.from_config(cfg.geometry)
if params is None:
if getattr(cfg.model, "adiabatic_ions", False):
params = _electron_only_params(
cfg.model,
kpar_scale=float(geom.gradpar()),
omega_d_scale=ETG_OMEGA_D_SCALE,
omega_star_scale=ETG_OMEGA_STAR_SCALE,
rho_star=ETG_RHO_STAR,
damp_ends_amp=0.0,
damp_ends_widthfrac=0.0,
nhermite=Nm,
)
else:
params = _two_species_params(
cfg.model,
kpar_scale=float(geom.gradpar()),
omega_d_scale=ETG_OMEGA_D_SCALE,
omega_star_scale=ETG_OMEGA_STAR_SCALE,
rho_star=ETG_RHO_STAR,
damp_ends_amp=0.0,
damp_ends_widthfrac=0.0,
nhermite=Nm,
)
if terms is None:
# Keep the ETG scan helper on the same electrostatic benchmark contract
# as the single-ky ETG wrapper and the tracked ETG figure builders.
terms = LinearTerms(apar=0.0, bpar=0.0, hypercollisions=1.0)
solver_key = normalize_solver_key(solver)
fit_key = normalize_fit_signal(fit_signal)
auto_solver = solver_key == "auto"
if auto_solver:
solver_key = "time"
streaming_fit, mode_only = apply_auto_fit_scan_policy(
fit_key, streaming_fit=streaming_fit, mode_only=mode_only
)
need_density = fit_key in {"density", "auto"}
gammas = []
omegas = []
ky_out = []
mode_method = resolve_scan_mode_method(mode_method, mode_only=mode_only)
use_batch = should_use_ky_batch(
ky_batch=ky_batch,
solver_key=solver_key,
dt=dt,
steps=steps,
tmin=tmin,
tmax=tmax,
)
fit_policy = ScanFitWindowPolicy(
tmin=tmin,
tmax=tmax,
auto_window=auto_window,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
max_fraction=max_fraction,
end_fraction=end_fraction,
max_amp_fraction=max_amp_fraction,
phase_weight=phase_weight,
length_weight=length_weight,
min_r2=min_r2,
late_penalty=late_penalty,
min_slope=min_slope,
min_slope_frac=min_slope_frac,
slope_var_weight=slope_var_weight,
window_method=window_method,
fit_growth_rate_fn=fit_growth_rate,
fit_growth_rate_auto_fn=fit_growth_rate_auto,
normalize_growth_rate_fn=_normalize_growth_rate,
)
def _fit_signal(
signal: np.ndarray, idx: int, dt_i: float, stride: int
) -> tuple[float, float]:
return fit_policy.fit_signal(
signal,
idx=idx,
dt=dt_i,
stride=stride,
params=params,
diagnostic_norm=diagnostic_norm,
)
ky_values_arr = np.asarray(ky_values, dtype=float)
if use_batch:
ky_iter = _iter_ky_batches(
ky_values_arr,
ky_batch=ky_batch,
fixed_batch_shape=fixed_batch_shape,
)
else:
ky_iter = _iter_ky_batches(ky_values_arr, ky_batch=1, fixed_batch_shape=False)
prev_vec: jnp.ndarray | None = None
prev_eig: complex | None = None
ky_slice: np.ndarray
ky_indices: list[int]
sel: ModeSelection | ModeSelectionBatch
for batch_start, ky_slice, valid_count in ky_iter:
if use_batch:
ky_indices = [
select_ky_index(np.asarray(grid_full.ky), float(ky)) for ky in ky_slice
]
grid = select_ky_grid(grid_full, ky_indices)
sel_indices = np.arange(len(ky_indices), dtype=int)
sel = ModeSelectionBatch(sel_indices, 0, _midplane_index(grid))
dt_i = float(dt)
steps_i = int(steps)
else:
ky_indices = [select_ky_index(np.asarray(grid_full.ky), float(ky_slice[0]))]
grid = select_ky_grid(grid_full, ky_indices[0])
sel = ModeSelection(ky_index=0, kx_index=0, z_index=_midplane_index(grid))
dt_i = float(dt[batch_start]) if isinstance(dt, np.ndarray) else float(dt)
steps_i = (
int(steps[batch_start]) if isinstance(steps, np.ndarray) else int(steps)
)
charge = np.atleast_1d(np.asarray(params.charge_sign))
ns = int(charge.size)
electron_index = int(np.argmin(charge))
G0 = np.zeros(
(ns, Nl, Nm, grid.ky.size, grid.kx.size, grid.z.size), dtype=np.complex64
)
G0_single = _build_initial_condition(
grid,
geom,
ky_index=np.arange(len(ky_indices), dtype=int),
kx_index=0,
Nl=Nl,
Nm=Nm,
init_cfg=cfg.init,
)
G0[electron_index] = np.asarray(G0_single, dtype=np.complex64)
cache = build_linear_cache(grid, geom, params, Nl, Nm)
G0_jax = jnp.asarray(G0)
if solver_key == "krylov":
cfg_use = krylov_cfg or ETG_KRYLOV_DEFAULT
use_cont = bool(cfg_use.continuation)
v0_use = G0_jax
v_ref = None
shift_override = cfg_use.shift
shift_selection_use = cfg_use.shift_selection
if use_cont and prev_vec is not None and prev_vec.shape == G0_jax.shape:
v0_use = prev_vec
v_ref = prev_vec
if (
cfg_use.method.strip().lower() == "shift_invert"
and prev_eig is not None
):
if shift_override is None:
shift_override = prev_eig
# When continuation carries an explicit previous eigenvalue
# as the shift, select the closest shifted branch first and
# let overlap tracking keep the mode family coherent.
shift_selection_use = "shift"
select_overlap = (
use_cont
and v_ref is not None
and (cfg_use.continuation_selection.strip().lower() == "overlap")
)
eig, vec = dominant_eigenpair(
v0_use,
cache,
params,
terms=terms,
v_ref=v_ref,
select_overlap=select_overlap,
krylov_dim=cfg_use.krylov_dim,
restarts=cfg_use.restarts,
omega_min_factor=cfg_use.omega_min_factor,
omega_target_factor=cfg_use.omega_target_factor,
omega_cap_factor=cfg_use.omega_cap_factor,
omega_sign=cfg_use.omega_sign,
method=cfg_use.method,
power_iters=cfg_use.power_iters,
power_dt=cfg_use.power_dt,
shift=shift_override,
shift_source=cfg_use.shift_source,
shift_tol=cfg_use.shift_tol,
shift_maxiter=cfg_use.shift_maxiter,
shift_restart=cfg_use.shift_restart,
shift_solve_method=cfg_use.shift_solve_method,
shift_preconditioner=cfg_use.shift_preconditioner,
shift_selection=shift_selection_use,
mode_family=cfg_use.mode_family,
fallback_method=cfg_use.fallback_method,
fallback_real_floor=cfg_use.fallback_real_floor,
)
if use_cont:
eig_host = complex(np.asarray(eig))
if np.isfinite(eig_host.real) and np.isfinite(eig_host.imag):
prev_vec = vec
prev_eig = eig_host
else:
prev_vec = None
prev_eig = None
gamma = float(np.real(eig))
omega = float(-np.imag(eig))
if cfg_use.omega_sign != 0:
omega = float(np.sign(cfg_use.omega_sign)) * abs(omega)
gamma, omega = _normalize_growth_rate(gamma, omega, params, diagnostic_norm)
gammas.append(gamma)
omegas.append(omega)
ky_out.append(float(ky_slice[0]))
continue
time_cfg_i = None
if time_cfg is not None:
time_cfg_i = replace(time_cfg, dt=dt_i, t_max=dt_i * steps_i)
if sample_stride is not None:
time_cfg_i = replace(time_cfg_i, sample_stride=sample_stride)
params_use = params
if time_cfg_i is not None and time_cfg_i.use_diffrax and streaming_fit:
t_total = float(time_cfg_i.t_max)
tmin_i, tmax_i = _resolve_streaming_window(
t_total,
indexed_float_value(tmin, batch_start),
indexed_float_value(tmax, batch_start),
start_fraction,
window_fraction,
1.0,
)
_, gamma_vals, omega_vals = integrate_linear_diffrax_streaming(
G0_jax,
grid,
geom,
params_use,
dt=dt_i,
steps=steps_i,
method=time_cfg_i.diffrax_solver,
cache=cache,
terms=terms,
adaptive=time_cfg_i.diffrax_adaptive,
rtol=time_cfg_i.diffrax_rtol,
atol=time_cfg_i.diffrax_atol,
max_steps=time_cfg_i.diffrax_max_steps,
progress_bar=time_cfg_i.progress_bar,
checkpoint=time_cfg_i.checkpoint,
tmin=tmin_i,
tmax=tmax_i,
fit_signal=fit_key,
mode_ky_indices=np.arange(valid_count, dtype=int),
mode_kx_index=0,
mode_z_index=_midplane_index(grid),
mode_method=mode_method,
amp_floor=streaming_amp_floor,
density_species_index=electron_index if fit_key == "density" else None,
return_state=False,
)
gamma_arr = np.asarray(gamma_vals)
omega_arr = np.asarray(omega_vals)
for local_idx in range(valid_count):
ky_val = ky_slice[local_idx]
gamma_i, omega_i = _normalize_growth_rate(
float(gamma_arr[local_idx]),
float(omega_arr[local_idx]),
params_use,
diagnostic_norm,
)
gammas.append(gamma_i)
omegas.append(omega_i)
ky_out.append(float(ky_val))
continue
if time_cfg_i is not None:
save_field = (
"phi+density"
if fit_key == "auto"
else ("density" if fit_key == "density" else "phi")
)
save_mode = (
None
if fit_key == "auto"
else (sel if (mode_only and fit_key == "phi") else None)
)
_, saved = integrate_linear_from_config(
G0_jax,
grid,
geom,
params_use,
time_cfg_i,
cache=cache,
terms=terms,
save_mode=save_mode,
mode_method=mode_method,
save_field=save_field,
density_species_index=electron_index if need_density else None,
show_progress=show_progress,
)
if fit_key == "auto":
phi_t, density_t = saved
else:
phi_t = saved
density_t = None
stride = time_cfg_i.sample_stride
else:
stride = 1 if sample_stride is None else int(sample_stride)
if need_density:
_diag = integrate_linear_diagnostics(
G0_jax,
grid,
geom,
params_use,
dt=dt_i,
steps=steps_i,
method=method,
cache=cache,
terms=terms,
sample_stride=stride,
species_index=1,
show_progress=show_progress,
)
phi_t = _diag[1]
density_t = _diag[2] if len(_diag) > 2 else None
else:
_, phi_out_time = integrate_linear(
G0_jax,
grid,
geom,
params_use,
dt=dt_i,
steps=steps_i,
method=method,
cache=cache,
terms=terms,
sample_stride=stride,
show_progress=show_progress,
)
phi_t = phi_out_time
density_t = None
phi_t_np = np.asarray(phi_t)
density_np = None if density_t is None else np.asarray(density_t)
if fit_key == "density" and density_np is None:
density_np = phi_t_np
t = np.arange(phi_t_np.shape[0]) * dt_i * stride
def _is_valid_growth(gamma_val: float, omega_val: float) -> bool:
if not np.isfinite(gamma_val) or not np.isfinite(omega_val):
return False
if require_positive and gamma_val <= 0.0:
return False
return True
for local_idx in range(valid_count):
ky_val = ky_slice[local_idx]
if fit_key == "auto":
sel_local = ModeSelection(
ky_index=local_idx, kx_index=0, z_index=_midplane_index(grid)
)
_signal, _name, gamma, omega = _select_fit_signal_auto(
t,
phi_t_np,
density_np,
sel_local,
mode_method=mode_method,
tmin=indexed_float_value(tmin, batch_start + local_idx),
tmax=indexed_float_value(tmax, batch_start + local_idx),
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
max_amp_fraction=max_amp_fraction,
window_method=window_method,
max_fraction=max_fraction,
end_fraction=end_fraction,
num_windows=8,
phase_weight=phase_weight,
length_weight=length_weight,
min_r2=min_r2,
late_penalty=late_penalty,
min_slope=min_slope,
min_slope_frac=min_slope_frac,
slope_var_weight=slope_var_weight,
)
gamma, omega = _normalize_growth_rate(
gamma, omega, params_use, diagnostic_norm
)
if auto_solver and not _is_valid_growth(gamma, omega):
res = run_etg_linear(
ky_target=float(ky_val),
cfg=cfg,
Nl=Nl,
Nm=Nm,
dt=dt_i,
steps=steps_i,
method=method,
params=params,
solver="krylov",
krylov_cfg=krylov_cfg,
diagnostic_norm=diagnostic_norm,
fit_signal="phi",
show_progress=show_progress,
)
gamma = float(res.gamma)
omega = float(res.omega)
gammas.append(gamma)
omegas.append(omega)
ky_out.append(float(ky_val))
continue
if mode_only and fit_key == "phi" and phi_t_np.ndim <= 2:
signal = _extract_mode_only_signal(phi_t_np, local_idx=local_idx)
else:
sel_local = ModeSelection(
ky_index=local_idx, kx_index=0, z_index=_midplane_index(grid)
)
signal = _select_fit_signal(
phi_t_np,
density_np,
sel_local,
fit_signal=fit_key,
mode_method=mode_method,
)
if gx_growth and fit_key == "phi":
sel_local = ModeSelection(
ky_index=local_idx, kx_index=0, z_index=_midplane_index(grid)
)
gamma, omega, _gamma_t, _omega_t, _t_mid = gx_growth_rate_from_phi(
phi_t_np,
t,
sel_local,
navg_fraction=gx_navg_fraction,
mode_method=mode_method,
)
gamma, omega = _normalize_growth_rate(
gamma, omega, params_use, diagnostic_norm
)
else:
gamma, omega = _fit_signal(
signal, batch_start + local_idx, dt_i, stride
)
if auto_solver and not _is_valid_growth(gamma, omega):
res = run_etg_linear(
ky_target=float(ky_val),
cfg=cfg,
Nl=Nl,
Nm=Nm,
dt=dt_i,
steps=steps_i,
method=method,
params=params,
solver="krylov",
krylov_cfg=krylov_cfg,
diagnostic_norm=diagnostic_norm,
fit_signal="phi",
show_progress=show_progress,
)
gamma = float(res.gamma)
omega = float(res.omega)
gammas.append(gamma)
omegas.append(omega)
ky_out.append(float(ky_val))
return LinearScanResult(
ky=np.array(ky_out), gamma=np.array(gammas), omega=np.array(omegas)
)
[docs]
def run_kinetic_linear(
ky_target: float = 0.3,
Nl: int = 6,
Nm: int = 12,
dt: float = 0.01,
steps: int = 800,
method: str = "rk4",
params: LinearParams | None = None,
cfg: KineticElectronBaseCase | None = None,
time_cfg: TimeConfig | None = None,
solver: str = "krylov",
krylov_cfg: KrylovConfig | None = None,
tmin: float | None = None,
tmax: float | None = None,
auto_window: bool = True,
window_fraction: float = 0.4,
min_points: int = 40,
start_fraction: float = 0.2,
growth_weight: float = 1.0,
require_positive: bool = True,
min_amp_fraction: float = 0.0,
mode_method: str = "project",
terms: LinearTerms | None = None,
sample_stride: int | None = None,
fit_signal: str = "density",
init_species_index: int = 1,
density_species_index: int = 1,
diagnostic_norm: str = "none",
gx_reference: bool | None = True,
show_progress: bool = False,
) -> LinearRunResult:
"""Run a kinetic-electron ITG/TEM benchmark and extract growth rate."""
cfg = cfg or KineticElectronBaseCase()
grid_full = build_spectral_grid(cfg.grid)
geom = SAlphaGeometry.from_config(cfg.geometry)
gx_reference_use = bool(gx_reference)
if gx_reference_use and diagnostic_norm == "none":
diagnostic_norm = "gx"
init_cfg_use = _kinetic_reference_init_cfg(cfg.init, gx_reference=gx_reference_use)
damp_ends_amp, damp_ends_widthfrac = _gx_linked_end_damping(gx_reference_use)
if params is None:
params = _two_species_params(
cfg.model,
kpar_scale=float(geom.gradpar()),
omega_d_scale=Kinetic_OMEGA_D_SCALE,
omega_star_scale=Kinetic_OMEGA_STAR_SCALE,
rho_star=Kinetic_RHO_STAR,
damp_ends_amp=damp_ends_amp,
damp_ends_widthfrac=damp_ends_widthfrac,
nhermite=Nm,
)
if gx_reference_use:
params = _apply_gx_hypercollisions(params, nhermite=Nm)
if terms is None:
terms = LinearTerms(bpar=0.0)
ky_index = select_ky_index(np.asarray(grid_full.ky), ky_target)
grid = select_ky_grid(grid_full, ky_index)
sel = ModeSelection(ky_index=0, kx_index=0, z_index=_midplane_index(grid))
ns = 2
if init_species_index < 0 or init_species_index >= ns:
raise ValueError("init_species_index out of range for kinetic species")
if density_species_index < 0 or density_species_index >= ns:
raise ValueError("density_species_index out of range for kinetic species")
G0 = np.zeros(
(ns, Nl, Nm, grid.ky.size, grid.kx.size, grid.z.size), dtype=np.complex64
)
G0_single = _build_initial_condition(
grid,
geom,
ky_index=sel.ky_index,
kx_index=sel.kx_index,
Nl=Nl,
Nm=Nm,
init_cfg=init_cfg_use,
)
G0[int(init_species_index)] = np.asarray(G0_single, dtype=np.complex64)
G0_jax = jnp.asarray(G0)
if solver.lower() == "krylov":
krylov_cfg = krylov_cfg or (
KINETIC_KRYLOV_GX_REFERENCE if gx_reference_use else KINETIC_KRYLOV_DEFAULT
)
cache = build_linear_cache(grid, geom, params, Nl, Nm)
eig, vec = dominant_eigenpair(
G0_jax,
cache,
params,
terms=terms,
krylov_dim=krylov_cfg.krylov_dim,
restarts=krylov_cfg.restarts,
omega_min_factor=krylov_cfg.omega_min_factor,
omega_target_factor=krylov_cfg.omega_target_factor,
omega_cap_factor=krylov_cfg.omega_cap_factor,
omega_sign=krylov_cfg.omega_sign,
method=krylov_cfg.method,
power_iters=krylov_cfg.power_iters,
power_dt=krylov_cfg.power_dt,
shift=krylov_cfg.shift,
shift_source=krylov_cfg.shift_source,
shift_tol=krylov_cfg.shift_tol,
shift_maxiter=krylov_cfg.shift_maxiter,
shift_restart=krylov_cfg.shift_restart,
shift_solve_method=krylov_cfg.shift_solve_method,
shift_preconditioner=krylov_cfg.shift_preconditioner,
shift_selection=krylov_cfg.shift_selection,
mode_family=krylov_cfg.mode_family,
fallback_method=krylov_cfg.fallback_method,
fallback_real_floor=krylov_cfg.fallback_real_floor,
)
term_cfg = linear_terms_to_term_config(terms)
phi = compute_fields_cached(vec, cache, params, terms=term_cfg).phi
phi_t_np = np.asarray(phi)[None, ...]
t = np.array([0.0])
gamma = float(np.real(eig))
omega = float(-np.imag(eig))
gamma, omega = _normalize_growth_rate(gamma, omega, params, diagnostic_norm)
else:
method_key = method.lower()
if time_cfg is not None:
time_cfg_use = time_cfg
if sample_stride is not None:
time_cfg_use = replace(time_cfg, sample_stride=sample_stride)
dt = float(time_cfg_use.dt)
steps = int(round(time_cfg_use.t_max / time_cfg_use.dt))
cache = build_linear_cache(grid, geom, params, Nl, Nm)
if time_cfg_use.use_diffrax and not (
method_key.startswith("imex") or method_key.startswith("implicit")
):
save_field = "density" if fit_signal == "density" else "phi"
_, phi_t = integrate_linear_from_config(
G0_jax,
grid,
geom,
params,
time_cfg_use,
cache=cache,
terms=terms,
save_field=save_field,
density_species_index=density_species_index
if fit_signal == "density"
else None,
)
density_t = phi_t if fit_signal == "density" else None
else:
if fit_signal == "density":
_diag = integrate_linear_diagnostics(
G0_jax,
grid,
geom,
params,
dt=dt,
steps=steps,
method=time_cfg_use.method,
cache=cache,
terms=terms,
sample_stride=time_cfg_use.sample_stride,
species_index=density_species_index,
)
phi_t = _diag[1]
density_t = _diag[2] if len(_diag) > 2 else None
else:
_, phi_t = integrate_linear_from_config(
G0_jax,
grid,
geom,
params,
time_cfg_use,
cache=cache,
terms=terms,
density_species_index=density_species_index
if fit_signal == "density"
else None,
)
density_t = None
stride = time_cfg_use.sample_stride
else:
stride = 1 if sample_stride is None else int(sample_stride)
if fit_signal == "density":
_diag = integrate_linear_diagnostics(
G0_jax,
grid,
geom,
params,
dt=dt,
steps=steps,
method=method,
terms=terms,
sample_stride=stride,
species_index=density_species_index,
)
phi_t = _diag[1]
density_t = _diag[2] if len(_diag) > 2 else None
else:
_, phi_t = integrate_linear(
G0_jax,
grid,
geom,
params,
dt=dt,
steps=steps,
method=method,
terms=terms,
sample_stride=stride,
show_progress=show_progress,
)
density_t = None
phi_t_np = np.asarray(phi_t)
t = np.arange(phi_t_np.shape[0]) * dt * stride
density_np = None if density_t is None else np.asarray(density_t)
signal = _select_fit_signal(
phi_t_np,
density_np,
sel,
fit_signal=fit_signal,
mode_method=mode_method,
)
use_auto = auto_window and tmin is None and tmax is None
if not use_auto and not scan_window_valid(t, tmin, tmax):
use_auto = True
if use_auto:
gamma, omega, _tmin, _tmax = fit_growth_rate_auto(
t,
signal,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
)
else:
try:
gamma, omega = fit_growth_rate(t, signal, tmin=tmin, tmax=tmax)
except ValueError:
gamma, omega, _tmin, _tmax = fit_growth_rate_auto(
t,
signal,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
)
gamma, omega = _normalize_growth_rate(gamma, omega, params, diagnostic_norm)
return LinearRunResult(
t=t,
phi_t=phi_t_np,
gamma=gamma,
omega=omega,
ky=float(grid.ky[sel.ky_index]),
selection=sel,
)
[docs]
def run_kinetic_scan(
ky_values: np.ndarray,
Nl: int = 6,
Nm: int = 12,
dt: float | np.ndarray = 0.01,
steps: int | np.ndarray = 800,
method: str = "imex2",
params: LinearParams | None = None,
cfg: KineticElectronBaseCase | None = None,
time_cfg: TimeConfig | None = None,
solver: str = "auto",
krylov_cfg: KrylovConfig | None = None,
tmin: float | None = None,
tmax: float | None = None,
auto_window: bool = True,
window_fraction: float = 0.4,
min_points: int = 40,
start_fraction: float = 0.2,
growth_weight: float = 1.0,
require_positive: bool = True,
min_amp_fraction: float = 0.0,
mode_method: str = "project",
mode_only: bool = True,
terms: LinearTerms | None = None,
sample_stride: int | None = None,
fit_signal: str = "density",
ky_batch: int = 4,
fixed_batch_shape: bool = True,
streaming_fit: bool = True,
streaming_amp_floor: float = 1.0e-30,
init_species_index: int = 1,
density_species_index: int = 1,
diagnostic_norm: str = "none",
gx_reference: bool | None = True,
show_progress: bool = False,
) -> LinearScanResult:
"""Run a kinetic-electron ITG/TEM benchmark for a list of ky values.
If ``time_cfg`` is provided, its ``dt`` and ``t_max`` override ``dt``/``steps``.
"""
cfg = cfg or KineticElectronBaseCase()
grid_full = build_spectral_grid(cfg.grid)
geom = SAlphaGeometry.from_config(cfg.geometry)
gx_reference_use = bool(gx_reference)
if gx_reference_use and diagnostic_norm == "none":
diagnostic_norm = "gx"
init_cfg_use = _kinetic_reference_init_cfg(cfg.init, gx_reference=gx_reference_use)
damp_ends_amp, damp_ends_widthfrac = _gx_linked_end_damping(gx_reference_use)
if params is None:
params = _two_species_params(
cfg.model,
kpar_scale=float(geom.gradpar()),
omega_d_scale=Kinetic_OMEGA_D_SCALE,
omega_star_scale=Kinetic_OMEGA_STAR_SCALE,
rho_star=Kinetic_RHO_STAR,
damp_ends_amp=damp_ends_amp,
damp_ends_widthfrac=damp_ends_widthfrac,
nhermite=Nm,
)
if gx_reference_use:
params = _apply_gx_hypercollisions(params, nhermite=Nm)
if terms is None:
terms = LinearTerms(bpar=0.0)
solver_key = normalize_solver_key(solver)
fit_key = normalize_fit_signal(fit_signal)
gammas = []
omegas = []
ky_out = []
mode_method = resolve_scan_mode_method(mode_method, mode_only=mode_only)
use_batch = should_use_ky_batch(
ky_batch=ky_batch,
solver_key=solver_key,
dt=dt,
steps=steps,
tmin=tmin,
tmax=tmax,
)
fit_policy = ScanFitWindowPolicy(
tmin=tmin,
tmax=tmax,
auto_window=auto_window,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
fit_growth_rate_fn=fit_growth_rate,
fit_growth_rate_auto_fn=fit_growth_rate_auto,
normalize_growth_rate_fn=_normalize_growth_rate,
)
def _fit_signal(
signal: np.ndarray, idx: int, dt_i: float, stride: int
) -> tuple[float, float]:
return fit_policy.fit_signal(
signal,
idx=idx,
dt=dt_i,
stride=stride,
params=params,
diagnostic_norm=diagnostic_norm,
)
ky_values_arr = np.asarray(ky_values, dtype=float)
if use_batch:
ky_iter = _iter_ky_batches(
ky_values_arr,
ky_batch=ky_batch,
fixed_batch_shape=fixed_batch_shape,
)
else:
ky_iter = _iter_ky_batches(ky_values_arr, ky_batch=1, fixed_batch_shape=False)
ky_slice: np.ndarray
ky_indices: list[int]
sel: ModeSelection | ModeSelectionBatch
if init_species_index < 0 or init_species_index >= 2:
raise ValueError("init_species_index out of range for kinetic species")
if density_species_index < 0 or density_species_index >= 2:
raise ValueError("density_species_index out of range for kinetic species")
for batch_start, ky_slice, valid_count in ky_iter:
if use_batch:
ky_indices = [
select_ky_index(np.asarray(grid_full.ky), float(ky)) for ky in ky_slice
]
grid = select_ky_grid(grid_full, ky_indices)
sel_indices = np.arange(len(ky_indices), dtype=int)
sel = ModeSelectionBatch(sel_indices, 0, _midplane_index(grid))
dt_i = float(dt)
steps_i = int(steps)
else:
ky_indices = [select_ky_index(np.asarray(grid_full.ky), float(ky_slice[0]))]
grid = select_ky_grid(grid_full, ky_indices[0])
sel = ModeSelection(ky_index=0, kx_index=0, z_index=_midplane_index(grid))
dt_i = float(dt[batch_start]) if isinstance(dt, np.ndarray) else float(dt)
steps_i = (
int(steps[batch_start]) if isinstance(steps, np.ndarray) else int(steps)
)
ns = 2
G0 = np.zeros(
(ns, Nl, Nm, grid.ky.size, grid.kx.size, grid.z.size), dtype=np.complex64
)
G0_single = _build_initial_condition(
grid,
geom,
ky_index=np.arange(len(ky_indices), dtype=int),
kx_index=0,
Nl=Nl,
Nm=Nm,
init_cfg=init_cfg_use,
)
G0[int(init_species_index)] = np.asarray(G0_single, dtype=np.complex64)
cache = build_linear_cache(grid, geom, params, Nl, Nm)
G0_jax = jnp.asarray(G0)
if solver_key == "krylov":
cfg_use = krylov_cfg or (
KINETIC_KRYLOV_GX_REFERENCE
if gx_reference_use
else KINETIC_KRYLOV_DEFAULT
)
eig, _vec = dominant_eigenpair(
G0_jax,
cache,
params,
terms=terms,
krylov_dim=cfg_use.krylov_dim,
restarts=cfg_use.restarts,
omega_min_factor=cfg_use.omega_min_factor,
omega_target_factor=cfg_use.omega_target_factor,
omega_cap_factor=cfg_use.omega_cap_factor,
omega_sign=cfg_use.omega_sign,
method=cfg_use.method,
power_iters=cfg_use.power_iters,
power_dt=cfg_use.power_dt,
shift=cfg_use.shift,
shift_source=cfg_use.shift_source,
shift_tol=cfg_use.shift_tol,
shift_maxiter=cfg_use.shift_maxiter,
shift_restart=cfg_use.shift_restart,
shift_solve_method=cfg_use.shift_solve_method,
shift_preconditioner=cfg_use.shift_preconditioner,
shift_selection=cfg_use.shift_selection,
mode_family=cfg_use.mode_family,
fallback_method=cfg_use.fallback_method,
fallback_real_floor=cfg_use.fallback_real_floor,
)
gamma = float(np.real(eig))
omega = float(-np.imag(eig))
gamma, omega = _normalize_growth_rate(gamma, omega, params, diagnostic_norm)
gammas.append(gamma)
omegas.append(omega)
ky_out.append(float(ky_slice[0]))
continue
time_cfg_i = None
if time_cfg is not None:
time_cfg_i = replace(time_cfg, dt=dt_i, t_max=dt_i * steps_i)
if sample_stride is not None:
time_cfg_i = replace(time_cfg_i, sample_stride=sample_stride)
params_use = params
if time_cfg_i is not None and time_cfg_i.use_diffrax and streaming_fit:
t_total = float(time_cfg_i.t_max)
tmin_i, tmax_i = _resolve_streaming_window(
t_total,
indexed_float_value(tmin, batch_start),
indexed_float_value(tmax, batch_start),
start_fraction,
window_fraction,
1.0,
)
_, gamma_vals, omega_vals = integrate_linear_diffrax_streaming(
G0_jax,
grid,
geom,
params_use,
dt=dt_i,
steps=steps_i,
method=time_cfg_i.diffrax_solver,
cache=cache,
terms=terms,
adaptive=time_cfg_i.diffrax_adaptive,
rtol=time_cfg_i.diffrax_rtol,
atol=time_cfg_i.diffrax_atol,
max_steps=time_cfg_i.diffrax_max_steps,
progress_bar=time_cfg_i.progress_bar,
checkpoint=time_cfg_i.checkpoint,
tmin=tmin_i,
tmax=tmax_i,
fit_signal=fit_key,
mode_ky_indices=np.arange(valid_count, dtype=int),
mode_kx_index=0,
mode_z_index=_midplane_index(grid),
mode_method=mode_method,
amp_floor=streaming_amp_floor,
density_species_index=density_species_index
if fit_key == "density"
else None,
return_state=False,
)
gamma_arr = np.asarray(gamma_vals)
omega_arr = np.asarray(omega_vals)
for local_idx in range(valid_count):
ky_val = ky_slice[local_idx]
gamma_i, omega_i = _normalize_growth_rate(
float(gamma_arr[local_idx]),
float(omega_arr[local_idx]),
params_use,
diagnostic_norm,
)
gammas.append(gamma_i)
omegas.append(omega_i)
ky_out.append(float(ky_val))
continue
if time_cfg_i is not None:
save_mode_method = (
mode_method if mode_method in {"z_index", "max"} else "z_index"
)
_, phi_t = integrate_linear_from_config(
G0_jax,
grid,
geom,
params_use,
time_cfg_i,
cache=cache,
terms=terms,
save_mode=sel if (mode_only and fit_key == "phi") else None,
mode_method=save_mode_method,
save_field="density" if fit_key == "density" else "phi",
density_species_index=density_species_index
if fit_key == "density"
else None,
)
stride = time_cfg_i.sample_stride
density_t = None
else:
stride = 1 if sample_stride is None else int(sample_stride)
if fit_key == "density":
_diag = integrate_linear_diagnostics(
G0_jax,
grid,
geom,
params_use,
dt=dt_i,
steps=steps_i,
method=method,
cache=cache,
terms=terms,
sample_stride=stride,
species_index=density_species_index,
)
phi_t = _diag[1]
density_t = _diag[2] if len(_diag) > 2 else None
else:
_, phi_t = integrate_linear(
G0_jax,
grid,
geom,
params_use,
dt=dt_i,
steps=steps_i,
method=method,
cache=cache,
terms=terms,
sample_stride=stride,
show_progress=show_progress,
)
density_t = None
phi_t_np = np.asarray(phi_t)
density_np = None if density_t is None else np.asarray(density_t)
if fit_key == "density" and density_np is None:
density_np = phi_t_np
for local_idx in range(valid_count):
ky_val = ky_slice[local_idx]
if mode_only and fit_key == "phi" and phi_t_np.ndim <= 2:
signal = _extract_mode_only_signal(phi_t_np, local_idx=local_idx)
elif (
mode_only
and fit_key == "density"
and density_np is not None
and density_np.ndim <= 3
):
signal = _extract_mode_only_signal(
density_np,
local_idx=local_idx,
species_index=density_species_index,
)
else:
sel_local = ModeSelection(
ky_index=local_idx, kx_index=0, z_index=_midplane_index(grid)
)
signal = _select_fit_signal(
phi_t_np,
density_np,
sel_local,
fit_signal=fit_key,
mode_method=mode_method,
)
gamma, omega = _fit_signal(signal, batch_start + local_idx, dt_i, stride)
gammas.append(gamma)
omegas.append(omega)
ky_out.append(float(ky_val))
return LinearScanResult(
ky=np.array(ky_out), gamma=np.array(gammas), omega=np.array(omegas)
)
[docs]
def run_tem_linear(
ky_target: float = 0.3,
Nl: int = 6,
Nm: int = 12,
dt: float = 0.01,
steps: int = 800,
method: str = "rk4",
params: LinearParams | None = None,
cfg: TEMBaseCase | None = None,
time_cfg: TimeConfig | None = None,
solver: str = "krylov",
krylov_cfg: KrylovConfig | None = None,
tmin: float | None = None,
tmax: float | None = None,
auto_window: bool = True,
window_fraction: float = 0.4,
min_points: int = 40,
start_fraction: float = 0.2,
growth_weight: float = 1.0,
require_positive: bool = True,
min_amp_fraction: float = 0.0,
mode_method: str = "project",
fit_signal: str = "phi",
terms: LinearTerms | None = None,
sample_stride: int | None = None,
init_species_index: int = 1,
density_species_index: int = 1,
diagnostic_norm: str = "none",
show_progress: bool = False,
) -> LinearRunResult:
"""Run the TEM benchmark and extract growth rate."""
cfg = cfg or TEMBaseCase()
grid_full = build_spectral_grid(cfg.grid)
geom = SAlphaGeometry.from_config(cfg.geometry)
if params is None:
params = _two_species_params(
cfg.model,
kpar_scale=float(geom.gradpar()),
omega_d_scale=TEM_OMEGA_D_SCALE,
omega_star_scale=TEM_OMEGA_STAR_SCALE,
rho_star=TEM_RHO_STAR,
damp_ends_amp=0.0,
damp_ends_widthfrac=0.0,
nhermite=Nm,
)
if terms is None:
terms = LinearTerms(bpar=0.0)
ky_index = select_ky_index(np.asarray(grid_full.ky), ky_target)
grid = select_ky_grid(grid_full, ky_index)
sel = ModeSelection(ky_index=0, kx_index=0, z_index=_midplane_index(grid))
ns = 2
if init_species_index < 0 or init_species_index >= ns:
raise ValueError("init_species_index out of range for kinetic species")
if density_species_index < 0 or density_species_index >= ns:
raise ValueError("density_species_index out of range for kinetic species")
G0 = np.zeros(
(ns, Nl, Nm, grid.ky.size, grid.kx.size, grid.z.size), dtype=np.complex64
)
G0_single = _build_initial_condition(
grid,
geom,
ky_index=sel.ky_index,
kx_index=sel.kx_index,
Nl=Nl,
Nm=Nm,
init_cfg=cfg.init,
)
G0[int(init_species_index)] = np.asarray(G0_single, dtype=np.complex64)
G0_jax = jnp.asarray(G0)
if solver.lower() == "krylov":
krylov_cfg = krylov_cfg or TEM_KRYLOV_DEFAULT
cache = build_linear_cache(grid, geom, params, Nl, Nm)
eig, vec = dominant_eigenpair(
G0_jax,
cache,
params,
terms=terms,
krylov_dim=krylov_cfg.krylov_dim,
restarts=krylov_cfg.restarts,
omega_min_factor=krylov_cfg.omega_min_factor,
omega_target_factor=krylov_cfg.omega_target_factor,
omega_cap_factor=krylov_cfg.omega_cap_factor,
omega_sign=krylov_cfg.omega_sign,
method=krylov_cfg.method,
power_iters=krylov_cfg.power_iters,
power_dt=krylov_cfg.power_dt,
shift=krylov_cfg.shift,
shift_source=krylov_cfg.shift_source,
shift_tol=krylov_cfg.shift_tol,
shift_maxiter=krylov_cfg.shift_maxiter,
shift_restart=krylov_cfg.shift_restart,
shift_solve_method=krylov_cfg.shift_solve_method,
shift_preconditioner=krylov_cfg.shift_preconditioner,
shift_selection=krylov_cfg.shift_selection,
mode_family=krylov_cfg.mode_family,
fallback_method=krylov_cfg.fallback_method,
fallback_real_floor=krylov_cfg.fallback_real_floor,
)
term_cfg = linear_terms_to_term_config(terms)
phi = compute_fields_cached(vec, cache, params, terms=term_cfg).phi
phi_t_np = np.asarray(phi)[None, ...]
t = np.array([0.0])
gamma = float(np.real(eig))
omega = float(-np.imag(eig))
gamma, omega = _normalize_growth_rate(gamma, omega, params, diagnostic_norm)
else:
if fit_signal not in {"phi", "density"}:
raise ValueError("fit_signal must be 'phi' or 'density'")
if time_cfg is not None:
time_cfg_use = time_cfg
if sample_stride is not None:
time_cfg_use = replace(time_cfg, sample_stride=sample_stride)
dt = float(time_cfg_use.dt)
steps = int(round(time_cfg_use.t_max / time_cfg_use.dt))
cache = build_linear_cache(grid, geom, params, Nl, Nm)
if fit_signal == "density":
_diag = integrate_linear_diagnostics(
G0_jax,
grid,
geom,
params,
dt=dt,
steps=steps,
method=method,
cache=cache,
terms=terms,
sample_stride=time_cfg_use.sample_stride,
species_index=density_species_index,
)
if len(_diag) == 4:
_, phi_t, density_t, _ = _diag
else:
_, phi_t, density_t = _diag
else:
_, phi_t = integrate_linear_from_config(
G0_jax,
grid,
geom,
params,
time_cfg_use,
cache=cache,
terms=terms,
show_progress=show_progress,
)
density_t = None
stride = time_cfg_use.sample_stride
else:
stride = 1 if sample_stride is None else int(sample_stride)
cache = build_linear_cache(grid, geom, params, Nl, Nm)
if fit_signal == "density":
_diag = integrate_linear_diagnostics(
G0_jax,
grid,
geom,
params,
dt=dt,
steps=steps,
method=method,
cache=cache,
terms=terms,
sample_stride=stride,
species_index=density_species_index,
)
if len(_diag) == 4:
_, phi_t, density_t, _ = _diag
else:
_, phi_t, density_t = _diag
else:
_, phi_t = integrate_linear(
G0_jax,
grid,
geom,
params,
dt=dt,
steps=steps,
method=method,
cache=cache,
terms=terms,
sample_stride=stride,
show_progress=show_progress,
)
density_t = None
phi_t_np = np.asarray(phi_t)
t = np.arange(phi_t_np.shape[0]) * dt * stride
if fit_signal == "density" and density_t is not None:
density_t_np = np.asarray(density_t)
signal = extract_mode_time_series(density_t_np, sel, method=mode_method)
else:
signal = extract_mode_time_series(phi_t_np, sel, method=mode_method)
if auto_window and tmin is None and tmax is None:
gamma, omega, _tmin, _tmax = fit_growth_rate_auto(
t,
signal,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
)
else:
try:
gamma, omega = fit_growth_rate(t, signal, tmin=tmin, tmax=tmax)
except ValueError:
gamma, omega, _tmin, _tmax = fit_growth_rate_auto(
t,
signal,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
)
gamma, omega = _normalize_growth_rate(gamma, omega, params, diagnostic_norm)
return LinearRunResult(
t=t,
phi_t=phi_t_np,
gamma=gamma,
omega=omega,
ky=float(grid.ky[sel.ky_index]),
selection=sel,
)
[docs]
def run_tem_scan(
ky_values: np.ndarray,
Nl: int = 6,
Nm: int = 12,
dt: float | np.ndarray = 0.01,
steps: int | np.ndarray = 800,
method: str = "imex2",
params: LinearParams | None = None,
cfg: TEMBaseCase | None = None,
time_cfg: TimeConfig | None = None,
solver: str = "auto",
krylov_cfg: KrylovConfig | None = None,
tmin: float | None = None,
tmax: float | None = None,
auto_window: bool = True,
window_fraction: float = 0.4,
min_points: int = 40,
start_fraction: float = 0.2,
growth_weight: float = 1.0,
require_positive: bool = True,
min_amp_fraction: float = 0.0,
mode_method: str = "project",
mode_only: bool = True,
terms: LinearTerms | None = None,
sample_stride: int | None = None,
ky_batch: int = 4,
fixed_batch_shape: bool = True,
streaming_fit: bool = True,
streaming_amp_floor: float = 1.0e-30,
init_species_index: int = 1,
density_species_index: int = 1,
diagnostic_norm: str = "none",
show_progress: bool = False,
) -> LinearScanResult:
"""Run the TEM benchmark for a list of ky values.
If ``time_cfg`` is provided, its ``dt`` and ``t_max`` override ``dt``/``steps``.
"""
cfg = cfg or TEMBaseCase()
grid_full = build_spectral_grid(cfg.grid)
geom = SAlphaGeometry.from_config(cfg.geometry)
if params is None:
params = _two_species_params(
cfg.model,
kpar_scale=float(geom.gradpar()),
omega_d_scale=TEM_OMEGA_D_SCALE,
omega_star_scale=TEM_OMEGA_STAR_SCALE,
rho_star=TEM_RHO_STAR,
damp_ends_amp=0.0,
damp_ends_widthfrac=0.0,
nhermite=Nm,
)
if terms is None:
terms = LinearTerms(bpar=0.0)
solver_key = normalize_solver_key(solver)
gammas = []
omegas = []
ky_out = []
mode_method = resolve_scan_mode_method(mode_method, mode_only=mode_only)
use_batch = should_use_ky_batch(
ky_batch=ky_batch,
solver_key=solver_key,
dt=dt,
steps=steps,
tmin=tmin,
tmax=tmax,
)
fit_policy = ScanFitWindowPolicy(
tmin=tmin,
tmax=tmax,
auto_window=auto_window,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
fit_growth_rate_fn=fit_growth_rate,
fit_growth_rate_auto_fn=fit_growth_rate_auto,
normalize_growth_rate_fn=_normalize_growth_rate,
)
def _fit_signal(
signal: np.ndarray, idx: int, dt_i: float, stride: int
) -> tuple[float, float]:
return fit_policy.fit_signal(
signal,
idx=idx,
dt=dt_i,
stride=stride,
params=params,
diagnostic_norm=diagnostic_norm,
)
ky_values_arr = np.asarray(ky_values, dtype=float)
if use_batch:
ky_iter = _iter_ky_batches(
ky_values_arr,
ky_batch=ky_batch,
fixed_batch_shape=fixed_batch_shape,
)
else:
ky_iter = _iter_ky_batches(ky_values_arr, ky_batch=1, fixed_batch_shape=False)
ky_slice: np.ndarray
ky_indices: list[int]
sel: ModeSelection | ModeSelectionBatch
if init_species_index < 0 or init_species_index >= 2:
raise ValueError("init_species_index out of range for kinetic species")
if density_species_index < 0 or density_species_index >= 2:
raise ValueError("density_species_index out of range for kinetic species")
for batch_start, ky_slice, valid_count in ky_iter:
if use_batch:
ky_indices = [
select_ky_index(np.asarray(grid_full.ky), float(ky)) for ky in ky_slice
]
grid = select_ky_grid(grid_full, ky_indices)
sel_indices = np.arange(len(ky_indices), dtype=int)
sel = ModeSelectionBatch(sel_indices, 0, _midplane_index(grid))
dt_i = float(dt)
steps_i = int(steps)
else:
ky_indices = [select_ky_index(np.asarray(grid_full.ky), float(ky_slice[0]))]
grid = select_ky_grid(grid_full, ky_indices[0])
sel = ModeSelection(ky_index=0, kx_index=0, z_index=_midplane_index(grid))
dt_i = float(dt[batch_start]) if isinstance(dt, np.ndarray) else float(dt)
steps_i = (
int(steps[batch_start]) if isinstance(steps, np.ndarray) else int(steps)
)
ns = 2
G0 = np.zeros(
(ns, Nl, Nm, grid.ky.size, grid.kx.size, grid.z.size), dtype=np.complex64
)
G0_single = _build_initial_condition(
grid,
geom,
ky_index=np.arange(len(ky_indices), dtype=int),
kx_index=0,
Nl=Nl,
Nm=Nm,
init_cfg=cfg.init,
)
G0[int(init_species_index)] = np.asarray(G0_single, dtype=np.complex64)
cache = build_linear_cache(grid, geom, params, Nl, Nm)
G0_jax = jnp.asarray(G0)
if solver_key == "krylov":
cfg_use = krylov_cfg or TEM_KRYLOV_DEFAULT
eig, _vec = dominant_eigenpair(
G0_jax,
cache,
params,
terms=terms,
krylov_dim=cfg_use.krylov_dim,
restarts=cfg_use.restarts,
omega_min_factor=cfg_use.omega_min_factor,
omega_target_factor=cfg_use.omega_target_factor,
omega_cap_factor=cfg_use.omega_cap_factor,
omega_sign=cfg_use.omega_sign,
method=cfg_use.method,
power_iters=cfg_use.power_iters,
power_dt=cfg_use.power_dt,
shift=cfg_use.shift,
shift_source=cfg_use.shift_source,
shift_tol=cfg_use.shift_tol,
shift_maxiter=cfg_use.shift_maxiter,
shift_restart=cfg_use.shift_restart,
shift_solve_method=cfg_use.shift_solve_method,
shift_preconditioner=cfg_use.shift_preconditioner,
shift_selection=cfg_use.shift_selection,
mode_family=cfg_use.mode_family,
fallback_method=cfg_use.fallback_method,
fallback_real_floor=cfg_use.fallback_real_floor,
)
gamma = float(np.real(eig))
omega = float(-np.imag(eig))
gamma, omega = _normalize_growth_rate(gamma, omega, params, diagnostic_norm)
gammas.append(gamma)
omegas.append(omega)
ky_out.append(float(ky_slice[0]))
continue
time_cfg_i = None
if time_cfg is not None:
time_cfg_i = replace(time_cfg, dt=dt_i, t_max=dt_i * steps_i)
if sample_stride is not None:
time_cfg_i = replace(time_cfg_i, sample_stride=sample_stride)
params_use = params
if time_cfg_i is not None and time_cfg_i.use_diffrax and streaming_fit:
t_total = float(time_cfg_i.t_max)
tmin_i, tmax_i = _resolve_streaming_window(
t_total,
indexed_float_value(tmin, batch_start),
indexed_float_value(tmax, batch_start),
start_fraction,
window_fraction,
1.0,
)
_, gamma_vals, omega_vals = integrate_linear_diffrax_streaming(
G0_jax,
grid,
geom,
params_use,
dt=dt_i,
steps=steps_i,
method=time_cfg_i.diffrax_solver,
cache=cache,
terms=terms,
adaptive=time_cfg_i.diffrax_adaptive,
rtol=time_cfg_i.diffrax_rtol,
atol=time_cfg_i.diffrax_atol,
max_steps=time_cfg_i.diffrax_max_steps,
progress_bar=time_cfg_i.progress_bar,
checkpoint=time_cfg_i.checkpoint,
tmin=tmin_i,
tmax=tmax_i,
fit_signal="phi",
show_progress=show_progress,
mode_ky_indices=np.arange(valid_count, dtype=int),
mode_kx_index=0,
mode_z_index=_midplane_index(grid),
mode_method=mode_method,
amp_floor=streaming_amp_floor,
return_state=False,
)
gamma_arr = np.asarray(gamma_vals)
omega_arr = np.asarray(omega_vals)
for local_idx in range(valid_count):
ky_val = ky_slice[local_idx]
gammas.append(float(gamma_arr[local_idx]))
omegas.append(float(omega_arr[local_idx]))
ky_out.append(float(ky_val))
continue
if time_cfg_i is not None:
_, phi_t = integrate_linear_from_config(
G0_jax,
grid,
geom,
params_use,
time_cfg_i,
cache=cache,
terms=terms,
save_mode=sel if mode_only else None,
mode_method=mode_method,
)
stride = time_cfg_i.sample_stride
else:
stride = 1 if sample_stride is None else int(sample_stride)
_, phi_t = integrate_linear(
G0_jax,
grid,
geom,
params_use,
dt=dt_i,
steps=steps_i,
method=method,
cache=cache,
terms=terms,
sample_stride=stride,
)
phi_t_np = np.asarray(phi_t)
for local_idx in range(valid_count):
ky_val = ky_slice[local_idx]
if mode_only and phi_t_np.ndim <= 2:
signal = _extract_mode_only_signal(phi_t_np, local_idx=local_idx)
else:
sel_local = ModeSelection(
ky_index=local_idx, kx_index=0, z_index=_midplane_index(grid)
)
signal = extract_mode_time_series(
phi_t_np, sel_local, method=mode_method
)
gamma, omega = _fit_signal(signal, batch_start + local_idx, dt_i, stride)
gammas.append(gamma)
omegas.append(omega)
ky_out.append(float(ky_val))
return LinearScanResult(
ky=np.array(ky_out), gamma=np.array(gammas), omega=np.array(omegas)
)
[docs]
def run_kbm_beta_scan(
betas: np.ndarray,
ky_target: float = 0.3,
Nl: int = 6,
Nm: int = 12,
dt: float | np.ndarray = 0.01,
steps: int | np.ndarray = 800,
method: str = "imex2",
cfg: KBMBaseCase | None = None,
time_cfg: TimeConfig | None = None,
solver: str = "auto",
krylov_cfg: KrylovConfig | None = None,
kbm_target_factors: Sequence[float] | None = (0.7, 1.5),
kbm_beta_transition: float | None = None,
tmin: float | None = None,
tmax: float | None = None,
auto_window: bool = True,
window_fraction: float = 0.4,
min_points: int = 40,
start_fraction: float = 0.2,
growth_weight: float = 1.0,
require_positive: bool = True,
min_amp_fraction: float = 0.0,
mode_method: str = "project",
mode_only: bool = True,
terms: LinearTerms | None = None,
sample_stride: int | None = None,
fit_signal: str = "auto",
ky_batch: int = 4,
fixed_batch_shape: bool = True,
streaming_fit: bool = True,
streaming_amp_floor: float = 1.0e-30,
init_species_index: int = 1,
density_species_index: int = 1,
diagnostic_norm: str = "none",
fapar_override: float | None = None,
apar_beta_scale: float | None = None,
ampere_g0_scale: float | None = None,
bpar_beta_scale: float | None = None,
gx_reference: bool | None = True,
) -> LinearScanResult:
"""Run a KBM beta scan at fixed ky.
If ``time_cfg`` is provided, its ``dt`` and ``t_max`` override ``dt``/``steps``.
"""
cfg = cfg or KBMBaseCase()
grid_full = build_spectral_grid(cfg.grid)
geom = SAlphaGeometry.from_config(cfg.geometry)
if terms is None:
terms = LinearTerms(bpar=0.0)
gx_reference_use = bool(gx_reference)
if gx_reference_use and diagnostic_norm == "none":
diagnostic_norm = "gx"
damp_ends_amp, damp_ends_widthfrac = _gx_linked_end_damping(gx_reference_use)
solver_key = normalize_solver_key(solver)
fit_key = normalize_fit_signal(fit_signal)
streaming_fit, mode_only = apply_auto_fit_scan_policy(
fit_key, streaming_fit=streaming_fit, mode_only=mode_only
)
krylov_cfg_use = krylov_cfg or KBM_KRYLOV_DEFAULT
use_continuation = bool(getattr(krylov_cfg_use, "continuation", False))
prev_vec = None
prev_eig = None
gammas = []
omegas = []
beta_out = []
ky_index = select_ky_index(np.asarray(grid_full.ky), ky_target)
grid = select_ky_grid(grid_full, ky_index)
sel = ModeSelection(ky_index=0, kx_index=0, z_index=_midplane_index(grid))
fit_policy = ScanFitWindowPolicy(
tmin=tmin,
tmax=tmax,
auto_window=auto_window,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
fit_growth_rate_fn=fit_growth_rate,
fit_growth_rate_auto_fn=fit_growth_rate_auto,
normalize_growth_rate_fn=_normalize_growth_rate,
)
if init_species_index < 0 or init_species_index >= 2:
raise ValueError("init_species_index out of range for kinetic species")
if density_species_index < 0 or density_species_index >= 2:
raise ValueError("density_species_index out of range for kinetic species")
def _is_valid_growth(gamma_val: float, omega_val: float) -> bool:
if not np.isfinite(gamma_val) or not np.isfinite(omega_val):
return False
if require_positive and gamma_val <= 0.0:
return False
return True
for i, beta in enumerate(betas):
dt_i = float(dt[i]) if isinstance(dt, np.ndarray) else float(dt)
steps_i = int(steps[i]) if isinstance(steps, np.ndarray) else int(steps)
params = _two_species_params(
cfg.model,
kpar_scale=float(geom.gradpar()),
omega_d_scale=KBM_OMEGA_D_SCALE,
omega_star_scale=KBM_OMEGA_STAR_SCALE,
rho_star=KBM_RHO_STAR,
beta_override=float(beta),
fapar_override=fapar_override,
apar_beta_scale=apar_beta_scale,
ampere_g0_scale=ampere_g0_scale,
bpar_beta_scale=bpar_beta_scale,
damp_ends_amp=damp_ends_amp,
damp_ends_widthfrac=damp_ends_widthfrac,
nhermite=Nm,
)
cache = build_linear_cache(grid, geom, params, Nl, Nm)
ns = 2
G0 = np.zeros(
(ns, Nl, Nm, grid.ky.size, grid.kx.size, grid.z.size), dtype=np.complex64
)
G0_single = _build_initial_condition(
grid,
geom,
ky_index=sel.ky_index,
kx_index=sel.kx_index,
Nl=Nl,
Nm=Nm,
init_cfg=cfg.init,
)
G0[int(init_species_index)] = np.asarray(G0_single, dtype=np.complex64)
G0_jax = jnp.asarray(G0)
solver_use = select_kbm_solver_auto(
solver_key, ky_target=ky_target, gx_reference=gx_reference_use
)
if solver_use == "gx_time":
gx_mode_method = (
mode_method if mode_method in {"z_index", "max"} else "z_index"
)
gx_time_cfg = ExplicitTimeConfig(
dt=dt_i,
t_max=dt_i * steps_i,
sample_stride=max(int(sample_stride or 1), 1),
fixed_dt=bool(time_cfg.fixed_dt) if time_cfg is not None else False,
use_dealias_mask=bool(getattr(time_cfg, "use_dealias_mask", False))
if time_cfg is not None
else False,
dt_min=float(time_cfg.dt_min) if time_cfg is not None else 1.0e-7,
dt_max=float(time_cfg.dt_max)
if (time_cfg is not None and time_cfg.dt_max is not None)
else None,
cfl=float(time_cfg.cfl) if time_cfg is not None else 0.9,
cfl_fac=(
resolve_cfl_fac(str(time_cfg.method), time_cfg.cfl_fac)
if time_cfg is not None
else float(ExplicitTimeConfig.cfl_fac)
),
)
t_arr, _phi_t, gamma_t, omega_t, _gx_diag = integrate_linear_gx_diagnostics(
G0_jax,
grid,
cache,
params,
geom,
gx_time_cfg,
terms=terms,
mode_method=gx_mode_method,
z_index=sel.z_index,
jit=True,
)
if t_arr.size > 1:
phi_np = np.asarray(_phi_t)
t_np = np.asarray(t_arr, dtype=float)
if mode_method in {"z_index", "max"}:
try:
gamma, omega, _g_t, _o_t, _t_mid = gx_growth_rate_from_phi(
phi_np,
t_np,
sel,
navg_fraction=0.5,
mode_method=mode_method,
)
except ValueError:
try:
gamma, omega, _g_t, _o_t = gx_growth_rate_from_omega_series(
np.asarray(gamma_t),
np.asarray(omega_t),
sel,
navg_fraction=0.5,
)
except ValueError:
signal = extract_mode_time_series(
phi_np, sel, method=mode_method
)
gamma, omega, _tmin, _tmax = fit_growth_rate_auto(
t_np,
signal,
window_method="fixed",
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
)
else:
signal = extract_mode_time_series(phi_np, sel, method=mode_method)
gamma, omega, _tmin, _tmax = fit_growth_rate_auto(
t_np,
signal,
window_method="fixed",
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
)
else:
gamma = float("nan")
omega = float("nan")
gamma, omega = _normalize_growth_rate(gamma, omega, params, diagnostic_norm)
elif solver_use == "krylov":
shift_val = krylov_cfg_use.shift
shift_selection = krylov_cfg_use.shift_selection
if use_continuation and prev_eig is not None:
shift_val = complex(np.asarray(prev_eig))
targets: Sequence[float] | None = (
kbm_target_factors if kbm_target_factors else None
)
use_multi_target = _kbm_use_multi_target_krylov(
krylov_cfg_use,
targets,
shift=shift_val,
)
if use_multi_target:
assert targets is not None
beta_transition = (
float(cfg.model.beta)
if kbm_beta_transition is None
else float(kbm_beta_transition)
)
eig_candidates = []
vec_candidates = []
for target in targets:
eig_i, vec_i = dominant_eigenpair(
G0_jax,
cache,
params,
terms=terms,
v_ref=None,
select_overlap=False,
krylov_dim=krylov_cfg_use.krylov_dim,
restarts=krylov_cfg_use.restarts,
omega_min_factor=krylov_cfg_use.omega_min_factor,
omega_target_factor=float(target),
omega_cap_factor=krylov_cfg_use.omega_cap_factor,
omega_sign=krylov_cfg_use.omega_sign,
method=krylov_cfg_use.method,
power_iters=krylov_cfg_use.power_iters,
power_dt=krylov_cfg_use.power_dt,
shift=None,
shift_source="target",
shift_tol=krylov_cfg_use.shift_tol,
shift_maxiter=krylov_cfg_use.shift_maxiter,
shift_restart=krylov_cfg_use.shift_restart,
shift_solve_method=krylov_cfg_use.shift_solve_method,
shift_preconditioner=krylov_cfg_use.shift_preconditioner,
shift_selection="targeted",
mode_family=krylov_cfg_use.mode_family,
fallback_method=krylov_cfg_use.fallback_method,
fallback_real_floor=krylov_cfg_use.fallback_real_floor,
)
eig_candidates.append(eig_i)
vec_candidates.append(vec_i)
if len(eig_candidates) >= 2 and np.isfinite(beta_transition):
pick_high = float(beta) >= beta_transition
idx = 1 if pick_high else 0
eig = eig_candidates[idx]
_vec = vec_candidates[idx]
else:
eig_arr = np.asarray(
[complex(np.asarray(e)) for e in eig_candidates]
)
growth = np.real(eig_arr)
if np.all(~np.isfinite(growth)):
eig = eig_candidates[0]
_vec = vec_candidates[0]
else:
idx = int(
np.nanargmax(np.where(np.isfinite(growth), growth, -np.inf))
)
eig = eig_candidates[idx]
_vec = vec_candidates[idx]
else:
eig, _vec = dominant_eigenpair(
G0_jax,
cache,
params,
terms=terms,
v_ref=prev_vec,
select_overlap=use_continuation,
krylov_dim=krylov_cfg_use.krylov_dim,
restarts=krylov_cfg_use.restarts,
omega_min_factor=krylov_cfg_use.omega_min_factor,
omega_target_factor=krylov_cfg_use.omega_target_factor,
omega_cap_factor=krylov_cfg_use.omega_cap_factor,
omega_sign=krylov_cfg_use.omega_sign,
method=krylov_cfg_use.method,
power_iters=krylov_cfg_use.power_iters,
power_dt=krylov_cfg_use.power_dt,
shift=shift_val,
shift_source=krylov_cfg_use.shift_source,
shift_tol=krylov_cfg_use.shift_tol,
shift_maxiter=krylov_cfg_use.shift_maxiter,
shift_restart=krylov_cfg_use.shift_restart,
shift_solve_method=krylov_cfg_use.shift_solve_method,
shift_preconditioner=krylov_cfg_use.shift_preconditioner,
shift_selection=shift_selection,
mode_family=krylov_cfg_use.mode_family,
fallback_method=krylov_cfg_use.fallback_method,
fallback_real_floor=krylov_cfg_use.fallback_real_floor,
)
gamma = float(np.real(eig))
omega = float(-np.imag(eig))
if krylov_cfg_use.omega_sign != 0:
omega = float(np.sign(krylov_cfg_use.omega_sign)) * abs(omega)
gamma, omega = _normalize_growth_rate(gamma, omega, params, diagnostic_norm)
if solver_key == "auto" and not _is_valid_growth(gamma, omega):
solver_use = "time"
elif use_continuation:
prev_vec = _vec
prev_eig = eig
if solver_use not in {"krylov", "gx_time"}:
time_cfg_i = None
if time_cfg is not None:
time_cfg_i = replace(time_cfg, dt=dt_i, t_max=dt_i * steps_i)
if sample_stride is not None:
time_cfg_i = replace(time_cfg_i, sample_stride=sample_stride)
params_use = params
if time_cfg_i is not None and time_cfg_i.use_diffrax and streaming_fit:
t_total = float(time_cfg_i.t_max)
tmin_i, tmax_i = _resolve_streaming_window(
t_total,
indexed_float_value(tmin, i),
indexed_float_value(tmax, i),
start_fraction,
window_fraction,
1.0,
)
_, gamma_vals, omega_vals = integrate_linear_diffrax_streaming(
G0_jax,
grid,
geom,
params_use,
dt=dt_i,
steps=steps_i,
method=time_cfg_i.diffrax_solver,
cache=cache,
terms=terms,
adaptive=time_cfg_i.diffrax_adaptive,
rtol=time_cfg_i.diffrax_rtol,
atol=time_cfg_i.diffrax_atol,
max_steps=time_cfg_i.diffrax_max_steps,
progress_bar=time_cfg_i.progress_bar,
checkpoint=time_cfg_i.checkpoint,
tmin=tmin_i,
tmax=tmax_i,
fit_signal=fit_key,
mode_ky_indices=[0],
mode_kx_index=0,
mode_z_index=_midplane_index(grid),
mode_method=mode_method,
amp_floor=streaming_amp_floor,
density_species_index=density_species_index
if fit_key == "density"
else None,
return_state=False,
)
gamma = float(np.asarray(gamma_vals)[0])
omega = float(np.asarray(omega_vals)[0])
gamma, omega = _normalize_growth_rate(
gamma, omega, params_use, diagnostic_norm
)
else:
if time_cfg_i is not None:
stride = time_cfg_i.sample_stride
if time_cfg_i.use_diffrax:
save_mode_method = (
mode_method
if mode_method in {"z_index", "max"}
else "z_index"
)
_, phi_t = integrate_linear_from_config(
G0_jax,
grid,
geom,
params_use,
time_cfg_i,
cache=cache,
terms=terms,
save_mode=sel if mode_only else None,
mode_method=save_mode_method,
save_field="phi+density"
if fit_key == "auto"
else ("density" if fit_key == "density" else "phi"),
density_species_index=density_species_index
if fit_key in {"density", "auto"}
else None,
)
if fit_key == "auto":
phi_t, density_t = phi_t
else:
density_t = None
else:
if fit_key in {"density", "auto"}:
diag_out = integrate_linear_diagnostics(
G0_jax,
grid,
geom,
params_use,
dt=dt_i,
steps=steps_i,
method=method,
cache=cache,
terms=terms,
sample_stride=stride,
species_index=density_species_index,
)
phi_t = diag_out[1]
density_t = diag_out[2] if len(diag_out) > 2 else None
else:
_, phi_t = integrate_linear(
G0_jax,
grid,
geom,
params_use,
dt=dt_i,
steps=steps_i,
method=method,
cache=cache,
terms=terms,
sample_stride=stride,
)
density_t = None
else:
stride = 1 if sample_stride is None else int(sample_stride)
if fit_key in {"density", "auto"}:
diag_out = integrate_linear_diagnostics(
G0_jax,
grid,
geom,
params_use,
dt=dt_i,
steps=steps_i,
method=method,
cache=cache,
terms=terms,
sample_stride=stride,
species_index=density_species_index,
)
phi_t = diag_out[1]
density_t = diag_out[2] if len(diag_out) > 2 else None
else:
_, phi_t = integrate_linear(
G0_jax,
grid,
geom,
params_use,
dt=dt_i,
steps=steps_i,
method=method,
cache=cache,
terms=terms,
sample_stride=stride,
)
density_t = None
phi_t_np = np.asarray(phi_t)
density_np = None if density_t is None else np.asarray(density_t)
if fit_key == "density" and density_np is None:
density_np = phi_t_np
if fit_key == "auto":
signal, _name, gamma, omega = _select_fit_signal_auto(
np.arange(phi_t_np.shape[0]) * dt_i * stride,
phi_t_np,
density_np,
sel,
mode_method=mode_method,
tmin=indexed_float_value(tmin, i),
tmax=indexed_float_value(tmax, i),
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
max_amp_fraction=0.9,
window_method="loglinear",
max_fraction=0.8,
end_fraction=0.9,
num_windows=8,
phase_weight=0.2,
length_weight=0.05,
min_r2=0.0,
late_penalty=0.1,
min_slope=None,
min_slope_frac=0.0,
slope_var_weight=0.0,
)
gamma, omega = _normalize_growth_rate(
gamma, omega, params_use, diagnostic_norm
)
gammas.append(gamma)
omegas.append(omega)
beta_out.append(float(beta))
continue
if (
mode_only
and fit_key == "density"
and density_np is not None
and density_np.ndim <= 3
):
signal = _extract_mode_only_signal(
density_np,
local_idx=0,
species_index=density_species_index,
)
elif mode_only and phi_t_np.ndim <= 2:
signal = _extract_mode_only_signal(phi_t_np, local_idx=0)
else:
signal = _select_fit_signal(
phi_t_np,
density_np,
sel,
fit_signal=fit_key,
mode_method=mode_method,
)
gamma, omega = fit_policy.fit_signal(
signal,
idx=i,
dt=dt_i,
stride=stride,
params=params_use,
diagnostic_norm=diagnostic_norm,
)
gammas.append(gamma)
omegas.append(omega)
beta_out.append(float(beta))
return LinearScanResult(
ky=np.array(beta_out), gamma=np.array(gammas), omega=np.array(omegas)
)
[docs]
def run_kbm_linear(
ky_target: float = 0.3,
*,
beta_value: float | None = None,
Nl: int = 6,
Nm: int = 12,
dt: float = 0.01,
steps: int = 800,
method: str = "imex2",
params: LinearParams | None = None,
cfg: KBMBaseCase | None = None,
time_cfg: TimeConfig | None = None,
solver: str = "auto",
krylov_cfg: KrylovConfig | None = None,
kbm_target_factors: Sequence[float] | None = (0.7, 1.5),
kbm_beta_transition: float | None = None,
tmin: float | None = None,
tmax: float | None = None,
auto_window: bool = True,
window_fraction: float = 0.4,
min_points: int = 40,
start_fraction: float = 0.2,
growth_weight: float = 1.0,
require_positive: bool = True,
min_amp_fraction: float = 0.0,
mode_method: str = "project",
terms: LinearTerms | None = None,
sample_stride: int | None = None,
fit_signal: str = "auto",
streaming_fit: bool = False,
init_species_index: int = 1,
density_species_index: int = 1,
diagnostic_norm: str = "none",
fapar_override: float | None = None,
apar_beta_scale: float | None = None,
ampere_g0_scale: float | None = None,
bpar_beta_scale: float | None = None,
gx_reference: bool | None = True,
show_progress: bool = False,
) -> LinearRunResult:
"""Run a single linear KBM point and return the stored field history."""
cfg_in = cfg or KBMBaseCase()
beta_use = float(cfg_in.model.beta) if beta_value is None else float(beta_value)
cfg_use = replace(cfg_in, model=replace(cfg_in.model, beta=beta_use))
geom = build_flux_tube_geometry(cfg_use.geometry)
grid_full = build_spectral_grid(apply_geometry_grid_defaults(geom, cfg_use.grid))
if terms is None:
terms = LinearTerms(bpar=0.0)
gx_reference_use = bool(gx_reference)
if gx_reference_use and diagnostic_norm == "none":
diagnostic_norm = "gx"
damp_ends_amp, damp_ends_widthfrac = _gx_linked_end_damping(gx_reference_use)
fit_key = fit_signal.strip().lower()
if fit_key not in {"phi", "density", "auto"}:
raise ValueError("fit_signal must be 'phi', 'density', or 'auto'")
if init_species_index < 0 or init_species_index >= 2:
raise ValueError("init_species_index out of range for kinetic species")
if density_species_index < 0 or density_species_index >= 2:
raise ValueError("density_species_index out of range for kinetic species")
if params is None:
params = _two_species_params(
cfg_use.model,
kpar_scale=float(geom.gradpar()),
omega_d_scale=KBM_OMEGA_D_SCALE,
omega_star_scale=KBM_OMEGA_STAR_SCALE,
rho_star=KBM_RHO_STAR,
beta_override=beta_use,
fapar_override=fapar_override,
apar_beta_scale=apar_beta_scale,
ampere_g0_scale=ampere_g0_scale,
bpar_beta_scale=bpar_beta_scale,
damp_ends_amp=damp_ends_amp,
damp_ends_widthfrac=damp_ends_widthfrac,
nhermite=Nm,
)
ky_index = select_ky_index(np.asarray(grid_full.ky), ky_target)
grid = select_ky_grid(grid_full, ky_index)
sel = ModeSelection(ky_index=0, kx_index=0, z_index=_midplane_index(grid))
cache = build_linear_cache(grid, geom, params, Nl, Nm)
G0 = np.zeros(
(2, Nl, Nm, grid.ky.size, grid.kx.size, grid.z.size), dtype=np.complex64
)
G0_single = _build_initial_condition(
grid,
geom,
ky_index=sel.ky_index,
kx_index=sel.kx_index,
Nl=Nl,
Nm=Nm,
init_cfg=cfg_use.init,
)
G0[int(init_species_index)] = np.asarray(G0_single, dtype=np.complex64)
G0_jax = jnp.asarray(G0)
solver_key = select_kbm_solver_auto(
solver,
ky_target=float(ky_target),
gx_reference=gx_reference_use,
)
krylov_cfg_use = krylov_cfg or KBM_KRYLOV_DEFAULT
def _fit_with_window(signal: np.ndarray, t_arr: np.ndarray) -> tuple[float, float]:
use_auto = auto_window and tmin is None and tmax is None
if not use_auto and not scan_window_valid(t_arr, tmin, tmax):
use_auto = True
if use_auto:
gamma_val, omega_val, _tmin, _tmax = fit_growth_rate_auto(
t_arr,
signal,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
)
else:
try:
gamma_val, omega_val = fit_growth_rate(
t_arr, signal, tmin=tmin, tmax=tmax
)
except ValueError:
gamma_val, omega_val, _tmin, _tmax = fit_growth_rate_auto(
t_arr,
signal,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
)
return gamma_val, omega_val
if solver_key == "gx_time":
gx_mode_method = mode_method if mode_method in {"z_index", "max"} else "z_index"
gx_time_cfg = ExplicitTimeConfig(
dt=dt,
t_max=dt * steps,
sample_stride=max(int(sample_stride or 1), 1),
fixed_dt=bool(time_cfg.fixed_dt) if time_cfg is not None else False,
use_dealias_mask=bool(getattr(time_cfg, "use_dealias_mask", False))
if time_cfg is not None
else False,
dt_min=float(time_cfg.dt_min) if time_cfg is not None else 1.0e-7,
dt_max=float(time_cfg.dt_max)
if (time_cfg is not None and time_cfg.dt_max is not None)
else None,
cfl=float(time_cfg.cfl) if time_cfg is not None else 0.9,
cfl_fac=(
resolve_cfl_fac(str(time_cfg.method), time_cfg.cfl_fac)
if time_cfg is not None
else float(ExplicitTimeConfig.cfl_fac)
),
)
t_arr, phi_t, gamma_t, omega_t, _gx_diag = integrate_linear_gx_diagnostics(
G0_jax,
grid,
cache,
params,
geom,
gx_time_cfg,
terms=terms,
mode_method=gx_mode_method,
z_index=sel.z_index,
jit=True,
)
t_out = np.asarray(t_arr, dtype=float)
phi_t_np = np.asarray(phi_t)
if t_out.size > 1:
if mode_method in {"z_index", "max"}:
try:
gamma, omega, _g_t, _o_t, _t_mid = gx_growth_rate_from_phi(
phi_t_np,
t_out,
sel,
navg_fraction=0.5,
mode_method=mode_method,
)
except ValueError:
try:
gamma, omega, _g_t, _o_t = gx_growth_rate_from_omega_series(
np.asarray(gamma_t),
np.asarray(omega_t),
sel,
navg_fraction=0.5,
)
except ValueError:
signal = extract_mode_time_series(
phi_t_np, sel, method=mode_method
)
gamma, omega = _fit_with_window(signal, t_out)
else:
signal = extract_mode_time_series(phi_t_np, sel, method=mode_method)
if auto_window and tmin is None and tmax is None:
gamma, omega, _tmin, _tmax = fit_growth_rate_auto(
t_out,
signal,
window_method="fixed",
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
)
else:
gamma, omega = _fit_with_window(signal, t_out)
else:
gamma = float("nan")
omega = float("nan")
gamma, omega = _normalize_growth_rate(gamma, omega, params, diagnostic_norm)
return LinearRunResult(
t=t_out,
phi_t=phi_t_np,
gamma=gamma,
omega=omega,
ky=float(ky_target),
selection=sel,
gamma_t=np.asarray(gamma_t),
omega_t=np.asarray(omega_t),
)
if solver_key == "krylov":
shift_val = krylov_cfg_use.shift
targets: Sequence[float] | None = (
kbm_target_factors if kbm_target_factors else None
)
use_multi_target = _kbm_use_multi_target_krylov(
krylov_cfg_use,
targets,
shift=shift_val,
)
if use_multi_target:
assert targets is not None
beta_transition = (
float(cfg_use.model.beta)
if kbm_beta_transition is None
else float(kbm_beta_transition)
)
eig_candidates = []
vec_candidates = []
for target in targets:
eig_i, vec_i = dominant_eigenpair(
G0_jax,
cache,
params,
terms=terms,
v_ref=None,
select_overlap=False,
krylov_dim=krylov_cfg_use.krylov_dim,
restarts=krylov_cfg_use.restarts,
omega_min_factor=krylov_cfg_use.omega_min_factor,
omega_target_factor=float(target),
omega_cap_factor=krylov_cfg_use.omega_cap_factor,
omega_sign=krylov_cfg_use.omega_sign,
method=krylov_cfg_use.method,
power_iters=krylov_cfg_use.power_iters,
power_dt=krylov_cfg_use.power_dt,
shift=None,
shift_source="target",
shift_tol=krylov_cfg_use.shift_tol,
shift_maxiter=krylov_cfg_use.shift_maxiter,
shift_restart=krylov_cfg_use.shift_restart,
shift_solve_method=krylov_cfg_use.shift_solve_method,
shift_preconditioner=krylov_cfg_use.shift_preconditioner,
shift_selection="targeted",
mode_family=krylov_cfg_use.mode_family,
fallback_method=krylov_cfg_use.fallback_method,
fallback_real_floor=krylov_cfg_use.fallback_real_floor,
)
eig_candidates.append(eig_i)
vec_candidates.append(vec_i)
if len(eig_candidates) >= 2 and np.isfinite(beta_transition):
idx = 1 if beta_use >= beta_transition else 0
else:
eig_arr = np.asarray([complex(np.asarray(e)) for e in eig_candidates])
growth = np.real(eig_arr)
idx = (
0
if np.all(~np.isfinite(growth))
else int(
np.nanargmax(np.where(np.isfinite(growth), growth, -np.inf))
)
)
eig = eig_candidates[idx]
vec = vec_candidates[idx]
else:
eig, vec = dominant_eigenpair(
G0_jax,
cache,
params,
terms=terms,
v_ref=None,
select_overlap=False,
krylov_dim=krylov_cfg_use.krylov_dim,
restarts=krylov_cfg_use.restarts,
omega_min_factor=krylov_cfg_use.omega_min_factor,
omega_target_factor=krylov_cfg_use.omega_target_factor,
omega_cap_factor=krylov_cfg_use.omega_cap_factor,
omega_sign=krylov_cfg_use.omega_sign,
method=krylov_cfg_use.method,
power_iters=krylov_cfg_use.power_iters,
power_dt=krylov_cfg_use.power_dt,
shift=shift_val,
shift_source=krylov_cfg_use.shift_source,
shift_tol=krylov_cfg_use.shift_tol,
shift_maxiter=krylov_cfg_use.shift_maxiter,
shift_restart=krylov_cfg_use.shift_restart,
shift_solve_method=krylov_cfg_use.shift_solve_method,
shift_preconditioner=krylov_cfg_use.shift_preconditioner,
shift_selection=krylov_cfg_use.shift_selection,
mode_family=krylov_cfg_use.mode_family,
fallback_method=krylov_cfg_use.fallback_method,
fallback_real_floor=krylov_cfg_use.fallback_real_floor,
)
gamma = float(np.real(eig))
omega = float(-np.imag(eig))
if krylov_cfg_use.omega_sign != 0:
omega = float(np.sign(krylov_cfg_use.omega_sign)) * abs(omega)
gamma, omega = _normalize_growth_rate(gamma, omega, params, diagnostic_norm)
term_cfg = linear_terms_to_term_config(terms)
phi = compute_fields_cached(vec, cache, params, terms=term_cfg).phi
return LinearRunResult(
t=np.array([0.0], dtype=float),
phi_t=np.asarray(phi)[None, ...],
gamma=gamma,
omega=omega,
ky=float(ky_target),
selection=sel,
)
stride = 1 if sample_stride is None else int(sample_stride)
time_cfg_use = time_cfg
if time_cfg_use is not None:
time_cfg_use = replace(time_cfg_use, dt=dt, t_max=dt * steps)
if sample_stride is not None:
time_cfg_use = replace(time_cfg_use, sample_stride=stride)
params_use = params
if time_cfg_use is not None:
stride = int(time_cfg_use.sample_stride)
if time_cfg_use.use_diffrax:
save_field = "phi+density" if fit_key in {"density", "auto"} else "phi"
_, phi_out = integrate_linear_from_config(
G0_jax,
grid,
geom,
params_use,
time_cfg_use,
cache=cache,
terms=terms,
save_mode=sel if fit_key == "phi" else None,
mode_method=mode_method,
save_field=save_field,
density_species_index=density_species_index
if fit_key in {"density", "auto"}
else None,
)
if fit_key in {"density", "auto"}:
phi_t_np, density_np = (np.asarray(phi_out[0]), np.asarray(phi_out[1]))
else:
phi_t_np = np.asarray(phi_out)
density_np = None
else:
if fit_key in {"density", "auto"}:
diag_out = integrate_linear_diagnostics(
G0_jax,
grid,
geom,
params_use,
dt=dt,
steps=steps,
method=method,
cache=cache,
terms=terms,
sample_stride=stride,
species_index=density_species_index,
)
phi_t_np = np.asarray(diag_out[1])
density_np = None if len(diag_out) <= 2 else np.asarray(diag_out[2])
else:
_, phi_out_time = integrate_linear(
G0_jax,
grid,
geom,
params_use,
dt=dt,
steps=steps,
method=method,
cache=cache,
terms=terms,
sample_stride=stride,
show_progress=show_progress,
)
phi_t_np = np.asarray(phi_out_time)
density_np = None
else:
if fit_key in {"density", "auto"}:
diag_out = integrate_linear_diagnostics(
G0_jax,
grid,
geom,
params_use,
dt=dt,
steps=steps,
method=method,
cache=cache,
terms=terms,
sample_stride=stride,
species_index=density_species_index,
)
phi_t_np = np.asarray(diag_out[1])
density_np = None if len(diag_out) <= 2 else np.asarray(diag_out[2])
else:
_, phi_out_time = integrate_linear(
G0_jax,
grid,
geom,
params_use,
dt=dt,
steps=steps,
method=method,
cache=cache,
terms=terms,
sample_stride=stride,
)
phi_t_np = np.asarray(phi_out_time)
density_np = None
if fit_key == "density" and density_np is None:
density_np = phi_t_np
if fit_key == "auto":
signal, _name, gamma, omega = _select_fit_signal_auto(
np.arange(phi_t_np.shape[0]) * dt * stride,
phi_t_np,
density_np,
sel,
mode_method=mode_method,
tmin=tmin,
tmax=tmax,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
max_amp_fraction=0.9,
window_method="loglinear",
max_fraction=0.8,
end_fraction=0.9,
num_windows=8,
phase_weight=0.2,
length_weight=0.05,
min_r2=0.0,
late_penalty=0.1,
min_slope=None,
min_slope_frac=0.0,
slope_var_weight=0.0,
)
_ = signal
else:
signal = _select_fit_signal(
phi_t_np,
density_np,
sel,
fit_signal=fit_key,
mode_method=mode_method,
)
t_out = np.arange(signal.shape[0]) * dt * stride
gamma, omega = _fit_with_window(signal, t_out)
gamma, omega = _normalize_growth_rate(gamma, omega, params_use, diagnostic_norm)
return LinearRunResult(
t=np.arange(phi_t_np.shape[0]) * dt * stride,
phi_t=phi_t_np,
gamma=gamma,
omega=omega,
ky=float(ky_target),
selection=sel,
)
[docs]
def run_kbm_scan(
ky_values: np.ndarray,
*,
beta_value: float | None = None,
Nl: int = 6,
Nm: int = 12,
dt: float | np.ndarray = 0.01,
steps: int | np.ndarray = 800,
method: str = "imex2",
cfg: KBMBaseCase | None = None,
time_cfg: TimeConfig | None = None,
solver: str = "auto",
krylov_cfg: KrylovConfig | None = None,
kbm_target_factors: Sequence[float] | None = (0.7, 1.5),
kbm_beta_transition: float | None = None,
tmin: float | np.ndarray | None = None,
tmax: float | np.ndarray | None = None,
auto_window: bool = True,
window_fraction: float = 0.4,
min_points: int = 40,
start_fraction: float = 0.2,
growth_weight: float = 1.0,
require_positive: bool = True,
min_amp_fraction: float = 0.0,
mode_method: str = "project",
mode_only: bool = True,
terms: LinearTerms | None = None,
sample_stride: int | None = None,
fit_signal: str = "auto",
ky_batch: int = 4,
fixed_batch_shape: bool = True,
streaming_fit: bool = True,
streaming_amp_floor: float = 1.0e-30,
init_species_index: int = 1,
density_species_index: int = 1,
diagnostic_norm: str = "none",
fapar_override: float | None = None,
apar_beta_scale: float | None = None,
ampere_g0_scale: float | None = None,
bpar_beta_scale: float | None = None,
gx_reference: bool | None = True,
) -> LinearScanResult:
"""Run a KBM ky scan at fixed beta.
This is a thin wrapper over :func:`run_kbm_beta_scan` used for
GX-reference workflows where the GX benchmark is a ky scan at fixed beta.
"""
cfg_in = cfg or KBMBaseCase()
if beta_value is None:
beta_use = float(cfg_in.model.beta)
else:
beta_use = float(beta_value)
cfg_use = replace(cfg_in, model=replace(cfg_in.model, beta=beta_use))
ky_vals = np.asarray(ky_values, dtype=float)
gamma_out: list[float] = []
omega_out: list[float] = []
ky_out: list[float] = []
for i, ky_val in enumerate(ky_vals):
dt_i = indexed_scan_value(dt, i)
steps_i = indexed_scan_value(steps, i)
if dt_i is None:
dt_i = dt
if steps_i is None:
steps_i = steps
out = run_kbm_beta_scan(
betas=np.asarray([beta_use], dtype=float),
ky_target=float(ky_val),
Nl=Nl,
Nm=Nm,
dt=float(dt_i),
steps=int(steps_i),
method=method,
cfg=cfg_use,
time_cfg=time_cfg,
solver=solver,
krylov_cfg=krylov_cfg,
kbm_target_factors=kbm_target_factors,
kbm_beta_transition=kbm_beta_transition,
tmin=indexed_scan_value(tmin, i),
tmax=indexed_scan_value(tmax, i),
auto_window=auto_window,
window_fraction=window_fraction,
min_points=min_points,
start_fraction=start_fraction,
growth_weight=growth_weight,
require_positive=require_positive,
min_amp_fraction=min_amp_fraction,
mode_method=mode_method,
mode_only=mode_only,
terms=terms,
sample_stride=sample_stride,
fit_signal=fit_signal,
ky_batch=ky_batch,
fixed_batch_shape=fixed_batch_shape,
streaming_fit=streaming_fit,
streaming_amp_floor=streaming_amp_floor,
init_species_index=init_species_index,
density_species_index=density_species_index,
diagnostic_norm=diagnostic_norm,
fapar_override=fapar_override,
apar_beta_scale=apar_beta_scale,
ampere_g0_scale=ampere_g0_scale,
bpar_beta_scale=bpar_beta_scale,
gx_reference=gx_reference,
)
ky_out.append(float(ky_val))
gamma_out.append(float(out.gamma[0]))
omega_out.append(float(out.omega[0]))
return LinearScanResult(
ky=np.asarray(ky_out, dtype=float),
gamma=np.asarray(gamma_out, dtype=float),
omega=np.asarray(omega_out, dtype=float),
)