def train(
args: argparse.Namespace,
outdir: Path | str | os.PathLike,
) -> int:
"""Run TP/SP + FSDP training and optionally log metrics."""
world_size = ezpz.dist.get_world_size()
assert world_size % args.tp == 0, "WORLD_SIZE must be divisible by TP"
dpsize = world_size // args.tp
device_mesh = init_device_mesh(
str(ezpz.get_torch_device()),
(dpsize, args.tp),
mesh_dim_names=("dp", "tp"),
)
logger.info(f"Device mesh created:\n{device_mesh=}")
hf_dataset = None
hf_tokenizer = None
if args.dataset.lower() not in {"mnist", "random"}:
from ezpz.data.hf import get_hf_text_dataset
seed = int(os.environ.get("EZPZ_HF_SAMPLE_SEED", "1337"))
hf_dataset, hf_tokenizer = get_hf_text_dataset(
dataset_name=args.dataset,
split=args.hf_split,
text_column=args.hf_text_column,
tokenizer_name=args.tokenizer_name,
seq_len=args.seq_len,
limit=args.hf_limit,
seed=seed,
)
if hf_tokenizer.vocab_size != args.vocab_size:
logger.warning(
"Overriding vocab_size from %s to tokenizer vocab_size=%s",
args.vocab_size,
hf_tokenizer.vocab_size,
)
args.vocab_size = hf_tokenizer.vocab_size
config = ModelArgs(
dim=args.dim,
n_layers=args.n_layers,
n_heads=args.n_heads,
n_kv_heads=args.n_kv_heads,
batch_size=args.batch_size,
vocab_size=args.vocab_size,
multiple_of=args.multiple_of,
)
logger.info(f"config:\n{config}")
metrics_every = int(os.environ.get("EZPZ_METRICS_EVERY", "1"))
track_logits = os.environ.get("EZPZ_TRACK_LOGITS", "0") == "1"
track_hist = os.environ.get("EZPZ_TRACK_HIST", "0") == "1"
track_act_hist = os.environ.get("EZPZ_TRACK_ACT_HIST", "1") == "1"
hist_bins = int(os.environ.get("EZPZ_HIST_BINS", "64"))
hist_samples = int(os.environ.get("EZPZ_HIST_SAMPLES", "20000"))
dataset_tag = args.dataset.lower().replace("/", "_")
if ezpz.get_rank() == 0 and not os.environ.get("WANDB_DISABLED", False):
run = ezpz.dist.setup_wandb(project_name=WBPROJ_NAME)
if wandb is not None:
assert run is not None and run is wandb.run
from dataclasses import asdict
wandb.config.update(ezpz.get_dist_info())
wandb.config.update(asdict(config)) # type:ignore
device_type = str(ezpz.get_torch_device(as_torch_device=False))
device_id = f"{device_type}:{ezpz.get_local_rank()}"
model = Transformer.from_model_args(config)
mstr = summarize_model(
model,
verbose=False,
depth=2,
# input_size=(
# torch.tensor((int(args.batch_size), int(args.seq_length))).to(
# torch.long
# )
# ).shape,
)
logger.info(f"\n{mstr}")
model.to(device_id)
mp_config: Optional[MixedPrecision] = None
if not args.fp32:
mp_config = MixedPrecision(
param_dtype=torch.bfloat16,
cast_forward_inputs=True,
reduce_dtype=torch.float32,
)
model = parallelize(
model,
device_mesh,
mp_config,
sharding_strategy=args.sharding_strategy,
)
base_model = model
if not hasattr(base_model, "layers"):
base_model = getattr(model, "_fsdp_wrapped_module", model)
act_activations: dict[str, torch.Tensor] = {}
act_handles: list[torch.utils.hooks.RemovableHandle] = []
if track_hist and track_act_hist and ezpz.get_rank() == 0:
hist_layers_spec = os.environ.get(
"EZPZ_HIST_LAYERS", f"0,{config.n_layers - 1}"
)
layer_ids = _parse_hist_layers(hist_layers_spec, config.n_layers)
act_activations, act_handles = _register_activation_hooks(
base_model, layer_ids
)
logger.info(f"Creating optimizer=AdamW with lr={args.lr}")
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, foreach=True)
device = ezpz.get_torch_device(as_torch_device=False)
tp_group = device_mesh.get_group("tp")
if args.dataset.lower() == "mnist":
data_prefix = Path(os.getcwd()).joinpath(
".cache", "ezpz", "data", f"{args.dataset.lower()}"
)
from ezpz.data.vision import get_mnist
from ezpz.data.distributed import TPBroadcastDataLoader
data = get_mnist(
outdir=Path(data_prefix),
train_batch_size=args.batch_size,
test_batch_size=args.test_batch_size,
num_replicas=dpsize,
rank=device_mesh.get_local_rank("dp"),
pin_memory=True,
num_workers=args.num_workers,
)
dataset = data["dataset"]
sampler = data["sampler"]
dataloader = data["dataloader"]
if args.tp > 1:
dataloader = TPBroadcastDataLoader(dataloader, tp_group)
elif args.dataset.lower() == "random":
from ezpz.data.distributed import get_random_dataset_fsdp_tp
data = get_random_dataset_fsdp_tp(
batch_size=args.batch_size,
vocab_size=args.vocab_size,
seq_length=args.seq_length,
dp_group=device_mesh.get_group("dp"),
tp_group=tp_group,
broadcast_within_tp=True,
drop_last=True,
)
dataset = data["dataset"]
sampler = data["sampler"]
dataloader = data["dataloader"]
# if args.dataset.lower() != "random":
else:
from ezpz.data.distributed import TPBroadcastDataLoader
assert hf_dataset is not None
dataset = hf_dataset
sampler = (
DistributedSampler(
dataset=dataset,
num_replicas=dpsize,
rank=device_mesh.get_local_rank("dp"),
)
if ezpz.get_world_size() > 1
else None
)
dataloader = DataLoader(
dataset,
sampler=sampler,
batch_size=args.batch_size,
shuffle=(sampler is None),
drop_last=False,
)
if args.tp > 1:
dataloader = TPBroadcastDataLoader(dataloader, tp_group)
# ezpz.breakpoint(0)
logger.info("Starting 2D training...")
model.train()
# outdir = Path(args.outdir).joinpath(ezpz.utils.get_timestamp())
metrics_path = Path(outdir).joinpath(
f"metrics-{ezpz.dist.get_rank()}.jsonl"
)
Path(outdir).mkdir(parents=True, exist_ok=True)
history = ezpz.history.History(
report_dir=outdir,
report_enabled=True,
jsonl_path=metrics_path,
jsonl_overwrite=True,
distributed_history=(
1 < world_size <= 384 # and not config.pytorch_profiler
),
)
# For TP, input needs to be the same across all TP ranks.
# while for SP, input can be different across all ranks
# We will use dp_rank for setting the random seed
# to mimic the behavior of the dataloader
# x = torch.tensor((args.batch_size, args.seq_len))
x = torch.tensor(0)
global_step = 0
for epoch in range(args.epochs):
if sampler is not None:
sampler.set_epoch(epoch)
for idx, batch in enumerate(dataloader):
ezpz.dist.synchronize()
t0 = perf_counter()
attn_mask = None
if isinstance(batch, dict) and "input_ids" in batch:
x = batch["input_ids"]
attn_mask = batch.get("attention_mask")
else:
x = batch
assert isinstance(x, torch.Tensor)
x = x.to(device_id)
x = x.to(torch.long)
if args.dataset == "random":
inp = x[:, :-1]
labels = x[:, 1:]
else:
inp = x[:, :-1]
labels = x[:, 1:]
inp = inp.to(device_id)
labels = labels.to(device_id)
if attn_mask is not None:
attn_mask = attn_mask.to(device_id)
pred = model(inp)
local_seq_len = pred.shape[1]
if labels.shape[1] != local_seq_len:
labels = _slice_for_sequence_parallel(labels, local_seq_len)
if attn_mask is not None:
if attn_mask.shape[1] > 1:
attn_labels = attn_mask[:, 1:]
else:
attn_labels = attn_mask
if attn_labels.shape[1] != local_seq_len:
attn_labels = _slice_for_sequence_parallel(
attn_labels, local_seq_len
)
labels = labels.clone()
labels[attn_labels == 0] = -100
pad_id = getattr(dataset, "pad_id", None)
if pad_id is not None:
labels = labels.clone()
labels[labels == int(pad_id)] = -100
ezpz.dist.synchronize()
t1 = perf_counter()
tp_mod = getattr(ezpz, "tp", None)
tp_rank = (
getattr(tp_mod, "get_tensor_parallel_rank", lambda: 0)()
if tp_mod is not None
else 0
)
if epoch == 0 and idx == 0:
pred_finite = torch.isfinite(pred)
pred_nonfinite = int((~pred_finite).sum().item())
pred_max = float(pred.abs().max().item())
logger.info(
"pred_stats rank=%s tp=%s shape=%s nonfinite=%s max_abs=%s",
ezpz.get_rank(),
tp_rank,
tuple(pred.shape),
pred_nonfinite,
f"{pred_max:.6f}",
)
loss = F.cross_entropy(
pred.reshape(-1, pred.size(-1)),
labels.reshape(-1),
ignore_index=-100,
)
if epoch == 0 and idx == 0:
valid_labels = int((labels != -100).sum().item())
logger.info(
"loss_inputs rank=%s tp=%s local_seq_len=%s labels=%s valid_labels=%s",
ezpz.get_rank(),
tp_rank,
local_seq_len,
tuple(labels.shape),
valid_labels,
)
# loss = F.cross_entropy(
# pred.flatten(0, 1),
# labels.flatten(0, 1),
# )
# loss = output.loss
optimizer.zero_grad(set_to_none=True)
loss.backward()
grad_norm_preclip = None
if args.max_grad_norm > 0:
grad_norm_preclip = torch.nn.utils.clip_grad_norm_(
model.parameters(), args.max_grad_norm
)
optimizer.step()
ezpz.dist.synchronize()
t2 = perf_counter()
global_step += 1
metrics: dict[str, object] = {
"train/iter": global_step,
"train/epoch": epoch,
"train/bidx": idx,
"train/loss": loss.item(),
"train/dt": t2 - t0,
"train/dtf": t1 - t0,
"train/dtb": t2 - t1,
}
if grad_norm_preclip is not None:
metrics["grad/norm_preclip"] = float(grad_norm_preclip)
if global_step % max(metrics_every, 1) == 0:
metrics.update(_collect_param_grad_stats(model, device_id))
metrics["opt/iter"] = (global_step,)
metrics["opt/lr"] = float(optimizer.param_groups[0]["lr"])
metrics["input/iter"] = (global_step,)
metrics["input/max"] = float(x.max().item())
metrics["input/min"] = float(x.min().item())
metrics["labels/valid"] = float((labels != -100).sum().item())
if track_logits:
pred_finite = torch.isfinite(pred)
metrics["logits/nonfinite"] = float(
(~pred_finite).sum().item()
)
metrics["logits/max_abs"] = float(pred.abs().max().item())
if track_hist and ezpz.get_rank() == 0:
logits_sample = _sample_tensor_values(pred, hist_samples)
if logits_sample is not None:
logits_hist = _histogram_dict(logits_sample, hist_bins)
if logits_hist is not None:
metrics[f"hist/{dataset_tag}/logits"] = logits_hist
layer_grad_norms = _collect_layer_grad_norms(base_model)
if layer_grad_norms:
layer_grad_hist = _histogram_dict(
torch.tensor(layer_grad_norms), hist_bins
)
if layer_grad_hist is not None:
metrics[
f"hist/{dataset_tag}/grad_norm_per_layer"
] = layer_grad_hist
if track_act_hist and act_activations:
for act_key, act_tensor in act_activations.items():
act_sample = _sample_tensor_values(
act_tensor, hist_samples
)
act_hist = _histogram_dict(act_sample, hist_bins)
if act_hist is not None:
metrics[
f"hist/{dataset_tag}/activations/{act_key}"
] = act_hist
_wandb_log_histograms(
metrics, step=global_step, enabled=track_hist
)
history.update(metrics, summarize=False)
history.log_metrics(
metrics,
logger=logger,
debug_prefixes=("hist/",),
include_summary=True,
rank0_only_summary=True,
)
if epoch == 0 and idx == 0:
logger.info(f"{x.shape}")
if act_handles:
for handle in act_handles:
handle.remove()
ezpz.dist.barrier()
logger.info("Finished 2D training")
if ezpz.get_rank() == 0:
dataset = history.finalize(
run_name=WBRUN_NAME,
dataset_fname="train",
warmup=0.1,
)
logger.info(f"{dataset=}")
return 0