# -*- coding: utf-8 -*-
"""
Turn a grid of Result objects into a nested dict of numpy
arrays whose structure mirrors Result itself — the single source of truth for
the scan / HDF5 schema (it is generated from the dataclass definitions, never
hand-maintained).
The output layout, for a scan over var1 × var2 grid points::
converged (V1, V2) bool
Touschek (V1, V2) float
xi (V1, V2) float
equilibrium/<field> (V1, V2[, k]) float / bool
theories/<name>/<field> (V1, V2[, n]) float / bool
Only numeric / boolean (and numeric ndarray) fields are persisted; live backend
handles (Equilibrium.backend_BLE / backend_LE), the cavity_list and
the capabilities set are non-serialisable and skipped automatically by the
type filter. Missing points (failed/None results, or a theory absent from a given
flow) are filled with a sentinel: NaN for floats, False for booleans.
"""
from __future__ import annotations
from dataclasses import fields
from typing import Callable
import numpy as np
from albums.equilibrium import Equilibrium
from albums.theories import Category, TheoryResult, names_in_category
[docs]def _is_numeric(v) -> bool:
"""True if v can be stored in an HDF5 dataset as-is."""
if isinstance(v, (bool, np.bool_, int, np.integer, float, np.floating)):
return True
if isinstance(v, np.ndarray):
return v.dtype.kind in "biufc"
return False
[docs]def _collect(grid_shape: tuple, getter: Callable, *, skip_if_all_none: bool = False):
"""Build one numpy array over the grid from getter applied per point.
getter returns the field value for a grid point (or None if the result
is missing / the field is unset). The trailing shape and dtype are inferred
from the first non-None, numeric sample; non-numeric samples are skipped
(returns None). Missing points get a NaN (float) / False (bool) sentinel.
"""
sample = None
for idx in np.ndindex(grid_shape):
v = getter(idx)
if v is not None:
sample = v
break
if sample is None:
return None if skip_if_all_none else np.full(grid_shape, np.nan)
if not _is_numeric(sample):
return None
sample_arr = np.asarray(sample)
trailing = sample_arr.shape
is_bool = sample_arr.dtype == np.bool_ or isinstance(sample, (bool, np.bool_))
dtype = bool if is_bool else float
out = np.full(grid_shape + trailing, False if is_bool else np.nan, dtype=dtype)
for idx in np.ndindex(grid_shape):
v = getter(idx)
if v is not None:
out[idx] = v
return out
[docs]def results_to_arrays(results: np.ndarray) -> dict:
"""Convert an object array of Result into a nested dict.
Parameters
----------
results : numpy.ndarray
Object array (any leading shape — typically 1D or 2D) whose entries are
Result instances, or None for grid points that were skipped.
Returns
-------
dict
Nested {field: array | {subfield: array}} mirroring Result (see
the module docstring for the layout).
"""
grid_shape = results.shape
out: dict = {}
# -- top-level Result scalar fields (skip the nested members) ------------- #
for fld in fields_of_result():
if fld in ("equilibrium", "theories"):
continue
arr = _collect(
grid_shape,
lambda idx, n=fld: _result_attr(results[idx], n),
skip_if_all_none=True,
)
if arr is not None:
out[fld] = arr
# -- equilibrium sub-fields (type filter drops handles/cavities) --------- #
eq_dict: dict = {}
for fld in fields(Equilibrium):
arr = _collect(
grid_shape,
lambda idx, n=fld.name: _eq_attr(results[idx], n),
skip_if_all_none=True,
)
if arr is not None:
eq_dict[fld.name] = arr
if eq_dict:
out["equilibrium"] = eq_dict
# -- per-theory fields, for whichever theories the flow actually ran ------ #
names = set()
for r in results.flat:
if r is not None:
names.update(r.theories.keys())
th_dict: dict = {}
for name in sorted(names):
fields_dict: dict = {}
for fld in fields(TheoryResult):
if fld.name == "name":
continue
arr = _collect(
grid_shape,
lambda idx, nm=name, a=fld.name: _theory_attr(results[idx], nm, a),
skip_if_all_none=True,
)
if arr is not None:
fields_dict[fld.name] = arr
th_dict[name] = fields_dict
if th_dict:
out["theories"] = th_dict
return out
[docs]def fields_of_result():
"""Field names of Result (imported lazily to avoid a
cycle with flow)."""
from albums.flow import Result
return [f.name for f in fields(Result)]
[docs]def _result_attr(r, name):
return getattr(r, name) if r is not None else None
[docs]def _eq_attr(r, name):
if r is None or r.equilibrium is None:
return None
return getattr(r.equilibrium, name)
[docs]def _theory_attr(r, theory_name, attr):
if r is None or theory_name not in r.theories:
return None
return getattr(r.theories[theory_name], attr)
# Theory-name groups: a flow uses one theory per category, so the grid-level
# accessors resolve "the Robinson result" / "the PTBL result" by trying these in
# order.
_ROBINSON_NAMES = names_in_category(Category.ROBINSON)
_PTBL_NAMES = names_in_category(Category.PTBL)
[docs]class ResultArray:
"""A grid of scan results, stored as the nested dict of arrays produced by
results_to_arrays.
Wraps that dict with vectorized, flow-agnostic accessors (the grid-level
analogue of Result's category accessors), ND indexing, HDF5 save/load and a
plotting shortcut. Missing categories or grid points read back as sentinels
(False for flags, NaN for floats).
A scan additionally attaches *scan metadata* — the swept axes, their display
units/labels, the flow and a name — so the object is self-describing:
result.plot("xi") and result.save() need no further arguments. A bare
ResultArray(data) (built by results_to_arrays or loaded from a plain HDF5
file) leaves these as None/identity and still works. The extra dict holds
optimisation-only maps (e.g. the optimised psi grid) that are not part of
the Result schema but are plottable as a background quantity.
Parameters
----------
data : dict
Nested dict from results_to_arrays (or loaded from HDF5).
var1, var2 : numpy.ndarray, optional
The swept axes (outer, inner). None for a bare result.
var1_unit, var2_unit : float
Multiplicative display scaling for each axis.
var1_label, var2_label : str, optional
Axis labels for plots.
name : str
Base name used for plot titles and saved files.
flow : Flow or str, optional
The flow that produced the scan (used for the saved-file label only;
marker derivation reads the theory names already stored in data).
tau_boundary : float, optional
Resolved integration half-window, used only in the Bosch plot filename.
extra : dict, optional
Named 2D arrays outside the Result schema (e.g. {"psi": grid}).
"""
[docs] def __init__(
self,
data: dict,
*,
var1=None,
var2=None,
var1_unit: float = 1.0,
var2_unit: float = 1.0,
var1_label: str | None = None,
var2_label: str | None = None,
name: str | None = None,
flow=None,
tau_boundary=None,
extra: dict | None = None,
):
self.data = data
self.var1 = var1
self.var2 = var2
self.var1_unit = var1_unit
self.var2_unit = var2_unit
self.var1_label = var1_label
self.var2_label = var2_label
self.name = name
self.flow = flow
self.tau_boundary = tau_boundary
self.extra = extra or {}
# -- shape / raw access ------------------------------------------------- #
@property
def shape(self) -> tuple:
"""Grid shape (var1, var2)."""
return self.data["converged"].shape
[docs] def to_dict(self) -> dict:
"""The underlying nested dict of arrays."""
return self.data
def __getitem__(self, key):
"""ra["converged"] returns a top-level field; ra[i, j] returns the
result at one grid point as a nested dict."""
if isinstance(key, str):
return self.data[key]
return self.point(key)
[docs] def point(self, idx) -> dict:
"""The values at one grid point as a nested dict (no Result rebuilt)."""
def walk(d):
out = {}
for k, v in d.items():
out[k] = walk(v) if isinstance(v, dict) else v[idx]
return out
return walk(self.data)
# -- helpers ------------------------------------------------------------ #
[docs] def _theory(self, names):
theories = self.data.get("theories", {})
for n in names:
if n in theories:
return theories[n]
return None
[docs] def _field(self, group, key, tail=(), fill=False, dtype=bool):
if group is not None and key in group:
return group[key]
return np.full(self.shape + tail, fill, dtype=dtype)
# -- top-level scalars -------------------------------------------------- #
@property
def converged(self):
return self.data["converged"]
@property
def Touschek(self):
return self.data.get("Touschek", np.full(self.shape, np.nan))
@property
def xi(self):
return self.data.get("xi", np.full(self.shape, np.nan))
@property
def bunch_length(self):
"""RMS bunch length per grid point, in [s]."""
eq = self.data.get("equilibrium", {})
return eq.get("bunch_length", np.full(self.shape, np.nan))
@property
def theories(self) -> dict:
return self.data.get("theories", {})
# -- per-category accessors (flow-agnostic) ----------------------------- #
@property
def zero_frequency(self):
return self._field(self.theories.get("zero_frequency"), "unstable")
@property
def robinson_flags(self):
return self._field(self._theory(_ROBINSON_NAMES), "unstable", (4,))
@property
def modes(self):
return self._field(
self._theory(_ROBINSON_NAMES), "mode_frequency", (4,), np.nan, float
)
@property
def converged_modes(self):
return self._field(self._theory(_ROBINSON_NAMES), "converged", (4,))
@property
def hom(self):
return self._field(self.theories.get("hom_coupled_bunch"), "unstable")
@property
def ptbl(self):
return self._field(self._theory(_PTBL_NAMES), "unstable")
@property
def growth_robinson(self):
return self._field(
self._theory(_ROBINSON_NAMES), "growth_rate", (4,), np.nan, float
)
@property
def growth_zero_frequency(self):
return self._field(
self.theories.get("zero_frequency"), "growth_rate", (), np.nan, float
)
@property
def growth_ptbl(self):
return self._field(self._theory(_PTBL_NAMES), "growth_rate", (), np.nan, float)
[docs] def _collapse_trailing(self, arr: np.ndarray, reduce: Callable) -> np.ndarray:
"""Reduce an array of shape self.shape followed by trailing dimensions
down to self.shape, where the trailing dimension(s) are whatever per-mode
axes the field carries (possibly none). Grid-dimension-agnostic: works
for 1D and 2D scans alike."""
grid_ndim = len(self.shape)
if arr.ndim > grid_ndim:
return reduce(arr, axis=tuple(range(grid_ndim, arr.ndim)))
return arr
# -- flow-agnostic aggregates across every theory present --------------- #
@property
def any_unstable(self) -> np.ndarray:
"""True where at least one theory in this result flags instability at that
point. Grid-level analogue of Result.any_unstable.
Points where a theory didn't run (e.g. because the equilibrium itself
failed there) read as not-unstable."""
out = np.zeros(self.shape, dtype=bool)
for theory_fields in self.theories.values():
unstable = theory_fields.get("unstable")
if unstable is None:
continue
out |= self._collapse_trailing(np.asarray(unstable), np.any)
return out
@property
def theory_converged(self) -> np.ndarray:
"""True where every theory that tracks its own convergence
(TheoryResult.converged) succeeded. Theories that don't track convergence
default to True and have no effect, so a new theory with a 'converged'
field is picked up automatically."""
out = np.ones(self.shape, dtype=bool)
for theory_fields in self.theories.values():
converged = theory_fields.get("converged")
if converged is None:
continue
out &= self._collapse_trailing(np.asarray(converged), np.all)
return out
[docs] def categories(self) -> dict:
"""Flat per-category arrays under the legacy names used by the plotters."""
bl = self.bunch_length
return {
"zero_freq_coup": self.zero_frequency,
"robinson_coup": self.robinson_flags,
"modes_coup": self.modes,
"HOM_coup": self.hom,
"converged_coup": self.converged_modes,
"PTBL_coup": self.ptbl,
"bl": bl * 1e12,
"xi": self.xi,
"Touschek": self.Touschek,
"growth_robinson": self.growth_robinson,
"growth_zero_freq": self.growth_zero_frequency,
"growth_PTBL": self.growth_ptbl,
}
# -- scan metadata <-> HDF5 -------------------------------------------- #
# -- I/O ---------------------------------------------------------------- #
[docs] def save(self, name: str) -> None:
"""Save to <name>.hdf5 (nested HDF5 groups), including scan metadata."""
if not name:
raise ValueError("ResultArray.save requires a non-empty name.")
from albums.saveload import save_hdf5
payload = dict(self.data)
payload["scan_meta"] = self._meta_dict()
save_hdf5(f"{name}.hdf5", payload)
[docs] @classmethod
def load(cls, file: str) -> "ResultArray":
"""Load a ResultArray (with any scan metadata) from a nested-HDF5 file."""
from albums.saveload import load_hdf5
data = load_hdf5(file)
meta = data.pop("scan_meta", None)
return cls(data, **cls._meta_kwargs(meta or {}))
# -- plotting shortcuts ------------------------------------------------- #
[docs] def plot(
self,
quantity,
var1=None,
var2=None,
var1_unit=None,
var2_unit=None,
var1_label=None,
var2_label=None,
plot_opts=None,
name=None,
flow=None,
tau_boundary=None,
):
"""Plot one background quantity plus the flow's instability markers.
Axes metadata (var1/var2, units, labels, name, flow, tau_boundary) fall
back to the values stored on the result, so a scan output plots with just
result.plot("xi"). A bare ResultArray must supply var1 and var2.
"""
from albums.plot_func import plot_scan
plot_scan(
self,
self.var1 if var1 is None else var1,
self.var2 if var2 is None else var2,
self.var1_unit if var1_unit is None else var1_unit,
self.var2_unit if var2_unit is None else var2_unit,
(self.var1_label or "var1") if var1_label is None else var1_label,
(self.var2_label or "var2") if var2_label is None else var2_label,
quantity,
plot_opts=plot_opts,
name=self.name if name is None else name,
flow=(self.flow or "") if flow is None else flow,
tau_boundary=self.tau_boundary if tau_boundary is None else tau_boundary,
)