Skip to content

ezpz.data.distributedΒΆ

ezpz/data/distributed.py

TPBroadcastDataLoader ΒΆ

Wrapper that ensures only TP leader samples/loads, then broadcasts each batch to other TP ranks.

Source code in src/ezpz/data/distributed.py
class TPBroadcastDataLoader:
    """
    Wrapper that ensures only TP leader samples/loads, then broadcasts
    each batch to other TP ranks.
    """

    def __init__(
        self, dl: DataLoader, tp_group: torch.distributed.ProcessGroup
    ):
        self.dl = dl
        self.tp_group = tp_group
        self.leader = _tp_is_leader(tp_group)

    def __iter__(self) -> Iterator:
        it: Iterable = iter(self.dl) if self.leader else range(len(self.dl))
        # Non-leaders iterate dummy range to keep step counts aligned
        for maybe_batch in it:
            batch = maybe_batch if self.leader else None
            batch = _broadcast_batch(batch, self.tp_group)
            yield batch

    def __len__(self) -> int:
        return len(self.dl)

get_random_dataset_fsdp_tp(batch_size, vocab_size, seq_length, *, num_workers=0, pin_memory=True, dp_group=None, tp_group=None, broadcast_within_tp=False, drop_last=True, seed=1337) ΒΆ

Build dataset/sampler/dataloader for FSDP (DP) + Tensor Parallel (TP).

Key idea
  • Shard the dataset ONLY across the DP group (FSDP replica group).
  • Optionally broadcast each batch within TP so only TP-leader does I/O.

Parameters:

Name Type Description Default
dp_group Optional[ProcessGroup]

Process group that defines FSDP data-parallel replicas.

None
tp_group Optional[ProcessGroup]

Process group that defines tensor parallel group.

None
broadcast_within_tp bool

If True, TP leader loads and broadcasts batches.

False
drop_last bool

Prefer True for static shapes across DP replicas.

True
seed int

Base seed for shuffling (per-epoch add epoch to this).

1337

Returns:

Type Description
Dict[str, Any]

dict with 'dataset', 'sampler', 'dataloader'

Source code in src/ezpz/data/distributed.py
def get_random_dataset_fsdp_tp(
    batch_size: int,
    vocab_size: int,
    seq_length: int,
    *,
    num_workers: int = 0,
    pin_memory: bool = True,
    dp_group: Optional[torch.distributed.ProcessGroup] = None,
    tp_group: Optional[torch.distributed.ProcessGroup] = None,
    broadcast_within_tp: bool = False,
    drop_last: bool = True,
    seed: int = 1337,
) -> Dict[str, Any]:
    """
    Build dataset/sampler/dataloader for FSDP (DP) + Tensor Parallel (TP).

    Key idea:
      - Shard the dataset ONLY across the **DP group** (FSDP replica group).
      - Optionally broadcast each batch within TP so only TP-leader does I/O.

    Args:
      dp_group: Process group that defines FSDP data-parallel replicas.
      tp_group: Process group that defines tensor parallel group.
      broadcast_within_tp: If True, TP leader loads and broadcasts batches.
      drop_last: Prefer True for static shapes across DP replicas.
      seed: Base seed for shuffling (per-epoch add epoch to this).

    Returns:
      dict with 'dataset', 'sampler', 'dataloader'
    """
    from ezpz.data.text import RandomTokenDataset

    dset = RandomTokenDataset(vocab_size=vocab_size, seq_length=seq_length)

    use_dist = _is_dist()
    sampler = None

    if use_dist:
        # Determine DP rank/world_size; TP is ignored by the sampler.
        dp_rank, dp_world = _rank_ws(dp_group)
        # Important: num_replicas/rank are DP-based, not global.
        sampler = DistributedSampler(
            dset,
            num_replicas=dp_world,
            rank=dp_rank,
            shuffle=True,
            drop_last=drop_last,
            seed=seed,
        )

    dl = DataLoader(
        dset,
        batch_size=batch_size,
        sampler=sampler,
        shuffle=(sampler is None),  # never shuffle when a sampler is provided
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=drop_last,
        persistent_workers=(num_workers > 0),
    )

    if use_dist and broadcast_within_tp and tp_group is not None:
        dl = TPBroadcastDataLoader(dl, tp_group)

    return {
        "dataset": dset,
        "sampler": sampler,
        "dataloader": dl,
    }