Source code for spectraxgk.runtime_orchestration

"""Runtime orchestration helpers split from public runtime entry points.

This module owns coordination policy that is not itself a solver kernel:
progress/ETA formatting, combined-ky scan batching, and nonlinear artifact
restart/checkpoint handoff.  Callers pass dependency tables so legacy
``spectraxgk.runtime`` and ``spectraxgk.runtime_artifacts`` monkeypatch seams
remain effective.
"""

from __future__ import annotations

from dataclasses import dataclass, replace
from pathlib import Path
from typing import Any, Callable, Protocol

import numpy as np

from spectraxgk.analysis import ModeSelection
from spectraxgk.diagnostics import SimulationDiagnostics
from spectraxgk.runtime_config import RuntimeConfig
from spectraxgk.runtime_results import RuntimeLinearScanResult, RuntimeNonlinearResult

__all__ = [
    "NonlinearArtifactPolicy",
    "RuntimeArtifactHandoffDeps",
    "RuntimeProgressSnapshot",
    "RuntimeScanBatchDeps",
    "build_runtime_progress_message",
    "format_duration",
    "resolve_nonlinear_artifact_policy",
    "run_runtime_scan_batch",
    "run_runtime_nonlinear_artifact_handoff",
]


[docs] class RuntimeScanBatchDeps(Protocol): """Dependency surface needed by the combined-ky scan batch helper.""" build_runtime_geometry: Callable[[RuntimeConfig], Any] build_runtime_linear_params: Callable[..., Any] build_runtime_linear_terms: Callable[[RuntimeConfig], Any] build_initial_condition: Callable[..., Any] apply_geometry_grid_defaults: Callable[..., Any] build_spectral_grid: Callable[[Any], Any] select_ky_index: Callable[[Any, float], int] midplane_index: Callable[[Any], int] integrate_linear_diagnostics: Callable[..., Any] extract_mode_time_series: Callable[..., Any] fit_growth_rate_auto_with_stats: Callable[..., Any] fit_growth_rate_auto: Callable[..., Any] fit_growth_rate: Callable[..., Any] apply_diagnostic_normalization: Callable[..., tuple[float, float]]
[docs] @dataclass(frozen=True) class RuntimeProgressSnapshot: """Computed wall-clock progress fields for a chunked runtime update.""" progress: float eta_seconds: float chunk_wall_seconds: float elapsed_seconds: float
[docs] @dataclass(frozen=True) class NonlinearArtifactPolicy: """Resolved nonlinear artifact/restart policy for a single handoff.""" out_path: Path | None gx_target: bool diagnostics_on: bool restart_from: Path | None restart_to: Path | None resume_requested: bool remaining_steps: int | None checkpoint_steps: int | None
[docs] @dataclass(frozen=True) class RuntimeArtifactHandoffDeps: """Patchable functions used by nonlinear artifact handoff orchestration.""" is_gx_netcdf_target: Callable[[Path], bool] resolve_restart_path: Callable[[str | Path, Any], Path] resolve_restart_write_path: Callable[[str | Path, Any], Path] gx_bundle_base: Callable[[Path], Path] load_runtime_nonlinear_gx_diagnostics: Callable[[str | Path], SimulationDiagnostics] condense_gx_diagnostics_for_output: Callable[ [SimulationDiagnostics], SimulationDiagnostics ] concat_gx_diagnostics: Callable[ [list[SimulationDiagnostics]], SimulationDiagnostics ] validate_finite_runtime_result: Callable[[Any], None] run_runtime_nonlinear: Callable[..., RuntimeNonlinearResult] write_runtime_nonlinear_artifacts: Callable[[str | Path, Any, Any], dict[str, str]]
[docs] def format_duration(seconds: float) -> str: """Format elapsed seconds as ``MM:SS`` or ``H:MM:SS``.""" seconds_i = max(int(round(seconds)), 0) minutes, secs = divmod(seconds_i, 60) hours, minutes = divmod(minutes, 60) if hours > 0: return f"{hours:d}:{minutes:02d}:{secs:02d}" return f"{minutes:02d}:{secs:02d}"
[docs] def build_runtime_progress_message( *, label: str, chunk_index: int, t_elapsed: float, t_max: float, chunk_wall_seconds: float, elapsed_seconds: float, ) -> tuple[str, RuntimeProgressSnapshot]: """Return the standard adaptive-runtime progress line and policy snapshot.""" progress = ( min(max(float(t_elapsed) / float(t_max), 0.0), 1.0) if float(t_max) > 0.0 else 1.0 ) eta = ( float(elapsed_seconds) * (1.0 / progress - 1.0) if progress > 1.0e-12 else float("inf") ) eta_text = format_duration(eta) if np.isfinite(eta) else "--:--" snapshot = RuntimeProgressSnapshot( progress=float(progress), eta_seconds=float(eta), chunk_wall_seconds=max(float(chunk_wall_seconds), 0.0), elapsed_seconds=max(float(elapsed_seconds), 0.0), ) message = ( f"completed {label} chunk {int(chunk_index)}: " f"t={float(t_elapsed):.6g}/{float(t_max):.6g} " f"progress={100.0 * snapshot.progress:5.1f}% " f"chunk_wall={format_duration(snapshot.chunk_wall_seconds)} " f"elapsed={format_duration(snapshot.elapsed_seconds)} " f"eta={eta_text}" ) return message, snapshot
[docs] def run_runtime_scan_batch( cfg: RuntimeConfig, ky_arr: np.ndarray, *, Nl: int, Nm: int, method: str | None, dt: float | None, steps: int | None, sample_stride: int | None, auto_window: bool, tmin: float | None, tmax: float | None, window_fraction: float, min_points: int, start_fraction: float, growth_weight: float, require_positive: bool, min_amp_fraction: float, mode_method: str, fit_signal: str, show_progress: bool, deps: RuntimeScanBatchDeps, ) -> RuntimeLinearScanResult: """Batch a ky scan using one time integration over the full grid.""" geom = deps.build_runtime_geometry(cfg) grid_cfg = deps.apply_geometry_grid_defaults(geom, cfg.grid) grid = deps.build_spectral_grid(grid_cfg) params = deps.build_runtime_linear_params(cfg, Nm=Nm, geom=geom) terms = deps.build_runtime_linear_terms(cfg) ky_indices = np.asarray( [deps.select_ky_index(np.asarray(grid.ky), ky) for ky in ky_arr], dtype=int ) nspecies = max(len([s for s in cfg.species if s.kinetic]), 1) g0 = None for ky_idx in ky_indices: g0_local = deps.build_initial_condition( grid, geom, cfg, ky_index=int(ky_idx), kx_index=0, Nl=Nl, Nm=Nm, nspecies=nspecies, ) g0 = g0_local if g0 is None else g0 + g0_local if g0 is None: raise ValueError("No ky values provided for batch scan") tcfg = cfg.time if method is not None: tcfg = replace(tcfg, method=str(method)) if dt is not None: tcfg = replace(tcfg, dt=float(dt)) if steps is not None: tcfg = replace(tcfg, t_max=float(steps) * float(tcfg.dt)) if sample_stride is not None: tcfg = replace(tcfg, sample_stride=int(sample_stride)) steps_val = int(round(tcfg.t_max / tcfg.dt)) diag = deps.integrate_linear_diagnostics( g0, grid, geom, params, dt=tcfg.dt, steps=steps_val, method=tcfg.method, terms=terms, sample_stride=tcfg.sample_stride, species_index=0, record_hl_energy=False, show_progress=show_progress, ) phi_t_np = np.asarray(diag[1]) dens_t_np = np.asarray(diag[2]) t_arr = ( float(tcfg.dt) * float(tcfg.sample_stride) * (np.arange(phi_t_np.shape[0], dtype=float) + 1.0) ) gamma = np.zeros_like(ky_arr, dtype=float) omega = np.zeros_like(ky_arr, dtype=float) fit_key = fit_signal.strip().lower() if fit_key not in {"phi", "density", "auto"}: raise ValueError("fit_signal must be 'phi', 'density', or 'auto'") for i, ky_idx in enumerate(ky_indices): sel = ModeSelection( ky_index=int(ky_idx), kx_index=0, z_index=deps.midplane_index(grid) ) if fit_key == "auto": phi_signal = deps.extract_mode_time_series( phi_t_np, sel, method=mode_method ) gamma_phi, omega_phi, _, _, r2_phi, r2p_phi = ( deps.fit_growth_rate_auto_with_stats( t_arr, phi_signal, window_fraction=window_fraction, min_points=min_points, start_fraction=start_fraction, growth_weight=growth_weight, require_positive=require_positive, min_amp_fraction=min_amp_fraction, ) ) dens_signal = deps.extract_mode_time_series( dens_t_np, sel, method=mode_method ) gamma_den, omega_den, _, _, r2_den, r2p_den = ( deps.fit_growth_rate_auto_with_stats( t_arr, dens_signal, window_fraction=window_fraction, min_points=min_points, start_fraction=start_fraction, growth_weight=growth_weight, require_positive=require_positive, min_amp_fraction=min_amp_fraction, ) ) score_phi = r2_phi + 0.2 * r2p_phi + growth_weight * gamma_phi score_den = r2_den + 0.2 * r2p_den + growth_weight * gamma_den g_val, o_val = ( (gamma_phi, omega_phi) if score_phi >= score_den else (gamma_den, omega_den) ) else: signal = deps.extract_mode_time_series( dens_t_np if fit_key == "density" else phi_t_np, sel, method=mode_method, ) if auto_window: g_val, o_val, _tmin, _tmax = deps.fit_growth_rate_auto( t_arr, signal, window_fraction=window_fraction, min_points=min_points, start_fraction=start_fraction, growth_weight=growth_weight, require_positive=require_positive, min_amp_fraction=min_amp_fraction, ) else: g_val, o_val = deps.fit_growth_rate(t_arr, signal, tmin=tmin, tmax=tmax) g_val, o_val = deps.apply_diagnostic_normalization( g_val, o_val, rho_star=float(np.asarray(params.rho_star)), diagnostic_norm=cfg.normalization.diagnostic_norm, ) gamma[i] = float(g_val) omega[i] = float(o_val) return RuntimeLinearScanResult(ky=ky_arr, gamma=gamma, omega=omega)
[docs] def resolve_nonlinear_artifact_policy( cfg: Any, *, out: str | Path | None, diagnostics: bool | None, steps: int | None, dt: float | None, deps: RuntimeArtifactHandoffDeps, ) -> NonlinearArtifactPolicy: """Resolve nonlinear output, restart, and checkpoint policy.""" out_path = None if out is None else Path(out) gx_target = out_path is not None and deps.is_gx_netcdf_target(out_path) diagnostics_on = bool(cfg.time.diagnostics if diagnostics is None else diagnostics) restart_from = None restart_to = None if gx_target: assert out_path is not None restart_from = deps.resolve_restart_path(out_path, cfg) restart_to = deps.resolve_restart_write_path(out_path, cfg) resume_requested = ( bool(getattr(cfg.output, "restart", False)) or cfg.init.init_file is not None ) if steps is not None: remaining_steps: int | None = int(steps) elif bool(cfg.time.fixed_dt): remaining_steps = int( round(float(cfg.time.t_max) / float(cfg.time.dt if dt is None else dt)) ) else: remaining_steps = None checkpoint_steps: int | None = None if ( gx_target and remaining_steps is not None and bool(getattr(cfg.output, "save_for_restart", True)) ): if ( getattr(cfg.time, "nstep_restart", None) is not None and int(cfg.time.nstep_restart) > 0 ): checkpoint_steps = int(cfg.time.nstep_restart) elif int(getattr(cfg.output, "nsave", 0)) > 0: checkpoint_steps = int(cfg.output.nsave) return NonlinearArtifactPolicy( out_path=out_path, gx_target=gx_target, diagnostics_on=diagnostics_on, restart_from=restart_from, restart_to=restart_to, resume_requested=resume_requested, remaining_steps=remaining_steps, checkpoint_steps=checkpoint_steps, )
def _restart_init_mode(cfg: Any) -> str: return ( "add" if bool(getattr(cfg.output, "restart_with_perturb", False)) else "replace" ) def _apply_restart_input(cfg_run: Any, cfg: Any, restart_from: Path) -> Any: return replace( cfg_run, init=replace( cfg_run.init, init_file=str(restart_from), init_file_scale=float(getattr(cfg.output, "restart_scale", 1.0)), init_file_mode=_restart_init_mode(cfg), ), )
[docs] def run_runtime_nonlinear_artifact_handoff( cfg: Any, *, out: str | Path | None, ky_target: float, kx_target: float | None = None, Nl: int | None = None, Nm: int | None = None, dt: float | None = None, steps: int | None = None, method: str | None = None, sample_stride: int | None = None, diagnostics_stride: int | None = None, laguerre_mode: str | None = None, diagnostics: bool | None = None, show_progress: bool = False, status_callback: Any = None, deps: RuntimeArtifactHandoffDeps, ) -> tuple[RuntimeNonlinearResult, dict[str, str]]: """Run nonlinear runtime chunks and hand results to artifact writers.""" policy = resolve_nonlinear_artifact_policy( cfg, out=out, diagnostics=diagnostics, steps=steps, dt=dt, deps=deps, ) if policy.gx_target and not policy.diagnostics_on: raise ValueError( "GX-style nonlinear NetCDF artifacts require diagnostics output" ) cfg_run = cfg resume_requested = policy.resume_requested if policy.gx_target and cfg.init.init_file is None: if ( bool(getattr(cfg.output, "restart_if_exists", False)) and policy.restart_from is not None and policy.restart_from.exists() ): resume_requested = True cfg_run = _apply_restart_input(cfg_run, cfg, policy.restart_from) elif ( bool(getattr(cfg.output, "restart", False)) and policy.restart_from is not None ): if not policy.restart_from.exists(): raise FileNotFoundError( f"restart file not found: {policy.restart_from}" ) cfg_run = _apply_restart_input(cfg_run, cfg, policy.restart_from) elif cfg.init.init_file is not None and bool( getattr(cfg.output, "restart_with_perturb", False) ): cfg_run = replace( cfg_run, init=replace( cfg_run.init, init_file_scale=float(getattr(cfg.output, "restart_scale", 1.0)), init_file_mode="add", ), ) cumulative_diag: SimulationDiagnostics | None = None history_from_file = False if ( policy.gx_target and resume_requested and bool(getattr(cfg.output, "append_on_restart", True)) ): assert policy.out_path is not None history_path = Path(f"{deps.gx_bundle_base(policy.out_path)}.out.nc") if history_path.exists(): cumulative_diag = deps.load_runtime_nonlinear_gx_diagnostics(history_path) history_from_file = True remaining_steps = policy.remaining_steps checkpoint_steps = policy.checkpoint_steps time_offset = 0.0 if cumulative_diag is not None and np.asarray(cumulative_diag.t).size: time_offset = float(np.asarray(cumulative_diag.t)[-1]) result_final: RuntimeNonlinearResult | None = None paths: dict[str, str] = {} while True: chunk_steps = remaining_steps if checkpoint_steps is not None: chunk_steps = ( checkpoint_steps if remaining_steps is None else min(int(remaining_steps), checkpoint_steps) ) result_chunk = deps.run_runtime_nonlinear( cfg_run, ky_target=ky_target, kx_target=kx_target, Nl=Nl, Nm=Nm, dt=dt, steps=chunk_steps, method=method, sample_stride=sample_stride, diagnostics_stride=diagnostics_stride, laguerre_mode=laguerre_mode, diagnostics=diagnostics, return_state=policy.gx_target, show_progress=show_progress, status_callback=status_callback, ) deps.validate_finite_runtime_result(result_chunk) result_effective = result_chunk if result_chunk.diagnostics is not None: diag_chunk = result_chunk.diagnostics if history_from_file: diag_chunk = deps.condense_gx_diagnostics_for_output(diag_chunk) if time_offset != 0.0: diag_chunk = replace( diag_chunk, t=np.asarray(diag_chunk.t) + time_offset ) cumulative_diag = ( diag_chunk if cumulative_diag is None else deps.concat_gx_diagnostics([cumulative_diag, diag_chunk]) ) time_offset = ( float(np.asarray(cumulative_diag.t)[-1]) if np.asarray(cumulative_diag.t).size else time_offset ) result_effective = replace( result_chunk, diagnostics=cumulative_diag, t=np.asarray(cumulative_diag.t), ) result_final = result_effective if policy.out_path is not None: paths = deps.write_runtime_nonlinear_artifacts( policy.out_path, result_effective, cfg ) if checkpoint_steps is None: break if remaining_steps is not None: assert chunk_steps is not None remaining_steps -= int(chunk_steps) if remaining_steps <= 0: break elif ( result_effective.diagnostics is None or time_offset >= float(cfg.time.t_max) - 1.0e-12 ): break if policy.restart_to is None: break cfg_run = replace( cfg, init=replace( cfg.init, init_file=str(policy.restart_to), init_file_scale=1.0, init_file_mode="replace", ), ) if result_final is None: raise RuntimeError("nonlinear runtime produced no result") return result_final, paths