# -*- coding: utf-8 -*-
"""
Stage 3 of the ALBuMS pipeline: flows, the aggregated result, and the computed
compatibility matrix.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Iterable
import numpy as np
from albums.equilibrium import (
AlvesEquilibrium,
BoschEquilibrium,
Capability,
Equilibrium,
EquilibriumSolver,
VenturiniEquilibrium,
)
from albums.theories import (
ALL_THEORY_CLASSES,
BoschMode1,
Category,
HOMCoupledBunch,
InstabilityTheory,
PTBLAlves,
PTBLHe,
RobinsonCoupling,
RobinsonNoCoupling,
TheoryResult,
ZeroFrequency,
names_in_category,
)
[docs]class IncompatibleFlowError(ValueError):
"""Raised by Flow.validate when a theory needs a capability the solver
does not provide."""
# Theory-name groupings, used to fetch results by physical category regardless
# of which concrete theory a flow chose.
_ROBINSON_NAMES = names_in_category(Category.ROBINSON)
_PTBL_NAMES = names_in_category(Category.PTBL)
[docs]@dataclass
class Result:
"""
Aggregated outcome of running a flow at one operating point.
"""
converged: bool
equilibrium: Equilibrium
theories: dict[str, TheoryResult] = field(default_factory=dict)
Touschek: float = np.nan
xi: float | None = None
def __repr__(self) -> str:
names = ", ".join(self.theories)
return (
f"Result(converged={self.converged}, Touschek={self.Touschek}, "
f"xi={self.xi}, theories=[{names}])"
)
[docs] def get(self, *names: str) -> "TheoryResult | None":
"""Return the first present TheoryResult among names, else None."""
for n in names:
if n in self.theories:
return self.theories[n]
return None
@property
def any_unstable(self) -> bool:
return any(bool(np.any(tr.unstable)) for tr in self.theories.values())
# -- physical-category accessors (flow-agnostic) ----------------------- #
@property
def zero_frequency(self) -> bool:
t = self.get("zero_frequency")
return bool(t.unstable) if t is not None else False
@property
def robinson_flags(self) -> np.ndarray:
t = self.get(*_ROBINSON_NAMES)
return np.asarray(t.unstable) if t is not None else np.zeros(4, dtype=bool)
@property
def hom(self) -> bool:
t = self.get("hom_coupled_bunch")
return bool(t.unstable) if t is not None else False
@property
def ptbl(self) -> bool:
t = self.get(*_PTBL_NAMES)
return bool(t.unstable) if t is not None else False
# -- growth rates (flow-agnostic) -------------------------------------- #
@property
def growth_robinson(self) -> np.ndarray:
t = self.get(*_ROBINSON_NAMES)
return np.asarray(t.growth_rate) if t is not None else np.zeros(4)
@property
def growth_zero_frequency(self) -> float:
t = self.get("zero_frequency")
return float(t.growth_rate) if t is not None else 0.0
@property
def growth_PTBL(self) -> float | np.ndarray:
t = self.get(*_PTBL_NAMES)
return t.growth_rate if t is not None else 0.0
[docs]@dataclass
class Flow:
"""A user-chosen pipeline: one equilibrium solver + a subset of theories."""
equilibrium: EquilibriumSolver
theories: list[InstabilityTheory]
name: str | None = None
[docs] def required_capabilities(self) -> frozenset[Capability]:
caps: set[Capability] = set()
for t in self.theories:
caps |= set(t.requires)
return frozenset(caps)
[docs] def validate(self) -> None:
"""Raise IncompatibleFlowError if any theory is incompatible with the
chosen equilibrium solver. Call this before any compute."""
provided = self.equilibrium.provides
problems = []
for t in self.theories:
missing = set(t.requires) - set(provided)
if missing:
names = ", ".join(sorted(c.value for c in missing))
problems.append(
f" - theory '{t.name}' needs [{names}] which "
f"'{self.equilibrium.name}' does not provide"
)
if problems:
raise IncompatibleFlowError(
f"Flow '{self.name or 'custom'}' is invalid:\n" + "\n".join(problems)
)
[docs] def is_valid(self) -> bool:
try:
self.validate()
return True
except IncompatibleFlowError:
return False
# --------------------------------------------------------------------------- #
# Default flows — exact equivalents of the old method= strings. #
# --------------------------------------------------------------------------- #
DEFAULT_FLOWS: dict[str, Flow] = {
"Bosch": Flow(
BoschEquilibrium(),
[ZeroFrequency(), RobinsonCoupling(), HOMCoupledBunch(), BoschMode1()],
name="Bosch",
),
"Bosch_no_coupling": Flow(
BoschEquilibrium(),
[ZeroFrequency(), RobinsonNoCoupling(), HOMCoupledBunch(), BoschMode1()],
name="Bosch_no_coupling",
),
"Venturini": Flow(
VenturiniEquilibrium(),
[ZeroFrequency(), RobinsonCoupling(), HOMCoupledBunch(), PTBLHe()],
name="Venturini",
),
"Alves": Flow(
AlvesEquilibrium(),
[ZeroFrequency(), RobinsonCoupling(), HOMCoupledBunch(), PTBLAlves()],
name="Alves",
),
}
# --------------------------------------------------------------------------- #
# Generated compatibility matrix — the single source for README/docs. #
# --------------------------------------------------------------------------- #
ALL_SOLVERS: list[EquilibriumSolver] = [
BoschEquilibrium(),
VenturiniEquilibrium(),
AlvesEquilibrium(),
]
ALL_THEORIES: list[InstabilityTheory] = [cls() for cls in ALL_THEORY_CLASSES]
[docs]def flow_label(flow) -> str:
"""Short label for filenames: the flow name (or 'custom'), or the name string."""
if isinstance(flow, Flow):
return flow.name or "custom"
return str(flow)
[docs]def is_compatible(solver: EquilibriumSolver, theory: InstabilityTheory) -> bool:
return set(theory.requires).issubset(solver.provides)
[docs]def compatibility_matrix(
solvers: Iterable[EquilibriumSolver] | None = None,
theories: Iterable[InstabilityTheory] | None = None,
) -> str:
"""Render the solver x theory compatibility table as a Markdown string.
Generated from the capability sets so the docs cannot drift from the code.
"""
solvers = list(solvers or ALL_SOLVERS)
theories = list(theories or ALL_THEORIES)
header = "| Theory | " + " | ".join(s.name for s in solvers) + " |"
sep = "|---|" + "|".join([":--:"] * len(solvers)) + "|"
rows = [header, sep]
for t in theories:
cells = ["✅" if is_compatible(s, t) else "❌" for s in solvers]
rows.append(f"| `{t.name}` | " + " | ".join(cells) + " |")
return "\n".join(rows)
if __name__ == "__main__":
# Sanity: default flows must validate; print the docs matrix.
for flow in DEFAULT_FLOWS.values():
flow.validate()
print(compatibility_matrix())