Skip to content

ezpz.plotΒΆ

plot_helpers.py

Contains helpers for plotting.

get_timestamp(fstr=None) ΒΆ

Get formatted timestamp.

Source code in src/ezpz/plot.py
def get_timestamp(fstr: Optional[str] = None) -> str:
    """Get formatted timestamp."""
    import datetime

    now = datetime.datetime.now()
    return now.strftime("%Y-%m-%d-%H%M%S") if fstr is None else now.strftime(fstr)

make_ridgeplots(dataset, num_chains=None, outdir=None, drop_zeros=False, drop_nans=True, cmap='viridis_r', save_plot=True) ΒΆ

Make ridgeplots.

Source code in src/ezpz/plot.py
def make_ridgeplots(
    dataset: xr.Dataset,
    num_chains: Optional[int] = None,
    outdir: Optional[os.PathLike] = None,
    drop_zeros: Optional[bool] = False,
    drop_nans: Optional[bool] = True,
    cmap: Optional[str] = "viridis_r",
    save_plot: bool = True,
):
    """Make ridgeplots."""
    import matplotlib.pyplot as plt
    import pandas as pd
    import seaborn as sns

    data = {}
    # with sns.axes_style('white', rc={'axes.facecolor': (0, 0, 0, 0)}):
    # sns.set(style='white', palette='bright', context='paper')
    # with sns.set_style(style='white'):
    outdir = Path(os.getcwd()) if outdir is None else Path(outdir)
    outdir = outdir.joinpath("ridgeplots")
    with sns.plotting_context(
        context="paper",
    ):
        sns.set_theme(
            style="white",
            palette="bright",
        )
        plt.rcParams["axes.facecolor"] = (0, 0, 0, 0.0)
        plt.rcParams["figure.facecolor"] = (0, 0, 0, 0.0)
        for key, val in dataset.data_vars.items():
            tstart = time.time()
            if "leapfrog" in val.coords.dims:
                lf_data = {
                    key: [],
                    "lf": [],
                    "avg": [],
                }
                for lf in val.leapfrog.values:
                    # val.shape = (chain, leapfrog, draw)
                    # x.shape = (chain, draw);  selects data for a single lf
                    x = val[{"leapfrog": lf}].values
                    # if num_chains is not None, keep `num_chains` for plotting
                    if num_chains is not None:
                        x = x[:num_chains, :]

                    x = x.flatten()
                    if drop_zeros:
                        x = x[x != 0]
                    #  x = val[{'leapfrog': lf}].values.flatten()
                    if drop_nans:
                        x = x[np.isfinite(x)]

                    lf_arr = np.array(len(x) * [f"{lf}"])
                    avg_arr = np.array(len(x) * [x.mean()])
                    lf_data[key].extend(x)
                    lf_data["lf"].extend(lf_arr)
                    lf_data["avg"].extend(avg_arr)

                lfdf = pd.DataFrame(lf_data)
                lfdf_avg = lfdf.groupby("lf")["avg"].mean()
                lfdf["lf_avg"] = lfdf["lf"].map(lfdf_avg)  # type:ignore

                # Initialize the FacetGrid object
                ncolors = len(val.leapfrog.values)
                pal = sns.color_palette(cmap, n_colors=ncolors)
                g = sns.FacetGrid(
                    lfdf,
                    row="lf",
                    hue="lf_avg",
                    aspect=15,
                    height=0.25,  # type:ignore
                    palette=pal,  # type:ignore
                )
                # avgs = lfdf.groupby('leapfrog')[f'Mean {key}']

                # Draw the densities in a few steps
                _ = g.map(
                    sns.kdeplot,
                    key,
                    cut=1,
                    bw_adjust=1.0,
                    clip_on=False,
                    fill=True,
                    alpha=0.7,
                    linewidth=1.25,
                )
                # _ = sns.histplot()
                # _ = g.map(sns.histplot, key)
                #           # rug=False, kde=False, norm_hist=False,
                #           # shade=True, alpha=0.7, linewidth=1.25)
                _ = g.map(plt.axhline, y=0, lw=1.0, alpha=0.9, clip_on=False)

                # Define and use a simple function to
                # label the plot in axes coords:

                def label(_, color, label):  # type: ignore # noqa
                    ax = plt.gca()
                    # assert isinstance(ax, plt.Axes)
                    _ = ax.set_ylabel("")  # type:ignore
                    _ = ax.set_yticks([])  # type:ignore
                    _ = ax.set_yticklabels([])  # type:ignore
                    color = ax.lines[-1].get_color()  # type:ignore
                    _ = ax.text(  # type:ignore
                        0,
                        0.10,
                        label,
                        fontweight="bold",
                        color=color,
                        ha="left",
                        va="center",
                        transform=ax.transAxes,  # type:ignore
                    )

                # _ = g.map(label, key)
                for i, ax in enumerate(g.axes.flat):
                    _ = ax.set_ylabel("")
                    _ = ax.set_yticks([])
                    _ = ax.set_yticklabels([])
                    ax.text(
                        0,
                        0.10,
                        f"{i}",
                        fontweight="bold",
                        ha="left",
                        va="center",
                        transform=ax.transAxes,
                        color=ax.lines[-1].get_color(),
                    )
                # Set the subplots to overlap
                _ = g.figure.subplots_adjust(hspace=-0.75)
                # Remove the axes details that don't play well with overlap
                _ = g.set_titles("")
                _ = g.set(yticks=[])
                _ = g.set(yticklabels=[])
                plt.rcParams["axes.labelcolor"] = "#bdbdbd"
                _ = g.set(xlabel=f"{key}")
                _ = g.despine(bottom=True, left=True)
                if outdir is not None and save_plot:
                    outdir = Path(outdir)
                    pngdir = outdir.joinpath("pngs")
                    svgdir = outdir.joinpath("svgs")
                    fsvg = Path(svgdir).joinpath(f"{key}_ridgeplot.svg")
                    fpng = Path(pngdir).joinpath(f"{key}_ridgeplot.png")

                    svgdir.mkdir(exist_ok=True, parents=True)
                    pngdir.mkdir(exist_ok=True, parents=True)

                    logger.info(f"Saving figure to: {fsvg.as_posix()}")
                    _ = plt.savefig(fsvg.as_posix(), dpi=400, bbox_inches="tight")
                    _ = plt.savefig(fpng.as_posix(), dpi=400, bbox_inches="tight")

                logger.debug(f"Ridgeplot for {key} took {time.time() - tstart:.3f}s")

    #  sns.set(style='whitegrid', palette='bright', context='paper')
    fig = plt.gcf()
    ax = plt.gca()

    return fig, ax, data

plot_arr(metric, name=None) ΒΆ

Returns (fig, axis) tuple

Source code in src/ezpz/plot.py
def plot_arr(
    metric: list,
    name: Optional[str] = None,
) -> tuple:
    """Returns (fig, axis) tuple"""
    assert len(metric) > 0
    y = np.stack(metric)
    if isinstance(metric[0], (int, float, bool, np.floating)):
        return plot_scalar(y, ylabel=name)
    element_shape = metric[0].shape
    if len(element_shape) == 2:
        # y = grab_tensor(torch.stack(metric))
        return plot_leapfrogs(y, ylabel=name)
    if len(element_shape) == 1:
        # y = grab_tensor(torch.stack(metric))
        return plot_chains(y, ylabel=name)
    raise ValueError

set_size(width=None, fraction=None, subplots=None) ΒΆ

Set figure dimensions to avoid scaling in LaTeX.

Source code in src/ezpz/plot.py
def set_size(
    width: Optional[str] = None,
    fraction: Optional[float] = None,
    subplots: Optional[tuple] = None,
) -> tuple[float, float]:
    """Set figure dimensions to avoid scaling in LaTeX."""
    width_pt = 345.0
    if width == "thesis":
        width_pt = 426.79135
    elif width == "beamer":
        width_pt = 307.28987
    fraction = 1.0 if fraction is None else fraction
    subplots = (1, 1) if subplots is None else subplots
    # Width of figure (in pts)
    fig_width_pt = width_pt * fraction
    # Convert from pt to inches
    inches_per_pt = 1 / 72.27

    # Golden ratio to set asethetic figure height
    golden_ratio = (5**0.5 - 1) / 2

    # Figure width in inches
    fig_width_in = fig_width_pt * inches_per_pt
    fig_height_in = fig_width_in * golden_ratio * (subplots[0] / subplots[1])

    return (fig_width_in, fig_height_in)

subplots(**kwargs) ΒΆ

Returns (fig, axis) tuple

Source code in src/ezpz/plot.py
def subplots(**kwargs) -> tuple:
    """Returns (fig, axis) tuple"""
    import matplotlib.pyplot as plt

    fig, ax = plt.subplots(**kwargs)
    assert isinstance(fig, plt.Figure)
    assert isinstance(ax, plt.Axes)
    return fig, ax