ezpz.data.distributedΒΆ
ezpz/data/distributed.py
TPBroadcastDataLoader
ΒΆ
Wrapper that ensures only TP leader samples/loads, then broadcasts each batch to other TP ranks.
Source code in src/ezpz/data/distributed.py
get_random_dataset_fsdp_tp(batch_size, vocab_size, seq_length, *, num_workers=0, pin_memory=True, dp_group=None, tp_group=None, broadcast_within_tp=False, drop_last=True, seed=1337)
ΒΆ
Build dataset/sampler/dataloader for FSDP (DP) + Tensor Parallel (TP).
Key idea
- Shard the dataset ONLY across the DP group (FSDP replica group).
- Optionally broadcast each batch within TP so only TP-leader does I/O.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dp_group
|
Optional[ProcessGroup]
|
Process group that defines FSDP data-parallel replicas. |
None
|
tp_group
|
Optional[ProcessGroup]
|
Process group that defines tensor parallel group. |
None
|
broadcast_within_tp
|
bool
|
If True, TP leader loads and broadcasts batches. |
False
|
drop_last
|
bool
|
Prefer True for static shapes across DP replicas. |
True
|
seed
|
int
|
Base seed for shuffling (per-epoch add epoch to this). |
1337
|
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
dict with 'dataset', 'sampler', 'dataloader' |