Skip to content

ezpz.examples.vitΒΆ

Train a lightweight Vision Transformer on fake or MNIST data.

Launch with:

1
ezpz launch -m ezpz.examples.vit --dataset mnist --batch_size 256

Help output (python3 -m ezpz.examples.vit --help):

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
usage: ezpz.examples.vit [-h] [--img_size IMG_SIZE] [--batch_size BATCH_SIZE]
                         [--num_heads NUM_HEADS] [--head_dim HEAD_DIM]
                         [--hidden-dim HIDDEN_DIM] [--mlp-dim MLP_DIM]
                         [--dropout DROPOUT]
                         [--attention-dropout ATTENTION_DROPOUT]
                         [--num_classes NUM_CLASSES] [--dataset {fake,mnist}]
                         [--depth DEPTH] [--patch_size PATCH_SIZE]
                         [--dtype DTYPE] [--compile]
                         [--num_workers NUM_WORKERS] [--max_iters MAX_ITERS]
                         [--warmup WARMUP] [--attn_type {native,sdpa}]
                         [--cuda_sdpa_backend {flash_sdp,mem_efficient_sdp,math_sdp,cudnn_sdp,all}]
                         [--fsdp]

Train a simple ViT

options:
  -h, --help            show this help message and exit
  --img_size IMG_SIZE, --img-size IMG_SIZE
                        Image size
  --batch_size BATCH_SIZE, --batch-size BATCH_SIZE
                        Batch size
  --num_heads NUM_HEADS, --num-heads NUM_HEADS
                        Number of heads
  --head_dim HEAD_DIM, --head-dim HEAD_DIM
                        Hidden Dimension
  --hidden-dim HIDDEN_DIM, --hidden_dim HIDDEN_DIM
                        Hidden Dimension
  --mlp-dim MLP_DIM, --mlp_dim MLP_DIM
                        MLP Dimension
  --dropout DROPOUT     Dropout rate
  --attention-dropout ATTENTION_DROPOUT, --attention_dropout ATTENTION_DROPOUT
                        Attention Dropout rate
  --num_classes NUM_CLASSES, --num-classes NUM_CLASSES
                        Number of classes
  --dataset {fake,mnist}
                        Dataset to use
  --depth DEPTH         Depth
  --patch_size PATCH_SIZE, --patch-size PATCH_SIZE
                        Patch size
  --dtype DTYPE         Data type
  --compile             Compile model
  --num_workers NUM_WORKERS, --num-workers NUM_WORKERS
                        Number of workers
  --max_iters MAX_ITERS, --max-iters MAX_ITERS
                        Maximum iterations
  --warmup WARMUP       Warmup iterations (or fraction) before starting to collect metrics.
  --attn_type {native,sdpa}, --attn-type {native,sdpa}
                        Attention function to use.
  --cuda_sdpa_backend {flash_sdp,mem_efficient_sdp,math_sdp,cudnn_sdp,all}, --cuda-sdpa-backend {flash_sdp,mem_efficient_sdp,math_sdp,cudnn_sdp,all}
                        CUDA SDPA backend to use.
  --fsdp                Use FSDP

VitTrainArgs dataclass ΒΆ

Structured configuration for Vision Transformer training.

Source code in src/ezpz/examples/vit.py
@dataclass
class VitTrainArgs:
    """Structured configuration for Vision Transformer training."""

    img_size: int = 224
    batch_size: int = 128
    num_heads: int = 16
    compile: bool = False
    depth: int = 8
    dtype: str = "bf16"
    head_dim: int = 64
    hidden_dim: int = 1024
    mlp_dim: int = 2048
    max_iters: int = 1000
    dropout: float = 0.1
    attention_dropout: float = 0.0
    num_classes: int = 1000
    dataset: str = "fake"
    depth: int = 24
    patch_size: int = 16
    num_workers: int = 0
    warmup: float = 0.1
    attn_type: str = "native"
    fsdp: Optional[bool] = None
    format: Optional[str] = field(default_factory=str)
    cuda_sdpa_backend: Optional[str] = "all"

get_device_type() ΒΆ

Resolve the torch device type, falling back if MPS lacks collectives.

Source code in src/ezpz/examples/vit.py
def get_device_type():
    """Resolve the torch device type, falling back if MPS lacks collectives."""
    import os

    device_override = os.environ.get("TORCH_DEVICE")
    device_type = device_override or ezpz.get_torch_device()
    if isinstance(device_type, str) and device_type.startswith("mps"):
        logger.warning(
            "MPS does not support torch.distributed collectives; falling back to CPU"
        )
        return "cpu"
    return ezpz.get_torch_device_type()

main(args) ΒΆ

CLI entrypoint to configure logging and launch ViT training.

Source code in src/ezpz/examples/vit.py
def main(args: argparse.Namespace):
    """CLI entrypoint to configure logging and launch ViT training."""
    rank = ezpz.dist.setup_torch()
    if rank == 0:
        try:
            fp = Path(__file__).resolve()
            run = ezpz.setup_wandb(
                project_name=f"ezpz.{fp.parent.name}.{fp.stem}"
            )
            if wandb is not None:
                assert run is not None and run is wandb.run
                wandb.config.update(ezpz.get_dist_info())
                wandb.config.update({**vars(args)})  # type:ignore
        except Exception:
            logger.warning("Failed to setup wandb, continuing without!")

    targs = dict(**vars(args))
    targs.pop("dataset", None)
    targs.pop("use_timm", None)
    train_args = VitTrainArgs(**targs)
    # train_args:  = (**targs)
    config = timmViTConfig(
        img_size=args.img_size,
        batch_size=args.batch_size,
        num_heads=args.num_heads,
        head_dim=args.head_dim,
        depth=args.depth,
        patch_size=int(args.patch_size),
    )

    def attn_fn(
        q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
    ) -> torch.Tensor:
        """Scaled dot-product attention with configurable backend."""
        scale = config.head_dim ** (-0.5)
        q = q * scale
        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        x = attn @ v
        return x

    logger.info(f"Using {args.attn_type} for SDPA backend")
    if args.attn_type == "native":
        block_fn = functools.partial(AttentionBlock, attn_fn=attn_fn)
    # if args.sdpa_backend == 'by_hand':
    elif args.attn_type == "sdpa":
        if torch.cuda.is_available():
            torch.backends.cuda.enable_flash_sdp(False)
            torch.backends.cuda.enable_mem_efficient_sdp(False)
            torch.backends.cuda.enable_math_sdp(False)
            torch.backends.cuda.enable_cudnn_sdp(False)

            if args.cuda_sdpa_backend in ["flash_sdp", "all"]:
                torch.backends.cuda.enable_flash_sdp(True)
            if args.cuda_sdpa_backend in ["mem_efficient_sdp", "all"]:
                torch.backends.cuda.enable_mem_efficient_sdp(True)
            if args.cuda_sdpa_backend in ["math_sdp", "all"]:
                torch.backends.cuda.enable_math_sdp(True)
            if args.cuda_sdpa_backend in ["cudnn_sdp", "all"]:
                torch.backends.cuda.enable_cudnn_sdp(True)

        block_fn = functools.partial(
            AttentionBlock,
            attn_fn=torch.nn.functional.scaled_dot_product_attention,
        )
    else:
        raise ValueError(f"Unknown attention type: {args.attn_type}")
    logger.info(f"Using AttentionBlock Attention with {args.compile=}")
    train_fn(block_fn, args=train_args, dataset=args.dataset)

parse_args() ΒΆ

Parse CLI arguments for ViT training.

Source code in src/ezpz/examples/vit.py
def parse_args() -> argparse.Namespace:
    """Parse CLI arguments for ViT training."""
    parser = argparse.ArgumentParser(
        prog="ezpz.examples.vit",
        description="Train a simple ViT",
    )
    parser.add_argument(
        "--img_size", "--img-size", default=224, help="Image size"
    )
    parser.add_argument(
        "--batch_size",
        "--batch-size",
        type=int,
        default=128,
        help="Batch size",
    )
    parser.add_argument(
        "--num_heads",
        "--num-heads",
        type=int,
        default=16,
        help="Number of heads",
    )
    parser.add_argument(
        "--head_dim",
        "--head-dim",
        type=int,
        default=64,
        help="Hidden Dimension",
    )
    parser.add_argument(
        "--hidden-dim",
        "--hidden_dim",
        type=int,
        default=1024,
        help="Hidden Dimension",
    )
    parser.add_argument(
        "--mlp-dim", "--mlp_dim", type=int, default=2048, help="MLP Dimension"
    )
    parser.add_argument(
        "--dropout", type=float, default=0.1, help="Dropout rate"
    )
    parser.add_argument(
        "--attention-dropout",
        "--attention_dropout",
        type=float,
        default=0.0,
        help="Attention Dropout rate",
    )
    parser.add_argument(
        "--num_classes",
        "--num-classes",
        type=int,
        default=1000,
        help="Number of classes",
    )
    parser.add_argument(
        "--dataset",
        default="fake",
        choices=["fake", "mnist"],
        help="Dataset to use",
    )
    parser.add_argument("--depth", type=int, default=24, help="Depth")
    parser.add_argument(
        "--patch_size", "--patch-size", type=int, default=16, help="Patch size"
    )
    parser.add_argument("--dtype", default="bf16", help="Data type")
    parser.add_argument("--compile", action="store_true", help="Compile model")
    parser.add_argument(
        "--num_workers",
        "--num-workers",
        type=int,
        default=0,
        help="Number of workers",
    )
    parser.add_argument(
        "--max_iters", "--max-iters", default=100, help="Maximum iterations"
    )
    parser.add_argument(
        "--warmup",
        default=0.1,
        help="Warmup iterations (or fraction) before starting to collect metrics.",
    )
    parser.add_argument(
        "--attn_type",
        "--attn-type",
        default="native",
        choices=["native", "sdpa"],
        help="Attention function to use.",
    )
    parser.add_argument(
        "--cuda_sdpa_backend",
        "--cuda-sdpa-backend",
        default="all",
        choices=[
            "flash_sdp",
            "mem_efficient_sdp",
            "math_sdp",
            "cudnn_sdp",
            "all",
        ],
        help="CUDA SDPA backend to use.",
    )
    parser.add_argument("--fsdp", action="store_true", help="Use FSDP")
    return parser.parse_args()

train_fn(block_fn, args, dataset='fake') ΒΆ

Train the Vision Transformer on fake or MNIST data.

Parameters:

Name Type Description Default
block_fn Any

Attention block constructor with attn_fn injected.

required
args VitTrainArgs

Training hyperparameters.

required
dataset Optional[str]

Dataset choice, either fake or mnist.

'fake'

Returns:

Type Description
History

History of training metrics.

Source code in src/ezpz/examples/vit.py
def train_fn(
    block_fn: Any,
    args: VitTrainArgs,
    dataset: Optional[str] = "fake",
) -> ezpz.History:
    """Train the Vision Transformer on fake or MNIST data.

    Args:
        block_fn: Attention block constructor with attn_fn injected.
        args: Training hyperparameters.
        dataset: Dataset choice, either ``fake`` or ``mnist``.

    Returns:
        History of training metrics.
    """
    # seed = int(os.environ.get('SEED', '0'))
    # rank = ezpz.setup(backend='DDP', seed=seed)
    world_size = ezpz.dist.get_world_size()

    local_rank = ezpz.dist.get_local_rank()
    # device_type = str(ezpz.get_torch_device(as_torch_device=False))
    device_type = ezpz.dist.get_torch_device_type()
    device = torch.device(f"{device_type}:{local_rank}")
    # torch.set_default_device(device)
    config = timmViTConfig(
        img_size=args.img_size,
        batch_size=args.batch_size,
        num_heads=args.num_heads,
        head_dim=args.head_dim,
        depth=args.depth,
        patch_size=args.patch_size,
    )

    logger.info(f"{asdict(config)=}")

    if dataset == "fake":
        data = get_fake_data(
            img_size=args.img_size,
            batch_size=args.batch_size,
        )
    elif dataset == "mnist":
        data = get_mnist(
            train_batch_size=args.batch_size,
            test_batch_size=args.batch_size,
            download=(ezpz.dist.get_rank() == 0),
        )
    else:
        raise ValueError(
            f"Unknown dataset: {dataset}. Expected 'fake' or 'mnist'."
        )

    # data = get

    # train_set = FakeImageDataset(config.img_size)
    # logger.info(f'{len(train_set)=}')
    # train_loader = DataLoader(
    #     train_set,
    #     batch_size=config.batch_size,
    #     num_workers=args.num_workers,
    #     pin_memory=True,
    #     drop_last=True,
    # )

    model = VisionTransformer(
        img_size=config.img_size,
        patch_size=config.patch_size,
        embed_dim=(config.num_heads * config.head_dim),
        depth=config.depth,
        num_heads=config.num_heads,
        class_token=False,
        global_pool="avg",
        block_fn=block_fn,
    )

    mstr = summarize_model(
        model,
        verbose=False,
        depth=1,
        input_size=(
            config.batch_size,
            3,
            config.img_size,
            config.img_size,
        ),
    )
    model.to(device)
    num_params = sum(
        [
            sum(
                [
                    getattr(p, "ds_numel", 0)
                    if hasattr(p, "ds_id")
                    else p.nelement()
                    for p in model_module.parameters()
                ]
            )
            for model_module in model.modules()
        ]
    )
    model_size_in_billions = num_params / 1e9
    logger.info(f"\n{mstr}")
    logger.info(f"Model size: nparams={model_size_in_billions:.2f} B")
    if wandb is not None:
        if wandb.run is not None:
            wandb.run.watch(model, log="all")

    model = ezpz.dist.wrap_model(
        model=model,
        use_fsdp=args.fsdp,
        dtype=args.dtype,
    )
    if world_size > 1:
        model = ezpz.dist.wrap_model(
            model=model,
            use_fsdp=args.fsdp,
            dtype=args.dtype,
        )
        # if args.fsdp:
        #     logger.info("Using FSDP for distributed training")
        #     if args.dtype in {"fp16", "bf16", "fp32"}:
        #         try:
        #             model = FSDP(
        #                 model,
        #                 mixed_precision=MixedPrecision(
        #                     param_dtype=TORCH_DTYPES_MAP[args.dtype],
        #                     reduce_dtype=torch.float32,
        #                     cast_forward_inputs=True,
        #                 ),
        #             )
        #         except Exception as exc:
        #             logger.warning(f"Encountered exception: {exc}")
        #             logger.warning(
        #                 "Unable to wrap model with FSDP. Falling back to DDP..."
        #             )
        #             model = ezpz.dist.wrap_model(model=model, f)
        #     else:
        #         try:
        #             model = FSDP(model)
        #         except Exception:
        #             model = ezpz.dist.wrap_model(args=args, model=model)
        # else:
        #     logger.info("Using DDP for distributed training")
        #     model = ezpz.dist.prepare_model_for_ddp(model)

    if args.compile:
        logger.info("Compiling model")
        model = torch.compile(model)

    torch_dtype = ezpz.dist.TORCH_DTYPES_MAP[args.dtype]
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters())  # type:ignore
    model.train()  # type:ignore

    history = ezpz.history.History()
    logger.info(
        f"Training with {world_size} x {device_type} (s), using {torch_dtype=}"
    )
    warmup_iters = (
        int(args.warmup)
        if args.warmup >= 1.0
        else int(
            args.warmup
            * (
                args.max_iters
                if args.max_iters is not None
                else len(data["train"]["loader"])
            )
        )
    )
    # data["train"].to(ezpz.dist.get_torch_device_type())
    for step, data in enumerate(data["train"]["loader"]):
        if args.max_iters is not None and step > int(args.max_iters):
            break
        t0 = time.perf_counter()
        inputs = data[0].to(device=device, non_blocking=True)
        label = data[1].to(device=device, non_blocking=True)
        ezpz.dist.synchronize()
        with torch.autocast(device_type=device_type, dtype=torch_dtype):
            t1 = time.perf_counter()
            outputs = model(inputs)
            loss = criterion(outputs, label)
            t2 = time.perf_counter()
        ezpz.dist.synchronize()
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        ezpz.dist.synchronize()
        t3 = time.perf_counter()
        optimizer.step()
        ezpz.dist.synchronize()
        t4 = time.perf_counter()
        if step >= warmup_iters:
            logger.info(
                history.update(
                    {
                        "train/iter": step,
                        "train/loss": loss,
                        "train/dt": t4 - t0,
                        "train/dtd": t1 - t0,
                        "train/dtf": t2 - t1,
                        "train/dto": t3 - t2,
                        "train/dtb": t4 - t3,
                    }
                ).replace("train/", "")
            )

    if ezpz.dist.get_rank() == 0:
        dataset = history.finalize(
            run_name=WBRUN_NAME, dataset_fname="train", verbose=False
        )
        logger.info(f"{dataset=}")

    return history