Source code for albums.plot_func

"""
Module with the core plotting functions.
"""

import warnings

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LogNorm

from albums.flow import flow_label
from albums.options import PlotOptions
from albums.theories import Category, coupled_robinson_names, names_in_category

plt.rcParams["figure.figsize"] = [16, 9]


[docs]def configure_plot( ax=None, title=None, xlabel=None, ylabel=None, grid=True, legend=True ): """Helper function to configure plot appearance.""" if ax is None: _, ax = plt.subplots(1, 1) if xlabel: ax.set_xlabel(xlabel) if ylabel: ax.set_ylabel(ylabel) ax.grid(grid) if legend: ax.legend(loc="lower right") if title: ax.set_title(title) return ax
[docs]def save_plot(name, label, plot_2D, tau_boundary=None): """Helper function to save plots with consistent naming.""" if label == "Bosch" and tau_boundary is not None: save_name = f"{name}_{label}_tau_{int(tau_boundary * 1e12)}ps" else: save_name = f"{name}_{label}" plt.savefig(f"{save_name}_{plot_2D}.png", dpi=300)
[docs]def create_grid(var1, var2, var1_unit, var2_unit): """Helper function to create grids and scaled variables.""" var1_grid, var2_grid = np.meshgrid(var1, var2) var1_grid, var2_grid = var1_grid.T * var1_unit, var2_grid.T * var2_unit return var1_grid, var2_grid, var1 * var1_unit, var2 * var2_unit
[docs]def plot_image( ax, data, extent, clabel, var1_grid, var2_grid, cmap="viridis", vmin=None, vmax=None, norm=None, colorbar=True, contour=False, colorplot=True, contour_dict={}, ): """Helper function to handle 2D data visualization.""" if colorplot: imshow_kw = dict(origin="lower", cmap=cmap, aspect="auto", extent=extent) if norm is not None: imshow_kw["norm"] = norm else: imshow_kw["vmin"], imshow_kw["vmax"] = vmin, vmax c = ax.imshow(data, **imshow_kw) if colorbar: cbar = plt.colorbar(c, ax=ax) cbar.set_label(clabel) if contour: data = data.copy() # avoid mutating the caller's (possibly stored) array data[data == 0] = np.nan contours = ax.contour( var1_grid.T, var2_grid.T, data, contour_dict["levels"], colors=contour_dict.get("colors", "black"), alpha=contour_dict.get("alpha", 1), linestyles=contour_dict.get("linestyles", "-"), ) ax.clabel( contours, inline=True, fontsize=10, manual=contour_dict.get("manual_clabel", False), )
[docs]def _as_result_array(out): """A ResultArray from either a ResultArray or a raw results_to_arrays dict.""" from albums.result_io import ResultArray return out if isinstance(out, ResultArray) else ResultArray(out)
[docs]def _categories(out): """Flat per-category arrays the plotters consume.""" return _as_result_array(out).categories()
# --------------------------------------------------------------------------- # # Background colour-map quantities # # --------------------------------------------------------------------------- #
[docs]def _aggregate_growth(ra): """Most-unstable growth rate at each grid point, reduced over every theory in the result that carries a growth rate. """ arrays = [] for fields in ra.theories.values(): growth = fields.get("growth_rate") if growth is None: continue growth = np.asarray(growth, dtype=float) if growth.ndim > 2: # (n1, n2, n_mode) -> per mode arrays.extend(growth[:, :, k] for k in range(growth.shape[-1])) else: arrays.append(growth) if not arrays: raise ValueError("no theory in this result carries a growth rate") return np.maximum.reduce(arrays)
[docs]def _theory_growth(ra, spec): """(data, label) for a per-theory growth quantity. spec is "<theory_name>" or "<theory_name>:<mode>". Growth is read straight from the theory's own storage (ra.theories[name]) and masked to the points where that theory (or the selected mode) flags an instability. """ name, sep, mode = spec.partition(":") if name not in ra.theories: raise ValueError( f"theory '{name}' is not in this result; available: {sorted(ra.theories)}" ) fields = ra.theories[name] growth = fields.get("growth_rate") if growth is None: raise ValueError(f"theory '{name}' does not carry a growth rate") growth = np.asarray(growth, dtype=float) unstable = np.asarray(fields["unstable"]) if sep: # explicit mode index k = int(mode) return np.where( unstable[:, :, k], growth[:, :, k], np.nan ), f"{name} mode {k} growth rate [1/s]" if growth.ndim > 2: # multi-mode -> most unstable masked = np.where(unstable, growth, np.nan) with warnings.catch_warnings(): # all-NaN slices are expected warnings.simplefilter("ignore", RuntimeWarning) data = np.nanmax(masked, axis=-1) return data, f"{name} growth rate [1/s]" return np.where(unstable, growth, np.nan), f"{name} growth rate [1/s]"
[docs]def _background_data(ra, quantity): """The (data, colour-bar label) for one background quantity. Equilibrium quantities ("xi", "bunch_length", "Touschek") and the optimiser-only "psi" map read the flow-agnostic categories. Growth quantities are keyed on the theory name: "growth" gives the most-unstable map across all theories, while "growth_<theory_name>" (optionally "growth_<theory_name>:<mode>") reads one theory's growth rate, masked where it flags an instability. """ cat = ra.categories() if quantity == "xi": return cat["xi"], r"$\xi$" if quantity == "bunch_length": return cat["bl"], "Bunch length [ps]" if quantity == "Touschek": return cat["Touschek"], "Touschek lifetime ratio" if quantity == "psi": psi = ra.extra.get("psi") if psi is None: raise ValueError("quantity 'psi' is only available for optimisation scans") return psi, "psi [deg]" if quantity == "growth": return _aggregate_growth(ra), "Most unstable growth rate [1/s]" if quantity.startswith("growth_"): return _theory_growth(ra, quantity[len("growth_") :]) raise ValueError( "Invalid quantity. Must be 'xi', 'bunch_length', 'Touschek', 'psi', " "'growth' (most unstable across theories), or 'growth_<theory_name>' " "with an optional ':<mode>' index (e.g. 'growth_robinson_coupling:0')." )
# --------------------------------------------------------------------------- # # Instability markers — derived from whichever theories the flow ran # # --------------------------------------------------------------------------- # # Robinson-family theory names; the coupled set draws a single fast-mode marker, # the uncoupled set splits the higher modes (sextupole/octupole). _ROBINSON_NAMES = names_in_category(Category.ROBINSON) _ROBINSON_COUPLED = coupled_robinson_names() # PTBL-family markers, keyed per theory so several can coexist on one plot. _PTBL_MARKERS = { "ptbl_he": ("X", "PTBL (He criterion)", "tab:green"), "ptbl_alves": ("P", "PTBL (Alves)", "tab:cyan"), "bosch_mode1": ("s", "Bosch coupled-bunch l=1", "tab:olive"), }
[docs]def _robinson_markers(unstable, coupled): """(mask, marker, label, color) tuples for a Robinson-family theory's 4 modes.""" if coupled: return [ (unstable[:, :, 0], "o", "Dipole Robinson instability", "tab:blue"), (unstable[:, :, 1], "v", "Quadrupole Robinson instability", "tab:orange"), ( unstable[:, :, 2] | unstable[:, :, 3], "*", "Fast mode-coupling instability", "tab:red", ), ] return [ (unstable[:, :, 0], "o", "Dipole instability", "tab:blue"), (unstable[:, :, 1], "v", "Quadrupole instability", "tab:orange"), (unstable[:, :, 2], ">", "Sextupole instability", "tab:olive"), (unstable[:, :, 3], "<", "Octupole instability", "tab:brown"), ]
[docs]def _instability_markers(ra): """Yield (mask, marker, label, color) for every instability the result holds. Driven by the theories actually present in ra.theories, so the plot reflects the flow, and multiple theories of the same physical category (for example two PTBL theories) each get their own marker for benchmarking. """ markers = [] for name, fields in ra.theories.items(): unstable = fields.get("unstable") if unstable is None: continue if name in _ROBINSON_NAMES: markers += _robinson_markers(unstable, name in _ROBINSON_COUPLED) elif name == "zero_frequency": markers.append((unstable, "d", "Zero-frequency instability", "tab:pink")) elif name == "hom_coupled_bunch": markers.append((unstable, "^", "CBI driven by HOMs", "tab:purple")) elif name in _PTBL_MARKERS: marker, label, color = _PTBL_MARKERS[name] markers.append((unstable, marker, label, color)) else: markers.append((unstable, "P", name, "tab:gray")) return markers
[docs]def _equilibrium_not_converged_mask(ra): """Points where the Stage-1 equilibrium solve itself failed (no theory ran).""" return ~ra.converged
[docs]def _theory_not_converged_mask(ra): """Points where the equilibrium converged but some theory's own convergence check (root-finder) failed, and no theory flags an instability there (so the point isn't already decisively explained, e.g. by zero-frequency instability).""" return ra.converged & ~ra.theory_converged & ~ra.any_unstable
# --------------------------------------------------------------------------- # # Public 2D plotter # # --------------------------------------------------------------------------- #
[docs]def plot_scan( out, var1, var2, var1_unit, var2_unit, var1_label, var2_label, quantity, plot_opts=None, name=None, flow="", tau_boundary=None, ): """Plot one background quantity over the 2D scan grid and overlay every instability the flow produced. Parameters ---------- out : ResultArray or dict The scan result. var1, var2 : numpy.ndarray Scan axes (outer, inner). var1_unit, var2_unit : float Display scaling for each axis. var1_label, var2_label : str Axis labels. quantity : str Background colour-map quantity (see _background_data for the list). plot_opts : PlotOptions, optional Display settings; defaults to PlotOptions(). name, flow, tau_boundary : optional Used for the title and the saved-file name. """ opts = plot_opts or PlotOptions() ra = _as_result_array(out) var1_grid, var2_grid, var1_plot, var2_plot = create_grid( var1, var2, var1_unit, var2_unit ) data, clabel = _background_data(ra, quantity) data = data.T # match imshow (rows = var2, cols = var1) ax = opts.axes or configure_plot(xlabel=var1_label, ylabel=var2_label, legend=False) norm = LogNorm(vmin=opts.cbar_v[0], vmax=opts.cbar_v[1]) if opts.log_color else None plot_image( ax, data, [var1_plot.min(), var1_plot.max(), var2_plot.min(), var2_plot.max()], clabel, var1_grid, var2_grid, cmap=opts.cmap, vmin=opts.cbar_v[0], vmax=opts.cbar_v[1], norm=norm, colorbar=opts.colorbar, contour=opts.contour, contour_dict={ "levels": opts.n_contour, "alpha": opts.contour_alpha, "linestyles": opts.contour_linestyles, "manual_clabel": opts.manual_clabel, }, colorplot=opts.colorplot, ) scatter = list(_instability_markers(ra)) eq_not_converged = _equilibrium_not_converged_mask(ra) if eq_not_converged.any(): scatter.append((eq_not_converged, "1", "Equilibrium not converged", "black")) theory_not_converged = _theory_not_converged_mask(ra) if theory_not_converged.any(): scatter.append((theory_not_converged, "1", "Not converged", "tab:gray")) for condition, marker, label, color in scatter: if not np.any(condition): continue ax.scatter( var1_grid[condition], var2_grid[condition], marker=marker, label=label, alpha=opts.alpha, s=opts.marker_size, color=color, ) if opts.show_legend: ax.legend(loc="lower right") if opts.title and name: ax.set_title(name) if opts.save: if not name: raise ValueError( "Saving a figure (PlotOptions(save=True)) requires a name; " "pass name=... to plot()." ) save_plot(name, flow_label(flow), quantity, tau_boundary)