Minimal synthetic training loop for testing distributed setup and logging.
This example builds a tiny MLP that learns to reconstruct random inputs.
Launch it with:
| ezpz launch -m ezpz.examples.minimal
|
Running python3 -m ezpz.examples.minimal --help prints:
| usage: ezpz.examples.minimal --help
(Set env vars such as PRINT_ITERS=100 TRAIN_ITERS=1000 INPUT_SIZE=128 OUTPUT_SIZE=128 LAYER_SIZES="128,256,128" before calling ezpz launch)
|
main()
Entrypoint for launching the minimal synthetic training example.
Source code in src/ezpz/examples/minimal.py
| def main():
"""Entrypoint for launching the minimal synthetic training example."""
model, optimizer = setup()
history = train(model, optimizer)
if ezpz.get_rank() == 0:
dataset = history.finalize()
logger.info(f"{dataset=}")
|
setup()
Initialize distributed runtime, model, and optimizer.
Source code in src/ezpz/examples/minimal.py
| @ezpz.timeitlogit(rank=ezpz.get_rank())
def setup():
"""Initialize distributed runtime, model, and optimizer."""
rank = ezpz.setup_torch(seed=int(os.environ.get("SEED", 0)))
if os.environ.get("WANDB_DISABLED", False):
logger.info("WANDB_DISABLED is set, not initializing wandb")
elif rank == 0:
try:
_ = ezpz.setup_wandb(
project_name=os.environ.get("PROJECT_NAME", "ezpz.examples.minimal")
)
except Exception:
logger.exception("Failed to initialize wandb, continuing without it")
device_type = ezpz.get_torch_device_type()
from ezpz.models.minimal import SequentialLinearNet
model = SequentialLinearNet(
input_dim=int((os.environ.get("INPUT_SIZE", 128))),
output_dim=int(os.environ.get("OUTPUT_SIZE", 128)),
sizes=[
int(x)
for x in os.environ.get(
"LAYER_SIZES", "256,512,1024,2048,1024,512,256,128"
).split(",")
],
)
model.to(device_type)
model.to((os.environ.get("DTYPE", torch.bfloat16)))
try:
from ezpz.utils import model_summary
model_summary(model)
except Exception:
logger.exception("Failed to summarize model")
logger.info(f"{model=}")
optimizer = torch.optim.Adam(model.parameters())
if ezpz.get_world_size() > 1:
model = ezpz.dist.wrap_model_for_ddp(model)
# from torch.nn.parallel import DistributedDataParallel as DDP
#
# model = DDP(model, device_ids=[ezpz.get_local_rank()])
return model, optimizer
|
train(model, optimizer)
Run a synthetic training loop on random data.
Parameters:
| Name |
Type |
Description |
Default |
model
|
Module
|
Model to train (wrapped or unwrapped).
|
required
|
optimizer
|
Optimizer
|
Optimizer configured for the model.
|
required
|
Returns:
| Type |
Description |
History
|
Training history with timing and loss metrics.
|
Source code in src/ezpz/examples/minimal.py
| @ezpz.timeitlogit(rank=ezpz.get_rank())
def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer) -> ezpz.History:
"""Run a synthetic training loop on random data.
Args:
model: Model to train (wrapped or unwrapped).
optimizer: Optimizer configured for the model.
Returns:
Training history with timing and loss metrics.
"""
unwrapped_model = (
model.module
if isinstance(model, torch.nn.parallel.DistributedDataParallel)
else model
)
history = ezpz.History()
device_type = ezpz.get_torch_device_type()
dtype = unwrapped_model.layers[0].weight.dtype
bsize = int(os.environ.get("BATCH_SIZE", 64))
isize = unwrapped_model.layers[0].in_features
warmup = int(os.environ.get("WARMUP_ITERS", 10))
log_freq = int(os.environ.get("LOG_FREQ", 1))
print_freq = int(os.environ.get("PRINT_FREQ", 10))
model.train()
summary = ""
for step in range(int(os.environ.get("TRAIN_ITERS", 500))):
with torch.autocast(
device_type=device_type,
dtype=dtype,
):
t0 = time.perf_counter()
x = torch.rand((bsize, isize), dtype=dtype).to(device_type)
y = model(x)
loss = ((y - x) ** 2).sum()
dtf = (t1 := time.perf_counter()) - t0
loss.backward()
optimizer.step()
optimizer.zero_grad()
dtb = time.perf_counter() - t1
if step % log_freq == 0 and step > warmup:
summary = history.update(
{
"iter": step,
"loss": loss.item(),
"dt": dtf + dtb,
"dtf": dtf,
"dtb": dtb,
}
)
if step % print_freq == 0 and step > warmup:
logger.info(summary)
return history
|