def train_fn(block_fn: Any, args: TrainArgs) -> ezpz.History:
"""Train the timm Vision Transformer with the provided attention block."""
seed = int(os.environ.get("SEED", "0"))
rank = ezpz.setup(backend="DDP", seed=seed)
world_size = ezpz.get_world_size()
local_rank = ezpz.get_local_rank()
device_type = str(ezpz.get_torch_device(as_torch_device=False))
device = torch.device(f"{device_type}:{local_rank}")
config = ViTConfig(
img_size=args.img_size,
batch_size=args.batch_size,
num_heads=args.num_heads,
head_dim=args.head_dim,
depth=args.depth,
patch_size=args.patch_size,
)
logger.info(f"{config=}")
data = get_fake_data(
img_size=args.img_size,
batch_size=args.batch_size,
)
# data = get
# train_set = FakeImageDataset(config.img_size)
# logger.info(f'{len(train_set)=}')
# train_loader = DataLoader(
# train_set,
# batch_size=config.batch_size,
# num_workers=args.num_workers,
# pin_memory=True,
# drop_last=True,
# )
model = VisionTransformer(
img_size=config.img_size,
patch_size=config.patch_size,
embed_dim=(config.num_heads * config.head_dim),
depth=config.depth,
num_heads=config.num_heads,
class_token=False,
global_pool="avg",
block_fn=block_fn,
)
mstr = summarize_model(
model,
verbose=False,
depth=1,
input_size=(
config.batch_size,
3,
config.img_size,
config.img_size,
),
)
logger.info(f"\n{mstr}")
model.to(device)
num_params = sum(
[
sum(
[
getattr(p, "ds_numel", 0)
if hasattr(p, "ds_id")
else p.nelement()
for p in model_module.parameters()
]
)
for model_module in model.modules()
]
)
model_size_in_billions = num_params / 1e9
logger.info(f"Model size: nparams={model_size_in_billions:.2f} B")
if world_size > 1:
if args.dtype in {"fp16", "bf16", "fp32"}:
model = FSDP(
model,
mixed_precision=MixedPrecision(
param_dtype=TORCH_DTYPES_MAP[args.dtype],
reduce_dtype=torch.float32,
cast_forward_inputs=True,
),
)
else:
model = FSDP(model)
if args.compile:
logger.info("Compiling model")
model = torch.compile(model)
torch_dtype = TORCH_DTYPES_MAP[args.dtype]
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters()) # type:ignore
model.train() # type:ignore
history = ezpz.History()
logger.info(
f"Training with {world_size} x {device_type} (s), using {torch_dtype=}"
)
for step, data in enumerate(data["train"]["loader"]):
if args.max_iters is not None and step > int(args.max_iters):
break
t0 = time.perf_counter()
inputs = data[0].to(device=device, non_blocking=True)
label = data[1].to(device=device, non_blocking=True)
t1 = time.perf_counter()
with torch.autocast(device_type=device_type, dtype=torch_dtype):
outputs = model(inputs)
loss = criterion(outputs, label)
t2 = time.perf_counter()
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
t3 = time.perf_counter()
logger.info(
history.update(
{
"train/iter": step,
"train/loss": loss.item(),
"train/dt": t3 - t0,
"train/dtf": t2 - t1,
"train/dtb": t3 - t2,
}
).replace("train/", "")
)
if rank == 0:
dataset = history.finalize(
run_name="mmm-vit", dataset_fname="train", verbose=False
)
logger.info(f"{dataset=}")
return history