Source code for albums.flow

# -*- coding: utf-8 -*-
"""
Stage 3 of the ALBuMS pipeline: flows, the aggregated result, and the computed
compatibility matrix.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Iterable

import numpy as np

from albums.equilibrium import (
    AlvesEquilibrium,
    BoschEquilibrium,
    Capability,
    Equilibrium,
    EquilibriumSolver,
    VenturiniEquilibrium,
)
from albums.theories import (
    ALL_THEORY_CLASSES,
    BoschMode1,
    Category,
    HOMCoupledBunch,
    InstabilityTheory,
    PTBLAlves,
    PTBLHe,
    RobinsonCoupling,
    RobinsonNoCoupling,
    TheoryResult,
    ZeroFrequency,
    names_in_category,
)


[docs]class IncompatibleFlowError(ValueError): """Raised by Flow.validate when a theory needs a capability the solver does not provide."""
# Theory-name groupings, used to fetch results by physical category regardless # of which concrete theory a flow chose. _ROBINSON_NAMES = names_in_category(Category.ROBINSON) _PTBL_NAMES = names_in_category(Category.PTBL)
[docs]@dataclass class Result: """ Aggregated outcome of running a flow at one operating point. """ converged: bool equilibrium: Equilibrium theories: dict[str, TheoryResult] = field(default_factory=dict) Touschek: float = np.nan xi: float | None = None def __repr__(self) -> str: names = ", ".join(self.theories) return ( f"Result(converged={self.converged}, Touschek={self.Touschek}, " f"xi={self.xi}, theories=[{names}])" )
[docs] def get(self, *names: str) -> "TheoryResult | None": """Return the first present TheoryResult among names, else None.""" for n in names: if n in self.theories: return self.theories[n] return None
@property def any_unstable(self) -> bool: return any(bool(np.any(tr.unstable)) for tr in self.theories.values()) # -- physical-category accessors (flow-agnostic) ----------------------- # @property def zero_frequency(self) -> bool: t = self.get("zero_frequency") return bool(t.unstable) if t is not None else False @property def robinson_flags(self) -> np.ndarray: t = self.get(*_ROBINSON_NAMES) return np.asarray(t.unstable) if t is not None else np.zeros(4, dtype=bool) @property def hom(self) -> bool: t = self.get("hom_coupled_bunch") return bool(t.unstable) if t is not None else False @property def ptbl(self) -> bool: t = self.get(*_PTBL_NAMES) return bool(t.unstable) if t is not None else False # -- growth rates (flow-agnostic) -------------------------------------- # @property def growth_robinson(self) -> np.ndarray: t = self.get(*_ROBINSON_NAMES) return np.asarray(t.growth_rate) if t is not None else np.zeros(4) @property def growth_zero_frequency(self) -> float: t = self.get("zero_frequency") return float(t.growth_rate) if t is not None else 0.0 @property def growth_PTBL(self) -> float | np.ndarray: t = self.get(*_PTBL_NAMES) return t.growth_rate if t is not None else 0.0
[docs]@dataclass class Flow: """A user-chosen pipeline: one equilibrium solver + a subset of theories.""" equilibrium: EquilibriumSolver theories: list[InstabilityTheory] name: str | None = None
[docs] def required_capabilities(self) -> frozenset[Capability]: caps: set[Capability] = set() for t in self.theories: caps |= set(t.requires) return frozenset(caps)
[docs] def validate(self) -> None: """Raise IncompatibleFlowError if any theory is incompatible with the chosen equilibrium solver. Call this before any compute.""" provided = self.equilibrium.provides problems = [] for t in self.theories: missing = set(t.requires) - set(provided) if missing: names = ", ".join(sorted(c.value for c in missing)) problems.append( f" - theory '{t.name}' needs [{names}] which " f"'{self.equilibrium.name}' does not provide" ) if problems: raise IncompatibleFlowError( f"Flow '{self.name or 'custom'}' is invalid:\n" + "\n".join(problems) )
[docs] def is_valid(self) -> bool: try: self.validate() return True except IncompatibleFlowError: return False
# --------------------------------------------------------------------------- # # Default flows — exact equivalents of the old method= strings. # # --------------------------------------------------------------------------- # DEFAULT_FLOWS: dict[str, Flow] = { "Bosch": Flow( BoschEquilibrium(), [ZeroFrequency(), RobinsonCoupling(), HOMCoupledBunch(), BoschMode1()], name="Bosch", ), "Bosch_no_coupling": Flow( BoschEquilibrium(), [ZeroFrequency(), RobinsonNoCoupling(), HOMCoupledBunch(), BoschMode1()], name="Bosch_no_coupling", ), "Venturini": Flow( VenturiniEquilibrium(), [ZeroFrequency(), RobinsonCoupling(), HOMCoupledBunch(), PTBLHe()], name="Venturini", ), "Alves": Flow( AlvesEquilibrium(), [ZeroFrequency(), RobinsonCoupling(), HOMCoupledBunch(), PTBLAlves()], name="Alves", ), } # --------------------------------------------------------------------------- # # Generated compatibility matrix — the single source for README/docs. # # --------------------------------------------------------------------------- # ALL_SOLVERS: list[EquilibriumSolver] = [ BoschEquilibrium(), VenturiniEquilibrium(), AlvesEquilibrium(), ] ALL_THEORIES: list[InstabilityTheory] = [cls() for cls in ALL_THEORY_CLASSES]
[docs]def flow_label(flow) -> str: """Short label for filenames: the flow name (or 'custom'), or the name string.""" if isinstance(flow, Flow): return flow.name or "custom" return str(flow)
[docs]def is_compatible(solver: EquilibriumSolver, theory: InstabilityTheory) -> bool: return set(theory.requires).issubset(solver.provides)
[docs]def compatibility_matrix( solvers: Iterable[EquilibriumSolver] | None = None, theories: Iterable[InstabilityTheory] | None = None, ) -> str: """Render the solver x theory compatibility table as a Markdown string. Generated from the capability sets so the docs cannot drift from the code. """ solvers = list(solvers or ALL_SOLVERS) theories = list(theories or ALL_THEORIES) header = "| Theory | " + " | ".join(s.name for s in solvers) + " |" sep = "|---|" + "|".join([":--:"] * len(solvers)) + "|" rows = [header, sep] for t in theories: cells = ["✅" if is_compatible(s, t) else "❌" for s in solvers] rows.append(f"| `{t.name}` | " + " | ".join(cells) + " |") return "\n".join(rows)
if __name__ == "__main__": # Sanity: default flows must validate; print the docs matrix. for flow in DEFAULT_FLOWS.values(): flow.validate() print(compatibility_matrix())