"""Deterministic decomposition contracts for parallel work portfolios.
The helpers in this module describe partitioning and reconstruction contracts.
They do not route solver execution, alter nonlinear state layout, or make
speedup claims.
"""
from __future__ import annotations
from collections import Counter
from collections.abc import Callable, Iterable, Sequence
from dataclasses import asdict, dataclass
from typing import Any, Literal, TypeVar
IndependentWorkload = Literal[
"independent_ky_scan",
"uq_ensemble",
"optimization_ensemble",
]
DiagnosticWorkload = Literal["diagnostic_nonlinear_domain"]
DecompositionWorkload = IndependentWorkload | DiagnosticWorkload
ClaimLevel = Literal[
"production_independent_batching",
"diagnostic_nonlinear_domain_partition",
]
_INDEPENDENT_WORKLOADS: frozenset[str] = frozenset(
{"independent_ky_scan", "uq_ensemble", "optimization_ensemble"}
)
T = TypeVar("T")
[docs]
@dataclass(frozen=True)
class ShardAssignment:
"""A deterministic contiguous assignment of serial indices to one shard."""
shard_id: int
start: int
stop: int
indices: tuple[int, ...]
label: str
@property
def size(self) -> int:
"""Number of serial items assigned to this shard."""
return len(self.indices)
[docs]
def to_dict(self) -> dict[str, Any]:
"""Return a JSON-friendly representation of the assignment."""
return asdict(self)
[docs]
@dataclass(frozen=True)
class DecompositionContract:
"""Claim-scoped shard assignment contract for a parallelization path."""
workload: DecompositionWorkload
claim_level: ClaimLevel
claim_label: str
n_items: int
requested_shards: int
actual_shards: int
shards: tuple[ShardAssignment, ...]
independent_work: bool
changes_solver_layout: bool
state_shape: tuple[int, ...] | None = None
axis: int | None = None
@property
def production_independent_batching(self) -> bool:
"""Whether this contract is for production independent-work batching."""
return self.claim_level == "production_independent_batching"
@property
def diagnostic_nonlinear_partition(self) -> bool:
"""Whether this contract is diagnostic nonlinear-domain metadata."""
return self.claim_level == "diagnostic_nonlinear_domain_partition"
[docs]
def to_dict(self) -> dict[str, Any]:
"""Return a JSON-friendly representation of the contract."""
return asdict(self)
[docs]
@dataclass(frozen=True)
class ReconstructionIdentityReport:
"""Serial reconstruction identity report for a decomposition contract."""
workload: DecompositionWorkload
claim_level: ClaimLevel
claim_label: str
n_items: int
requested_shards: int
actual_shards: int
identity_passed: bool
expected_indices: tuple[int, ...]
reconstructed_indices: tuple[int, ...]
missing_indices: tuple[int, ...]
duplicate_indices: tuple[int, ...]
out_of_range_indices: tuple[int, ...]
out_of_order: bool
[docs]
def to_dict(self) -> dict[str, Any]:
"""Return a JSON-friendly representation of the report."""
return asdict(self)
def _validate_count(name: str, value: int, *, allow_zero: bool) -> int:
count = int(value)
minimum = 0 if allow_zero else 1
if count < minimum:
qualifier = "non-negative" if allow_zero else ">= 1"
raise ValueError(f"{name} must be {qualifier}")
return count
def _balanced_assignments(
*,
workload: DecompositionWorkload,
n_items: int,
requested_shards: int,
label_prefix: str,
) -> tuple[ShardAssignment, ...]:
n = _validate_count("n_items", n_items, allow_zero=True)
requested = _validate_count("requested_shards", requested_shards, allow_zero=False)
if n == 0:
return ()
actual = min(requested, n)
base, remainder = divmod(n, actual)
assignments: list[ShardAssignment] = []
start = 0
for shard_id in range(actual):
size = base + (1 if shard_id < remainder else 0)
stop = start + size
indices = tuple(range(start, stop))
label = f"{label_prefix}:shard_{shard_id:03d}:items_{start:06d}_{stop:06d}"
assignments.append(
ShardAssignment(
shard_id=shard_id,
start=start,
stop=stop,
indices=indices,
label=label,
)
)
start = stop
if start != n: # pragma: no cover - defensive invariant check
raise AssertionError(f"{workload} assignments did not cover all items")
return tuple(assignments)
def _independent_claim_label(workload: IndependentWorkload) -> str:
if workload == "independent_ky_scan":
portfolio = "independent ky scan"
elif workload == "uq_ensemble":
portfolio = "independent UQ ensemble"
elif workload == "optimization_ensemble":
portfolio = "independent optimization ensemble"
else: # pragma: no cover - protected by caller validation
raise ValueError(f"unknown independent workload: {workload}")
return (
f"production independent batching contract for {portfolio}; "
"serial ordering and reconstruction identity only; "
"not a nonlinear state-domain decomposition speedup claim"
)
[docs]
def build_independent_portfolio_decomposition(
n_items: int,
*,
requested_shards: int,
workload: IndependentWorkload,
) -> DecompositionContract:
"""Build a production independent-work decomposition contract.
The assignment is deterministic, balanced, contiguous, and contains no
empty shards. It covers release-ready independent portfolios only:
``independent_ky_scan``, ``uq_ensemble``, and
``optimization_ensemble``.
"""
if workload not in _INDEPENDENT_WORKLOADS:
raise ValueError(
"workload must be 'independent_ky_scan', 'uq_ensemble', "
"or 'optimization_ensemble'"
)
n = _validate_count("n_items", n_items, allow_zero=True)
requested = _validate_count(
"requested_shards",
requested_shards,
allow_zero=False,
)
shards = _balanced_assignments(
workload=workload,
n_items=n,
requested_shards=requested,
label_prefix=workload,
)
return DecompositionContract(
workload=workload,
claim_level="production_independent_batching",
claim_label=_independent_claim_label(workload),
n_items=n,
requested_shards=requested,
actual_shards=len(shards),
shards=shards,
independent_work=True,
changes_solver_layout=False,
)
def _diagnostic_claim_label() -> str:
return (
"diagnostic nonlinear state-domain partition contract; "
"serial split/reassemble identity only; "
"no production routing or speedup claim"
)
[docs]
def build_diagnostic_nonlinear_domain_decomposition(
state_shape: Iterable[int],
*,
axis: int,
requested_shards: int,
) -> DecompositionContract:
"""Build a diagnostic nonlinear-domain partition contract.
This metadata describes split/reassemble coverage along one state axis.
It is intentionally not a production nonlinear route and does not claim
nonlinear speedup.
"""
shape = tuple(int(size) for size in state_shape)
if not shape:
raise ValueError("state_shape must contain at least one axis")
if any(size <= 0 for size in shape):
raise ValueError("state_shape entries must be positive")
canonical_axis = int(axis) % len(shape)
requested = _validate_count(
"requested_shards",
requested_shards,
allow_zero=False,
)
domain_size = shape[canonical_axis]
shards = _balanced_assignments(
workload="diagnostic_nonlinear_domain",
n_items=domain_size,
requested_shards=requested,
label_prefix=f"diagnostic_nonlinear_domain:axis_{canonical_axis}",
)
return DecompositionContract(
workload="diagnostic_nonlinear_domain",
claim_level="diagnostic_nonlinear_domain_partition",
claim_label=_diagnostic_claim_label(),
n_items=domain_size,
requested_shards=requested,
actual_shards=len(shards),
shards=shards,
independent_work=False,
changes_solver_layout=True,
state_shape=shape,
axis=canonical_axis,
)
[docs]
def shard_sequence(
values: Sequence[T],
contract: DecompositionContract,
) -> tuple[tuple[T, ...], ...]:
"""Return values grouped according to a decomposition contract."""
items = tuple(values)
if len(items) != contract.n_items:
raise ValueError("values length must match contract.n_items")
return tuple(tuple(items[index] for index in shard.indices) for shard in contract.shards)
[docs]
def reconstruct_serial(
contract: DecompositionContract,
shard_values: Sequence[Sequence[T]],
) -> tuple[T, ...]:
"""Reassemble shard values into serial index order."""
if len(shard_values) != contract.actual_shards:
raise ValueError("shard_values length must match contract.actual_shards")
reconstructed: list[Any] = [None] * contract.n_items
filled = [False] * contract.n_items
for shard, values in zip(contract.shards, shard_values, strict=True):
shard_tuple = tuple(values)
if len(shard_tuple) != shard.size:
raise ValueError("each shard value group must match its assignment size")
for index, value in zip(shard.indices, shard_tuple, strict=True):
if index < 0 or index >= contract.n_items:
raise ValueError("shard assignment index out of range")
if filled[index]:
raise ValueError("shard assignments contain duplicate indices")
reconstructed[index] = value
filled[index] = True
if not all(filled):
raise ValueError("shard assignments do not cover all serial indices")
return tuple(reconstructed)
def _coverage(contract: DecompositionContract) -> tuple[
tuple[int, ...],
tuple[int, ...],
tuple[int, ...],
tuple[int, ...],
tuple[int, ...],
bool,
]:
expected = tuple(range(contract.n_items))
reconstructed = tuple(index for shard in contract.shards for index in shard.indices)
counts = Counter(reconstructed)
missing = tuple(index for index in expected if counts[index] == 0)
duplicates = tuple(index for index, count in sorted(counts.items()) if count > 1)
out_of_range = tuple(
index for index in reconstructed if index < 0 or index >= contract.n_items
)
out_of_order = reconstructed != expected
return expected, reconstructed, missing, duplicates, out_of_range, out_of_order
def _default_equal(left: T, right: T) -> bool:
if left is right:
return True
try:
return bool(left == right)
except (TypeError, ValueError):
return False
[docs]
def serial_reconstruction_identity_report(
values: Sequence[T],
contract: DecompositionContract,
*,
equal: Callable[[T, T], bool] | None = None,
) -> ReconstructionIdentityReport:
"""Check that contract sharding reassembles exactly to serial order."""
items = tuple(values)
shards = shard_sequence(items, contract)
reconstructed_values = reconstruct_serial(contract, shards)
expected_indices, reconstructed_indices, missing, duplicates, out_of_range, out_of_order = _coverage(
contract
)
comparator = equal or _default_equal
values_match = len(reconstructed_values) == len(items) and all(
comparator(left, right)
for left, right in zip(reconstructed_values, items, strict=True)
)
identity_passed = bool(
values_match
and not missing
and not duplicates
and not out_of_range
and not out_of_order
)
return ReconstructionIdentityReport(
workload=contract.workload,
claim_level=contract.claim_level,
claim_label=contract.claim_label,
n_items=contract.n_items,
requested_shards=contract.requested_shards,
actual_shards=contract.actual_shards,
identity_passed=identity_passed,
expected_indices=expected_indices,
reconstructed_indices=reconstructed_indices,
missing_indices=missing,
duplicate_indices=duplicates,
out_of_range_indices=out_of_range,
out_of_order=out_of_order,
)
__all__ = [
"ClaimLevel",
"DecompositionContract",
"DecompositionWorkload",
"DiagnosticWorkload",
"IndependentWorkload",
"ReconstructionIdentityReport",
"ShardAssignment",
"build_diagnostic_nonlinear_domain_decomposition",
"build_independent_portfolio_decomposition",
"reconstruct_serial",
"serial_reconstruction_identity_report",
"shard_sequence",
]