class History:
"""
A class to track and log metrics during training or evaluation.
"""
def __init__(
self,
keys: Optional[list[str]] = None,
*,
report_dir: Optional[PathLike] = None,
report_enabled: bool = True,
jsonl_path: Optional[PathLike] = None,
jsonl_overwrite: bool = False,
distributed_history: bool = AUTO_USE_DISTRIBUTED_HISTORY,
) -> None:
"""
Initialize the History object.
Args:
keys (Optional[list[str]]): List of keys to initialize the history with.
If None, initializes with an empty list.
report_dir (Optional[PathLike]): Directory for markdown reports. Defaults
to ``OUTPUTS_DIR/history``.
report_enabled (bool): Toggle automatic markdown generation.
jsonl_path (Optional[PathLike]): Destination for JSONL metric logging.
jsonl_overwrite (bool): Whether to truncate an existing JSONL log.
distributed_history (bool): Enable distributed history tracking.
"""
self.keys = [] if keys is None else keys
self.history: dict[str, list[Any]] = {}
self.data = self.history
if (
os.environ.get("EZPZ_NO_DISTRIBUTED_HISTORY", None)
or os.environ.get("EZPZ_LOCAL_HISTORY", False)
or ezpz.dist.get_world_size() <= 1
):
logger.info(
"Not using distributed metrics! Will only be tracked from a single rank..."
)
distributed_history = False
# aggregate_metrics = False
self.distributed_history = distributed_history
logger.info(
f"Using {self.__class__.__name__} with distributed_history={self.distributed_history}"
)
# self._aggregate_metrics = aggregate_metrics
self._rank = get_rank()
now = datetime.now(timezone.utc)
self._run_id = now.strftime("%Y%m%d-%H%M%S")
self.report_enabled = report_enabled
base_report_root = (
Path(report_dir)
if report_dir is not None
else Path(OUTPUTS_DIR).joinpath("history")
)
self._report_root = Path(base_report_root).expanduser().resolve()
self._report_dir = self._report_root.joinpath(self._run_id)
self._report_path: Optional[Path] = None
self._asset_dir: Optional[Path] = None
self._report_filename = "report.md"
self._report_initialized = False
if jsonl_path is None:
default_jsonl_dir = (
self._report_dir if report_enabled else Path(OUTPUTS_DIR)
)
self._jsonl_path = (
Path(default_jsonl_dir)
.expanduser()
.resolve()
.joinpath(f"{self._run_id}.jsonl")
)
else:
self._jsonl_path = Path(jsonl_path).expanduser().resolve()
if jsonl_overwrite and self._jsonl_path.exists():
try:
self._jsonl_path.unlink()
except OSError:
logger.warning(
"Unable to remove existing JSONL log at %s",
self._jsonl_path,
)
self._jsonl_enabled = True
self._dist = torch.distributed
self._environment_written = False
self._metric_summary_written = False
# ------------------------------------------------------------------ #
# Internal helpers
# ------------------------------------------------------------------ #
@staticmethod
def _utc_iso() -> str:
"""Return the current UTC timestamp in ISO-8601 format with trailing Z."""
return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
def _configure_report_destination(self, base_dir: Path) -> None:
"""Configure the report directory to live under *base_dir*."""
base_dir = base_dir.expanduser().resolve()
self._report_root = base_dir
self._report_dir = base_dir
self._report_path = base_dir.joinpath(self._report_filename)
self._asset_dir = base_dir.joinpath("assets")
self._report_initialized = False
if self._jsonl_enabled:
self._jsonl_path = base_dir.joinpath(f"{self._run_id}.jsonl")
self._environment_written = False
self._metric_summary_written = False
def _ensure_report_file(self) -> Optional[Path]:
"""Ensure the markdown report and asset directories exist."""
if not self.report_enabled:
return None
if not self._report_initialized:
self._report_dir.mkdir(parents=True, exist_ok=True)
self._asset_dir = self._report_dir.joinpath("assets")
self._asset_dir.mkdir(parents=True, exist_ok=True)
self._report_path = self._report_dir.joinpath(
self._report_filename
)
header = (
f"# History Report ({self._run_id})\n\n"
f"_Generated at {self._utc_iso()}_\n\n"
)
self._report_path.write_text(header, encoding="utf-8")
self._report_initialized = True
return self._report_path
def _prepare_report_asset(self, source: Path) -> Optional[Path]:
"""Copy plot artifacts into the report asset directory."""
report_file = self._ensure_report_file()
if report_file is None:
return None
assert self._asset_dir is not None
source = source.resolve()
try:
if source.is_relative_to(self._asset_dir):
return source
except AttributeError: # Python < 3.9 fallback (not expected)
pass
destination = self._asset_dir.joinpath(source.name)
if destination != source:
try:
shutil.copy2(source, destination)
except OSError:
logger.warning(
"Unable to copy asset %s into report directory.", source
)
return source
return destination
def _write_plot_report(
self,
key: Optional[str],
asset_path: Path,
*,
kind: str,
metadata: Optional[dict[str, Any]] = None,
) -> None:
"""Append a markdown section describing the generated plot."""
report_file = self._ensure_report_file()
if report_file is None:
return
asset_path = asset_path.resolve()
if not asset_path.exists():
return
asset_path = self._prepare_report_asset(asset_path) or asset_path
try:
rel_path = asset_path.relative_to(report_file.parent)
except ValueError:
rel_path = asset_path
title = key or asset_path.stem
timestamp = self._utc_iso()
lines = [
f"## {title}",
"",
f"_Kind_: `{kind}` ",
f"_Generated_: {timestamp}",
"",
]
if asset_path.suffix.lower() in {".txt", ".log"}:
try:
text = asset_path.read_text(encoding="utf-8")
except OSError:
text = ""
snippet = "\n".join(text.splitlines()[:40]).rstrip("\n")
lines.extend(["```", snippet, "```", ""])
else:
lines.append(f"})")
lines.append("")
if metadata:
for meta_key, meta_val in metadata.items():
lines.append(f"- **{meta_key}**: {meta_val}")
lines.append("")
with report_file.open("a", encoding="utf-8") as handle:
handle.write("\n".join(lines))
if not lines[-1].endswith("\n"):
handle.write("\n")
def _wandb_log_matplotlib_asset(
self,
key: Optional[str],
asset_path: Optional[Path],
*,
kind: str = "matplotlib",
commit: bool = False,
) -> None:
if (
not ENABLE_WANDB
or wandb is None
or getattr(wandb, "run", None) is None
or self._rank != 0
):
return
if asset_path is None:
return
asset_path = Path(asset_path)
if not asset_path.exists():
return
if asset_path.suffix.lower() not in {".png", ".jpg", ".jpeg"}:
return
try:
title = key or asset_path.stem
wandb.log(
{
f"plots/{title}": wandb.Image(
str(asset_path), caption=f"{kind}:{title}"
)
},
commit=commit,
)
except Exception as exc:
logger.debug("W&B image logging failed: %s", exc)
def _write_environment_section(
self, env_info: Optional[dict[str, Any]]
) -> None:
"""Write environment details into the report."""
if (
not self.report_enabled
or env_info is None
or self._environment_written
):
return
report_file = self._ensure_report_file()
if report_file is None:
return
lines: list[str] = ["## Environment", ""]
for section, details in env_info.items():
if isinstance(details, dict):
lines.append(f"### {section}")
lines.append("")
lines.extend((f"### {section}", ""))
lines.extend(
f"- **{key}**: {value}" for key, value in details.items()
)
lines.append("")
else:
lines.append(f"- **{section}**: {details}")
with report_file.open("a", encoding="utf-8") as handle:
handle.write("\n".join(lines))
if not lines[-1].endswith("\n"):
handle.write("\n")
self._environment_written = True
def _default_environment_info(self) -> dict[str, dict[str, str]]:
"""Return a minimal environment summary."""
python_info = {
"Version": (
f"{sys.version_info.major}."
f"{sys.version_info.minor}."
f"{sys.version_info.micro}"
),
"Implementation": sys.implementation.name,
"Executable": sys.executable,
"Platform": platform.platform(),
}
try:
torch_version = torch.__version__
except Exception: # pragma: no cover - torch should be importable
torch_version = "unknown"
torch_info = {
"Version": torch_version,
}
path_info = {
"Working directory": str(Path.cwd()),
}
env_vars: dict[str, str] = {}
for key in (
"MASTER_ADDR",
"MASTER_PORT",
"WORLD_SIZE",
"RANK",
"LOCAL_RANK",
):
value = os.environ.get(key)
if value is not None:
env_vars[key] = value
summary: dict[str, dict[str, str]] = {
"Paths": path_info,
"Python": python_info,
"Torch": torch_info,
}
if env_vars:
summary["Environment Variables"] = env_vars
return summary
def _collect_metric_groups(
self, dataset: xr.Dataset
) -> dict[str, dict[str, float]]:
"""Return metric statistics grouped by base metric name."""
assert dataset is not None and hasattr(dataset, "data_vars")
groups: dict[str, dict[str, float]] = {}
for name in sorted(dataset.data_vars):
arr = dataset[name]
if arr.size == 0:
continue
try:
latest = arr.isel({arr.dims[0]: -1})
except Exception:
latest = arr
data = np.asarray(latest)
if data.size == 0:
continue
value = float(data.mean()) if data.ndim > 0 else float(data.item())
base, _, suffix = name.partition("_")
if suffix in {"mean", "max", "min", "std"}:
groups.setdefault(base, {})[suffix] = value
else:
groups.setdefault(name, {})["latest"] = value
return groups
def _write_metric_summary(self, dataset: xr.Dataset) -> None:
"""Append a metric overview table grouped by metric."""
if not self.report_enabled or self._metric_summary_written:
return
groups = self._collect_metric_groups(dataset)
if not groups:
return
report_file = self._ensure_report_file()
if report_file is None:
return
with report_file.open("a", encoding="utf-8") as handle:
handle.write("## Metric Overview\n\n")
for metric_name, stats in groups.items():
handle.write(f"### {metric_name}\n\n")
rows: list[tuple[str, str]] = []
for label in ("latest", "mean", "max", "min", "std"):
if label in stats:
value = stats[label]
rows.append((label.capitalize(), f"{value:.6f}"))
if rows:
header = ("Statistic", "Value")
col1 = max(len(header[0]), *(len(r[0]) for r in rows))
col2 = max(len(header[1]), *(len(r[1]) for r in rows))
handle.write(
f"| {header[0]:<{col1}} | {header[1]:>{col2}} |\n"
)
handle.write(
f"|:{'-' * (col1 - 1)} | {'-' * (col2 - 1)}:|\n"
)
for stat_label, stat_value in rows:
handle.write(
f"| {stat_label:<{col1}} | {stat_value:>{col2}} |\n"
)
handle.write("\n")
self._metric_summary_written = True
def _series_from_dataarray(self, data: xr.DataArray) -> np.ndarray:
"""Convert an xarray DataArray into a 1-D numerical series."""
arr = np.asarray(data.values)
if arr.ndim == 1:
return arr
if arr.ndim == 0:
return np.array([float(arr)])
axes = tuple(range(arr.ndim - 1))
return arr.mean(axis=axes)
def _group_metric_variables(
self, dataset: xr.Dataset
) -> dict[str, dict[str, xr.DataArray]]:
"""Group metric variables by base name and associated aggregates."""
groups: dict[str, dict[str, xr.DataArray]] = {}
for name, data_array in dataset.data_vars.items():
base, sep, suffix = name.rpartition("_")
if sep and base and suffix in {"mean", "max", "min", "std"}:
groups.setdefault(base, {})[suffix] = data_array
else:
groups.setdefault(name, {})["raw"] = data_array
return groups
def _plot_metric_group(
self,
name: str,
metric_vars: dict[str, xr.DataArray],
*,
warmup: Optional[float | int] = 0.0,
title: Optional[str] = None,
outdir: Optional[Path] = None,
subplots_kwargs: Optional[dict[str, Any]] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
verbose: bool = False,
) -> Optional[Path]:
"""Render a single matplotlib figure combining metric aggregates."""
import matplotlib.pyplot as plt
import seaborn as sns
subplots_kwargs = (
{} if subplots_kwargs is None else dict(subplots_kwargs)
)
plot_kwargs = {} if plot_kwargs is None else dict(plot_kwargs)
series_candidates = [
metric_vars.get("raw"),
metric_vars.get("mean"),
metric_vars.get("min"),
metric_vars.get("max"),
]
# base_series = None
# for candidate in series_candidates:
# if candidate is not None:
# base_series = self._series_from_dataarray(candidate)
# break
base_series = next(
(
self._series_from_dataarray(candidate)
for candidate in series_candidates
if candidate is not None
),
None,
)
if base_series is None or len(base_series) == 0:
return None
x = np.arange(base_series.shape[-1])
fig, ax = plt.subplots(**subplots_kwargs)
color = plot_kwargs.get("color")
raw_da = metric_vars.get("raw")
if raw_da is not None:
raw_series = self._series_from_dataarray(raw_da)
ax.plot(
x,
raw_series,
label=name,
color=color,
alpha=0.35,
linewidth=1.25,
)
mean_da = metric_vars.get("mean")
std_da = metric_vars.get("std")
min_da = metric_vars.get("min")
max_da = metric_vars.get("max")
mean_series = None
if mean_da is not None:
mean_series = self._series_from_dataarray(mean_da)
ax.plot(
x,
mean_series,
label=f"{name} mean",
color=color,
linewidth=2.0,
)
if mean_series is not None and std_da is not None:
std_series = self._series_from_dataarray(std_da)
upper = mean_series + std_series
lower = mean_series - std_series
ax.fill_between(
x,
lower,
upper,
color=color,
alpha=0.2,
label=f"{name} Β± std",
)
elif min_da is not None and max_da is not None:
min_series = self._series_from_dataarray(min_da)
max_series = self._series_from_dataarray(max_da)
ax.fill_between(
x,
min_series,
max_series,
color=color,
alpha=0.15,
label=f"{name} range",
)
if (
mean_da is None
and raw_da is None
and min_da is None
and max_da is None
):
# fall back to plotting whichever aggregate is available
for label, array in metric_vars.items():
series = self._series_from_dataarray(array)
ax.plot(x, series, label=f"{name} {label}", linewidth=1.75)
ax.set_xlabel("draw")
ax.set_ylabel(name)
if title is not None:
ax.set_title(title)
sns.despine(ax=ax, top=True, right=True)
ax.legend(loc="best", frameon=False)
if outdir is None and self.report_enabled:
save_dir = self._report_dir.joinpath("mplot")
elif outdir is not None:
save_dir = Path(outdir)
else:
save_dir = None
primary_asset: Optional[Path] = None
if save_dir is not None:
save_dir = save_dir.expanduser().resolve()
save_dir.mkdir(parents=True, exist_ok=True)
asset_name = name.replace("/", "_")
dirs = {
"png": save_dir.joinpath("pngs"),
"svg": save_dir.joinpath("svgs"),
}
for directory in dirs.values():
directory.mkdir(parents=True, exist_ok=True)
if verbose:
logger.info("Saving %s plot to: %s", name, save_dir)
for ext, directory in dirs.items():
outfile = directory.joinpath(f"{asset_name}.{ext}")
if outfile.exists():
outfile = directory.joinpath(
f"{asset_name}-{get_timestamp()}.{ext}"
)
fig.savefig(outfile, dpi=400, bbox_inches="tight")
if primary_asset is None and ext == "png":
primary_asset = outfile
plt.close(fig)
return primary_asset
def _tplot_metric_group(
self,
name: str,
metric_vars: dict[str, xr.DataArray],
*,
warmup: Optional[float | int] = 0.0,
outdir: Optional[Path] = None,
plot_type: Optional[str] = None,
marker: Optional[str] = None,
verbose: bool = False,
logfreq: Optional[int] = None,
) -> Optional[Path]:
"""Render grouped metrics into a single text-based plot asset."""
outdir = Path(outdir) if outdir is not None else None
if outdir is None and self.report_enabled:
outdir = self._report_dir.joinpath("tplot")
if outdir is None:
return None
outdir = outdir.expanduser().resolve()
outdir.mkdir(parents=True, exist_ok=True)
asset_path = outdir.joinpath(f"{name.replace('/', '_')}.txt")
summary_path = outdir.joinpath(f"{name.replace('/', '_')}_summary.txt")
hist_path = outdir.joinpath(f"{name.replace('/', '_')}_hist.txt")
order = [
("raw", name),
("mean", f"{name} mean"),
("max", f"{name} max"),
("min", f"{name} min"),
("std", f"{name} std"),
]
stats_keys = ("mean", "max", "min", "std")
stats_present = any(key in metric_vars for key in stats_keys)
stats_nonzero = False
if stats_present:
nonzero_keys = [
key for key in ("mean", "max", "min") if key in metric_vars
]
if nonzero_keys:
stats_nonzero = any(
np.any(
np.nan_to_num(
self._series_from_dataarray(metric_vars[key])
)
!= 0
)
for key in nonzero_keys
)
try:
_ = plotext_prepare_figure(theme="clear")
except ModuleNotFoundError: # pragma: no cover - optional dependency
logger.error(
"Unable to import `plotext` which is needed for text-based plotting."
)
return None
resolved_plot_type = (
plot_type
if plot_type is not None
else os.environ.get("EZPZ_TPLOT_TYPE")
)
resolved_marker = (
marker
if marker is not None
else os.environ.get("EZPZ_TPLOT_MARKER")
)
if resolved_marker is None and resolved_plot_type != "hist":
resolved_marker = DEFAULT_MARKER
def _metric_marker(metric_key: str) -> Optional[str]:
key = metric_key.split("/")[-1]
mapped = MARKER_MAP.get(key)
if mapped == "line":
return None
if mapped is not None:
return mapped
return resolved_marker
use_subplots = stats_present and stats_nonzero
plt = None
left = None
right = None
if use_subplots:
try:
plt, left, right = plotext_subplots(
left_layout=(2, 1),
right_layout=(3, 1),
# height_scale=8.0,
)
except (
ModuleNotFoundError
): # pragma: no cover - optional dependency
use_subplots = False
wrote_any = False
points = 0
if use_subplots:
assert plt is not None and left is not None and right is not None
left_slots = [
(1, "raw", name, None),
(2, "mean", f"{name}/mean", "green"),
]
right_slots = [
(1, "min", f"{name}/min", "cyan"),
(2, "std", f"{name}/std", "magenta"),
(3, "max", f"{name}/max", "red"),
]
for row, key, label, color in left_slots:
data_array = metric_vars.get(key)
if data_array is None:
continue
series = self._series_from_dataarray(data_array)
points = max(points, len(series))
if left is not None and hasattr(left, "subplot"):
left.subplot(row, 1)
plotext_plot_series(
plt,
series,
label=None,
color=color,
plot_type=resolved_plot_type,
marker=_metric_marker(key),
)
if plt is not None and hasattr(plt, "title"):
plt.title(label)
if plt is not None and hasattr(plt, "xlabel"):
plt.xlabel("iter")
if plt is not None and hasattr(plt, "ylabel"):
plt.ylabel(label)
wrote_any = True
for row, key, label, color in right_slots:
data_array = metric_vars.get(key)
if data_array is None:
continue
series = self._series_from_dataarray(data_array)
points = max(points, len(series))
if hasattr(right, "subplot"):
right.subplot(row, 1)
plotext_plot_series(
plt,
series,
label=None,
color=color,
plot_type=resolved_plot_type,
marker=_metric_marker(key),
)
if hasattr(plt, "title"):
plt.title(label)
if hasattr(plt, "xlabel"):
plt.xlabel("iter")
if hasattr(plt, "ylabel"):
plt.ylabel(label)
wrote_any = True
if wrote_any:
plt.show()
plt.savefig(
asset_path.as_posix(), append=False, keep_colors=True
)
if stats_present:
plt = plotext_prepare_figure(theme="clear")
plotext_set_size(plt, min_height=40)
overlay_order = [
("max", f"{name}/max", "red"),
("min", f"{name}/min", "cyan"),
("mean", f"{name}/mean", "green"),
("raw", name, None),
]
overlay_points = 0
for key, label, color in overlay_order:
data_array = metric_vars.get(key)
if data_array is None:
continue
series = self._series_from_dataarray(data_array)
overlay_points = max(overlay_points, len(series))
plotext_plot_series(
plt,
series,
label=label,
color=color,
marker=_metric_marker(key),
)
if overlay_points > 0:
plt.show()
plt.savefig(
summary_path.as_posix(),
append=False,
keep_colors=True,
)
if self.report_enabled:
self._write_plot_report(
f"{name} summary",
summary_path,
kind="tplot",
metadata={
"components": ", ".join(
key
for key, _, _ in overlay_order
if key in metric_vars
),
"points": overlay_points,
},
)
hist_order = [
("mean", f"{name}/mean"),
("max", f"{name}/max"),
("min", f"{name}/min"),
("std", f"{name}/std"),
]
plt = plotext_prepare_figure(theme="clear")
plotext_set_size(plt)
plt.subplots(2, 2)
hist_points = 0
for idx, (key, label) in enumerate(hist_order, start=1):
data_array = metric_vars.get(key)
if data_array is None:
continue
series = self._series_from_dataarray(data_array)
hist_points = max(hist_points, len(series))
row = ((idx - 1) // 2) + 1
col = ((idx - 1) % 2) + 1
if hasattr(plt, "subplot"):
plt.subplot(row, col)
plotext_hist_series(plt, series, label=None)
if hasattr(plt, "title"):
plt.title(f"{label} hist")
if hist_points > 0:
plt.show()
plt.savefig(
hist_path.as_posix(),
append=False,
keep_colors=True,
)
if self.report_enabled:
self._write_plot_report(
f"{name} hist",
hist_path,
kind="tplot-hist",
metadata={
"components": ", ".join(
key
for key, _ in hist_order
if key in metric_vars
),
"points": hist_points,
},
)
if not use_subplots:
append_flag = False
for key, label in order:
data_array = metric_vars.get(key)
if data_array is None:
continue
series = self._series_from_dataarray(data_array)
points = max(points, len(series))
self._tplot(
y=series,
xlabel="iter",
ylabel=label,
append=append_flag,
outfile=asset_path.as_posix(),
verbose=verbose,
plot_type=resolved_plot_type,
marker=_metric_marker(key),
logfreq=(1 if logfreq is None else logfreq),
record_report=False,
)
append_flag = True
wrote_any = True
if stats_present:
overlay_order = [
("raw", name),
("mean", f"{name}/mean"),
("max", f"{name}/max"),
("min", f"{name}/min"),
]
overlay_points = 0
overlay_append = False
for key, label in overlay_order:
data_array = metric_vars.get(key)
if data_array is None:
continue
series = self._series_from_dataarray(data_array)
overlay_points = max(overlay_points, len(series))
self._tplot(
y=series,
xlabel="iter",
ylabel=label,
append=overlay_append,
outfile=summary_path.as_posix(),
verbose=verbose,
plot_type=resolved_plot_type,
marker=_metric_marker(key),
logfreq=(1 if logfreq is None else logfreq),
record_report=False,
)
overlay_append = True
if (
overlay_points > 0
and self.report_enabled
and summary_path.exists()
):
self._write_plot_report(
f"{name} summary",
summary_path,
kind="tplot",
metadata={
"components": ", ".join(
key
for key, _ in overlay_order
if key in metric_vars
),
"points": overlay_points,
},
)
if wrote_any and self.report_enabled:
self._write_plot_report(
name,
asset_path,
kind="tplot",
metadata={
"components": ", ".join(
key for key, _ in order if key in metric_vars
),
"points": points,
},
)
return asset_path
def _write_jsonl_entry(
self,
metrics: dict[str, Any],
aggregated: Optional[dict[str, float]] = None,
) -> None:
"""Append metrics to the configured JSONL log."""
if not self._jsonl_enabled:
return
if self._jsonl_path is None:
return
payload: dict[str, Any] = {
"timestamp": time.time(),
"rank": self._rank,
"metrics": metrics,
}
if aggregated and self._rank == 0:
payload["aggregated"] = aggregated
try:
self._jsonl_path.parent.mkdir(parents=True, exist_ok=True)
with self._jsonl_path.open("a", encoding="utf-8") as handle:
handle.write(
json.dumps(payload, default=self._to_serializable)
)
handle.write("\n")
except OSError:
logger.warning(
"Unable to write JSONL metrics to %s", self._jsonl_path
)
@classmethod
def _to_serializable(cls, value: Any) -> Any:
"""Convert values to JSON-serializable structures."""
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, (np.floating, np.integer, np.bool_)):
return value.item()
if isinstance(value, Path):
return value.as_posix()
if torch.is_tensor(value):
tensor = value.detach()
if tensor.numel() == 1:
return tensor.item()
return tensor.cpu().tolist()
if isinstance(value, np.ndarray):
if value.shape == ():
return value.item()
return value.tolist()
if isinstance(value, dict):
return {
key: cls._to_serializable(sub_value)
for key, sub_value in value.items()
}
if isinstance(value, (list, tuple)):
return [cls._to_serializable(item) for item in value]
if hasattr(value, "item"):
try:
return value.item()
except Exception:
pass
return str(value)
@classmethod
def _sanitize_metrics(cls, metrics: dict[str, Any]) -> dict[str, Any]:
"""Return a copy of metrics with values converted to JSON-safe types."""
return {
key: cls._to_serializable(value) for key, value in metrics.items()
}
def _iter_scalar_metrics(
self, metrics: dict[str, Any]
) -> Iterable[tuple[str, float]]:
"""Yield scalar metrics suitable for distributed reductions."""
for key, value in metrics.items():
if isinstance(value, (int, float)):
yield key, float(value)
elif isinstance(value, np.ndarray) and value.shape == ():
yield key, float(value.item())
elif torch.is_tensor(value) and value.numel() == 1:
yield key, float(value.item())
def _select_metric_device(self) -> torch.device:
"""Return the device to use for distributed metric aggregation."""
candidate = ezpz.get_torch_device(as_torch_device=True)
device = (
candidate
if isinstance(candidate, torch.device)
else torch.device(str(candidate))
)
device_type = device.type
if device_type == "mps":
return torch.device("cpu")
if device_type == "cuda" and not torch.cuda.is_available():
return torch.device("cpu")
xpu_backend = getattr(torch, "xpu", None)
if device_type == "xpu" and not (
xpu_backend and xpu_backend.is_available()
):
return torch.device("cpu")
return device
def _compute_distributed_metrics(
self, metrics: dict[str, Any]
) -> dict[str, float]:
"""Compute distributed reductions for scalar metrics."""
if not self.distributed_history or self._dist is None:
return {}
try:
if (
not self._dist.is_available()
or not torch.distributed.is_initialized() # type: ignore[attr-defined]
):
return {}
except AttributeError:
return {}
scalars = dict(self._iter_scalar_metrics(metrics))
if not scalars:
return {}
metric_device = self._select_metric_device()
dtype = torch.get_default_dtype()
values = torch.tensor(
list(scalars.values()),
dtype=dtype,
device=metric_device,
)
sum_vals = values.clone()
sq_vals = values.square()
max_vals = values.clone()
min_vals = values.clone()
# world_size = ezpz.dist.get_world_size()
# world_size = self._dist.get_world_size()
if (world_size := ezpz.dist.get_world_size()) <= 1:
return {
f"{key}/{suffix}": (value if suffix != "std" else 0.0)
for key, value in scalars.items()
for suffix in ("mean", "max", "min", "std")
}
# ezpz.dist.all_reduce(sum_vals, op=ops.SUM, implementation="torch")
# ezpz.dist.all_reduce(sq_vals, op=ops.SUM, implementation="torch")
# ezpz.dist.all_reduce(max_vals, op=ops.MAX, implementation="torch")
# ezpz.dist.all_reduce(min_vals, op=ops.MIN, implementation="torch")
# ops = self._dist.ReduceOp # type: ignore[attr-defined]
ops = torch.distributed.ReduceOp # type: ignore[attr-defined]
torch.distributed.all_reduce(sum_vals, op=ops.SUM)
torch.distributed.all_reduce(sq_vals, op=ops.SUM)
torch.distributed.all_reduce(max_vals, op=ops.MAX)
torch.distributed.all_reduce(min_vals, op=ops.MIN)
mean_vals = sum_vals.div(world_size)
var_vals = sq_vals.div(world_size).sub(mean_vals.square())
std_vals = var_vals.clamp_min_(0.0).sqrt_()
stats: dict[str, float] = {}
for idx, key in enumerate(scalars):
# if any([s in key] for s in ["iter", "epoch", "step", "batch"]):
# continue
stats[f"{key}/mean"] = float(mean_vals[idx].item())
stats[f"{key}/max"] = float(max_vals[idx].item())
stats[f"{key}/min"] = float(min_vals[idx].item())
stats[f"{key}/std"] = float(std_vals[idx].item())
return stats
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def _update(
self,
key: str,
val: Union[Any, ScalarLike, list, torch.Tensor, np.ndarray],
):
"""
Update the history with a new key-value pair.
Args:
key (str): The key to update in the history.
val (Union[Any, ScalarLike, list, torch.Tensor, np.ndarray]): The value
to associate with the key.
"""
try:
self.history[key].append(val)
except KeyError:
self.history[key] = [val]
return val
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def update(
self,
metrics: dict,
precision: int = 6,
use_wandb: Optional[bool] = True,
commit: Optional[bool] = True,
summarize: Optional[bool] = True,
) -> str:
"""
Update the history with a dictionary of metrics.
Args:
metrics (dict): Dictionary of metrics to update the history with.
precision (int): Precision for summarizing the metrics.
use_wandb (Optional[bool]): Whether to log the metrics to Weights & Biases.
commit (Optional[bool]): Whether to commit the log to Weights & Biases.
summarize (Optional[bool]): Whether to summarize the metrics.
"""
for key, val in metrics.items():
# if isinstance(val, (list, np.ndarray, torch.Tensor)):
# val = grab_tensor(val)
try:
self.history[key].append(val)
except KeyError:
self.history[key] = [val]
aggregated_metrics = self._compute_distributed_metrics(metrics)
if aggregated_metrics and self._rank == 0:
for agg_key, agg_val in aggregated_metrics.items():
self._update(agg_key, agg_val)
metrics_for_logging = dict(metrics)
if aggregated_metrics and self._rank == 0:
metrics_for_logging.update(aggregated_metrics)
sanitized_metrics = self._sanitize_metrics(metrics_for_logging)
summary_source = (
sanitized_metrics
if aggregated_metrics and self._rank == 0
else self._sanitize_metrics(metrics)
)
if (
wandb is not None
and use_wandb
# and not WANDB_DISABLED
and getattr(wandb, "run", None) is not None
):
wandb.log(sanitized_metrics, commit=commit)
self._write_jsonl_entry(sanitized_metrics, aggregated_metrics)
if summarize:
scalar_summary = {
key: value
for key, value in summary_source.items()
# skip keys like "train/iter/min", "eval/step/std", etc.,
if not any(
count_str in key
for count_str in [
"iter/min",
"iter/max",
"iter/std",
"iter/avg",
"iter/mean",
"step/min",
"step/max",
"step/std",
"step/avg",
"step/mean",
"epoch/min",
"epoch/max",
"epoch/std",
"epoch/avg",
"epoch/mean",
"batch/min",
"batch/max",
"batch/std",
"batch/avg",
"batch/mean",
"idx/min",
"idx/max",
"idx/std",
"idx/avg",
"idx/mean",
]
)
}
# _ss = {"max", "min", "std"}
# _sk = {"iter", "step", "epoch", "batch", "idx"}
# keys_to_skip = [
# f"{i}/{s}" for s in _ss for i in _sk
# ]
if scalar_summary:
return summarize_dict(
scalar_summary, precision=precision
).replace("/", "/")
return ""
return ""
@staticmethod
def split_metrics_for_logging(
metrics: dict[str, Any],
debug_prefixes: tuple[str, ...] = ("hist/",),
) -> tuple[dict[str, Any], dict[str, Any]]:
info_metrics: dict[str, Any] = {}
debug_metrics: dict[str, Any] = {}
for key, value in metrics.items():
if key.startswith(debug_prefixes):
debug_metrics[key] = value
else:
info_metrics[key] = value
return info_metrics, debug_metrics
@staticmethod
def summarize_min_max_std(
metrics: dict[str, Any],
) -> dict[str, float]:
numeric: dict[str, list[float]] = {}
for key, value in metrics.items():
if isinstance(value, (int, float)):
numeric[key] = [float(value)]
elif torch.is_tensor(value) and value.numel() == 1:
numeric[key] = [float(value.item())]
summary: dict[str, float] = {}
for key, values in numeric.items():
if not values:
continue
t = torch.tensor(values)
summary[f"{key}/mean"] = float(t.mean().item())
summary[f"{key}/min"] = float(t.min().item())
summary[f"{key}/max"] = float(t.max().item())
summary[f"{key}/std"] = float(t.std(unbiased=False).item())
return summary
def summarize_distributed_min_max_std(
self, metrics: dict[str, Any]
) -> dict[str, float]:
summary_stats = self._compute_distributed_metrics(metrics)
if not summary_stats:
summary_stats = self.summarize_min_max_std(metrics)
filtered: dict[str, float] = {
k: v
for k, v in summary_stats.items()
if k.endswith(("/mean", "/min", "/max", "/std"))
}
keys = {k.rsplit("/", 1)[0] for k in filtered}
pruned: dict[str, float] = {}
for base in keys:
mean_v = filtered.get(f"{base}/mean")
min_v = filtered.get(f"{base}/min")
max_v = filtered.get(f"{base}/max")
std_v = filtered.get(f"{base}/std")
if (
mean_v == 0.0
and min_v == 0.0
and max_v == 0.0
and std_v == 0.0
):
continue
if mean_v is not None:
pruned[f"{base}/mean"] = mean_v
if min_v is not None:
pruned[f"{base}/min"] = min_v
if max_v is not None:
pruned[f"{base}/max"] = max_v
if std_v is not None:
pruned[f"{base}/std"] = std_v
return pruned
def log_metrics(
self,
metrics: dict[str, Any],
*,
logger: Optional[Any] = None,
debug_prefixes: tuple[str, ...] = ("hist/",),
include_summary: bool = True,
rank0_only_summary: bool = True,
precision: int = 6,
omit_counter_metrics: bool = True,
counter_tokens: tuple[str, ...] = (
"iter",
"epoch",
"step",
"batch",
"idx",
"bidx",
),
) -> None:
log = logger if logger is not None else get_logger(__name__)
info_metrics, debug_metrics = self.split_metrics_for_logging(
metrics, debug_prefixes=debug_prefixes
)
def _is_counter_key(key: str) -> bool:
parts = key.replace("\\", "/").split("/")
if not parts:
return False
last = parts[-1]
for token in counter_tokens:
if last == token or last.endswith(f"_{token}"):
return True
return False
info_msg = summarize_dict(info_metrics, precision=precision).replace(
"train/", ""
)
if info_msg:
log.info(info_msg)
if include_summary:
summary_input = info_metrics
if omit_counter_metrics:
summary_input = {
k: v
for k, v in info_metrics.items()
if not _is_counter_key(k)
}
summary_stats = self.summarize_distributed_min_max_std(
summary_input
)
if summary_stats and (not rank0_only_summary or self._rank == 0):
summary_msg = summarize_dict(
summary_stats, precision=precision
).replace("train/", "")
if summary_msg:
log.info(summary_msg)
debug_msg = summarize_dict(debug_metrics, precision=precision).replace(
"train/", ""
)
if debug_msg:
log.debug(debug_msg)
def _tplot(
self,
y: np.ndarray,
x: Optional[np.ndarray] = None,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
append: bool = True,
title: Optional[str] = None,
verbose: bool = False,
outfile: Optional[str] = None,
logfreq: Optional[int] = None,
plot_type: Optional[str] = None,
marker: Optional[str] = None,
record_report: bool = True,
):
"""
Create a text plot of the given data.
Args:
y (np.ndarray): The data to plot.
x (Optional[np.ndarray]): The x-axis data.
xlabel (Optional[str]): The x-axis label.
ylabel (Optional[str]): The y-axis label.
append (bool): Whether to append to an existing plot.
title (Optional[str]): The title of the plot.
verbose (bool): Whether to print the plot.
outfile (Optional[str]): The path to save the plot to.
logfreq (Optional[int]): The log frequency of the plot.
plot_type (Optional[str]): The type of plot to create.
"""
outfile_path: Optional[Path] = None
if outfile is None and self.report_enabled:
label = (ylabel or xlabel or "metric").replace("/", "_")
default_dir = self._report_dir.joinpath("tplot")
default_dir.mkdir(parents=True, exist_ok=True)
outfile_path = default_dir.joinpath(
f"{label}-{get_timestamp()}.txt"
)
outfile = outfile_path.as_posix()
elif outfile is not None:
outfile_path = Path(outfile)
if xlabel is not None and ylabel == xlabel:
return
if len(y) > 1:
x = x if x is not None else np.arange(len(y))
assert x is not None
eztplot(
y=y,
x=x,
xlabel=xlabel,
ylabel=ylabel,
logfreq=(1 if logfreq is None else logfreq),
append=append,
verbose=verbose,
outfile=outfile,
plot_type=plot_type,
marker=marker,
title=title,
# plot_type=('scatter' if 'dt' in ylabel else None),
)
if (
record_report
and self.report_enabled
and outfile_path is not None
):
self._write_plot_report(
ylabel,
outfile_path,
kind="tplot",
metadata={"points": len(y)},
)
if ylabel is not None and "dt" in ylabel:
of = Path(outfile) if outfile is not None else None
if of is not None:
of = Path(of.parent).joinpath(f"{of.stem}-hist{of.suffix}")
eztplot(
y=y,
xlabel=ylabel,
title=title,
ylabel="freq",
append=append,
verbose=verbose,
outfile=(of if of is not None else None),
plot_type="hist",
marker=marker,
)
if record_report and self.report_enabled and of is not None:
self._write_plot_report(
f"{ylabel}-hist",
of,
kind="tplot-hist",
metadata={"points": len(y)},
)
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def plot(
self,
val: np.ndarray,
key: Optional[str] = None,
warmup: Optional[float] = 0.0,
num_chains: Optional[int] = 128,
title: Optional[str] = None,
outdir: Optional[os.PathLike] = None,
subplots_kwargs: Optional[dict[str, Any]] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
verbose: bool = False,
):
"""
Plot a single variable from the history.
NOTE: The `warmup` argument can be used to drop the first `warmup`
iterations (as a percent of the total number of iterations) from the
plot.
Args:
val (np.ndarray): The data to plot.
key (Optional[str]): The key for the data.
warmup (Optional[float]): The percentage of iterations to drop from the
beginning of the plot.
num_chains (Optional[int]): The number of chains to plot.
title (Optional[str]): The title of the plot.
outdir (Optional[os.PathLike]): The directory to save the plot to.
subplots_kwargs (Optional[dict[str, Any]]): Additional arguments for
subplots.
plot_kwargs (Optional[dict[str, Any]]): Additional arguments for plotting.
verbose (bool): Emit additional logging when saving plots.
"""
import matplotlib.pyplot as plt
LW = plt.rcParams.get("axes.linewidth", 1.75)
plot_kwargs = {} if plot_kwargs is None else plot_kwargs
subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs
figsize = subplots_kwargs.get("figsize", ezplot.set_size())
subplots_kwargs.update({"figsize": figsize})
num_chains = 16 if num_chains is None else num_chains
# tmp = val[0]
arr = np.array(val)
subfigs = None
steps = np.arange(arr.shape[0])
if warmup is not None and warmup > 0 and arr.size > 0:
if isinstance(warmup, int) or warmup >= 1:
warmup_frac = float(warmup) / float(arr.shape[0])
else:
warmup_frac = float(warmup)
warmup_frac = min(max(warmup_frac, 0.0), 1.0)
drop = min(int(round(warmup_frac * arr.shape[0])), arr.shape[0])
if drop > 0:
arr = arr[drop:]
steps = steps[drop:]
if len(arr.shape) == 2:
import seaborn as sns
_ = subplots_kwargs.pop("constrained_layout", True)
figsize = (3 * figsize[0], 1.5 * figsize[1])
fig = plt.figure(figsize=figsize, constrained_layout=True)
subfigs = fig.subfigures(1, 2)
gs_kw = {"width_ratios": [1.33, 0.33]}
(ax, ax1) = subfigs[1].subplots(
1, 2, sharey=True, gridspec_kw=gs_kw
)
ax.grid(alpha=0.2)
ax1.grid(False)
color = plot_kwargs.get("color", None)
label = r"$\langle$" + f" {key} " + r"$\rangle$"
ax.plot(
steps, arr.mean(-1), lw=1.5 * LW, label=label, **plot_kwargs
)
sns.kdeplot(y=arr.flatten(), ax=ax1, color=color, shade=True)
ax1.set_xticks([])
ax1.set_xticklabels([])
# ax1.set_yticks([])
# ax1.set_yticklabels([])
sns.despine(ax=ax, top=True, right=True)
sns.despine(ax=ax1, top=True, right=True, left=True, bottom=True)
# ax.legend(loc='best', frameon=False)
ax1.set_xlabel("")
# ax1.set_ylabel('')
# ax.set_yticks(ax.get_yticks())
# ax.set_yticklabels(ax.get_yticklabels())
# ax.set_ylabel(key)
# _ = subfigs[1].subplots_adjust(wspace=-0.75)
axes = (ax, ax1)
else:
if len(arr.shape) == 1:
fig, ax = plt.subplots(**subplots_kwargs)
# assert isinstance(ax, plt.Axes)
ax.plot(steps, arr, **plot_kwargs)
axes = ax
elif len(arr.shape) == 3:
fig, ax = plt.subplots(**subplots_kwargs)
# assert isinstance(ax, plt.Axes)
cmap = plt.get_cmap("viridis")
nlf = arr.shape[1]
for idx in range(nlf):
# y = arr[:, idx, :].mean(-1)
# pkwargs = {
# 'color': cmap(idx / nlf),
# 'label': f'{idx}',
# }
# ax.plot(steps, y, **pkwargs)
label = plot_kwargs.pop("label", None)
if label is not None:
label = f"{label}-{idx}"
y = arr[:, idx, :]
color = cmap(idx / y.shape[1])
plot_kwargs["color"] = cmap(idx / y.shape[1])
if len(y.shape) == 2:
# TOO: Plot chains
if num_chains > 0:
for idx in range(min((num_chains, y.shape[1]))):
_ = ax.plot(
steps,
y[:, idx], # color,
lw=LW / 2.0,
alpha=0.8,
**plot_kwargs,
)
_ = ax.plot(
steps,
y.mean(-1), # color=color,
label=label,
**plot_kwargs,
)
else:
_ = ax.plot(
steps,
y, # color=color,
label=label,
**plot_kwargs,
)
axes = ax
else:
raise ValueError("Unexpected shape encountered")
ax.set_ylabel(key)
if num_chains > 0 and len(arr.shape) > 1:
# lw = LW / 2.
for idx in range(min(num_chains, arr.shape[1])):
# ax = subfigs[0].subplots(1, 1)
# plot values of invidual chains, arr[:, idx]
# where arr[:, idx].shape = [ndraws, 1]
ax.plot(
steps, arr[:, idx], alpha=0.5, lw=LW / 2.0, **plot_kwargs
)
ax.set_xlabel("draw")
if title is not None:
fig.suptitle(title)
save_dir: Optional[Path]
if outdir is not None:
save_dir = Path(outdir).expanduser().resolve()
elif self.report_enabled:
save_dir = self._report_dir.joinpath("mplot")
else:
save_dir = None
if save_dir is not None:
# plt.savefig(Path(outdir).joinpath(f'{key}.svg'),
# dpi=400, bbox_inches='tight')
save_dir.mkdir(parents=True, exist_ok=True)
outfile = save_dir.joinpath(f"{key}.svg")
if outfile.is_file():
tstamp = ezpz.get_timestamp()
pngdir = save_dir.joinpath("pngs")
pngdir.mkdir(exist_ok=True, parents=True)
pngfile = pngdir.joinpath(f"{key}-{tstamp}.png")
svgfile = save_dir.joinpath(f"{key}-{tstamp}.svg")
plt.savefig(pngfile, dpi=400, bbox_inches="tight")
plt.savefig(svgfile, dpi=400, bbox_inches="tight")
primary_asset: Optional[Path] = None
if save_dir is not None:
dirs = {
"png": Path(save_dir).joinpath("pngs/"),
"svg": Path(save_dir).joinpath("svgs/"),
}
_ = [i.mkdir(exist_ok=True, parents=True) for i in dirs.values()]
for ext, d in dirs.items():
outfile = d.joinpath(f"{key}.{ext}")
if outfile.is_file():
outfile = d.joinpath(f"{key}-subfig.{ext}")
if verbose:
logger.info(f"Saving {key} plot to: {outfile.resolve()}")
plt.savefig(outfile, dpi=400, bbox_inches="tight")
if primary_asset is None and ext == "png":
primary_asset = outfile
if (
self.report_enabled
and primary_asset is not None
and Path(primary_asset).exists()
):
self._write_plot_report(
key,
primary_asset,
kind="matplotlib",
metadata={"shape": list(arr.shape)},
)
self._wandb_log_matplotlib_asset(key, primary_asset, kind="matplotlib")
return fig, subfigs, axes
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def plot_dataArray(
self,
val: xr.DataArray,
key: Optional[str] = None,
warmup: Optional[float] = 0.0,
num_chains: Optional[int] = 0,
title: Optional[str] = None,
outdir: Optional[str] = None,
subplots_kwargs: Optional[dict[str, Any]] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
verbose: bool = False,
line_labels: bool = False,
logfreq: Optional[int] = None,
):
"""
Plot a single variable from the history as an xarray DataArray.
Args:
val (xr.DataArray): The data to plot.
key (Optional[str]): The key for the data.
warmup (Optional[float]): The percentage of iterations to drop from the
beginning of the plot.
num_chains (Optional[int]): The number of chains to plot.
title (Optional[str]): The title of the plot.
outdir (Optional[str]): The directory to save the plot to.
subplots_kwargs (Optional[dict[str, Any]]): Additional arguments for
subplots.
plot_kwargs (Optional[dict[str, Any]]): Additional arguments for plotting.
verbose (bool): Whether to print the plot.
line_labels (bool): Whether to label lines in the plot.
logfreq (Optional[int]): The log frequency of the plot.
"""
import matplotlib.pyplot as plt
plot_kwargs = {} if plot_kwargs is None else plot_kwargs
subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs
ezplot.set_plot_style()
plt.rcParams["axes.labelcolor"] = "#bdbdbd"
figsize = subplots_kwargs.get("figsize", ezplot.set_size())
subplots_kwargs.update({"figsize": figsize})
subfigs = None
# if key == 'dt':
# warmup = 0.2
arr = val.values # shape: [nchains, ndraws]
# steps = np.arange(len(val.coords['draw']))
steps = val.coords["draw"]
if warmup is not None and warmup > 0.0 and arr.size > 0:
if isinstance(warmup, int) or warmup >= 1:
warmup_frac = float(warmup) / float(arr.shape[0])
else:
warmup_frac = float(warmup)
warmup_frac = min(max(warmup_frac, 0.0), 1.0)
drop = min(int(round(warmup_frac * arr.shape[0])), arr.shape[0])
if drop > 0:
arr = arr[drop:]
steps = steps[drop:]
if len(arr.shape) == 2:
fig, axes = ezplot.plot_combined(
val,
key=key,
num_chains=num_chains,
plot_kwargs=plot_kwargs,
subplots_kwargs=subplots_kwargs,
)
else:
if len(arr.shape) == 1:
fig, ax = ezplot.subplots(**subplots_kwargs)
try:
ax.plot(steps, arr, **plot_kwargs)
except ValueError:
try:
ax.plot(steps, arr[~np.isnan(arr)], **plot_kwargs)
except Exception:
logger.error(f"Unable to plot {key}! Continuing")
_ = ax.grid(True, alpha=0.2)
axes = ax
elif len(arr.shape) == 3:
fig, ax = ezplot.subplots(**subplots_kwargs)
cmap = plt.get_cmap("viridis")
y = val.mean("chain")
for idx in range(len(val.coords["leapfrog"])):
pkwargs = {
"color": cmap(idx / len(val.coords["leapfrog"])),
"label": f"{idx}",
}
ax.plot(steps, y[idx], **pkwargs)
axes = ax
else:
raise ValueError("Unexpected shape encountered")
ax = plt.gca()
# assert isinstance(ax, plt.Axes)
assert key is not None
_ = ax.set_ylabel(key)
_ = ax.set_xlabel("draw")
# if num_chains > 0 and len(arr.shape) > 1:
# lw = LW / 2.
# #for idx in range(min(num_chains, arr.shape[1])):
# nchains = len(val.coords['chains'])
# for idx in range(min(nchains, num_chains)):
# # ax = subfigs[0].subplots(1, 1)
# # plot values of invidual chains, arr[:, idx]
# # where arr[:, idx].shape = [ndraws, 1]
# ax.plot(steps, val
# alpha=0.5, lw=lw/2., **plot_kwargs)
if title is not None:
fig = plt.gcf()
_ = fig.suptitle(title)
if logfreq is not None:
ax = plt.gca()
xticks = ax.get_xticks() # type: ignore
_ = ax.set_xticklabels( # type: ignore
[f"{logfreq * int(i)}" for i in xticks] # type: ignore
)
save_dir: Optional[Path]
if outdir is not None:
save_dir = Path(outdir).expanduser().resolve()
elif self.report_enabled:
save_dir = self._report_dir.joinpath("dataarray")
else:
save_dir = None
primary_asset: Optional[Path] = None
if save_dir is not None:
dirs = {
"png": Path(save_dir).joinpath("pngs/"),
"svg": Path(save_dir).joinpath("svgs/"),
}
_ = [i.mkdir(exist_ok=True, parents=True) for i in dirs.values()]
if verbose:
logger.info(
f"Saving {key} plot to: {Path(save_dir).resolve()}"
)
for ext, d in dirs.items():
outfile = d.joinpath(f"{key}.{ext}")
plt.savefig(outfile, dpi=400, bbox_inches="tight")
if primary_asset is None and ext == "png":
primary_asset = outfile
if (
self.report_enabled
and primary_asset is not None
and Path(primary_asset).exists()
):
metadata = {"dims": list(val.dims)}
self._write_plot_report(
key,
primary_asset,
kind="dataarray",
metadata=metadata,
)
self._wandb_log_matplotlib_asset(key, primary_asset, kind="dataarray")
return (fig, subfigs, axes)
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def plot_dataset(
self,
title: Optional[str] = None,
nchains: Optional[int] = None,
outdir: Optional[os.PathLike] = None,
dataset: Optional[xr.Dataset] = None,
data: Optional[dict] = None,
warmup: Optional[int | float] = None,
# subplots_kwargs: Optional[dict[str, Any]] = None,
# plot_kwargs: Optional[dict[str, Any]] = None,
):
dataset = (
dataset
if dataset is not None
else (
self.get_dataset(
data=(data if data is not None else self.history),
warmup=warmup,
)
)
)
return ezplot.plot_dataset(
dataset=dataset,
nchains=nchains,
title=title,
outdir=outdir,
)
def plot_2d_xarr(
self,
xarr: xr.DataArray,
label: Optional[str] = None,
num_chains: Optional[int] = None,
title: Optional[str] = None,
outdir: Optional[os.PathLike] = None,
subplots_kwargs: Optional[dict[str, Any]] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
):
import matplotlib.pyplot as plt
import seaborn as sns
LW = plt.rcParams.get("axes.linewidth", 1.75)
plot_kwargs = {} if plot_kwargs is None else plot_kwargs
subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs
assert len(xarr.shape) == 2
assert "draw" in xarr.coords and "chain" in xarr.coords
num_chains = len(xarr.chain) if num_chains is None else num_chains
# _ = subplots_kwargs.pop('constrained_layout', True)
figsize = plt.rcParams.get("figure.figsize", (8, 6))
figsize = (3 * figsize[0], 1.5 * figsize[1])
fig = plt.figure(figsize=figsize, constrained_layout=True)
subfigs = fig.subfigures(1, 2)
gs_kw = {"width_ratios": [1.33, 0.33]}
(ax, ax1) = subfigs[1].subplots(1, 2, sharey=True, gridspec_kw=gs_kw)
ax.grid(alpha=0.2)
ax1.grid(False)
color = plot_kwargs.get("color", f"C{np.random.randint(6)}")
label = r"$\langle$" + f" {label} " + r"$\rangle$"
ax.plot(
xarr.draw.values,
xarr.mean("chain"),
color=color,
lw=1.5 * LW,
label=label,
**plot_kwargs,
)
for idx in range(num_chains):
# ax = subfigs[0].subplots(1, 1)
# plot values of invidual chains, arr[:, idx]
# where arr[:, idx].shape = [ndraws, 1]
# ax0.plot(
# xarr.draw.values,
# xarr[xarr.chain == idx][0],
# lw=1.,
# alpha=0.7,
# color=color
# )
ax.plot(
xarr.draw.values,
xarr[xarr.chain == idx][0],
color=color,
alpha=0.5,
lw=LW / 2.0,
**plot_kwargs,
)
axes = (ax, ax1)
sns.kdeplot(y=xarr.values.flatten(), ax=ax1, color=color, shade=True)
ax1.set_xticks([])
ax1.set_xticklabels([])
# ax1.set_yticks([])
# ax1.set_yticklabels([])
sns.despine(ax=ax, top=True, right=True)
sns.despine(ax=ax1, top=True, right=True, left=True, bottom=True)
# ax.legend(loc='best', frameon=False)
ax1.set_xlabel("")
# ax1.set_ylabel('')
# ax.set_yticks(ax.get_yticks())
# ax.set_yticklabels(ax.get_yticklabels())
# ax.set_ylabel(key)
# _ = subfigs[1].subplots_adjust(wspace=-0.75)
# if num_chains > 0 and len(arr.shape) > 1:
# lw = LW / 2.
# num_chains = np.min([
# 16,
# len(xarr.coords['chain']),
# ])
sns.despine(subfigs[0])
ax0 = subfigs[0].subplots(1, 1)
im = xarr.plot(ax=ax0) # type:ignore
im.colorbar.set_label(label) # type:ignore
# ax0.plot(
# xarr.draw.values,
# xarr.mean('chain'),
# lw=2.,
# color=color
# )
# for idx in range(min(num_chains, i.shape[1])):
ax.set_xlabel("draw")
if title is not None:
fig.suptitle(title)
if outdir is not None:
assert label is not None
# plt.savefig(Path(outdir).joinpath(f'{key}.svg'),
# dpi=400, bbox_inches='tight')
outfile = Path(outdir).joinpath(f"{label}.svg")
if outfile.is_file():
tstamp = get_timestamp("%Y-%m-%d-%H%M%S")
pngdir = Path(outdir).joinpath("pngs")
pngdir.mkdir(exist_ok=True, parents=True)
pngfile = pngdir.joinpath(f"{label}-{tstamp}.png")
svgfile = Path(outdir).joinpath(f"{label}-{tstamp}.svg")
plt.savefig(pngfile, dpi=400, bbox_inches="tight")
plt.savefig(svgfile, dpi=400, bbox_inches="tight")
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def tplot_all(
self,
outdir: Optional[os.PathLike] = None,
warmup: Optional[float] = 0.0,
append: bool = True,
xkey: Optional[str] = None,
dataset: Optional[xr.Dataset] = None,
data: Optional[dict] = None,
logfreq: Optional[int] = None,
plot_type: Optional[str] = None,
verbose: bool = False,
):
dataset = (
dataset
if dataset is not None
else (
self.get_dataset(
data=(data if data is not None else self.history),
warmup=warmup,
)
)
)
outdir_path = Path(os.getcwd()) if outdir is None else Path(outdir)
groups = self._group_metric_variables(dataset)
for metric_name, metric_vars in sorted(groups.items()):
parts = metric_name.replace("\\", "/").split("/")
last = parts[-1] if parts else metric_name
if last in {"iter", "epoch", "step", "batch", "idx", "bidx"}:
continue
if any(
last.endswith(f"_{token}")
for token in ("iter", "epoch", "step", "batch", "idx", "bidx")
):
continue
if (xkey is not None and metric_name == xkey) or xkey in [
"iter",
"draw",
]:
continue
self._tplot_metric_group(
metric_name,
metric_vars,
warmup=warmup,
outdir=outdir_path,
plot_type=plot_type,
verbose=verbose,
logfreq=logfreq,
)
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def plot_all(
self,
num_chains: int = 128,
warmup: Optional[float | int] = 0.0,
title: Optional[str] = None,
verbose: bool = False,
outdir: Optional[os.PathLike] = None,
subplots_kwargs: Optional[dict[str, Any]] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
dataset: Optional[xr.Dataset] = None,
data: Optional[dict] = None,
):
plot_kwargs = {} if plot_kwargs is None else dict(plot_kwargs)
subplots_kwargs = (
{} if subplots_kwargs is None else dict(subplots_kwargs)
)
dataset = (
dataset
if dataset is not None
else (
self.get_dataset(
data=(data if data is not None else self.history),
warmup=warmup,
)
)
)
_ = ezplot.make_ridgeplots(
dataset,
outdir=outdir,
drop_nans=True,
drop_zeros=False,
num_chains=num_chains,
cmap="viridis",
save_plot=(outdir is not None),
)
groups = self._group_metric_variables(dataset)
for idx, (metric_name, metric_vars) in enumerate(
sorted(groups.items())
):
plot_kwargs["color"] = f"C{idx % 9}"
asset = self._plot_metric_group(
metric_name,
metric_vars,
warmup=warmup,
title=title,
outdir=Path(outdir) if outdir is not None else None,
subplots_kwargs=subplots_kwargs,
plot_kwargs=plot_kwargs,
verbose=verbose,
)
self._wandb_log_matplotlib_asset(
metric_name, asset, kind="matplotlib"
)
if asset is not None and self.report_enabled and asset.exists():
components = sorted(metric_vars.keys())
sample_series = self._series_from_dataarray(
metric_vars[components[0]]
)
self._write_plot_report(
metric_name,
asset,
kind="matplotlib",
metadata={
"components": ", ".join(components),
"points": len(sample_series),
},
)
return dataset
def history_to_dict(self) -> dict:
# return {k: np.stack(v).squeeze() for k, v in self.history.items()}
return {
k: torch.Tensor(v).numpy(force=True)
for k, v in self.history.items()
}
def to_DataArray(
self,
x: Union[list, np.ndarray, torch.Tensor],
warmup: Optional[float] = 0.0,
) -> xr.DataArray:
if isinstance(x, tuple):
x = list(x)
if (
isinstance(x, list)
and len(x) > 0
and isinstance(x[0], torch.Tensor)
):
x = torch.Tensor(x).numpy(force=True)
try:
arr = grab_tensor(x)
except ValueError:
arr = np.array(x).real
# arr = np.array(x)
logger.info(f"len(x): {len(x)}")
x0_shape = getattr(x[0], "shape", None) if len(x) > 0 else None
logger.info(f"x[0].shape: {x0_shape}")
logger.info(f"arr.shape: {arr.shape}")
assert isinstance(arr, np.ndarray)
if warmup is not None and warmup > 0 and len(arr) > 0:
if isinstance(warmup, int):
warmup = warmup / len(arr)
# drop = int(warmup * arr.shape[0])
drop = int(warmup * len(arr))
arr = arr[drop:]
# steps = np.arange(len(arr))
if len(arr.shape) == 1: # [ndraws]
ndraws = arr.shape[0]
dims = ["draw"]
coords = [np.arange(len(arr))]
return xr.DataArray(arr, dims=dims, coords=coords) # type:ignore
if len(arr.shape) == 2: # [nchains, ndraws]
arr = arr.T
nchains, ndraws = arr.shape
dims = ("chain", "draw")
coords = [np.arange(nchains), np.arange(ndraws)]
return xr.DataArray(arr, dims=dims, coords=coords) # type:ignore
if len(arr.shape) == 3: # [nchains, nlf, ndraws]
arr = arr.T
nchains, nlf, ndraws = arr.shape
dims = ("chain", "leapfrog", "draw")
coords = [np.arange(nchains), np.arange(nlf), np.arange(ndraws)]
return xr.DataArray(arr, dims=dims, coords=coords) # type:ignore
else:
print(f"arr.shape: {arr.shape}")
raise ValueError("Invalid shape encountered")
def get_dataset(
self,
data: Optional[
dict[str, Union[list, np.ndarray, torch.Tensor]]
] = None,
warmup: Optional[float] = 0.0,
):
data = self.history_to_dict() if data is None else data
data_vars = {}
for key, val in data.items():
name = key.replace("/", "_")
try:
data_vars[name] = self.to_DataArray(val, warmup)
except ValueError:
logger.error(
f"Unable to create DataArray for {key}! Skipping!"
)
logger.error(f"{key}.shape= {np.stack(val).shape}") # type:ignore
return xr.Dataset(data_vars)
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def save_dataset(
self,
outdir: PathLike,
fname: str = "dataset",
use_hdf5: bool = True,
data: Optional[
dict[str, Union[list, np.ndarray, torch.Tensor]]
] = None,
dataset: Optional[xr.Dataset] = None,
warmup: Optional[int | float] = None,
**kwargs,
) -> Path:
dataset = (
dataset
if dataset is not None
else (
self.get_dataset(
data=(data if data is not None else self.history),
warmup=warmup,
)
)
)
if (
ENABLE_WANDB
and dataset is not None
and wandb is not None
and wandb.run is not None
):
dset_name = f"{fname}_dataset" if fname != "dataset" else fname
try:
wandb.log(
{f"{dset_name}": wandb.Table(dataset.to_dataframe())}
)
except Exception:
logger.warning("Unable to save dataset to W&B, skipping!")
return save_dataset(
dataset,
outdir=outdir,
fname=fname,
use_hdf5=use_hdf5,
**kwargs,
)
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def finalize(
self,
outdir: Optional[PathLike] = None,
run_name: Optional[str] = None,
dataset_fname: Optional[str] = None,
num_chains: int = 128,
warmup: Optional[int | float] = 0.0,
verbose: bool = False,
save: bool = True,
plot: bool = True,
append_tplot: bool = True,
title: Optional[str] = None,
data: Optional[
dict[str, Union[list, np.ndarray, torch.Tensor]]
] = None,
dataset: Optional[xr.Dataset] = None,
xkey: Optional[str] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
subplots_kwargs: Optional[dict[str, Any]] = None,
tplot_type: Optional[str] = None,
env_info: Optional[dict[str, Any]] = None,
) -> xr.Dataset:
dataset = (
dataset
if dataset is not None
else (
self.get_dataset(
data=(data if data is not None else self.history),
warmup=warmup,
)
)
)
run_name = (
f"History-{get_timestamp()}" if run_name is None else run_name
)
if outdir is None:
base_dir = (
Path(os.getcwd())
.joinpath("outputs", run_name, get_timestamp())
.expanduser()
.resolve()
)
else:
base_dir = Path(outdir).expanduser().resolve()
base_dir.mkdir(parents=True, exist_ok=True)
self._configure_report_destination(base_dir)
env_details = (
env_info
if env_info is not None
else self._default_environment_info()
)
self._write_environment_section(env_details)
self._write_metric_summary(dataset)
if plot:
logger.info(
"Saving plots to %s (matplotlib) and %s (tplot)",
base_dir.joinpath("plots", "mplot"),
base_dir.joinpath("plots", "tplot"),
)
plotdir = base_dir.joinpath("plots")
tplotdir = plotdir.joinpath("tplot")
mplotdir = plotdir.joinpath("mplot")
tplotdir.mkdir(exist_ok=True, parents=True)
mplotdir.mkdir(exist_ok=True, parents=True)
_ = self.plot_all(
dataset=dataset,
outdir=mplotdir,
verbose=verbose,
num_chains=num_chains,
warmup=warmup,
title=title,
plot_kwargs=plot_kwargs,
subplots_kwargs=subplots_kwargs,
)
_ = self.tplot_all(
dataset=dataset,
outdir=tplotdir,
warmup=warmup,
append=append_tplot,
plot_type=tplot_type,
xkey=xkey,
verbose=verbose,
)
if save:
try:
import h5py
use_hdf5 = True
except ImportError:
logger.warning(
"h5py not found! Saving dataset as netCDF instead."
)
use_hdf5 = False
fname = "dataset" if dataset_fname is None else dataset_fname
_ = self.save_dataset(
dataset=dataset,
outdir=base_dir,
fname=fname,
use_hdf5=use_hdf5,
)
if self.report_enabled:
logger.info(
"Saving history report to %s",
base_dir.joinpath(self._report_filename),
)
return dataset