Skip to content

ezpz.examples.fsdp_tpΒΆ

ezpz/examples/fsdp_tp.py

2D tensor/sequence parallel + FSDP training demo on a Llama-style model.

Sam Foreman 2025-09-08

Modified from: https://pytorch.org/tutorials/intermediate/TP_tutorial.html

This is the script to test 2D Parallel which combines Tensor/Sequence parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a example Llama2 model. We show an E2E working flow from forward, backward and optimization.

We enabled Fully Sharded Data Parallel + Tensor Parallel in separate parallel dimensions: Data Parallel ("dp") across hosts Tensor Parallel ("tp") within each host

We use a simple diagram to illustrate below:

+-----.-----+-----+-----+ | 0 | 1 | 2 | 3 | | | | | | +-----+-----+-----+-----+ | 4 | 5 | 6 | 7 | | | | | | +-----+-----+-----+-----+ | 8 | 9 | 10 | 11 | | | | | | +-----+-----+-----+-----+

+----------+ +------------+ +----------+ +------------+ | Host 1 | | Host 2 | | | | Host N | | 8 GPUs | | 8 GPUs | | | | 8 GPUs | | | | | | ... | | | | (TP) | | (TP) | | | | (TP) | |[0,1,..,7]| | [8,9..,15] | | | | [8N-8,8N-7 | | | | | | | | .., 8N-1] | | | | | | | | | +----------+ +------------+ +----------+ +------------+

  • FSDP:

[0, 8, ..., 8N-8], [1, 9, ..., 8N-7], ..., [7, 15, ..., 8N-1]

Launch with:

1
ezpz launch -m ezpz.examples.fsdp_tp --tp 2 --batch-size 8

Help output (python3 -m ezpz.examples.fsdp_tp --help):

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
usage: fsdp_tp.py [-h] [--dim DIM] [--n-layers N_LAYERS] [--n-heads N_HEADS]
                  [--n-kv-heads N_KV_HEADS] [--multiple-of MULTIPLE_OF]
                  [--ffn-dim-multiplier FFN_DIM_MULTIPLIER]
                  [--norm-eps NORM_EPS] [--vocab-size VOCAB_SIZE]
                  [--seq-length SEQ_LENGTH] [--lr LR] [--epochs EPOCHS]
                  [--batch-size BATCH_SIZE]
                  [--test-batch-size TEST_BATCH_SIZE]
                  [--num-workers NUM_WORKERS] [--seed SEED] [--tp TP]
                  [--sharding-strategy SHARDING_STRATEGY]
                  [--max-grad-norm MAX_GRAD_NORM] [--outdir OUTDIR]
                  [--dataset DATASET] [--tokenizer_name TOKENIZER_NAME]
                  [--model_name_or_path MODEL_NAME_OR_PATH]
                  [--hf-split HF_SPLIT] [--hf-text-column HF_TEXT_COLUMN]
                  [--hf-limit HF_LIMIT] [--seq-len SEQ_LEN]
                  [--max-seq-len MAX_SEQ_LEN] [--depth-init DEPTH_INIT]
                  [--fp32]

2D Parallel Training

options:
  -h, --help            show this help message and exit
  --dim DIM
  --n-layers N_LAYERS
  --n-heads N_HEADS
  --n-kv-heads N_KV_HEADS
  --multiple-of MULTIPLE_OF
  --ffn-dim-multiplier FFN_DIM_MULTIPLIER
  --norm-eps NORM_EPS
  --vocab-size VOCAB_SIZE
  --seq-length SEQ_LENGTH
  --lr LR
  --epochs EPOCHS
  --batch-size BATCH_SIZE
  --test-batch-size TEST_BATCH_SIZE
  --num-workers NUM_WORKERS
  --seed SEED
  --tp TP
  --sharding-strategy SHARDING_STRATEGY
  --max-grad-norm MAX_GRAD_NORM
  --outdir OUTDIR
  --dataset DATASET
  --tokenizer_name TOKENIZER_NAME
  --model_name_or_path MODEL_NAME_OR_PATH
  --hf-split HF_SPLIT, --hf_split HF_SPLIT
                        Dataset split to load.
  --hf-text-column HF_TEXT_COLUMN, --hf_text_column HF_TEXT_COLUMN
                        Column containing raw text in the dataset.
  --hf-limit HF_LIMIT, --hf_limit HF_LIMIT
                        Number of rows to sample from the HF dataset for quick
                        experiments.
  --seq-len SEQ_LEN
  --max-seq-len MAX_SEQ_LEN
  --depth-init DEPTH_INIT
  --fp32                Disable mixed precision (use fp32) for debugging NaNs.

The remaining comments outline the parallel layout used to combine TP/SP with FSDP.

main(args) ΒΆ

Entrypoint to set up distributed context and dispatch training.

Source code in src/ezpz/examples/fsdp_tp.py
def main(args: argparse.Namespace) -> int:
    """Entrypoint to set up distributed context and dispatch training."""
    rank = ezpz.dist.setup_torch(tensor_parallel_size=args.tp, seed=args.seed)
    if rank == 0:
        outdir = args.outdir if args.outdir is not None else OUTDIR
    else:
        outdir = None
    outdir = ezpz.dist.broadcast(outdir, root=0)
    logger.info(f"Using {outdir=}")
    train(args=args, outdir=outdir)
    return 0

parallelize(model, device_mesh, mixed_precision, sharding_strategy=None) ΒΆ

Wrap the model with tensor-parallel and FSDP sharding strategies.

Source code in src/ezpz/examples/fsdp_tp.py
def parallelize(
    model: nn.Module,
    device_mesh: DeviceMesh,
    mixed_precision: Optional[MixedPrecision],
    sharding_strategy: Optional[ShardingStrategy | str] = None,
) -> nn.Module:
    """Wrap the model with tensor-parallel and FSDP sharding strategies."""
    tp_mesh = device_mesh["tp"]
    dp_mesh = device_mesh["dp"]

    if isinstance(sharding_strategy, str):
        sharding_strategy = SHARDING_STRATEGIES.get(sharding_strategy, None)

    model.init_weights()  # type: ignore
    model = parallelize_module(
        model,
        tp_mesh,
        {
            "tok_embeddings": RowwiseParallel(
                input_layouts=Replicate(),
                output_layouts=Shard(1),
            ),
            "norm": SequenceParallel(),
            "output": ColwiseParallel(
                input_layouts=Shard(1),
                output_layouts=Replicate(),
                # use DTensor as the output
                # use_local_output=False,
            ),
        },
    )

    assert isinstance(model.layers, Iterable)
    for _, transformer_block in enumerate(model.layers):
        layer_tp_plan = {
            "attention_norm": SequenceParallel(),
            "attention": PrepareModuleInput(
                input_layouts=(Shard(1), None),  # type:ignore
                desired_input_layouts=(Replicate(), None),  # type:ignore
            ),
            "attention.wq": ColwiseParallel(),
            "attention.wk": ColwiseParallel(),
            "attention.wv": ColwiseParallel(),
            "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
            "ffn_norm": SequenceParallel(),
            "feed_forward": PrepareModuleInput(
                input_layouts=(Shard(1),),
                desired_input_layouts=(Replicate(),),
            ),
            "feed_forward.w1": ColwiseParallel(),
            "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
            "feed_forward.w3": ColwiseParallel(),
        }

        attn_layer = transformer_block.attention  # type: ignore
        attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
        attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
        parallelize_module(
            module=transformer_block,  # type: ignore
            device_mesh=tp_mesh,
            parallelize_plan=layer_tp_plan,
        )

    # from torch.distributed.fsdp import fully_shard

    # ShardingStrategy.NO_SHARD: HandleShardingStrategy.NO_SHARD,
    # ShardingStrategy.FULL_SHARD: HandleShardingStrategy.FULL_SHARD,
    # ShardingStrategy.SHARD_GRAD_OP: HandleShardingStrategy.SHARD_GRAD_OP,
    # ShardingStrategy.HYBRID_SHARD: HandleShardingStrategy.HYBRID_SHARD,
    # ShardingStrategy._HYBRID_SHARD_ZERO2: HandleShardingStrategy._HYBRID_SHARD_ZERO2,
    sharded_model = FSDP(
        model,
        mixed_precision=mixed_precision,
        device_mesh=dp_mesh,
        sharding_strategy=sharding_strategy,
    )
    logger.info(f"Model after parallelization:\n{sharded_model=}\n")
    return sharded_model

parse_args() ΒΆ

CLI parser for 2D parallel (TP/SP + FSDP) training.

Source code in src/ezpz/examples/fsdp_tp.py
def parse_args():
    """CLI parser for 2D parallel (TP/SP + FSDP) training."""
    parser = argparse.ArgumentParser(description="2D Parallel Training")
    parser.add_argument("--dim", type=int, default=256)
    parser.add_argument("--n-layers", type=int, default=32)
    parser.add_argument("--n-heads", type=int, default=32)
    parser.add_argument("--n-kv-heads", type=int, default=4)
    parser.add_argument("--multiple-of", type=int, default=360)
    parser.add_argument("--ffn-dim-multiplier", type=float, default=None)
    parser.add_argument("--norm-eps", type=float, default=1e-5)
    parser.add_argument("--vocab-size", type=int, default=32_000)
    parser.add_argument("--seq-length", type=int, default=2048)
    parser.add_argument("--lr", type=float, default=3e-3)
    parser.add_argument("--epochs", type=int, default=5)
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--test-batch-size", type=int, default=1000)
    parser.add_argument("--num-workers", type=int, default=0)
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--tp", type=int, default=2)
    parser.add_argument("--sharding-strategy", type=str, default="full_shard")
    parser.add_argument("--max-grad-norm", type=float, default=1.0)
    parser.add_argument("--outdir", type=str, default="outputs/fsdp_tp")
    # parser.add_argument('--dataset', type=str, default='random')
    parser.add_argument(
        "--dataset", type=str, default="eliplutchok/fineweb-small-sample"
    )
    parser.add_argument(
        "--tokenizer_name", type=str, default="meta-llama/llama-2-7b-hf"
    )
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--hf-split",
        "--hf_split",
        type=str,
        default="train",
        help="Dataset split to load.",
    )
    parser.add_argument(
        "--hf-text-column",
        "--hf_text_column",
        type=str,
        default="text",
        help="Column containing raw text in the dataset.",
    )
    parser.add_argument(
        "--hf-limit",
        "--hf_limit",
        type=int,
        default=512,
        help="Number of rows to sample from the HF dataset for quick experiments.",
    )
    # parser.add_argument('--max_batch_size', type=int, default=None)
    parser.add_argument(
        "--seq-len", type=int, default=int(os.environ.get("SEQ_LEN", 1024))
    )
    parser.add_argument("--max-seq-len", type=int, default=32768)
    parser.add_argument("--depth-init", type=bool, default=True)
    parser.add_argument(
        "--fp32",
        action="store_true",
        help="Disable mixed precision (use fp32) for debugging NaNs.",
    )
    # max_batch_size: int = 32
    # max_seq_len: int = 32768
    # depth_init: bool = True
    return parser.parse_args()

train(args, outdir) ΒΆ

Run TP/SP + FSDP training and optionally log metrics.

Source code in src/ezpz/examples/fsdp_tp.py
def train(
    args: argparse.Namespace,
    outdir: Path | str | os.PathLike,
) -> int:
    """Run TP/SP + FSDP training and optionally log metrics."""
    world_size = ezpz.dist.get_world_size()
    assert world_size % args.tp == 0, "WORLD_SIZE must be divisible by TP"
    dpsize = world_size // args.tp
    device_mesh = init_device_mesh(
        str(ezpz.get_torch_device()),
        (dpsize, args.tp),
        mesh_dim_names=("dp", "tp"),
    )
    logger.info(f"Device mesh created:\n{device_mesh=}")

    hf_dataset = None
    hf_tokenizer = None
    if args.dataset.lower() not in {"mnist", "random"}:
        from ezpz.data.hf import get_hf_text_dataset

        seed = int(os.environ.get("EZPZ_HF_SAMPLE_SEED", "1337"))
        hf_dataset, hf_tokenizer = get_hf_text_dataset(
            dataset_name=args.dataset,
            split=args.hf_split,
            text_column=args.hf_text_column,
            tokenizer_name=args.tokenizer_name,
            seq_len=args.seq_len,
            limit=args.hf_limit,
            seed=seed,
        )
        if hf_tokenizer.vocab_size != args.vocab_size:
            logger.warning(
                "Overriding vocab_size from %s to tokenizer vocab_size=%s",
                args.vocab_size,
                hf_tokenizer.vocab_size,
            )
            args.vocab_size = hf_tokenizer.vocab_size

    config = ModelArgs(
        dim=args.dim,
        n_layers=args.n_layers,
        n_heads=args.n_heads,
        n_kv_heads=args.n_kv_heads,
        batch_size=args.batch_size,
        vocab_size=args.vocab_size,
        multiple_of=args.multiple_of,
    )
    logger.info(f"config:\n{config}")
    metrics_every = int(os.environ.get("EZPZ_METRICS_EVERY", "1"))
    track_logits = os.environ.get("EZPZ_TRACK_LOGITS", "0") == "1"
    track_hist = os.environ.get("EZPZ_TRACK_HIST", "0") == "1"
    track_act_hist = os.environ.get("EZPZ_TRACK_ACT_HIST", "1") == "1"
    hist_bins = int(os.environ.get("EZPZ_HIST_BINS", "64"))
    hist_samples = int(os.environ.get("EZPZ_HIST_SAMPLES", "20000"))
    dataset_tag = args.dataset.lower().replace("/", "_")
    if ezpz.get_rank() == 0 and not os.environ.get("WANDB_DISABLED", False):
        run = ezpz.dist.setup_wandb(project_name=WBPROJ_NAME)
        if wandb is not None:
            assert run is not None and run is wandb.run
            from dataclasses import asdict

            wandb.config.update(ezpz.get_dist_info())
            wandb.config.update(asdict(config))  # type:ignore

    device_type = str(ezpz.get_torch_device(as_torch_device=False))
    device_id = f"{device_type}:{ezpz.get_local_rank()}"
    model = Transformer.from_model_args(config)
    mstr = summarize_model(
        model,
        verbose=False,
        depth=2,
        # input_size=(
        #     torch.tensor((int(args.batch_size), int(args.seq_length))).to(
        #         torch.long
        #     )
        # ).shape,
    )
    logger.info(f"\n{mstr}")
    model.to(device_id)
    mp_config: Optional[MixedPrecision] = None
    if not args.fp32:
        mp_config = MixedPrecision(
            param_dtype=torch.bfloat16,
            cast_forward_inputs=True,
            reduce_dtype=torch.float32,
        )
    model = parallelize(
        model,
        device_mesh,
        mp_config,
        sharding_strategy=args.sharding_strategy,
    )
    base_model = model
    if not hasattr(base_model, "layers"):
        base_model = getattr(model, "_fsdp_wrapped_module", model)
    act_activations: dict[str, torch.Tensor] = {}
    act_handles: list[torch.utils.hooks.RemovableHandle] = []
    if track_hist and track_act_hist and ezpz.get_rank() == 0:
        hist_layers_spec = os.environ.get(
            "EZPZ_HIST_LAYERS", f"0,{config.n_layers - 1}"
        )
        layer_ids = _parse_hist_layers(hist_layers_spec, config.n_layers)
        act_activations, act_handles = _register_activation_hooks(
            base_model, layer_ids
        )
    logger.info(f"Creating optimizer=AdamW with lr={args.lr}")

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, foreach=True)

    device = ezpz.get_torch_device(as_torch_device=False)

    tp_group = device_mesh.get_group("tp")
    if args.dataset.lower() == "mnist":
        data_prefix = Path(os.getcwd()).joinpath(
            ".cache", "ezpz", "data", f"{args.dataset.lower()}"
        )
        from ezpz.data.vision import get_mnist
        from ezpz.data.distributed import TPBroadcastDataLoader

        data = get_mnist(
            outdir=Path(data_prefix),
            train_batch_size=args.batch_size,
            test_batch_size=args.test_batch_size,
            num_replicas=dpsize,
            rank=device_mesh.get_local_rank("dp"),
            pin_memory=True,
            num_workers=args.num_workers,
        )
        dataset = data["dataset"]
        sampler = data["sampler"]
        dataloader = data["dataloader"]
        if args.tp > 1:
            dataloader = TPBroadcastDataLoader(dataloader, tp_group)
    elif args.dataset.lower() == "random":
        from ezpz.data.distributed import get_random_dataset_fsdp_tp

        data = get_random_dataset_fsdp_tp(
            batch_size=args.batch_size,
            vocab_size=args.vocab_size,
            seq_length=args.seq_length,
            dp_group=device_mesh.get_group("dp"),
            tp_group=tp_group,
            broadcast_within_tp=True,
            drop_last=True,
        )
        dataset = data["dataset"]
        sampler = data["sampler"]
        dataloader = data["dataloader"]
    # if args.dataset.lower() != "random":
    else:
        from ezpz.data.distributed import TPBroadcastDataLoader

        assert hf_dataset is not None
        dataset = hf_dataset
        sampler = (
            DistributedSampler(
                dataset=dataset,
                num_replicas=dpsize,
                rank=device_mesh.get_local_rank("dp"),
            )
            if ezpz.get_world_size() > 1
            else None
        )
        dataloader = DataLoader(
            dataset,
            sampler=sampler,
            batch_size=args.batch_size,
            shuffle=(sampler is None),
            drop_last=False,
        )
        if args.tp > 1:
            dataloader = TPBroadcastDataLoader(dataloader, tp_group)

    # ezpz.breakpoint(0)
    logger.info("Starting 2D training...")
    model.train()

    # outdir = Path(args.outdir).joinpath(ezpz.utils.get_timestamp())
    metrics_path = Path(outdir).joinpath(
        f"metrics-{ezpz.dist.get_rank()}.jsonl"
    )
    Path(outdir).mkdir(parents=True, exist_ok=True)
    history = ezpz.history.History(
        report_dir=outdir,
        report_enabled=True,
        jsonl_path=metrics_path,
        jsonl_overwrite=True,
        distributed_history=(
            1 < world_size <= 384  # and not config.pytorch_profiler
        ),
    )

    # For TP, input needs to be the same across all TP ranks.
    # while for SP, input can be different across all ranks
    # We will use dp_rank for setting the random seed
    # to mimic the behavior of the dataloader
    # x = torch.tensor((args.batch_size, args.seq_len))
    x = torch.tensor(0)
    global_step = 0
    for epoch in range(args.epochs):
        if sampler is not None:
            sampler.set_epoch(epoch)
        for idx, batch in enumerate(dataloader):
            ezpz.dist.synchronize()
            t0 = perf_counter()
            attn_mask = None
            if isinstance(batch, dict) and "input_ids" in batch:
                x = batch["input_ids"]
                attn_mask = batch.get("attention_mask")
            else:
                x = batch
            assert isinstance(x, torch.Tensor)
            x = x.to(device_id)
            x = x.to(torch.long)
            if args.dataset == "random":
                inp = x[:, :-1]
                labels = x[:, 1:]
            else:
                inp = x[:, :-1]
                labels = x[:, 1:]
            inp = inp.to(device_id)
            labels = labels.to(device_id)
            if attn_mask is not None:
                attn_mask = attn_mask.to(device_id)
            pred = model(inp)
            local_seq_len = pred.shape[1]
            if labels.shape[1] != local_seq_len:
                labels = _slice_for_sequence_parallel(labels, local_seq_len)
            if attn_mask is not None:
                if attn_mask.shape[1] > 1:
                    attn_labels = attn_mask[:, 1:]
                else:
                    attn_labels = attn_mask
                if attn_labels.shape[1] != local_seq_len:
                    attn_labels = _slice_for_sequence_parallel(
                        attn_labels, local_seq_len
                    )
                labels = labels.clone()
                labels[attn_labels == 0] = -100
            pad_id = getattr(dataset, "pad_id", None)
            if pad_id is not None:
                labels = labels.clone()
                labels[labels == int(pad_id)] = -100
            ezpz.dist.synchronize()
            t1 = perf_counter()
            tp_mod = getattr(ezpz, "tp", None)
            tp_rank = (
                getattr(tp_mod, "get_tensor_parallel_rank", lambda: 0)()
                if tp_mod is not None
                else 0
            )
            if epoch == 0 and idx == 0:
                pred_finite = torch.isfinite(pred)
                pred_nonfinite = int((~pred_finite).sum().item())
                pred_max = float(pred.abs().max().item())
                logger.info(
                    "pred_stats rank=%s tp=%s shape=%s nonfinite=%s max_abs=%s",
                    ezpz.get_rank(),
                    tp_rank,
                    tuple(pred.shape),
                    pred_nonfinite,
                    f"{pred_max:.6f}",
                )
            loss = F.cross_entropy(
                pred.reshape(-1, pred.size(-1)),
                labels.reshape(-1),
                ignore_index=-100,
            )
            if epoch == 0 and idx == 0:
                valid_labels = int((labels != -100).sum().item())
                logger.info(
                    "loss_inputs rank=%s tp=%s local_seq_len=%s labels=%s valid_labels=%s",
                    ezpz.get_rank(),
                    tp_rank,
                    local_seq_len,
                    tuple(labels.shape),
                    valid_labels,
                )
                # loss = F.cross_entropy(
                #     pred.flatten(0, 1),
                #     labels.flatten(0, 1),
                # )
                # loss = output.loss
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            grad_norm_preclip = None
            if args.max_grad_norm > 0:
                grad_norm_preclip = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.max_grad_norm
                )
            optimizer.step()
            ezpz.dist.synchronize()
            t2 = perf_counter()
            global_step += 1
            metrics: dict[str, object] = {
                "train/iter": global_step,
                "train/epoch": epoch,
                "train/bidx": idx,
                "train/loss": loss.item(),
                "train/dt": t2 - t0,
                "train/dtf": t1 - t0,
                "train/dtb": t2 - t1,
            }
            if grad_norm_preclip is not None:
                metrics["grad/norm_preclip"] = float(grad_norm_preclip)
            if global_step % max(metrics_every, 1) == 0:
                metrics.update(_collect_param_grad_stats(model, device_id))
                metrics["opt/iter"] = (global_step,)
                metrics["opt/lr"] = float(optimizer.param_groups[0]["lr"])
                metrics["input/iter"] = (global_step,)
                metrics["input/max"] = float(x.max().item())
                metrics["input/min"] = float(x.min().item())
                metrics["labels/valid"] = float((labels != -100).sum().item())
                if track_logits:
                    pred_finite = torch.isfinite(pred)
                    metrics["logits/nonfinite"] = float(
                        (~pred_finite).sum().item()
                    )
                    metrics["logits/max_abs"] = float(pred.abs().max().item())
                if track_hist and ezpz.get_rank() == 0:
                    logits_sample = _sample_tensor_values(pred, hist_samples)
                    if logits_sample is not None:
                        logits_hist = _histogram_dict(logits_sample, hist_bins)
                        if logits_hist is not None:
                            metrics[f"hist/{dataset_tag}/logits"] = logits_hist
                    layer_grad_norms = _collect_layer_grad_norms(base_model)
                    if layer_grad_norms:
                        layer_grad_hist = _histogram_dict(
                            torch.tensor(layer_grad_norms), hist_bins
                        )
                        if layer_grad_hist is not None:
                            metrics[
                                f"hist/{dataset_tag}/grad_norm_per_layer"
                            ] = layer_grad_hist
                    if track_act_hist and act_activations:
                        for act_key, act_tensor in act_activations.items():
                            act_sample = _sample_tensor_values(
                                act_tensor, hist_samples
                            )
                            act_hist = _histogram_dict(act_sample, hist_bins)
                            if act_hist is not None:
                                metrics[
                                    f"hist/{dataset_tag}/activations/{act_key}"
                                ] = act_hist
                    _wandb_log_histograms(
                        metrics, step=global_step, enabled=track_hist
                    )
            history.update(metrics, summarize=False)
            history.log_metrics(
                metrics,
                logger=logger,
                debug_prefixes=("hist/",),
                include_summary=True,
                rank0_only_summary=True,
            )
            if epoch == 0 and idx == 0:
                logger.info(f"{x.shape}")
    if act_handles:
        for handle in act_handles:
            handle.remove()
    ezpz.dist.barrier()
    logger.info("Finished 2D training")
    if ezpz.get_rank() == 0:
        dataset = history.finalize(
            run_name=WBRUN_NAME,
            dataset_fname="train",
            warmup=0.1,
        )
        logger.info(f"{dataset=}")

    return 0