Source code for albums.result_io

# -*- coding: utf-8 -*-
"""
Turn a grid of Result objects into a nested dict of numpy
arrays whose structure mirrors Result itself — the single source of truth for
the scan / HDF5 schema (it is generated from the dataclass definitions, never
hand-maintained).

The output layout, for a scan over var1 × var2 grid points::

    converged                         (V1, V2)            bool
    Touschek                          (V1, V2)            float
    xi                                (V1, V2)            float
    equilibrium/<field>               (V1, V2[, k])       float / bool
    theories/<name>/<field>           (V1, V2[, n])       float / bool

Only numeric / boolean (and numeric ndarray) fields are persisted; live backend
handles (Equilibrium.backend_BLE / backend_LE), the cavity_list and
the capabilities set are non-serialisable and skipped automatically by the
type filter. Missing points (failed/None results, or a theory absent from a given
flow) are filled with a sentinel: NaN for floats, False for booleans.
"""

from __future__ import annotations

from dataclasses import fields
from typing import Callable

import numpy as np

from albums.equilibrium import Equilibrium
from albums.theories import Category, TheoryResult, names_in_category


[docs]def _is_numeric(v) -> bool: """True if v can be stored in an HDF5 dataset as-is.""" if isinstance(v, (bool, np.bool_, int, np.integer, float, np.floating)): return True if isinstance(v, np.ndarray): return v.dtype.kind in "biufc" return False
[docs]def _collect(grid_shape: tuple, getter: Callable, *, skip_if_all_none: bool = False): """Build one numpy array over the grid from getter applied per point. getter returns the field value for a grid point (or None if the result is missing / the field is unset). The trailing shape and dtype are inferred from the first non-None, numeric sample; non-numeric samples are skipped (returns None). Missing points get a NaN (float) / False (bool) sentinel. """ sample = None for idx in np.ndindex(grid_shape): v = getter(idx) if v is not None: sample = v break if sample is None: return None if skip_if_all_none else np.full(grid_shape, np.nan) if not _is_numeric(sample): return None sample_arr = np.asarray(sample) trailing = sample_arr.shape is_bool = sample_arr.dtype == np.bool_ or isinstance(sample, (bool, np.bool_)) dtype = bool if is_bool else float out = np.full(grid_shape + trailing, False if is_bool else np.nan, dtype=dtype) for idx in np.ndindex(grid_shape): v = getter(idx) if v is not None: out[idx] = v return out
[docs]def results_to_arrays(results: np.ndarray) -> dict: """Convert an object array of Result into a nested dict. Parameters ---------- results : numpy.ndarray Object array (any leading shape — typically 1D or 2D) whose entries are Result instances, or None for grid points that were skipped. Returns ------- dict Nested {field: array | {subfield: array}} mirroring Result (see the module docstring for the layout). """ grid_shape = results.shape out: dict = {} # -- top-level Result scalar fields (skip the nested members) ------------- # for fld in fields_of_result(): if fld in ("equilibrium", "theories"): continue arr = _collect( grid_shape, lambda idx, n=fld: _result_attr(results[idx], n), skip_if_all_none=True, ) if arr is not None: out[fld] = arr # -- equilibrium sub-fields (type filter drops handles/cavities) --------- # eq_dict: dict = {} for fld in fields(Equilibrium): arr = _collect( grid_shape, lambda idx, n=fld.name: _eq_attr(results[idx], n), skip_if_all_none=True, ) if arr is not None: eq_dict[fld.name] = arr if eq_dict: out["equilibrium"] = eq_dict # -- per-theory fields, for whichever theories the flow actually ran ------ # names = set() for r in results.flat: if r is not None: names.update(r.theories.keys()) th_dict: dict = {} for name in sorted(names): fields_dict: dict = {} for fld in fields(TheoryResult): if fld.name == "name": continue arr = _collect( grid_shape, lambda idx, nm=name, a=fld.name: _theory_attr(results[idx], nm, a), skip_if_all_none=True, ) if arr is not None: fields_dict[fld.name] = arr th_dict[name] = fields_dict if th_dict: out["theories"] = th_dict return out
[docs]def fields_of_result(): """Field names of Result (imported lazily to avoid a cycle with flow).""" from albums.flow import Result return [f.name for f in fields(Result)]
[docs]def _result_attr(r, name): return getattr(r, name) if r is not None else None
[docs]def _eq_attr(r, name): if r is None or r.equilibrium is None: return None return getattr(r.equilibrium, name)
[docs]def _theory_attr(r, theory_name, attr): if r is None or theory_name not in r.theories: return None return getattr(r.theories[theory_name], attr)
# Theory-name groups: a flow uses one theory per category, so the grid-level # accessors resolve "the Robinson result" / "the PTBL result" by trying these in # order. _ROBINSON_NAMES = names_in_category(Category.ROBINSON) _PTBL_NAMES = names_in_category(Category.PTBL)
[docs]class ResultArray: """A grid of scan results, stored as the nested dict of arrays produced by results_to_arrays. Wraps that dict with vectorized, flow-agnostic accessors (the grid-level analogue of Result's category accessors), ND indexing, HDF5 save/load and a plotting shortcut. Missing categories or grid points read back as sentinels (False for flags, NaN for floats). A scan additionally attaches *scan metadata* — the swept axes, their display units/labels, the flow and a name — so the object is self-describing: result.plot("xi") and result.save() need no further arguments. A bare ResultArray(data) (built by results_to_arrays or loaded from a plain HDF5 file) leaves these as None/identity and still works. The extra dict holds optimisation-only maps (e.g. the optimised psi grid) that are not part of the Result schema but are plottable as a background quantity. Parameters ---------- data : dict Nested dict from results_to_arrays (or loaded from HDF5). var1, var2 : numpy.ndarray, optional The swept axes (outer, inner). None for a bare result. var1_unit, var2_unit : float Multiplicative display scaling for each axis. var1_label, var2_label : str, optional Axis labels for plots. name : str Base name used for plot titles and saved files. flow : Flow or str, optional The flow that produced the scan (used for the saved-file label only; marker derivation reads the theory names already stored in data). tau_boundary : float, optional Resolved integration half-window, used only in the Bosch plot filename. extra : dict, optional Named 2D arrays outside the Result schema (e.g. {"psi": grid}). """
[docs] def __init__( self, data: dict, *, var1=None, var2=None, var1_unit: float = 1.0, var2_unit: float = 1.0, var1_label: str | None = None, var2_label: str | None = None, name: str | None = None, flow=None, tau_boundary=None, extra: dict | None = None, ): self.data = data self.var1 = var1 self.var2 = var2 self.var1_unit = var1_unit self.var2_unit = var2_unit self.var1_label = var1_label self.var2_label = var2_label self.name = name self.flow = flow self.tau_boundary = tau_boundary self.extra = extra or {}
# -- shape / raw access ------------------------------------------------- # @property def shape(self) -> tuple: """Grid shape (var1, var2).""" return self.data["converged"].shape
[docs] def to_dict(self) -> dict: """The underlying nested dict of arrays.""" return self.data
def __getitem__(self, key): """ra["converged"] returns a top-level field; ra[i, j] returns the result at one grid point as a nested dict.""" if isinstance(key, str): return self.data[key] return self.point(key)
[docs] def point(self, idx) -> dict: """The values at one grid point as a nested dict (no Result rebuilt).""" def walk(d): out = {} for k, v in d.items(): out[k] = walk(v) if isinstance(v, dict) else v[idx] return out return walk(self.data)
# -- helpers ------------------------------------------------------------ #
[docs] def _theory(self, names): theories = self.data.get("theories", {}) for n in names: if n in theories: return theories[n] return None
[docs] def _field(self, group, key, tail=(), fill=False, dtype=bool): if group is not None and key in group: return group[key] return np.full(self.shape + tail, fill, dtype=dtype)
# -- top-level scalars -------------------------------------------------- # @property def converged(self): return self.data["converged"] @property def Touschek(self): return self.data.get("Touschek", np.full(self.shape, np.nan)) @property def xi(self): return self.data.get("xi", np.full(self.shape, np.nan)) @property def bunch_length(self): """RMS bunch length per grid point, in [s].""" eq = self.data.get("equilibrium", {}) return eq.get("bunch_length", np.full(self.shape, np.nan)) @property def theories(self) -> dict: return self.data.get("theories", {}) # -- per-category accessors (flow-agnostic) ----------------------------- # @property def zero_frequency(self): return self._field(self.theories.get("zero_frequency"), "unstable") @property def robinson_flags(self): return self._field(self._theory(_ROBINSON_NAMES), "unstable", (4,)) @property def modes(self): return self._field( self._theory(_ROBINSON_NAMES), "mode_frequency", (4,), np.nan, float ) @property def converged_modes(self): return self._field(self._theory(_ROBINSON_NAMES), "converged", (4,)) @property def hom(self): return self._field(self.theories.get("hom_coupled_bunch"), "unstable") @property def ptbl(self): return self._field(self._theory(_PTBL_NAMES), "unstable") @property def growth_robinson(self): return self._field( self._theory(_ROBINSON_NAMES), "growth_rate", (4,), np.nan, float ) @property def growth_zero_frequency(self): return self._field( self.theories.get("zero_frequency"), "growth_rate", (), np.nan, float ) @property def growth_ptbl(self): return self._field(self._theory(_PTBL_NAMES), "growth_rate", (), np.nan, float)
[docs] def _collapse_trailing(self, arr: np.ndarray, reduce: Callable) -> np.ndarray: """Reduce an array of shape self.shape followed by trailing dimensions down to self.shape, where the trailing dimension(s) are whatever per-mode axes the field carries (possibly none). Grid-dimension-agnostic: works for 1D and 2D scans alike.""" grid_ndim = len(self.shape) if arr.ndim > grid_ndim: return reduce(arr, axis=tuple(range(grid_ndim, arr.ndim))) return arr
# -- flow-agnostic aggregates across every theory present --------------- # @property def any_unstable(self) -> np.ndarray: """True where at least one theory in this result flags instability at that point. Grid-level analogue of Result.any_unstable. Points where a theory didn't run (e.g. because the equilibrium itself failed there) read as not-unstable.""" out = np.zeros(self.shape, dtype=bool) for theory_fields in self.theories.values(): unstable = theory_fields.get("unstable") if unstable is None: continue out |= self._collapse_trailing(np.asarray(unstable), np.any) return out @property def theory_converged(self) -> np.ndarray: """True where every theory that tracks its own convergence (TheoryResult.converged) succeeded. Theories that don't track convergence default to True and have no effect, so a new theory with a 'converged' field is picked up automatically.""" out = np.ones(self.shape, dtype=bool) for theory_fields in self.theories.values(): converged = theory_fields.get("converged") if converged is None: continue out &= self._collapse_trailing(np.asarray(converged), np.all) return out
[docs] def categories(self) -> dict: """Flat per-category arrays under the legacy names used by the plotters.""" bl = self.bunch_length return { "zero_freq_coup": self.zero_frequency, "robinson_coup": self.robinson_flags, "modes_coup": self.modes, "HOM_coup": self.hom, "converged_coup": self.converged_modes, "PTBL_coup": self.ptbl, "bl": bl * 1e12, "xi": self.xi, "Touschek": self.Touschek, "growth_robinson": self.growth_robinson, "growth_zero_freq": self.growth_zero_frequency, "growth_PTBL": self.growth_ptbl, }
# -- scan metadata <-> HDF5 -------------------------------------------- #
[docs] def _meta_dict(self) -> dict: """The scan-metadata group written next to the data on save.""" from albums.flow import flow_label meta: dict = { "var1_unit": float(self.var1_unit), "var2_unit": float(self.var2_unit), "var1_label": self.var1_label or "", "var2_label": self.var2_label or "", "name": self.name or "", "flow_label": flow_label(self.flow) if self.flow is not None else "", } if self.var1 is not None: meta["var1"] = np.asarray(self.var1) if self.var2 is not None: meta["var2"] = np.asarray(self.var2) if self.tau_boundary is not None: meta["tau_boundary"] = float(self.tau_boundary) if self.extra: meta["extra"] = {k: np.asarray(v) for k, v in self.extra.items()} return meta
[docs] @staticmethod def _meta_kwargs(meta: dict) -> dict: """Turn a loaded scan-metadata group back into __init__ keyword args.""" def _str(x): v = np.asarray(x).item() return v.decode() if isinstance(v, bytes) else str(v) if not meta: return {} kw = { "var1_unit": float(np.asarray(meta["var1_unit"])), "var2_unit": float(np.asarray(meta["var2_unit"])), "var1_label": _str(meta["var1_label"]), "var2_label": _str(meta["var2_label"]), "name": _str(meta["name"]) or None, "flow": _str(meta["flow_label"]) or None, } if "var1" in meta: kw["var1"] = np.asarray(meta["var1"]) if "var2" in meta: kw["var2"] = np.asarray(meta["var2"]) if "tau_boundary" in meta: kw["tau_boundary"] = float(np.asarray(meta["tau_boundary"])) if "extra" in meta: kw["extra"] = {k: np.asarray(v) for k, v in meta["extra"].items()} return kw
# -- I/O ---------------------------------------------------------------- #
[docs] def save(self, name: str) -> None: """Save to <name>.hdf5 (nested HDF5 groups), including scan metadata.""" if not name: raise ValueError("ResultArray.save requires a non-empty name.") from albums.saveload import save_hdf5 payload = dict(self.data) payload["scan_meta"] = self._meta_dict() save_hdf5(f"{name}.hdf5", payload)
[docs] @classmethod def load(cls, file: str) -> "ResultArray": """Load a ResultArray (with any scan metadata) from a nested-HDF5 file.""" from albums.saveload import load_hdf5 data = load_hdf5(file) meta = data.pop("scan_meta", None) return cls(data, **cls._meta_kwargs(meta or {}))
# -- plotting shortcuts ------------------------------------------------- #
[docs] def plot( self, quantity, var1=None, var2=None, var1_unit=None, var2_unit=None, var1_label=None, var2_label=None, plot_opts=None, name=None, flow=None, tau_boundary=None, ): """Plot one background quantity plus the flow's instability markers. Axes metadata (var1/var2, units, labels, name, flow, tau_boundary) fall back to the values stored on the result, so a scan output plots with just result.plot("xi"). A bare ResultArray must supply var1 and var2. """ from albums.plot_func import plot_scan plot_scan( self, self.var1 if var1 is None else var1, self.var2 if var2 is None else var2, self.var1_unit if var1_unit is None else var1_unit, self.var2_unit if var2_unit is None else var2_unit, (self.var1_label or "var1") if var1_label is None else var1_label, (self.var2_label or "var2") if var2_label is None else var2_label, quantity, plot_opts=plot_opts, name=self.name if name is None else name, flow=(self.flow or "") if flow is None else flow, tau_boundary=self.tau_boundary if tau_boundary is None else tau_boundary, )