Source code for spectraxgk.benchmark_scan

"""Shared scan policies for benchmark runners.

The public benchmark API stays in :mod:`spectraxgk.benchmarks`; this module
keeps small, deterministic scan decisions separate from solver orchestration.
"""

from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass
from typing import Any

import numpy as np

from spectraxgk.analysis import fit_growth_rate, fit_growth_rate_auto
from spectraxgk.benchmark_helpers import _is_array_like, _normalize_growth_rate
from spectraxgk.linear import LinearParams


VALID_FIT_SIGNALS = frozenset({"phi", "density", "auto"})


[docs] def normalize_solver_key(solver: str) -> str: """Normalize a benchmark solver selector without changing its semantics.""" return solver.strip().lower()
[docs] def normalize_fit_signal(fit_signal: str) -> str: """Normalize and validate benchmark fit-signal selectors.""" fit_key = fit_signal.strip().lower() if fit_key not in VALID_FIT_SIGNALS: raise ValueError("fit_signal must be 'phi', 'density', or 'auto'") return fit_key
[docs] def apply_auto_fit_scan_policy( fit_key: str, *, streaming_fit: bool, mode_only: bool ) -> tuple[bool, bool]: """Disable streaming and mode-only saves when auto signal selection needs both fields.""" if fit_key == "auto": return False, False return streaming_fit, mode_only
[docs] def resolve_scan_mode_method(mode_method: str, *, mode_only: bool) -> str: """Use direct mode extraction when a runner saved only a mode time series.""" if mode_only and mode_method not in {"z_index", "max"}: return "z_index" return mode_method
[docs] def indexed_float_value(value: Any, idx: int) -> float | None: """Return a scalar or indexed scan value as ``float`` for window policies.""" if value is None: return None if isinstance(value, (list, tuple, np.ndarray)): return float(value[idx]) return float(value)
[docs] def indexed_scan_value(value: Any, idx: int) -> Any: """Return a scalar or indexed scan value while preserving non-float types.""" if value is None: return None if isinstance(value, np.ndarray): return value[idx].item() if isinstance(value, (list, tuple)): return value[idx] return value
[docs] def scan_window_valid( t: np.ndarray, tmin: float | None, tmax: float | None, *, min_points: int = 2 ) -> bool: """Return whether an explicit fit window contains enough sampled points.""" if tmin is None or tmax is None: return False mask = (t >= tmin) & (t <= tmax) return int(np.count_nonzero(mask)) >= int(min_points)
[docs] def should_use_ky_batch( *, ky_batch: int, solver_key: str, dt: Any, steps: Any, tmin: Any, tmax: Any, ) -> bool: """Return whether a ky scan can use a fixed-shape batch path.""" if ky_batch < 1: raise ValueError("ky_batch must be >= 1") return ( ky_batch > 1 and solver_key != "krylov" and not _is_array_like(dt) and not _is_array_like(steps) and not _is_array_like(tmin) and not _is_array_like(tmax) )
[docs] @dataclass(frozen=True) class ScanFitWindowPolicy: """Window-selection and normalization policy shared by benchmark scans.""" tmin: Any = None tmax: Any = None auto_window: bool = True window_fraction: float = 0.3 min_points: int = 20 start_fraction: float = 0.0 growth_weight: float = 0.0 require_positive: bool = False min_amp_fraction: float = 0.0 max_fraction: float = 0.8 end_fraction: float = 0.9 max_amp_fraction: float = 0.9 phase_weight: float = 0.2 length_weight: float = 0.05 min_r2: float = 0.0 late_penalty: float = 0.1 min_slope: float | None = None min_slope_frac: float = 0.0 slope_var_weight: float = 0.0 window_method: str = "loglinear" fit_growth_rate_fn: Callable[..., tuple[float, float]] = fit_growth_rate fit_growth_rate_auto_fn: Callable[..., tuple[float, float, float, float]] = ( fit_growth_rate_auto ) normalize_growth_rate_fn: Callable[ [float, float, LinearParams, str], tuple[float, float] ] = _normalize_growth_rate def window_at(self, idx: int) -> tuple[float | None, float | None]: return indexed_float_value(self.tmin, idx), indexed_float_value(self.tmax, idx) def use_auto_window(self, t: np.ndarray, idx: int) -> tuple[bool, float | None, float | None]: tmin_i, tmax_i = self.window_at(idx) use_auto = self.auto_window and tmin_i is None and tmax_i is None if not use_auto and not scan_window_valid(t, tmin_i, tmax_i): use_auto = True return use_auto, tmin_i, tmax_i def auto_kwargs(self) -> dict[str, Any]: return { "window_fraction": self.window_fraction, "min_points": self.min_points, "start_fraction": self.start_fraction, "growth_weight": self.growth_weight, "require_positive": self.require_positive, "min_amp_fraction": self.min_amp_fraction, "max_fraction": self.max_fraction, "end_fraction": self.end_fraction, "max_amp_fraction": self.max_amp_fraction, "phase_weight": self.phase_weight, "length_weight": self.length_weight, "min_r2": self.min_r2, "late_penalty": self.late_penalty, "min_slope": self.min_slope, "min_slope_frac": self.min_slope_frac, "slope_var_weight": self.slope_var_weight, "window_method": self.window_method, }
[docs] def fit_signal( self, signal: np.ndarray, *, idx: int, dt: float, stride: int, params: LinearParams, diagnostic_norm: str, ) -> tuple[float, float]: """Fit one scan signal and apply the configured diagnostic normalization.""" t = np.arange(signal.shape[0]) * float(dt) * int(stride) use_auto, tmin_i, tmax_i = self.use_auto_window(t, idx) if use_auto: gamma, omega, _tmin, _tmax = self.fit_growth_rate_auto_fn( t, signal, **self.auto_kwargs(), ) else: try: gamma, omega = self.fit_growth_rate_fn( t, signal, tmin=tmin_i, tmax=tmax_i ) except ValueError: gamma, omega, _tmin, _tmax = self.fit_growth_rate_auto_fn( t, signal, **self.auto_kwargs(), ) return self.normalize_growth_rate_fn(gamma, omega, params, diagnostic_norm)