Source code for spectraxgk.autodiff_validation

"""Autodiff validation helpers for inverse and UQ examples."""

from __future__ import annotations

from typing import Any, Callable

import jax
import jax.numpy as jnp
import numpy as np

from spectraxgk.parallel import independent_map


def _jax_enable_x64() -> bool:
    """Return the active JAX 64-bit precision flag without relying on dynamic attrs."""

    return bool(jax.config.read("jax_enable_x64"))


def _normalize_fd_executor(executor: str) -> str:
    """Normalize finite-difference worker executor names."""

    key = str(executor).strip().lower()
    if key in {"thread", "threads"}:
        return "thread"
    if key in {"process", "processes"}:
        return "process"
    raise ValueError("parallel_executor must be 'thread' or 'process'")


[docs] def covariance_diagnostics( jacobian: np.ndarray, residual: np.ndarray, *, regularization: float = 1.0e-9, ) -> dict[str, object]: """Return covariance and conditioning diagnostics for a least-squares inverse. The covariance uses the local Gauss-Newton approximation ``sigma^2 (J^T J + lambda I)^-1``. The returned dictionary is strict-JSON friendly and records enough metadata to decide whether an inverse demo is identifiable, merely locally sensitive, or ill-conditioned. """ jac = np.asarray(jacobian, dtype=float) res = np.asarray(residual, dtype=float).reshape(-1) if jac.ndim != 2: raise ValueError("jacobian must be a two-dimensional array") if jac.shape[1] == 0: raise ValueError("jacobian must contain at least one parameter column") if jac.shape[0] != res.size: raise ValueError("residual length must match the number of Jacobian rows") if not np.all(np.isfinite(jac)): raise ValueError("jacobian must contain only finite values") if not np.all(np.isfinite(res)): raise ValueError("residual must contain only finite values") reg = float(regularization) if reg < 0.0: raise ValueError("regularization must be non-negative") sigma2 = float(np.mean(res**2) + 1.0e-12) singular_values = np.linalg.svd(jac, compute_uv=False) if singular_values.size == 0 or float(singular_values[-1]) <= 0.0: condition_number = float("inf") else: condition_number = float(singular_values[0] / singular_values[-1]) rank = int(np.linalg.matrix_rank(jac)) normal = jac.T @ jac + reg * np.eye(jac.shape[1]) covariance = sigma2 * np.linalg.inv(normal) covariance = 0.5 * (covariance + covariance.T) std = np.sqrt(np.maximum(np.diag(covariance), 0.0)) denom = np.outer(std, std) correlation = np.divide( covariance, denom, out=np.zeros_like(covariance), where=denom > 0.0 ) eigvals = np.linalg.eigvalsh(covariance) positive = eigvals[eigvals > 0.0] if positive.size >= 2: ellipse_area = float(np.pi * np.sqrt(positive[-1] * positive[-2])) elif positive.size == 1: ellipse_area = 0.0 else: ellipse_area = 0.0 return { "sigma2": sigma2, "covariance": covariance.tolist(), "covariance_std": std.tolist(), "covariance_correlation": correlation.tolist(), "covariance_eigenvalues": eigvals.tolist(), "uq_ellipse_area_1sigma": ellipse_area, "jacobian_singular_values": singular_values.tolist(), "jacobian_condition_number": condition_number, "sensitivity_map_rank": rank, }
[docs] def central_finite_difference_jacobian( fn: Callable[[jnp.ndarray], Any], params: jnp.ndarray | np.ndarray, *, step: float = 1.0e-4, workers: int = 1, parallel_executor: str = "thread", ) -> jnp.ndarray: """Central finite-difference Jacobian for small differentiability gates.""" p = jnp.asarray(params, dtype=jnp.float64 if _jax_enable_x64() else jnp.float32) if p.ndim != 1: raise ValueError("params must be one-dimensional") h = float(step) if h <= 0.0: raise ValueError("step must be positive") n_workers = int(workers) if n_workers < 1: raise ValueError("workers must be >= 1") executor_key = _normalize_fd_executor(parallel_executor) if executor_key == "process" and n_workers > 1: raise ValueError( "parallel finite differences require the thread executor because objective closures are not pickled" ) eye = jnp.eye(p.size, dtype=p.dtype) def column(i: int) -> jnp.ndarray: fp = jnp.ravel(jnp.asarray(fn(p + h * eye[i]))) fm = jnp.ravel(jnp.asarray(fn(p - h * eye[i]))) return (fp - fm) / (2.0 * h) cols = independent_map( column, range(int(p.size)), workers=n_workers, executor=executor_key ) if not cols: return jnp.zeros((jnp.ravel(jnp.asarray(fn(p))).size, 0), dtype=p.dtype) return jnp.stack(cols, axis=1)
[docs] def autodiff_finite_difference_report( fn: Callable[[jnp.ndarray], Any], params: jnp.ndarray | np.ndarray, *, step: float = 1.0e-4, rtol: float = 1.0e-4, atol: float = 1.0e-6, direction: jnp.ndarray | np.ndarray | None = None, workers: int = 1, parallel_executor: str = "thread", ) -> dict[str, object]: """Compare JAX forward-mode derivatives against finite differences.""" p = jnp.asarray(params, dtype=jnp.float64 if _jax_enable_x64() else jnp.float32) if p.ndim != 1: raise ValueError("params must be one-dimensional") n_workers = int(workers) if n_workers < 1: raise ValueError("workers must be >= 1") executor_key = _normalize_fd_executor(parallel_executor) def flat_fn(x: jnp.ndarray) -> jnp.ndarray: return jnp.ravel(jnp.asarray(fn(x))) jac_ad = jax.jacfwd(flat_fn)(p) jac_fd = central_finite_difference_jacobian( flat_fn, p, step=step, workers=n_workers, parallel_executor=executor_key, ) err = np.asarray(jac_ad - jac_fd, dtype=float) denom = np.maximum(np.asarray(np.abs(jac_fd), dtype=float), float(atol)) rel = np.abs(err) / denom if direction is None: d = jnp.ones_like(p) d = d / jnp.maximum(jnp.linalg.norm(d), jnp.asarray(1.0, dtype=d.dtype)) else: d = jnp.asarray(direction, dtype=p.dtype) if d.shape != p.shape: raise ValueError("direction must have the same shape as params") tangent_ad = jac_ad @ d tangent_fd = (flat_fn(p + step * d) - flat_fn(p - step * d)) / (2.0 * step) tangent_err = np.asarray(tangent_ad - tangent_fd, dtype=float) max_abs = float(np.max(np.abs(err))) if err.size else 0.0 max_rel = float(np.max(rel)) if rel.size else 0.0 tangent_max_abs = float(np.max(np.abs(tangent_err))) if tangent_err.size else 0.0 passed = bool(max_abs <= float(atol) or max_rel <= float(rtol)) return { "passed": passed, "step": float(step), "rtol": float(rtol), "atol": float(atol), "max_abs_error": max_abs, "max_rel_error": max_rel, "tangent_max_abs_error": tangent_max_abs, "jacobian_ad": np.asarray(jac_ad, dtype=float).tolist(), "jacobian_fd": np.asarray(jac_fd, dtype=float).tolist(), "tangent_ad": np.asarray(tangent_ad, dtype=float).tolist(), "tangent_fd": np.asarray(tangent_fd, dtype=float).tolist(), "finite_difference_parallel": { "requested_workers": n_workers, "effective_workers": int(min(n_workers, max(int(p.size), 1))), "executor": executor_key, "identity_contract": "parallel finite-difference columns must match serial columns", }, }
[docs] def explicit_complex_operator_matrix( operator: Callable[[jnp.ndarray], Any], state_shape: tuple[int, ...], *, dtype: Any | None = None, ) -> jnp.ndarray: """Materialize a small complex linear operator as a dense matrix. This helper is intended for validation fixtures, not production solves. It applies ``operator`` to each basis vector of ``state_shape`` and returns a matrix whose columns are the flattened outputs. Small dense matrices make eigenvalue AD-vs-finite-difference gates easy to express while keeping the production code matrix-free. """ if not state_shape or any(int(size) <= 0 for size in state_shape): raise ValueError("state_shape must contain positive dimensions") matrix_dtype = dtype or (jnp.complex128 if _jax_enable_x64() else jnp.complex64) size = int(np.prod(tuple(int(dim) for dim in state_shape))) eye = jnp.eye(size, dtype=matrix_dtype) def column(vec: jnp.ndarray) -> jnp.ndarray: out = jnp.asarray(operator(jnp.reshape(vec, state_shape))) if tuple(out.shape) != tuple(state_shape): raise ValueError("operator output shape must match state_shape") return jnp.ravel(out) columns = jax.vmap(column)(eye) return jnp.swapaxes(columns, 0, 1)
[docs] def isolated_eigenvalue_sensitivity_report( matrix_fn: Callable[[jnp.ndarray], Any], params: jnp.ndarray | np.ndarray, *, selector: str = "max_real", step: float = 1.0e-4, rtol: float = 1.0e-4, atol: float = 1.0e-6, gap_floor: float = 1.0e-8, ) -> dict[str, object]: """Validate AD sensitivities of one isolated eigenvalue branch. The branch index is selected at the base point and then held fixed during the finite-difference comparison. This mirrors the branch-continuity assumption used for linear growth/frequency sensitivities. """ p = jnp.asarray(params, dtype=jnp.float64 if _jax_enable_x64() else jnp.float32) if p.ndim != 1: raise ValueError("params must be one-dimensional") eig_base = jnp.linalg.eigvals(jnp.asarray(matrix_fn(p))) eig_np = np.asarray(eig_base) if eig_np.ndim != 1 or eig_np.size == 0: raise ValueError( "matrix_fn must return a square matrix with at least one eigenvalue" ) selector_key = selector.strip().lower() if selector_key == "max_real": index = int(np.argmax(np.real(eig_np))) elif selector_key.startswith("index:"): index = int(selector_key.split(":", 1)[1]) if index < 0 or index >= eig_np.size: raise ValueError( f"selector index {index} is out of bounds for {eig_np.size} eigenvalues" ) else: raise ValueError("selector must be 'max_real' or 'index:N'") selected = eig_np[index] if eig_np.size == 1: gap = float("inf") else: others = np.delete(eig_np, index) gap = float(np.min(np.abs(selected - others))) def branch_fn(x: jnp.ndarray) -> jnp.ndarray: value = jnp.linalg.eigvals(jnp.asarray(matrix_fn(x)))[index] return jnp.asarray([jnp.real(value), jnp.imag(value)]) branch_isolated = bool(gap >= float(gap_floor)) try: report = autodiff_finite_difference_report( branch_fn, p, step=step, rtol=rtol, atol=atol, ) except NotImplementedError as exc: return { "passed": False, "ad_supported": False, "failure_reason": str(exc), "step": float(step), "rtol": float(rtol), "atol": float(atol), "selector": selector_key, "selected_index": index, "eigenvalue_real": float(np.real(selected)), "eigenvalue_imag": float(np.imag(selected)), "eigenvalue_gap": gap, "gap_floor": float(gap_floor), "branch_isolated": branch_isolated, } return { **report, "passed": bool(report["passed"]) and branch_isolated, "ad_supported": True, "selector": selector_key, "selected_index": index, "eigenvalue_real": float(np.real(selected)), "eigenvalue_imag": float(np.imag(selected)), "eigenvalue_gap": gap, "gap_floor": float(gap_floor), "branch_isolated": branch_isolated, }
[docs] def isolated_eigenpair_observable_sensitivity_report( matrix_fn: Callable[[jnp.ndarray], Any], observable_fn: Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], Any], params: jnp.ndarray | np.ndarray, *, selector: str = "max_real", step: float = 1.0e-4, rtol: float = 1.0e-4, atol: float = 1.0e-6, gap_floor: float = 1.0e-8, ) -> dict[str, object]: """Validate AD sensitivities of an observable of one isolated eigenpair. ``observable_fn`` receives ``(eigenvalue, eigenvector, params)`` for the branch selected at the base point. The selected index is held fixed during finite differences, so this gate is appropriate for branch-continuous, phase-invariant quantities such as ``gamma / <k_perp^2>``. """ p = jnp.asarray(params, dtype=jnp.float64 if _jax_enable_x64() else jnp.float32) if p.ndim != 1: raise ValueError("params must be one-dimensional") eig_base, vec_base = jnp.linalg.eig(jnp.asarray(matrix_fn(p))) eig_np = np.asarray(eig_base) if eig_np.ndim != 1 or eig_np.size == 0: raise ValueError( "matrix_fn must return a square matrix with at least one eigenvalue" ) if np.asarray(vec_base).shape[1] != eig_np.size: raise ValueError("eigenvector matrix shape is inconsistent with eigenvalues") selector_key = selector.strip().lower() if selector_key == "max_real": index = int(np.argmax(np.real(eig_np))) elif selector_key.startswith("index:"): index = int(selector_key.split(":", 1)[1]) if index < 0 or index >= eig_np.size: raise ValueError( f"selector index {index} is out of bounds for {eig_np.size} eigenvalues" ) else: raise ValueError("selector must be 'max_real' or 'index:N'") selected = eig_np[index] if eig_np.size == 1: gap = float("inf") else: others = np.delete(eig_np, index) gap = float(np.min(np.abs(selected - others))) def branch_fn(x: jnp.ndarray) -> jnp.ndarray: eigvals, eigvecs = jnp.linalg.eig(jnp.asarray(matrix_fn(x))) obs = jnp.ravel( jnp.asarray(observable_fn(eigvals[index], eigvecs[:, index], x)) ) if jnp.iscomplexobj(obs): obs = jnp.concatenate([jnp.real(obs), jnp.imag(obs)]) return obs branch_isolated = bool(gap >= float(gap_floor)) try: report = autodiff_finite_difference_report( branch_fn, p, step=step, rtol=rtol, atol=atol, ) except NotImplementedError as exc: return { "passed": False, "ad_supported": False, "failure_reason": str(exc), "step": float(step), "rtol": float(rtol), "atol": float(atol), "selector": selector_key, "selected_index": index, "eigenvalue_real": float(np.real(selected)), "eigenvalue_imag": float(np.imag(selected)), "eigenvalue_gap": gap, "gap_floor": float(gap_floor), "branch_isolated": branch_isolated, } return { **report, "passed": bool(report["passed"]) and branch_isolated, "ad_supported": True, "selector": selector_key, "selected_index": index, "eigenvalue_real": float(np.real(selected)), "eigenvalue_imag": float(np.imag(selected)), "eigenvalue_gap": gap, "gap_floor": float(gap_floor), "branch_isolated": branch_isolated, }
[docs] def implicit_eigenpair_observable_sensitivity_report( matrix_fn: Callable[[jnp.ndarray], Any], observable_fn: Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], Any], params: jnp.ndarray | np.ndarray, *, selector: str = "max_real", step: float = 1.0e-4, rtol: float = 1.0e-4, atol: float = 1.0e-6, gap_floor: float = 1.0e-8, ) -> dict[str, object]: """Validate implicit sensitivities of an isolated non-Hermitian eigenpair. JAX currently supports first derivatives of non-Hermitian eigenvalues but not eigenvectors. This helper avoids differentiating through ``jnp.linalg.eig``. It differentiates the matrix entries with JAX, solves the left/right eigenvector perturbation equation for each parameter, and compares the resulting observable Jacobian against nearest-branch central finite differences. The observable should be phase-invariant under ``v -> exp(i alpha) v``. The implicit solve fixes the gauge with ``w^H dv = 0`` using the left eigenvector ``w`` normalized by ``w^H v = 1``. """ p = jnp.asarray(params, dtype=jnp.float64 if _jax_enable_x64() else jnp.float32) if p.ndim != 1: raise ValueError("params must be one-dimensional") A = jnp.asarray(matrix_fn(p)) if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError("matrix_fn must return a square matrix") eigvals, eigvecs = jnp.linalg.eig(A) eig_np = np.asarray(eigvals) if eig_np.size == 0: raise ValueError("matrix_fn must return at least one eigenvalue") selector_key = selector.strip().lower() if selector_key == "max_real": index = int(np.argmax(np.real(eig_np))) elif selector_key.startswith("index:"): index = int(selector_key.split(":", 1)[1]) if index < 0 or index >= eig_np.size: raise ValueError( f"selector index {index} is out of bounds for {eig_np.size} eigenvalues" ) else: raise ValueError("selector must be 'max_real' or 'index:N'") lam = eigvals[index] v = eigvecs[:, index] if eig_np.size == 1: gap = float("inf") else: gap = float(np.min(np.abs(np.delete(eig_np, index) - np.asarray(lam)))) branch_isolated = bool(gap >= float(gap_floor)) left_vals, left_vecs = jnp.linalg.eig(jnp.conj(jnp.swapaxes(A, 0, 1))) left_index = int( np.argmin(np.abs(np.asarray(left_vals) - np.conj(np.asarray(lam)))) ) w = left_vecs[:, left_index] overlap = jnp.vdot(w, v) overlap_abs = float(np.abs(np.asarray(overlap))) if overlap_abs <= 0.0 or not np.isfinite(overlap_abs): raise ValueError("left/right eigenvectors are biorthogonally singular") w = w / jnp.conj(overlap) def flat_matrix_fn(x: jnp.ndarray) -> jnp.ndarray: return jnp.ravel(jnp.asarray(matrix_fn(x))) dA_flat = jax.jacfwd(flat_matrix_fn)(p) dA = jnp.reshape(dA_flat, A.shape + (p.size,)) n = int(A.shape[0]) identity = jnp.eye(n, dtype=A.dtype) top = jnp.concatenate([A - lam * identity, -v[:, None]], axis=1) bottom = jnp.concatenate( [jnp.conj(w)[None, :], jnp.zeros((1, 1), dtype=A.dtype)], axis=1, ) augmented = jnp.concatenate([top, bottom], axis=0) rhs_columns = [] for i in range(int(p.size)): rhs_columns.append( jnp.concatenate([-dA[:, :, i] @ v, jnp.zeros((1,), dtype=A.dtype)]) ) rhs = jnp.stack(rhs_columns, axis=1) solution = jnp.linalg.solve(augmented, rhs) dv = solution[:n, :] dlam = solution[n, :] def observable_real( lam_i: jnp.ndarray, v_i: jnp.ndarray, p_i: jnp.ndarray ) -> jnp.ndarray: obs = jnp.ravel(jnp.asarray(observable_fn(lam_i, v_i, p_i))) if jnp.iscomplexobj(obs): return jnp.concatenate([jnp.real(obs), jnp.imag(obs)]) return jnp.real(obs) def observable_real_from_eigenpair(packed: jnp.ndarray) -> jnp.ndarray: lam_i = packed[0] + 1j * packed[1] v_real_start = 2 v_imag_start = v_real_start + n v_i = packed[v_real_start:v_imag_start] + 1j * packed[v_imag_start:] return observable_real(lam_i, v_i, p) def observable_real_from_params(p_i: jnp.ndarray) -> jnp.ndarray: return observable_real(lam, v, p_i) # Split the chain rule so expensive parameter-dependent context, e.g. # VMEC/Boozer geometry reconstruction, is only differentiated along the # actual parameter directions. Differentiating one packed vector # [lambda, v, p] is mathematically equivalent but can replicate heavy # geometry tangents for every eigenvector component. eigenpair_base = jnp.concatenate( [jnp.asarray([jnp.real(lam), jnp.imag(lam)]), jnp.real(v), jnp.imag(v)] ) obs_jac_eigenpair = jax.jacfwd(observable_real_from_eigenpair)(eigenpair_base) obs_jac_params = jax.jacfwd(observable_real_from_params)(p) implicit_cols = [] eye = jnp.eye(p.size, dtype=p.dtype) for i in range(int(p.size)): eigenpair_tangent = jnp.concatenate( [ jnp.asarray([jnp.real(dlam[i]), jnp.imag(dlam[i])]), jnp.real(dv[:, i]), jnp.imag(dv[:, i]), ] ) implicit_cols.append( obs_jac_eigenpair @ eigenpair_tangent + obs_jac_params @ eye[i] ) jac_implicit = jnp.stack(implicit_cols, axis=1) def branch_observable(x: jnp.ndarray) -> jnp.ndarray: eigvals_i, eigvecs_i = jnp.linalg.eig(jnp.asarray(matrix_fn(x))) branch_index = int(np.argmin(np.abs(np.asarray(eigvals_i) - np.asarray(lam)))) obs = jnp.ravel( jnp.asarray( observable_fn(eigvals_i[branch_index], eigvecs_i[:, branch_index], x) ) ) if jnp.iscomplexobj(obs): return jnp.concatenate([jnp.real(obs), jnp.imag(obs)]) return jnp.real(obs) jac_fd = central_finite_difference_jacobian(branch_observable, p, step=step) err = np.asarray(jac_implicit - jac_fd, dtype=float) denom = np.maximum(np.asarray(np.abs(jac_fd), dtype=float), float(atol)) rel = np.abs(err) / denom max_abs = float(np.max(np.abs(err))) if err.size else 0.0 max_rel = float(np.max(rel)) if rel.size else 0.0 passed = bool( branch_isolated and (max_abs <= float(atol) or max_rel <= float(rtol)) ) dlam_np = np.asarray(dlam, dtype=complex) return { "passed": passed, "ad_supported": True, "sensitivity_method": "implicit_left_right_eigenpair", "observable_chain_rule": "split_eigenpair_and_explicit_parameter", "step": float(step), "rtol": float(rtol), "atol": float(atol), "max_abs_error": max_abs, "max_rel_error": max_rel, "jacobian_implicit": np.asarray(jac_implicit, dtype=float).tolist(), "jacobian_fd": np.asarray(jac_fd, dtype=float).tolist(), "eigenvalue_sensitivity_real": np.real(dlam_np).tolist(), "eigenvalue_sensitivity_imag": np.imag(dlam_np).tolist(), "selector": selector_key, "selected_index": index, "left_selected_index": left_index, "eigenvalue_real": float(np.real(np.asarray(lam))), "eigenvalue_imag": float(np.imag(np.asarray(lam))), "eigenvalue_gap": gap, "gap_floor": float(gap_floor), "branch_isolated": branch_isolated, "biorthogonal_overlap_abs": overlap_abs, }
__all__ = [ "autodiff_finite_difference_report", "central_finite_difference_jacobian", "covariance_diagnostics", "explicit_complex_operator_matrix", "implicit_eigenpair_observable_sensitivity_report", "isolated_eigenpair_observable_sensitivity_report", "isolated_eigenvalue_sensitivity_report", ]