def make_ridgeplots(
dataset: xr.Dataset,
num_chains: Optional[int] = None,
outdir: Optional[os.PathLike] = None,
drop_zeros: Optional[bool] = False,
drop_nans: Optional[bool] = True,
cmap: Optional[str] = "viridis_r",
save_plot: bool = True,
):
"""Make ridgeplots."""
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
data = {}
# with sns.axes_style('white', rc={'axes.facecolor': (0, 0, 0, 0)}):
# sns.set(style='white', palette='bright', context='paper')
# with sns.set_style(style='white'):
outdir = Path(os.getcwd()) if outdir is None else Path(outdir)
outdir = outdir.joinpath("ridgeplots")
with sns.plotting_context(
context="paper",
):
sns.set_theme(
style="white",
palette="bright",
)
plt.rcParams["axes.facecolor"] = (0, 0, 0, 0.0)
plt.rcParams["figure.facecolor"] = (0, 0, 0, 0.0)
for key, val in dataset.data_vars.items():
tstart = time.time()
if "leapfrog" in val.coords.dims:
lf_data = {
key: [],
"lf": [],
"avg": [],
}
for lf in val.leapfrog.values:
# val.shape = (chain, leapfrog, draw)
# x.shape = (chain, draw); selects data for a single lf
x = val[{"leapfrog": lf}].values
# if num_chains is not None, keep `num_chains` for plotting
if num_chains is not None:
x = x[:num_chains, :]
x = x.flatten()
if drop_zeros:
x = x[x != 0]
# x = val[{'leapfrog': lf}].values.flatten()
if drop_nans:
x = x[np.isfinite(x)]
lf_arr = np.array(len(x) * [f"{lf}"])
avg_arr = np.array(len(x) * [x.mean()])
lf_data[key].extend(x)
lf_data["lf"].extend(lf_arr)
lf_data["avg"].extend(avg_arr)
lfdf = pd.DataFrame(lf_data)
lfdf_avg = lfdf.groupby("lf")["avg"].mean()
lfdf["lf_avg"] = lfdf["lf"].map(lfdf_avg) # type:ignore
# Initialize the FacetGrid object
ncolors = len(val.leapfrog.values)
pal = sns.color_palette(cmap, n_colors=ncolors)
g = sns.FacetGrid(
lfdf,
row="lf",
hue="lf_avg",
aspect=15,
height=0.25, # type:ignore
palette=pal, # type:ignore
)
# avgs = lfdf.groupby('leapfrog')[f'Mean {key}']
# Draw the densities in a few steps
_ = g.map(
sns.kdeplot,
key,
cut=1,
bw_adjust=1.0,
clip_on=False,
fill=True,
alpha=0.7,
linewidth=1.25,
)
# _ = sns.histplot()
# _ = g.map(sns.histplot, key)
# # rug=False, kde=False, norm_hist=False,
# # shade=True, alpha=0.7, linewidth=1.25)
_ = g.map(plt.axhline, y=0, lw=1.0, alpha=0.9, clip_on=False)
# Define and use a simple function to
# label the plot in axes coords:
def label(_, color, label): # type: ignore # noqa
ax = plt.gca()
# assert isinstance(ax, plt.Axes)
_ = ax.set_ylabel("") # type:ignore
_ = ax.set_yticks([]) # type:ignore
_ = ax.set_yticklabels([]) # type:ignore
color = ax.lines[-1].get_color() # type:ignore
_ = ax.text( # type:ignore
0,
0.10,
label,
fontweight="bold",
color=color,
ha="left",
va="center",
transform=ax.transAxes, # type:ignore
)
# _ = g.map(label, key)
for i, ax in enumerate(g.axes.flat):
_ = ax.set_ylabel("")
_ = ax.set_yticks([])
_ = ax.set_yticklabels([])
ax.text(
0,
0.10,
f"{i}",
fontweight="bold",
ha="left",
va="center",
transform=ax.transAxes,
color=ax.lines[-1].get_color(),
)
# Set the subplots to overlap
_ = g.figure.subplots_adjust(hspace=-0.75)
# Remove the axes details that don't play well with overlap
_ = g.set_titles("")
_ = g.set(yticks=[])
_ = g.set(yticklabels=[])
plt.rcParams["axes.labelcolor"] = "#bdbdbd"
_ = g.set(xlabel=f"{key}")
_ = g.despine(bottom=True, left=True)
if outdir is not None and save_plot:
outdir = Path(outdir)
pngdir = outdir.joinpath("pngs")
svgdir = outdir.joinpath("svgs")
fsvg = Path(svgdir).joinpath(f"{key}_ridgeplot.svg")
fpng = Path(pngdir).joinpath(f"{key}_ridgeplot.png")
svgdir.mkdir(exist_ok=True, parents=True)
pngdir.mkdir(exist_ok=True, parents=True)
logger.info(f"Saving figure to: {fsvg.as_posix()}")
_ = plt.savefig(fsvg.as_posix(), dpi=400, bbox_inches="tight")
_ = plt.savefig(fpng.as_posix(), dpi=400, bbox_inches="tight")
logger.debug(f"Ridgeplot for {key} took {time.time() - tstart:.3f}s")
# sns.set(style='whitegrid', palette='bright', context='paper')
fig = plt.gcf()
ax = plt.gca()
return fig, ax, data