"""Pure runtime policy helpers shared by runtime runners and tests."""
from __future__ import annotations
from dataclasses import asdict, dataclass
from typing import Any
import numpy as np
from spectraxgk.analysis import select_ky_index
from spectraxgk.grids import SpectralGrid
from spectraxgk.runtime_config import RuntimeConfig
__all__ = [
"RuntimeIndependentParallelPlan",
"_infer_runtime_nonlinear_steps",
"_midplane_index",
"_normalize_linear_solver_name",
"_parallel_requests_combined_ky_scan",
"_runtime_external_phi",
"_runtime_independent_parallel_plan",
"_select_nonlinear_mode_indices",
"_zero_kx_index",
]
[docs]
@dataclass(frozen=True)
class RuntimeIndependentParallelPlan:
"""Resolved independent-worker policy for runtime scan workloads."""
requested_workers: int
effective_workers: int
executor: str
strategy: str
axis: str
source: str
problem_size: int
@property
def enabled(self) -> bool:
"""Whether the resolved plan uses more than one independent worker."""
return self.effective_workers > 1
[docs]
def to_dict(self) -> dict[str, Any]:
"""Return a JSON-friendly policy payload for runtime artifacts."""
payload = asdict(self)
payload["enabled"] = self.enabled
return payload
def _normalize_linear_solver_name(solver: str) -> str:
solver_key = solver.strip().lower()
if solver_key == "explicit_time":
return "gx_time"
return solver_key
[docs]
def _parallel_requests_combined_ky_scan(cfg: RuntimeConfig) -> bool:
"""Return whether runtime parallel config requests the combined-ky scan path."""
parallel = getattr(cfg, "parallel", None)
if parallel is None:
return False
return str(getattr(parallel, "strategy", "serial")).lower() == "combined_ky" and str(
getattr(parallel, "axis", "ky")
).lower() == "ky"
def _normalize_independent_executor(backend: str, fallback: str) -> str:
backend_key = str(backend).strip().lower().replace("-", "_")
fallback_key = str(fallback).strip().lower().replace("-", "_")
aliases = {
"thread": "thread",
"threads": "thread",
"process": "process",
"processes": "process",
}
if backend_key in {"", "auto"}:
try:
return aliases[fallback_key]
except KeyError as exc:
raise ValueError("parallel_executor must be 'thread' or 'process'") from exc
try:
return aliases[backend_key]
except KeyError as exc:
raise ValueError(
"runtime [parallel] backend for independent scans must be "
"'auto', 'thread', or 'process'"
) from exc
[docs]
def _runtime_independent_parallel_plan(
cfg: RuntimeConfig,
*,
problem_size: int,
workers: int,
executor: str,
) -> RuntimeIndependentParallelPlan:
"""Resolve independent ``k_y`` worker policy from arguments and config."""
size = int(problem_size)
if size < 0:
raise ValueError("problem_size must be non-negative")
requested = int(workers)
if requested < 1:
raise ValueError("workers must be >= 1")
executor_key = _normalize_independent_executor("auto", executor)
source = "arguments"
strategy = "serial"
axis = "ky"
parallel = getattr(cfg, "parallel", None)
if parallel is not None:
strategy = str(getattr(parallel, "strategy", "serial")).strip().lower()
axis = str(getattr(parallel, "axis", "ky")).strip().lower()
if requested == 1 and strategy == "batch":
if axis != "ky":
raise ValueError(
"runtime [parallel] strategy='batch' is supported only for axis='ky'"
)
configured_workers = getattr(parallel, "num_devices", None)
if configured_workers is None:
configured_workers = getattr(parallel, "batch_size", None)
requested = max(int(configured_workers or 1), 1)
executor_key = _normalize_independent_executor(
str(getattr(parallel, "backend", "auto")), executor
)
source = "runtime_config"
effective = 0 if size == 0 else min(requested, size)
return RuntimeIndependentParallelPlan(
requested_workers=requested,
effective_workers=effective,
executor=executor_key,
strategy=strategy,
axis=axis,
source=source,
problem_size=size,
)
def _midplane_index(grid: SpectralGrid) -> int:
if grid.z.size <= 1:
return 0
return min(int(grid.z.size // 2 + 1), int(grid.z.size) - 1)
def _zero_kx_index(grid: SpectralGrid) -> int:
kx = np.asarray(grid.kx, dtype=float)
return int(np.argmin(np.abs(kx)))
def _nearest_index_from_candidates(
values: np.ndarray,
target: float,
candidates: np.ndarray,
) -> int:
"""Return the candidate index nearest to ``target`` in physical coordinates."""
values_arr = np.asarray(values, dtype=float)
candidate_arr = np.asarray(candidates, dtype=int)
if values_arr.size == 0:
raise ValueError("values must be non-empty")
if candidate_arr.size == 0:
raise ValueError("candidate indices must be non-empty")
return int(
candidate_arr[
int(np.argmin(np.abs(values_arr[candidate_arr] - float(target))))
]
)
def _validate_dealias_mask_shape(
mask: Any,
*,
ky_size: int,
kx_size: int,
) -> np.ndarray:
"""Return a boolean dealias mask after validating it matches ky/kx axes."""
mask_arr = np.asarray(mask, dtype=bool)
expected = (int(ky_size), int(kx_size))
if mask_arr.shape != expected:
raise ValueError(
"dealias_mask shape must match (ky, kx) grid sizes; "
f"got {mask_arr.shape}, expected {expected}"
)
return mask_arr
def _active_ky_indices(mask: np.ndarray, ky_size: int) -> np.ndarray:
"""Return ky rows with at least one retained kx, falling back to all ky."""
candidates = np.where(np.any(mask, axis=1))[0]
if candidates.size == 0:
return np.arange(int(ky_size), dtype=int)
return candidates
def _active_kx_indices(mask: np.ndarray, ky_index: int, kx_size: int) -> np.ndarray:
"""Return retained kx entries for ``ky_index``, falling back to all kx."""
candidates = np.where(mask[int(ky_index)])[0]
if candidates.size == 0:
return np.arange(int(kx_size), dtype=int)
return candidates
def _select_nonlinear_mode_indices(
grid: SpectralGrid,
*,
ky_target: float,
kx_target: float | None,
use_dealias_mask: bool,
) -> tuple[int, int]:
ky = np.asarray(grid.ky, dtype=float)
kx = np.asarray(grid.kx, dtype=float)
kx_pick_target = 0.0 if kx_target is None else float(kx_target)
if not use_dealias_mask:
ky_pick = select_ky_index(ky, ky_target)
kx_pick = _nearest_index_from_candidates(
kx, kx_pick_target, np.arange(kx.size, dtype=int)
)
return ky_pick, kx_pick
mask = _validate_dealias_mask_shape(
grid.dealias_mask,
ky_size=ky.size,
kx_size=kx.size,
)
ky_pick = _nearest_index_from_candidates(
ky, ky_target, _active_ky_indices(mask, ky.size)
)
kx_pick = _nearest_index_from_candidates(
kx, kx_pick_target, _active_kx_indices(mask, ky_pick, kx.size)
)
return int(ky_pick), int(kx_pick)
[docs]
def _infer_runtime_nonlinear_steps(
cfg: RuntimeConfig,
*,
dt: float,
steps: int | None,
) -> int:
"""Infer nonlinear explicit step counts with the same dt ceiling as the integrator."""
if steps is not None:
steps_val = int(steps)
elif bool(cfg.time.fixed_dt):
steps_val = int(np.round(float(cfg.time.t_max) / max(float(cfg.time.dt), 1.0e-12)))
else:
# Keep runtime inference aligned with GX-style adaptive stepping: when
# dt_max is unset, the nonlinear integrator clamps at dt itself.
dt_cap = float(cfg.time.dt_max) if cfg.time.dt_max is not None else float(dt)
steps_val = int(np.ceil(float(cfg.time.t_max) / max(dt_cap, 1.0e-12)))
if steps_val < 1:
raise ValueError("steps must be >= 1")
return steps_val
[docs]
def _runtime_external_phi(cfg: RuntimeConfig) -> float | None:
"""Return a GX-style runtime external-phi source if requested."""
source = str(cfg.expert.source).strip().lower()
if source in {"", "default"}:
return None
if source != "phiext_full":
raise ValueError(
f"unsupported expert.source={cfg.expert.source!r}; expected 'default' or 'phiext_full'"
)
return float(cfg.expert.phi_ext)