"""
Module with the core plotting functions.
"""
import warnings
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LogNorm
from albums.flow import flow_label
from albums.options import PlotOptions
from albums.theories import Category, coupled_robinson_names, names_in_category
plt.rcParams["figure.figsize"] = [16, 9]
[docs]def save_plot(name, label, plot_2D, tau_boundary=None):
"""Helper function to save plots with consistent naming."""
if label == "Bosch" and tau_boundary is not None:
save_name = f"{name}_{label}_tau_{int(tau_boundary * 1e12)}ps"
else:
save_name = f"{name}_{label}"
plt.savefig(f"{save_name}_{plot_2D}.png", dpi=300)
[docs]def create_grid(var1, var2, var1_unit, var2_unit):
"""Helper function to create grids and scaled variables."""
var1_grid, var2_grid = np.meshgrid(var1, var2)
var1_grid, var2_grid = var1_grid.T * var1_unit, var2_grid.T * var2_unit
return var1_grid, var2_grid, var1 * var1_unit, var2 * var2_unit
[docs]def plot_image(
ax,
data,
extent,
clabel,
var1_grid,
var2_grid,
cmap="viridis",
vmin=None,
vmax=None,
norm=None,
colorbar=True,
contour=False,
colorplot=True,
contour_dict={},
):
"""Helper function to handle 2D data visualization."""
if colorplot:
imshow_kw = dict(origin="lower", cmap=cmap, aspect="auto", extent=extent)
if norm is not None:
imshow_kw["norm"] = norm
else:
imshow_kw["vmin"], imshow_kw["vmax"] = vmin, vmax
c = ax.imshow(data, **imshow_kw)
if colorbar:
cbar = plt.colorbar(c, ax=ax)
cbar.set_label(clabel)
if contour:
data = data.copy() # avoid mutating the caller's (possibly stored) array
data[data == 0] = np.nan
contours = ax.contour(
var1_grid.T,
var2_grid.T,
data,
contour_dict["levels"],
colors=contour_dict.get("colors", "black"),
alpha=contour_dict.get("alpha", 1),
linestyles=contour_dict.get("linestyles", "-"),
)
ax.clabel(
contours,
inline=True,
fontsize=10,
manual=contour_dict.get("manual_clabel", False),
)
[docs]def _as_result_array(out):
"""A ResultArray from either a ResultArray or a raw results_to_arrays dict."""
from albums.result_io import ResultArray
return out if isinstance(out, ResultArray) else ResultArray(out)
[docs]def _categories(out):
"""Flat per-category arrays the plotters consume."""
return _as_result_array(out).categories()
# --------------------------------------------------------------------------- #
# Background colour-map quantities #
# --------------------------------------------------------------------------- #
[docs]def _aggregate_growth(ra):
"""Most-unstable growth rate at each grid point, reduced over every theory in
the result that carries a growth rate.
"""
arrays = []
for fields in ra.theories.values():
growth = fields.get("growth_rate")
if growth is None:
continue
growth = np.asarray(growth, dtype=float)
if growth.ndim > 2: # (n1, n2, n_mode) -> per mode
arrays.extend(growth[:, :, k] for k in range(growth.shape[-1]))
else:
arrays.append(growth)
if not arrays:
raise ValueError("no theory in this result carries a growth rate")
return np.maximum.reduce(arrays)
[docs]def _theory_growth(ra, spec):
"""(data, label) for a per-theory growth quantity.
spec is "<theory_name>" or "<theory_name>:<mode>". Growth is read straight
from the theory's own storage (ra.theories[name]) and masked to the points
where that theory (or the selected mode) flags an instability.
"""
name, sep, mode = spec.partition(":")
if name not in ra.theories:
raise ValueError(
f"theory '{name}' is not in this result; available: {sorted(ra.theories)}"
)
fields = ra.theories[name]
growth = fields.get("growth_rate")
if growth is None:
raise ValueError(f"theory '{name}' does not carry a growth rate")
growth = np.asarray(growth, dtype=float)
unstable = np.asarray(fields["unstable"])
if sep: # explicit mode index
k = int(mode)
return np.where(
unstable[:, :, k], growth[:, :, k], np.nan
), f"{name} mode {k} growth rate [1/s]"
if growth.ndim > 2: # multi-mode -> most unstable
masked = np.where(unstable, growth, np.nan)
with warnings.catch_warnings(): # all-NaN slices are expected
warnings.simplefilter("ignore", RuntimeWarning)
data = np.nanmax(masked, axis=-1)
return data, f"{name} growth rate [1/s]"
return np.where(unstable, growth, np.nan), f"{name} growth rate [1/s]"
[docs]def _background_data(ra, quantity):
"""The (data, colour-bar label) for one background quantity.
Equilibrium quantities ("xi", "bunch_length", "Touschek") and the
optimiser-only "psi" map read the flow-agnostic categories. Growth quantities are keyed on
the theory name: "growth" gives the most-unstable map across all theories,
while "growth_<theory_name>" (optionally "growth_<theory_name>:<mode>") reads
one theory's growth rate, masked where it flags an instability.
"""
cat = ra.categories()
if quantity == "xi":
return cat["xi"], r"$\xi$"
if quantity == "bunch_length":
return cat["bl"], "Bunch length [ps]"
if quantity == "Touschek":
return cat["Touschek"], "Touschek lifetime ratio"
if quantity == "psi":
psi = ra.extra.get("psi")
if psi is None:
raise ValueError("quantity 'psi' is only available for optimisation scans")
return psi, "psi [deg]"
if quantity == "growth":
return _aggregate_growth(ra), "Most unstable growth rate [1/s]"
if quantity.startswith("growth_"):
return _theory_growth(ra, quantity[len("growth_") :])
raise ValueError(
"Invalid quantity. Must be 'xi', 'bunch_length', 'Touschek', 'psi', "
"'growth' (most unstable across theories), or 'growth_<theory_name>' "
"with an optional ':<mode>' index (e.g. 'growth_robinson_coupling:0')."
)
# --------------------------------------------------------------------------- #
# Instability markers — derived from whichever theories the flow ran #
# --------------------------------------------------------------------------- #
# Robinson-family theory names; the coupled set draws a single fast-mode marker,
# the uncoupled set splits the higher modes (sextupole/octupole).
_ROBINSON_NAMES = names_in_category(Category.ROBINSON)
_ROBINSON_COUPLED = coupled_robinson_names()
# PTBL-family markers, keyed per theory so several can coexist on one plot.
_PTBL_MARKERS = {
"ptbl_he": ("X", "PTBL (He criterion)", "tab:green"),
"ptbl_alves": ("P", "PTBL (Alves)", "tab:cyan"),
"bosch_mode1": ("s", "Bosch coupled-bunch l=1", "tab:olive"),
}
[docs]def _robinson_markers(unstable, coupled):
"""(mask, marker, label, color) tuples for a Robinson-family theory's 4 modes."""
if coupled:
return [
(unstable[:, :, 0], "o", "Dipole Robinson instability", "tab:blue"),
(unstable[:, :, 1], "v", "Quadrupole Robinson instability", "tab:orange"),
(
unstable[:, :, 2] | unstable[:, :, 3],
"*",
"Fast mode-coupling instability",
"tab:red",
),
]
return [
(unstable[:, :, 0], "o", "Dipole instability", "tab:blue"),
(unstable[:, :, 1], "v", "Quadrupole instability", "tab:orange"),
(unstable[:, :, 2], ">", "Sextupole instability", "tab:olive"),
(unstable[:, :, 3], "<", "Octupole instability", "tab:brown"),
]
[docs]def _instability_markers(ra):
"""Yield (mask, marker, label, color) for every instability the result holds.
Driven by the theories actually present in ra.theories, so the plot reflects
the flow, and multiple theories of the same physical category (for example
two PTBL theories) each get their own marker for benchmarking.
"""
markers = []
for name, fields in ra.theories.items():
unstable = fields.get("unstable")
if unstable is None:
continue
if name in _ROBINSON_NAMES:
markers += _robinson_markers(unstable, name in _ROBINSON_COUPLED)
elif name == "zero_frequency":
markers.append((unstable, "d", "Zero-frequency instability", "tab:pink"))
elif name == "hom_coupled_bunch":
markers.append((unstable, "^", "CBI driven by HOMs", "tab:purple"))
elif name in _PTBL_MARKERS:
marker, label, color = _PTBL_MARKERS[name]
markers.append((unstable, marker, label, color))
else:
markers.append((unstable, "P", name, "tab:gray"))
return markers
[docs]def _equilibrium_not_converged_mask(ra):
"""Points where the Stage-1 equilibrium solve itself failed (no theory ran)."""
return ~ra.converged
[docs]def _theory_not_converged_mask(ra):
"""Points where the equilibrium converged but some theory's own convergence
check (root-finder) failed, and no theory flags an instability there (so the
point isn't already decisively explained, e.g. by zero-frequency instability)."""
return ra.converged & ~ra.theory_converged & ~ra.any_unstable
# --------------------------------------------------------------------------- #
# Public 2D plotter #
# --------------------------------------------------------------------------- #
[docs]def plot_scan(
out,
var1,
var2,
var1_unit,
var2_unit,
var1_label,
var2_label,
quantity,
plot_opts=None,
name=None,
flow="",
tau_boundary=None,
):
"""Plot one background quantity over the 2D scan grid and overlay every
instability the flow produced.
Parameters
----------
out : ResultArray or dict
The scan result.
var1, var2 : numpy.ndarray
Scan axes (outer, inner).
var1_unit, var2_unit : float
Display scaling for each axis.
var1_label, var2_label : str
Axis labels.
quantity : str
Background colour-map quantity (see _background_data for the list).
plot_opts : PlotOptions, optional
Display settings; defaults to PlotOptions().
name, flow, tau_boundary : optional
Used for the title and the saved-file name.
"""
opts = plot_opts or PlotOptions()
ra = _as_result_array(out)
var1_grid, var2_grid, var1_plot, var2_plot = create_grid(
var1, var2, var1_unit, var2_unit
)
data, clabel = _background_data(ra, quantity)
data = data.T # match imshow (rows = var2, cols = var1)
ax = opts.axes or configure_plot(xlabel=var1_label, ylabel=var2_label, legend=False)
norm = LogNorm(vmin=opts.cbar_v[0], vmax=opts.cbar_v[1]) if opts.log_color else None
plot_image(
ax,
data,
[var1_plot.min(), var1_plot.max(), var2_plot.min(), var2_plot.max()],
clabel,
var1_grid,
var2_grid,
cmap=opts.cmap,
vmin=opts.cbar_v[0],
vmax=opts.cbar_v[1],
norm=norm,
colorbar=opts.colorbar,
contour=opts.contour,
contour_dict={
"levels": opts.n_contour,
"alpha": opts.contour_alpha,
"linestyles": opts.contour_linestyles,
"manual_clabel": opts.manual_clabel,
},
colorplot=opts.colorplot,
)
scatter = list(_instability_markers(ra))
eq_not_converged = _equilibrium_not_converged_mask(ra)
if eq_not_converged.any():
scatter.append((eq_not_converged, "1", "Equilibrium not converged", "black"))
theory_not_converged = _theory_not_converged_mask(ra)
if theory_not_converged.any():
scatter.append((theory_not_converged, "1", "Not converged", "tab:gray"))
for condition, marker, label, color in scatter:
if not np.any(condition):
continue
ax.scatter(
var1_grid[condition],
var2_grid[condition],
marker=marker,
label=label,
alpha=opts.alpha,
s=opts.marker_size,
color=color,
)
if opts.show_legend:
ax.legend(loc="lower right")
if opts.title and name:
ax.set_title(name)
if opts.save:
if not name:
raise ValueError(
"Saving a figure (PlotOptions(save=True)) requires a name; "
"pass name=... to plot()."
)
save_plot(name, flow_label(flow), quantity, tau_boundary)