Source code for spectraxgk.plotting

"""Plotting utilities for publication-ready figures."""

from __future__ import annotations

from dataclasses import dataclass
import json
from pathlib import Path
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np

from spectraxgk.benchmarks import CycloneReference, CycloneScanResult
from spectraxgk.analysis import fit_growth_rate


[docs] def set_plot_style() -> None: """Apply a consistent plotting style suitable for publications.""" plt.rcParams.update( { "font.size": 12, "axes.labelsize": 12, "axes.titlesize": 12, "legend.fontsize": 10, "xtick.labelsize": 11, "ytick.labelsize": 11, "axes.grid": True, "grid.alpha": 0.3, "grid.linestyle": "--", "figure.dpi": 120, } )
def _normalize_by_real_max(eigenfunction: np.ndarray) -> np.ndarray: eigen = np.asarray(eigenfunction, dtype=np.complex128) real_scale = float(np.max(np.abs(np.real(eigen)))) if eigen.size else 0.0 if real_scale <= 0.0: abs_scale = float(np.max(np.abs(eigen))) if eigen.size else 0.0 if abs_scale > 0.0: return eigen / abs_scale return eigen return eigen / real_scale
[docs] def linear_runtime_panel_figure( *, t: np.ndarray, signal: np.ndarray, z: np.ndarray, eigenfunction: np.ndarray, gamma: float, omega: float, title: str = "SPECTRAX-GK Linear Runtime", ) -> Tuple[plt.Figure, np.ndarray]: """Create the default two-panel linear runtime plot.""" set_plot_style() fig, axes = plt.subplots(1, 2, figsize=(10.5, 4.1)) ax0, ax1 = axes signal_arr = np.asarray(signal, dtype=np.complex128) amp2 = np.maximum(np.abs(signal_arr) ** 2, 1.0e-30) ax0.plot(np.asarray(t, dtype=float), amp2, color="#0f4c81", linewidth=2.4) ax0.set_yscale("log") ax0.set_xlabel("t") ax0.set_ylabel(r"$|\phi|^2$") ax0.set_title("Linear growth history") ax0.text( 0.04, 0.96, rf"$\gamma={gamma:.5f}$" + "\n" + rf"$\omega={omega:.5f}$", transform=ax0.transAxes, va="top", ha="left", bbox={"boxstyle": "round,pad=0.3", "facecolor": "white", "alpha": 0.9, "edgecolor": "#cccccc"}, ) eigen_norm = _normalize_by_real_max(eigenfunction) ax1.plot(np.asarray(z, dtype=float), np.real(eigen_norm), color="#0f4c81", linewidth=2.4, label="Re") ax1.plot( np.asarray(z, dtype=float), np.imag(eigen_norm), color="#c44e52", linewidth=2.2, linestyle="--", label="Im", ) ax1.set_xlabel(r"$\theta$") ax1.set_ylabel(r"$\phi / \max |\Re(\phi)|$") ax1.set_title("Eigenfunction") ax1.legend(loc="best", frameon=False) fig.suptitle(title, y=1.02) fig.tight_layout() return fig, axes
[docs] def nonlinear_runtime_panel_figure( *, t: np.ndarray, phi2: np.ndarray | None = None, wphi: np.ndarray | None = None, heat_flux: np.ndarray | None = None, gamma: np.ndarray | None = None, omega: np.ndarray | None = None, title: str = "SPECTRAX-GK Nonlinear Runtime", ) -> Tuple[plt.Figure, np.ndarray]: """Create the default three-panel nonlinear runtime plot.""" set_plot_style() fig, axes = plt.subplots(1, 3, figsize=(14.0, 4.0)) t_arr = np.asarray(t, dtype=float) ax0, ax1, ax2 = axes if phi2 is not None: ax0.plot(t_arr, np.maximum(np.asarray(phi2, dtype=float), 1.0e-30), color="#0f4c81", linewidth=2.4) ax0.set_yscale("log") ax0.set_ylabel(r"$|\phi|^2$") ax0.set_title("Field amplitude") elif wphi is not None: ax0.plot(t_arr, np.asarray(wphi, dtype=float), color="#0f4c81", linewidth=2.4) ax0.set_ylabel(r"$W_\phi$") ax0.set_title("Electrostatic energy") if wphi is not None: ax1.plot(t_arr, np.asarray(wphi, dtype=float), color="#2a9d8f", linewidth=2.4, label=r"$W_\phi$") if gamma is not None: ax1.plot(t_arr, np.asarray(gamma, dtype=float), color="#f4a261", linewidth=2.0, linestyle="--", label=r"$\gamma$") if omega is not None: ax1.plot(t_arr, np.asarray(omega, dtype=float), color="#c44e52", linewidth=2.0, linestyle=":", label=r"$\omega$") ax1.set_xlabel("t") ax1.set_title("Resolved diagnostics") if wphi is not None or gamma is not None or omega is not None: ax1.legend(loc="best", frameon=False) if heat_flux is not None: ax2.plot(t_arr, np.asarray(heat_flux, dtype=float), color="#c44e52", linewidth=2.4) ax2.set_xlabel("t") ax2.set_ylabel("Heat flux") ax2.set_title("Transport") ax0.set_xlabel("t") for axis in axes: axis.grid(True, alpha=0.25) fig.suptitle(title, y=1.02) fig.tight_layout() return fig, axes
def _artifact_base(path: Path) -> Path: name = path.name for suffix in (".summary.json", ".timeseries.csv", ".eigenfunction.csv", ".diagnostics.csv", ".out.nc"): if name.lower().endswith(suffix): return path.with_name(name[: -len(suffix)]) if path.suffix.lower() in {".json", ".csv", ".nc"}: return path.with_suffix("") return path def _load_linear_bundle(base: Path) -> tuple[dict, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: summary = json.loads(base.with_suffix(".summary.json").read_text(encoding="utf-8")) timeseries = np.genfromtxt(base.with_suffix(".timeseries.csv"), delimiter=",", names=True, dtype=float) eigen = np.genfromtxt(base.with_suffix(".eigenfunction.csv"), delimiter=",", names=True, dtype=float) t = np.asarray(timeseries["t"], dtype=float) signal = np.asarray(timeseries["signal_real"], dtype=float) + 1j * np.asarray(timeseries["signal_imag"], dtype=float) z = np.asarray(eigen["z"], dtype=float) eig = np.asarray(eigen["eigen_real"], dtype=float) + 1j * np.asarray(eigen["eigen_imag"], dtype=float) return summary, t, signal, z, eig def _load_nonlinear_csv(base: Path) -> tuple[dict, np.ndarray, np.ndarray | None, np.ndarray | None, np.ndarray | None, np.ndarray | None]: summary = json.loads(base.with_suffix(".summary.json").read_text(encoding="utf-8")) diag = np.genfromtxt(base.with_suffix(".diagnostics.csv"), delimiter=",", names=True, dtype=float) names = set(diag.dtype.names or ()) t = np.asarray(diag["t"], dtype=float) wphi = np.asarray(diag["Wphi"], dtype=float) if "Wphi" in names else None heat_flux = np.asarray(diag["heat_flux"], dtype=float) if "heat_flux" in names else None gamma = np.asarray(diag["gamma"], dtype=float) if "gamma" in names else None omega = np.asarray(diag["omega"], dtype=float) if "omega" in names else None return summary, t, wphi, heat_flux, gamma, omega def _load_nonlinear_netcdf(path: Path) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None, np.ndarray | None]: try: import netCDF4 except ModuleNotFoundError as exc: # pragma: no cover - optional runtime dependency raise SystemExit("netCDF4 is required to plot *.out.nc runtime bundles") from exc with netCDF4.Dataset(path) as root: diag = root.groups["Diagnostics"] t = np.asarray(diag.variables["t"][:], dtype=float) phi2 = np.asarray(diag.variables["Phi2_t"][:], dtype=float) if "Phi2_t" in diag.variables else None wphi = None heat_flux = None if "Wphi_st" in diag.variables: wphi = np.sum(np.asarray(diag.variables["Wphi_st"][:], dtype=float), axis=1) if "HeatFlux_st" in diag.variables: heat_flux = np.sum(np.asarray(diag.variables["HeatFlux_st"][:], dtype=float), axis=1) return t, phi2, wphi, heat_flux
[docs] def plot_saved_output(path: str | Path, *, out: str | Path | None = None) -> Path: """Plot a saved linear or nonlinear runtime artifact bundle.""" in_path = Path(path) base = _artifact_base(in_path) out_path = Path(out) if out is not None else Path(f"{base}.plot.png") out_path.parent.mkdir(parents=True, exist_ok=True) if in_path.suffix.lower() == ".nc" or in_path.name.lower().endswith(".out.nc"): t, phi2, wphi, heat_flux = _load_nonlinear_netcdf(in_path) fig, _axes = nonlinear_runtime_panel_figure( t=t, phi2=phi2, wphi=wphi, heat_flux=heat_flux, title=f"SPECTRAX-GK nonlinear runtime: {base.name}", ) else: summary_path = base.with_suffix(".summary.json") if not summary_path.exists(): raise FileNotFoundError(f"Could not infer runtime summary from {in_path}") summary = json.loads(summary_path.read_text(encoding="utf-8")) kind = summary.get("kind") if kind == "linear": _summary, t, signal, z, eig = _load_linear_bundle(base) fig, _axes = linear_runtime_panel_figure( t=t, signal=signal, z=z, eigenfunction=eig, gamma=float(summary["gamma"]), omega=float(summary["omega"]), title=f"SPECTRAX-GK linear runtime: {base.name}", ) elif kind == "nonlinear": _summary, t, wphi, heat_flux, gamma, omega = _load_nonlinear_csv(base) fig, _axes = nonlinear_runtime_panel_figure( t=t, wphi=wphi, heat_flux=heat_flux, gamma=gamma, omega=omega, title=f"SPECTRAX-GK nonlinear runtime: {base.name}", ) else: raise ValueError(f"Unsupported runtime artifact kind: {kind!r}") fig.savefig(out_path, dpi=220, bbox_inches="tight") plt.close(fig) return out_path
[docs] def cyclone_reference_figure(ref: CycloneReference) -> Tuple[plt.Figure, np.ndarray]: """Create a two-panel Cyclone base case reference plot.""" set_plot_style() fig, axes = plt.subplots(2, 1, sharex=True, figsize=(5.0, 5.0)) ax0, ax1 = axes ax0.plot(ref.ky, ref.gamma, marker="o", color="#1f77b4", label="Reference") ax0.set_ylabel(r"$\gamma a / v_{ti}$") ax0.set_title("Cyclone base case (adiabatic electrons)") ax0.legend(loc="best") ax0.set_xscale("log") ax1.plot(ref.ky, ref.omega, marker="o", color="#ff7f0e", label="Reference") ax1.set_xlabel(r"$k_y \rho_i$") ax1.set_ylabel(r"$\omega a / v_{ti}$") ax1.legend(loc="best") ax1.set_xscale("log") fig.tight_layout() return fig, axes
[docs] def cyclone_comparison_figure( ref: CycloneReference, scan: CycloneScanResult, label: str = "SPECTRAX-GK", ) -> Tuple[plt.Figure, np.ndarray]: """Create a two-panel comparison plot between reference and solver output.""" set_plot_style() fig, axes = plt.subplots(2, 1, sharex=True, figsize=(5.0, 5.0)) ax0, ax1 = axes ax0.plot(ref.ky, ref.gamma, marker="o", color="#1f77b4", label="Reference") ax0.plot(scan.ky, scan.gamma, marker="s", color="#2ca02c", label=label) ax0.set_ylabel(r"$\gamma a / v_{ti}$") ax0.set_title("Cyclone base case (adiabatic electrons)") ax0.legend(loc="best") ax0.set_xscale("log") ax1.plot(ref.ky, ref.omega, marker="o", color="#ff7f0e", label="Reference") ax1.plot(scan.ky, scan.omega, marker="s", color="#d62728", label=label) ax1.set_xlabel(r"$k_y \rho_i$") ax1.set_ylabel(r"$\omega a / v_{ti}$") ax1.legend(loc="best") ax1.set_xscale("log") fig.tight_layout() return fig, axes
[docs] def scan_comparison_figure( x: np.ndarray, gamma: np.ndarray, omega: np.ndarray, x_label: str, title: str, x_ref: np.ndarray | None = None, gamma_ref: np.ndarray | None = None, omega_ref: np.ndarray | None = None, label: str = "SPECTRAX-GK", ref_label: str = "Reference", log_x: bool = False, ) -> Tuple[plt.Figure, np.ndarray]: """Create a two-panel comparison plot for a generic scan.""" set_plot_style() fig, axes = plt.subplots(2, 1, sharex=True, figsize=(5.0, 5.0)) ax0, ax1 = axes ax0.plot(x, gamma, marker="o", color="#2ca02c", label=label) if x_ref is not None and gamma_ref is not None: ax0.plot(x_ref, gamma_ref, marker="o", linestyle="None", color="#1f77b4", label=ref_label) ax0.set_ylabel(r"$\gamma a / v_{ti}$") ax0.set_title(title) ax0.legend(loc="best") if log_x: ax0.set_xscale("log") ax1.plot(x, omega, marker="o", color="#d62728", label=label) if x_ref is not None and omega_ref is not None: ax1.plot(x_ref, omega_ref, marker="o", linestyle="None", color="#1f77b4", label=ref_label) ax1.set_xlabel(x_label) ax1.set_ylabel(r"$\omega a / v_{ti}$") ax1.legend(loc="best") if log_x: ax1.set_xscale("log") fig.tight_layout() return fig, axes
[docs] def etg_trend_figure( R_over_LTe: np.ndarray, gamma: np.ndarray, omega: np.ndarray, ky_target: float, ) -> Tuple[plt.Figure, np.ndarray]: """Create a two-panel ETG trend plot versus R/LTe.""" set_plot_style() fig, axes = plt.subplots(2, 1, sharex=True, figsize=(5.0, 5.0)) ax0, ax1 = axes ax0.plot(R_over_LTe, gamma, marker="o", color="#1f77b4") ax0.set_ylabel(r"$\gamma a / v_{ti}$") ax0.set_title(fr"ETG trend at $k_y={ky_target:.2f}$") ax1.plot(R_over_LTe, omega, marker="o", color="#ff7f0e") ax1.set_xlabel(r"$R/L_{Te}$") ax1.set_ylabel(r"$\omega a / v_{ti}$") fig.tight_layout() return fig, axes
[docs] @dataclass(frozen=True) class LinearValidationPanel: name: str z: np.ndarray eigenfunction: np.ndarray x: np.ndarray gamma: np.ndarray omega: np.ndarray x_label: str x_ref: np.ndarray | None = None gamma_ref: np.ndarray | None = None omega_ref: np.ndarray | None = None ref_label: str = "Reference" log_x: bool = False
[docs] @dataclass(frozen=True) class ReferenceSeries: label: str x: np.ndarray gamma: np.ndarray omega: np.ndarray color: str marker: str = "o" linestyle: str = "--"
[docs] @dataclass(frozen=True) class MultiReferenceValidationPanel: name: str z: np.ndarray eigenfunction: np.ndarray x: np.ndarray gamma: np.ndarray omega: np.ndarray x_label: str references: list[ReferenceSeries] log_x: bool = False
[docs] def linear_validation_figure( panels: list[LinearValidationPanel], ) -> Tuple[plt.Figure, np.ndarray]: """Create a multi-panel summary plot of eigenfunctions, growth rates, and frequencies.""" if len(panels) == 0: raise ValueError("panels must be non-empty") set_plot_style() nrows = len(panels) fig, axes = plt.subplots(nrows, 3, figsize=(12.0, 3.0 * nrows), sharex="col") if nrows == 1: axes = np.asarray([axes]) for i, panel in enumerate(panels): ax0, ax1, ax2 = axes[i] ax0.plot(panel.z, panel.eigenfunction.real, color="#1f77b4", label="Re") ax0.plot(panel.z, panel.eigenfunction.imag, color="#ff7f0e", linestyle="--", label="Im") ax0.set_ylabel(panel.name) ax0.set_xlabel(r"$\theta$") if i == 0: ax0.set_title("Eigenfunction") ax1.set_title("Growth rate") ax2.set_title("Frequency") if i == 0: ax0.legend(loc="best", fontsize=9) ax1.plot(panel.x, panel.gamma, marker="o", color="#2ca02c", label="SPECTRAX-GK") if panel.x_ref is not None and panel.gamma_ref is not None: ax1.plot(panel.x_ref, panel.gamma_ref, marker="o", linestyle="None", color="#1f77b4", label=panel.ref_label) ax1.set_xlabel(panel.x_label) ax1.set_ylabel(r"$\gamma a / v_{ti}$") if panel.log_x: ax1.set_xscale("log") ax2.plot(panel.x, panel.omega, marker="o", color="#d62728", label="SPECTRAX-GK") if panel.x_ref is not None and panel.omega_ref is not None: ax2.plot(panel.x_ref, panel.omega_ref, marker="o", linestyle="None", color="#1f77b4", label=panel.ref_label) ax2.set_xlabel(panel.x_label) ax2.set_ylabel(r"$\omega a / v_{ti}$") if panel.log_x: ax2.set_xscale("log") if i == 0: ax1.legend(loc="best", fontsize=9) ax2.legend(loc="best", fontsize=9) fig.tight_layout() return fig, axes
[docs] def linear_validation_multi_reference_figure( panels: list[MultiReferenceValidationPanel], ) -> Tuple[plt.Figure, np.ndarray]: """Create summary panels with multiple external reference curves.""" if len(panels) == 0: raise ValueError("panels must be non-empty") set_plot_style() nrows = len(panels) # Keep each row on its own x-range so Cyclone- and ETG-scale ky scans # remain readable in the combined summary figure. fig, axes = plt.subplots(nrows, 3, figsize=(12.0, 3.0 * nrows), sharex=False) if nrows == 1: axes = np.asarray([axes]) for i, panel in enumerate(panels): ax0, ax1, ax2 = axes[i] ax0.plot(panel.z, panel.eigenfunction.real, color="#1f77b4", label="Re") ax0.plot(panel.z, panel.eigenfunction.imag, color="#ff7f0e", linestyle="--", label="Im") ax0.set_ylabel(panel.name) ax0.set_xlabel(r"$\theta$") if i == 0: ax0.set_title("Eigenfunction") ax1.set_title("Growth rate") ax2.set_title("Frequency") ax0.legend(loc="best", fontsize=9) ax1.plot(panel.x, panel.gamma, marker="o", color="#2ca02c", label="SPECTRAX-GK") ax2.plot(panel.x, panel.omega, marker="o", color="#d62728", label="SPECTRAX-GK") for ref in panel.references: ax1.plot( ref.x, ref.gamma, marker=ref.marker, linestyle=ref.linestyle, color=ref.color, label=ref.label, ) ax2.plot( ref.x, ref.omega, marker=ref.marker, linestyle=ref.linestyle, color=ref.color, label=ref.label, ) ax1.set_xlabel(panel.x_label) ax1.set_ylabel(r"$\gamma a / v_{ti}$") ax2.set_xlabel(panel.x_label) ax2.set_ylabel(r"$\omega a / v_{ti}$") if panel.log_x: ax1.set_xscale("log") ax2.set_xscale("log") if i == 0: ax1.legend(loc="best", fontsize=9) ax2.legend(loc="best", fontsize=9) fig.tight_layout() return fig, axes
[docs] def scan_multi_reference_figure( x: np.ndarray, gamma: np.ndarray, omega: np.ndarray, x_label: str, title: str, references: list[ReferenceSeries], *, log_x: bool = False, ) -> Tuple[plt.Figure, np.ndarray]: """Create a two-panel comparison figure against multiple reference curves.""" set_plot_style() fig, axes = plt.subplots(2, 1, sharex=True, figsize=(5.5, 5.0)) ax0, ax1 = axes ax0.plot(x, gamma, marker="o", color="#2ca02c", label="SPECTRAX-GK") ax1.plot(x, omega, marker="o", color="#d62728", label="SPECTRAX-GK") for ref in references: ax0.plot( ref.x, ref.gamma, marker=ref.marker, linestyle=ref.linestyle, color=ref.color, label=ref.label, ) ax1.plot( ref.x, ref.omega, marker=ref.marker, linestyle=ref.linestyle, color=ref.color, label=ref.label, ) ax0.set_title(title) ax0.set_ylabel(r"$\gamma a / v_{ti}$") ax1.set_ylabel(r"$\omega a / v_{ti}$") ax1.set_xlabel(x_label) if log_x: ax0.set_xscale("log") ax1.set_xscale("log") ax0.legend(loc="best") ax1.legend(loc="best") fig.tight_layout() return fig, axes
[docs] def growth_rate_heatmap( x: np.ndarray, y: np.ndarray, gamma: np.ndarray, title: str, x_label: str, y_label: str, cmap: str = "jet", ) -> Tuple[plt.Figure, plt.Axes]: """Render a growth-rate heatmap versus two gradient axes.""" set_plot_style() fig, ax = plt.subplots(1, 1, figsize=(5.5, 4.5)) extent = (float(x[0]), float(x[-1]), float(y[0]), float(y[-1])) im = ax.imshow(gamma, origin="lower", aspect="auto", extent=extent, cmap=cmap) ax.set_title(title) ax.set_xlabel(x_label) ax.set_ylabel(y_label) fig.colorbar(im, ax=ax, label=r"$\gamma a / v_{ti}$") fig.tight_layout() return fig, ax
[docs] def growth_fit_figure( t: np.ndarray, signal: np.ndarray, *, tmin: float | None = None, tmax: float | None = None, title: str = "Growth-fit window", ) -> Tuple[plt.Figure, np.ndarray]: """Plot :math:`|s|^2` and :math:`\\log |s|^2` with an optional fit window.""" set_plot_style() fig, axes = plt.subplots(2, 1, sharex=True, figsize=(5.0, 4.5)) ax0, ax1 = axes energy = np.abs(signal) ** 2 tiny = np.finfo(float).tiny log_energy = np.log(np.maximum(energy, tiny)) ax0.plot(t, energy, label=r"$|s|^2$") ax0.set_ylabel("energy") ax1.plot(t, log_energy, label=r"$\log|s|^2$") ax1.set_ylabel("log energy") ax1.set_xlabel("t") ax0.set_title(title) if tmin is not None and tmax is not None and tmax > tmin: ax0.axvspan(tmin, tmax, color="orange", alpha=0.2, label="fit window") ax1.axvspan(tmin, tmax, color="orange", alpha=0.2) gamma, _omega = fit_growth_rate(t, signal, tmin=tmin, tmax=tmax) idx = int(np.searchsorted(t, tmin)) log_ref = log_energy[idx] if idx < log_energy.size else log_energy[-1] fit_line = 2.0 * gamma * (t - tmin) + log_ref ax1.plot(t, fit_line, color="red", linestyle="--", label="fit line") ax0.legend(loc="best", fontsize=9) ax1.legend(loc="best", fontsize=9) fig.tight_layout() return fig, axes
[docs] def eigenfunction_overlap_summary_figure( ky: np.ndarray, overlap: np.ndarray, relative_l2: np.ndarray, *, title: str = "Eigenfunction overlap summary", x_label: str = r"$k_y \rho_i$", overlap_label: str = "Normalized overlap", rel_l2_label: str = "Relative $L^2$ error", log_x: bool = True, ) -> Tuple[plt.Figure, np.ndarray]: """Render a compact two-panel eigenfunction-overlap summary.""" set_plot_style() fig, axes = plt.subplots(2, 1, sharex=True, figsize=(5.6, 5.2)) ax0, ax1 = axes ky_arr = np.asarray(ky, dtype=float) overlap_arr = np.asarray(overlap, dtype=float) rel_l2_arr = np.asarray(relative_l2, dtype=float) ax0.plot(ky_arr, overlap_arr, color="#0f4c81", marker="o", linewidth=2.2, label=overlap_label) ax0.set_ylabel("overlap") ax0.set_ylim(0.0, min(1.02, max(1.0, float(np.nanmax(overlap_arr)) + 0.02))) ax0.set_title(title) ax0.legend(loc="best", frameon=False) ax1.plot(ky_arr, rel_l2_arr, color="#c44e52", marker="s", linewidth=2.2, label=rel_l2_label) ax1.set_xlabel(x_label) ax1.set_ylabel(r"relative $L^2$") ax1.legend(loc="best", frameon=False) if log_x: ax0.set_xscale("log") ax1.set_xscale("log") for axis in axes: axis.grid(True, alpha=0.25) fig.tight_layout() return fig, axes
[docs] def eigenfunction_reference_overlay_figure( theta: np.ndarray, eigenfunction: np.ndarray, theta_ref: np.ndarray, reference: np.ndarray, *, title: str = "Eigenfunction overlay", ) -> Tuple[plt.Figure, np.ndarray]: """Render a phase-aligned raw overlay against a frozen reference mode.""" from spectraxgk.benchmarking import compare_eigenfunctions, phase_align_eigenfunction set_plot_style() theta_arr = np.asarray(theta, dtype=float) eig = np.asarray(eigenfunction, dtype=np.complex128) theta_ref_arr = np.asarray(theta_ref, dtype=float) ref = np.asarray(reference, dtype=np.complex128) if eig.shape != ref.shape: raise ValueError("eigenfunction and reference must have the same shape") eig_aligned, _phase = phase_align_eigenfunction(eig, ref) metrics = compare_eigenfunctions(eig, ref) fig, axes = plt.subplots(1, 3, figsize=(12.0, 3.9)) ax0, ax1, ax2 = axes ax0.plot(theta_ref_arr, np.real(ref), color="#0f4c81", linewidth=2.4, label="Reference Re") ax0.plot(theta_arr, np.real(eig_aligned), color="#c44e52", linewidth=2.0, linestyle="--", label="SPECTRAX Re") ax0.set_xlabel(r"$\theta$") ax0.set_ylabel("real") ax0.set_title("Real part") ax0.legend(loc="best", frameon=False) ax1.plot(theta_ref_arr, np.imag(ref), color="#0f4c81", linewidth=2.4, label="Reference Im") ax1.plot(theta_arr, np.imag(eig_aligned), color="#c44e52", linewidth=2.0, linestyle="--", label="SPECTRAX Im") ax1.set_xlabel(r"$\theta$") ax1.set_ylabel("imag") ax1.set_title("Imaginary part") ax1.legend(loc="best", frameon=False) ax2.plot(theta_ref_arr, np.abs(ref), color="#0f4c81", linewidth=2.4, label="Reference $|\\phi|$") ax2.plot(theta_arr, np.abs(eig_aligned), color="#c44e52", linewidth=2.0, linestyle="--", label="SPECTRAX $|\\phi|$") ax2.set_xlabel(r"$\theta$") ax2.set_ylabel(r"$|\phi|$") ax2.set_title("Amplitude") ax2.legend(loc="upper right", frameon=False) ax2.text( 0.03, 0.04, f"overlap = {metrics.overlap:.4f}\nrel $L^2$ = {metrics.relative_l2:.4f}", transform=ax2.transAxes, va="bottom", ha="left", bbox={"boxstyle": "round,pad=0.3", "facecolor": "white", "alpha": 0.9, "edgecolor": "#cccccc"}, ) for axis in axes: axis.grid(True, alpha=0.25) fig.suptitle(title, y=1.02) fig.tight_layout() return fig, axes
[docs] def zonal_flow_response_figure( t: np.ndarray, response: np.ndarray, *, metrics=None, title: str = "Zonal-flow response", y_label: str = "normalized response", ) -> Tuple[plt.Figure, np.ndarray]: """Render a zonal-flow response trace and its envelope summary.""" from spectraxgk.benchmarking import zonal_flow_response_metrics set_plot_style() t_arr = np.asarray(t, dtype=float) resp = np.asarray(response, dtype=float) if t_arr.ndim != 1 or resp.ndim != 1 or t_arr.size != resp.size: raise ValueError("t and response must be one-dimensional arrays of equal length") if metrics is None: metrics = zonal_flow_response_metrics(t_arr, resp) response_norm = resp / float(metrics.initial_level) residual = float(metrics.residual_level) env_t = np.asarray(metrics.peak_times, dtype=float) env_y = np.asarray(metrics.peak_envelope, dtype=float) fit_count = int(getattr(metrics, "peak_fit_count", env_t.size)) fit_tmin = float(getattr(metrics, "fit_tmin", t_arr[0])) fit_tmax = float(getattr(metrics, "fit_tmax", t_arr[-1])) damping_method = str(getattr(metrics, "damping_method", "combined_envelope")) frequency_method = str(getattr(metrics, "frequency_method", "peak_spacing")) max_peak_t = np.asarray(getattr(metrics, "max_peak_times", np.asarray([], dtype=float)), dtype=float) max_peak_y = np.asarray(getattr(metrics, "max_peak_values", np.asarray([], dtype=float)), dtype=float) min_peak_t = np.asarray(getattr(metrics, "min_peak_times", np.asarray([], dtype=float)), dtype=float) min_peak_y = np.asarray(getattr(metrics, "min_peak_values", np.asarray([], dtype=float)), dtype=float) fig, axes = plt.subplots(1, 2, figsize=(11.0, 4.0)) ax0, ax1 = axes ax0.plot(t_arr, response_norm, color="#0f4c81", linewidth=2.2, label="response") ax0.axhline(residual, color="#c44e52", linestyle="--", linewidth=2.0, label="residual") ax0.axvspan(fit_tmin, fit_tmax, color="#d9ead3", alpha=0.22, linewidth=0.0) if damping_method == "branchwise_extrema": if max_peak_t.size: keep = (max_peak_t >= fit_tmin) & (max_peak_t <= fit_tmax) ax0.plot(max_peak_t[keep], max_peak_y[keep], linestyle="none", marker="o", color="#2a9d8f", markersize=5.2, label="maxima fit points") if min_peak_t.size: keep = (min_peak_t >= fit_tmin) & (min_peak_t <= fit_tmax) ax0.plot(min_peak_t[keep], min_peak_y[keep], linestyle="none", marker="o", color="#7b2cbf", markersize=5.2, label="minima fit points") ax0.fill_between( t_arr, residual - float(metrics.residual_std), residual + float(metrics.residual_std), color="#c44e52", alpha=0.15, linewidth=0.0, ) ax0.set_xlabel("t") ax0.set_ylabel(y_label) ax0.set_title("Normalized response") ax0.legend(loc="best", frameon=False) ax1.plot(t_arr, np.maximum(np.abs(response_norm - residual), 1.0e-14), color="#4c956c", linewidth=2.0, alpha=0.5) if env_t.size: ax1.plot(env_t, env_y, color="#c44e52", marker="o", linewidth=1.8, label="envelope peaks") fit_env_t = env_t[(env_t >= fit_tmin) & (env_t <= fit_tmax)] fit_env_y = env_y[(env_t >= fit_tmin) & (env_t <= fit_tmax)] if damping_method == "combined_envelope" and fit_count >= 2 and np.isfinite(float(metrics.gam_damping_rate)) and fit_env_t.size: fit_n = min(fit_count, fit_env_t.size) fit_t = fit_env_t[:fit_n] fit = fit_env_y[0] * np.exp(-float(metrics.gam_damping_rate) * (fit_t - fit_t[0])) label = "envelope fit" if fit_n == fit_env_t.size else f"envelope fit (first {fit_n} peaks)" ax1.plot(fit_t, fit, color="#2a9d8f", linestyle="--", linewidth=2.0, label=label) ax1.set_yscale("log") ax1.set_xlabel("t") ax1.set_ylabel("envelope") ax1.set_title("GAM envelope") if env_t.size: ax1.legend(loc="best", frameon=False) ax1.text( 0.03, 0.97, ( f"residual = {metrics.residual_level:.4f}\n" f"std = {metrics.residual_std:.4f}\n" f"ω_GAM = {metrics.gam_frequency:.4f}\n" f"γ_damp = {metrics.gam_damping_rate:.4f}\n" f"fit_peaks = {fit_count}\n" f"norm = {getattr(metrics, 'initial_policy', 'window_abs_mean')}\n" f"damp = {damping_method}\n" f"freq = {frequency_method}\n" f"fit_t = [{fit_tmin:.1f}, {fit_tmax:.1f}]" ), transform=ax1.transAxes, va="top", ha="left", fontsize=9, bbox={"boxstyle": "round,pad=0.3", "facecolor": "white", "alpha": 0.9, "edgecolor": "#cccccc"}, ) for axis in axes: axis.grid(True, alpha=0.25) fig.suptitle(title, y=1.02) fig.tight_layout() return fig, axes