Skip to content

ezpz.test_distΒΆ

This module is part of the ezpz package

test_dist.py

  • to launch:
$ source ezpz/src/ezpz/bin/savejobenv
$ BACKEND=DDP launch python3 ezpz_ddp.py

TrainConfig dataclass ΒΆ

Runtime configuration for the ezpz.test_dist distributed smoke test.

Source code in src/ezpz/test_dist.py
@dataclass
class TrainConfig:
    """Runtime configuration for the ``ezpz.test_dist`` distributed smoke test."""
    warmup: int
    tp: int
    pp: int
    cp: int
    batch_size: int
    input_size: int
    output_size: int
    train_iters: int
    log_freq: int
    backend: str
    dtype: str
    print_freq: int
    pyinstrument_profiler: bool
    pytorch_profiler: bool
    pytorch_profiler_wait: int
    pytorch_profiler_warmup: int
    pytorch_profiler_active: int
    pytorch_profiler_repeat: int
    profile_memory: bool
    rank_zero_only: bool
    record_shapes: bool
    with_stack: bool
    with_flops: bool
    with_modules: bool
    acc_events: bool
    layer_sizes: list = field(default_factory=lambda: [1024, 512, 256, 128])

    def __post_init__(self):
        """Initialise output paths and configure profiling context managers."""
        self._created_at = ezpz.get_timestamp() if ezpz.get_rank() == 0 else None
        self._created_at = ezpz.dist.broadcast(self._created_at, root=0)
        self.outdir = Path(os.getcwd()).joinpath(
            "outputs", "ezpz.test_dist", f"{self._created_at}"
        )
        self.outdir.mkdir(parents=True, exist_ok=True)
        profiler_type = "torch" if self.pytorch_profiler else "pyinstrument"
        self.ctx = get_profiling_context(
            profiler_type=profiler_type,
            rank_zero_only=self.rank_zero_only,
            record_shapes=self.record_shapes,
            with_stack=self.with_stack,
            with_flops=self.with_flops,
            with_modules=self.with_modules,
            acc_events=self.acc_events,
            profile_memory=self.profile_memory,
            wait=self.pytorch_profiler_wait,
            warmup=self.pytorch_profiler_warmup,
            active=self.pytorch_profiler_active,
            repeat=self.pytorch_profiler_repeat,
            outdir=self.outdir,
        )
        logger.info(f"Outputs will be saved to {self.outdir}")

    def get_torch_dtype(self) -> torch.dtype:
        """Return the torch dtype requested by this configuration."""
        if self.dtype is None:
            return torch.get_default_dtype()
        if self.dtype in {
            "fp16",
            "half",
            "float16",
        }:
            return torch.float16
        if self.dtype in {
            "bfloat16",
            "bf16",
        }:
            return torch.bfloat16
        logger.warning(f"Unknown dtype: {self.dtype=}, using float32")
        return torch.float32

__post_init__() ΒΆ

Initialise output paths and configure profiling context managers.

Source code in src/ezpz/test_dist.py
def __post_init__(self):
    """Initialise output paths and configure profiling context managers."""
    self._created_at = ezpz.get_timestamp() if ezpz.get_rank() == 0 else None
    self._created_at = ezpz.dist.broadcast(self._created_at, root=0)
    self.outdir = Path(os.getcwd()).joinpath(
        "outputs", "ezpz.test_dist", f"{self._created_at}"
    )
    self.outdir.mkdir(parents=True, exist_ok=True)
    profiler_type = "torch" if self.pytorch_profiler else "pyinstrument"
    self.ctx = get_profiling_context(
        profiler_type=profiler_type,
        rank_zero_only=self.rank_zero_only,
        record_shapes=self.record_shapes,
        with_stack=self.with_stack,
        with_flops=self.with_flops,
        with_modules=self.with_modules,
        acc_events=self.acc_events,
        profile_memory=self.profile_memory,
        wait=self.pytorch_profiler_wait,
        warmup=self.pytorch_profiler_warmup,
        active=self.pytorch_profiler_active,
        repeat=self.pytorch_profiler_repeat,
        outdir=self.outdir,
    )
    logger.info(f"Outputs will be saved to {self.outdir}")

get_torch_dtype() ΒΆ

Return the torch dtype requested by this configuration.

Source code in src/ezpz/test_dist.py
def get_torch_dtype(self) -> torch.dtype:
    """Return the torch dtype requested by this configuration."""
    if self.dtype is None:
        return torch.get_default_dtype()
    if self.dtype in {
        "fp16",
        "half",
        "float16",
    }:
        return torch.float16
    if self.dtype in {
        "bfloat16",
        "bf16",
    }:
        return torch.bfloat16
    logger.warning(f"Unknown dtype: {self.dtype=}, using float32")
    return torch.float32

Trainer dataclass ΒΆ

Co-ordinate training loops, logging, and profiling for the test model.

Source code in src/ezpz/test_dist.py
@dataclass
class Trainer:
    """Co-ordinate training loops, logging, and profiling for the test model."""
    config: TrainConfig
    model: torch.nn.Module
    optimizer: torch.optim.Optimizer
    history: ezpz.History = field(default_factory=ezpz.History)
    train_iter: int = 0
    rank: int = ezpz.get_rank()
    # device_type: str = ezpz.get_torch_device_type()
    device_type = os.environ.get("TORCH_DEVICE", ezpz.get_torch_device())
    world_size = ezpz.get_world_size()
    local_rank = ezpz.get_local_rank()
    device_id = f"{device_type}:{local_rank}"

    def __post_init__(self):
        """Move the model to the target device and register logging hooks."""
        self.device_id = f"{self.device_type}:{self.local_rank}"
        self.dtype = self.config.get_torch_dtype()
        self.model.to(self.device_id)
        self.model.to(self.dtype)

        if self.config.tp > 1 or self.config.pp > 1 or self.config.cp > 1:
            ezpz.dist.barrier(group=ezpz.tp.get_tensor_parallel_group())
            ezpz.dist.barrier(group=ezpz.tp.get_data_parallel_group())
            ezpz.dist.barrier(group=ezpz.tp.get_pipeline_parallel_group())
            ezpz.dist.barrier(group=ezpz.tp.get_context_parallel_group())

        if self.rank == 0 and not WANDB_DISABLED:
            import wandb

            logger.debug("Setting up wandb")
            wbconfig = {}
            wbconfig |= asdict(self.config)
            wbconfig |= ezpz.get_dist_info()
            _ = ezpz.setup_wandb(
                project_name="ezpz.test_dist",
                config=wbconfig,
            )
            if (wbrun := getattr(wandb, "run", None)) is not None and callable(
                wbrun.watch
            ):
                wbrun.watch(self.model, log="all")

        if self.world_size > 1:
            logger.debug("Hit torch.distributed.barrier()")
            ezpz.dist.barrier()

    @ezpz.timeitlogit(rank=ezpz.get_rank())
    def _forward_step(self) -> dict:
        """Execute a forward pass returning loss and timing metrics."""
        t0 = time.perf_counter()
        x = torch.rand(
            *(self.config.batch_size, self.config.input_size),
            device=self.device_type,
            dtype=self.config.get_torch_dtype(),
        )
        y = self.model(x)
        return {"loss": calc_loss(x, y), "dtf": (time.perf_counter() - t0)}

    @ezpz.timeitlogit(rank=ezpz.get_rank())
    def _backward_step(self, loss: torch.Tensor) -> float:
        """Perform the backwards/optimiser step and return elapsed seconds."""
        t0 = time.perf_counter()
        if self.config.backend == "deepspeed":
            self.model.backward(loss)  # type:ignore
            self.model.step(loss)  # type:ignore
        else:
            loss.backward()
            self.optimizer.step()
        return time.perf_counter() - t0

    @ezpz.timeitlogit(rank=ezpz.get_rank())
    def train_step(self) -> dict:
        """Run one optimiser step, emitting periodic logs/metrics."""
        self.train_iter += 1
        metrics = self._forward_step()
        metrics["dtb"] = self._backward_step(metrics["loss"])
        self.optimizer.zero_grad()
        if self.train_iter == self.config.train_iters:
            return metrics
        if (
            self.train_iter % self.config.log_freq == 0
            or self.train_iter % self.config.print_freq == 0
        ):
            summary = self.history.update({"iter": self.train_iter, **metrics})
            if self.train_iter % self.config.print_freq == 0:
                logger.info(f"{summary}")
        return metrics

    @ezpz.timeitlogit(rank=ezpz.get_rank())
    def finalize(self, outdir: Optional[str | Path | os.PathLike] = None) -> Dataset:
        """Flush profilers and return the aggregated training dataset."""
        import ambivalent
        import matplotlib.pyplot as plt

        plt.style.use(ambivalent.STYLES["ambivalent"])
        outdir = Path(outdir) if outdir is not None else self.config.outdir
        dataset = self.history.finalize(
            run_name="ezpz.test_dist",
            dataset_fname="train",
            warmup=self.config.warmup,
            save=False,  # XXX: don't bother saving test data
            plot=(self.rank == 0),
            outdir=outdir,
        )
        logger.info(f"{dataset=}")
        return dataset

    @ezpz.timeitlogit(rank=ezpz.get_rank())
    def train(self, profiler: Optional[torch.profiler.profile] = None) -> Dataset:
        """Loop over all training iterations and return the final dataset."""
        for step in range(self.config.train_iters):
            if step == self.config.warmup:
                logger.info(f"Warmup complete at step {step}")
            _ = self.train_step()
            if profiler is not None:
                profiler.step()

        return (
            self.finalize()
            if self.rank == 0
            else self.history.get_dataset(warmup=self.config.warmup)
        )

__post_init__() ΒΆ

Move the model to the target device and register logging hooks.

Source code in src/ezpz/test_dist.py
def __post_init__(self):
    """Move the model to the target device and register logging hooks."""
    self.device_id = f"{self.device_type}:{self.local_rank}"
    self.dtype = self.config.get_torch_dtype()
    self.model.to(self.device_id)
    self.model.to(self.dtype)

    if self.config.tp > 1 or self.config.pp > 1 or self.config.cp > 1:
        ezpz.dist.barrier(group=ezpz.tp.get_tensor_parallel_group())
        ezpz.dist.barrier(group=ezpz.tp.get_data_parallel_group())
        ezpz.dist.barrier(group=ezpz.tp.get_pipeline_parallel_group())
        ezpz.dist.barrier(group=ezpz.tp.get_context_parallel_group())

    if self.rank == 0 and not WANDB_DISABLED:
        import wandb

        logger.debug("Setting up wandb")
        wbconfig = {}
        wbconfig |= asdict(self.config)
        wbconfig |= ezpz.get_dist_info()
        _ = ezpz.setup_wandb(
            project_name="ezpz.test_dist",
            config=wbconfig,
        )
        if (wbrun := getattr(wandb, "run", None)) is not None and callable(
            wbrun.watch
        ):
            wbrun.watch(self.model, log="all")

    if self.world_size > 1:
        logger.debug("Hit torch.distributed.barrier()")
        ezpz.dist.barrier()

finalize(outdir=None) ΒΆ

Flush profilers and return the aggregated training dataset.

Source code in src/ezpz/test_dist.py
@ezpz.timeitlogit(rank=ezpz.get_rank())
def finalize(self, outdir: Optional[str | Path | os.PathLike] = None) -> Dataset:
    """Flush profilers and return the aggregated training dataset."""
    import ambivalent
    import matplotlib.pyplot as plt

    plt.style.use(ambivalent.STYLES["ambivalent"])
    outdir = Path(outdir) if outdir is not None else self.config.outdir
    dataset = self.history.finalize(
        run_name="ezpz.test_dist",
        dataset_fname="train",
        warmup=self.config.warmup,
        save=False,  # XXX: don't bother saving test data
        plot=(self.rank == 0),
        outdir=outdir,
    )
    logger.info(f"{dataset=}")
    return dataset

train(profiler=None) ΒΆ

Loop over all training iterations and return the final dataset.

Source code in src/ezpz/test_dist.py
@ezpz.timeitlogit(rank=ezpz.get_rank())
def train(self, profiler: Optional[torch.profiler.profile] = None) -> Dataset:
    """Loop over all training iterations and return the final dataset."""
    for step in range(self.config.train_iters):
        if step == self.config.warmup:
            logger.info(f"Warmup complete at step {step}")
        _ = self.train_step()
        if profiler is not None:
            profiler.step()

    return (
        self.finalize()
        if self.rank == 0
        else self.history.get_dataset(warmup=self.config.warmup)
    )

train_step() ΒΆ

Run one optimiser step, emitting periodic logs/metrics.

Source code in src/ezpz/test_dist.py
@ezpz.timeitlogit(rank=ezpz.get_rank())
def train_step(self) -> dict:
    """Run one optimiser step, emitting periodic logs/metrics."""
    self.train_iter += 1
    metrics = self._forward_step()
    metrics["dtb"] = self._backward_step(metrics["loss"])
    self.optimizer.zero_grad()
    if self.train_iter == self.config.train_iters:
        return metrics
    if (
        self.train_iter % self.config.log_freq == 0
        or self.train_iter % self.config.print_freq == 0
    ):
        summary = self.history.update({"iter": self.train_iter, **metrics})
        if self.train_iter % self.config.print_freq == 0:
            logger.info(f"{summary}")
    return metrics

build_model_and_optimizer(model, backend='DDP') ΒΆ

Prepare the model and optimiser for the requested backend.

Source code in src/ezpz/test_dist.py
@ezpz.timeitlogit(rank=ezpz.get_rank())
def build_model_and_optimizer(
    model: torch.nn.Module, backend: str = "DDP"
) -> ModelOptimizerPair:
    """Prepare the model and optimiser for the requested backend."""
    if backend is not None:
        assert backend.lower() in {"ddp", "deepspeed", "ds"}
    device_override = os.environ.get("TORCH_DEVICE")
    device_type = device_override or ezpz.get_torch_device()
    if isinstance(device_type, str) and device_type.startswith("mps"):
        logger.warning(
            "MPS does not support torch.distributed collectives; falling back to CPU"
        )
        device_type = "cpu"
    world_size = ezpz.get_world_size()
    local_rank = ezpz.get_local_rank()
    if isinstance(device_type, str) and device_type in {"cuda", "xpu"}:
        device_type = f"{device_type}:{local_rank}"
    model.to(device_type)
    if isinstance(device_type, str) and device_type.startswith("cuda"):
        model.to(local_rank)
    logger.info(f"model=\n{model}")
    optimizer = torch.optim.Adam(model.parameters())
    if backend.lower() == "ddp":
        if world_size > 1:
            model.to(device_type)
            # model = DDP(model)
            try:
                if isinstance(device_type, str) and device_type.startswith("cuda"):
                    model = DDP(model, device_ids=[local_rank])
                else:
                    model = DDP(model)
            except Exception:
                model = DDP(model)

    elif backend.lower() in ("ds", "deepspeed"):
        parser = argparse.ArgumentParser(
            prog="deepspeed", description="My training script."
        )
        parser.add_argument(
            "--local_rank",
            required=False,
            type=int,
            default=-1,
            help="local rank passed from distributed launcher",
        )
        # parser.add_argument(
        #     '--deepspeed',
        #     action='store_true',
        #     default=True,
        #     help='Use deepspeed',
        # )
        # parser.add_argument(
        #     '--deepspeed_config',
        #     type=str,
        #     default='deepspeed_config.json',
        #     help='Deepspeed config file',
        # )
        try:
            import deepspeed  # type:ignore
        except (ImportError, ModuleNotFoundError) as e:
            logger.error(
                "Deepspeed not available. "
                "Install via `python3 -m pip install deepspeed`"
            )
            raise e

        # Include DeepSpeed configuration arguments
        parser = deepspeed.add_config_arguments(parser)
        cmd_args = parser.parse_args()
        model, optimizer, *_ = deepspeed.initialize(
            args=cmd_args,
            model=model,
            optimizer=optimizer,
        )
        logger.info(f"{cmd_args=}")
    return model, optimizer

calc_loss(x, y) ΒΆ

Return the squared error loss used by the smoke test.

Source code in src/ezpz/test_dist.py
@ezpz.timeitlogit(rank=ezpz.get_rank())
def calc_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Return the squared error loss used by the smoke test."""
    return (y - x).pow(2).sum()

get_config_from_args(args) ΒΆ

Translate CLI arguments into a :class:TrainConfig.

Source code in src/ezpz/test_dist.py
def get_config_from_args(args: argparse.Namespace) -> TrainConfig:
    """Translate CLI arguments into a :class:`TrainConfig`."""
    config = TrainConfig(
        acc_events=args.acc_events,
        batch_size=args.batch_size,
        profile_memory=args.profile_memory,
        record_shapes=args.record_shapes,
        with_stack=args.with_stack,
        with_flops=args.with_flops,
        with_modules=args.with_modules,
        rank_zero_only=args.rank_zero_only,
        backend=args.backend,
        dtype=args.dtype,
        log_freq=args.log_freq,
        print_freq=args.print_freq,
        tp=args.tp,
        pp=args.pp,
        cp=args.cp,
        input_size=args.input_size,
        output_size=args.output_size,
        train_iters=args.train_iters,
        layer_sizes=args.layer_sizes,
        pyinstrument_profiler=args.pyinstrument_profiler,
        pytorch_profiler=args.pytorch_profiler,
        pytorch_profiler_wait=args.pytorch_profiler_wait,
        pytorch_profiler_warmup=args.pytorch_profiler_warmup,
        pytorch_profiler_active=args.pytorch_profiler_active,
        pytorch_profiler_repeat=args.pytorch_profiler_repeat,
        warmup=args.warmup,
    )
    return config

main() ΒΆ

Entry point used by python -m ezpz.test_dist.

Source code in src/ezpz/test_dist.py
@ezpz.timeitlogit(rank=ezpz.get_rank())
def main() -> Trainer:
    """Entry point used by ``python -m ezpz.test_dist``."""
    t0 = time.perf_counter()
    args = parse_args()
    config = get_config_from_args(args)
    timings = {}
    with config.ctx as c:
        _ = ezpz.setup_torch(
            backend=config.backend,
            tensor_parallel_size=config.tp,
            pipeline_parallel_size=config.pp,
            context_parallel_size=config.cp,
        )
        t_setup = time.perf_counter()
        logger.info(f"Took: {(t_setup - t0):.2f} seconds to setup torch")
        trainer = train(config, profiler=c)
        t_train = time.perf_counter()
    if trainer.config.backend.lower() in ["ds", "deepspeed"]:
        try:
            import deepspeed.comm

            deepspeed.comm.log_summary()
        except ImportError as e:
            logger.exception(e)
            logger.exception(
                "Deepspeed not available. "
                "Install via `python3 -m pip install deepspeed`"
            )
            logger.info("Continuing without deepspeed summary...")

    logger.info(f"Took: {time.perf_counter() - START_TIME:.2f} seconds")
    t1 = time.perf_counter()
    timings = {
        "main/setup_torch": (t_setup - t0),
        "main/train": (t_train - t_setup),
        "main/total": (t1 - t0),
    }
    if wandb is not None:
        try:
            wandb.log(data=timings)
        except Exception:
            logger.warning("Failed to log timings to wandb")
    return trainer

parse_args() ΒΆ

Parse CLI arguments for ezpz.test_dist.

Source code in src/ezpz/test_dist.py
def parse_args() -> argparse.Namespace:
    """Parse CLI arguments for ``ezpz.test_dist``."""
    parser = argparse.ArgumentParser(description="Training configuration parameters")
    parser.add_argument(
        "--warmup",
        type=int,
        default=50,
        help="Warmup iterations",
    )
    parser.add_argument(
        "--tp",
        type=int,
        default=1,
        help="Tensor parallel size",
    )
    parser.add_argument(
        "--pp",
        type=int,
        default=1,
        help="Pipeline length",
    )

    # parser.add_argument(
    #     '--deepspeed',
    #     action='store_true',
    #     default=True,
    #     help='Use deepspeed',
    # )
    parser.add_argument(
        "--deepspeed_config",
        type=str,
        default="deepspeed_config.json",
        help="Deepspeed config file",
    )
    parser.add_argument(
        "--cp",
        type=int,
        default=1,
        help="Context parallel size",
    )
    parser.add_argument(
        "--backend",
        required=False,
        type=str,
        default="DDP",
        help="Backend (DDP, DeepSpeed, etc.)",
    )
    parser.add_argument(
        "--pyinstrument-profiler",
        action="store_true",
        help="Profile the training loop",
    )
    parser.add_argument(
        "-p",
        "--profile",
        default=False,
        dest="pytorch_profiler",
        required=False,
        action="store_true",
        help="Use PyTorch profiler",
    )
    parser.add_argument(
        "--rank-zero-only",
        action="store_true",
        help="Run profiler only on rank 0",
    )
    parser.add_argument(
        "--pytorch-profiler-wait",
        type=int,
        default=1,
        help="Wait time before starting the PyTorch profiler",
    )
    parser.add_argument(
        "--pytorch-profiler-warmup",
        type=int,
        default=2,
        help="Warmup iterations for the PyTorch profiler",
    )
    parser.add_argument(
        "--pytorch-profiler-active",
        type=int,
        default=3,
        help="Active iterations for the PyTorch profiler",
    )
    parser.add_argument(
        "--pytorch-profiler-repeat",
        type=int,
        default=5,
        help="Repeat iterations for the PyTorch profiler",
    )
    parser.add_argument(
        "--profile-memory",
        default=True,
        action="store_true",
        help="Profile memory usage",
    )
    parser.add_argument(
        "--record-shapes",
        default=True,
        action="store_true",
        help="Record shapes in the profiler",
    )
    parser.add_argument(
        "--with-stack",
        default=True,
        action="store_true",
        help="Include stack traces in the profiler",
    )
    parser.add_argument(
        "--with-flops",
        default=True,
        action="store_true",
        help="Include FLOPs in the profiler",
    )
    parser.add_argument(
        "--with-modules",
        default=True,
        action="store_true",
        help="Include module information in the profiler",
    )
    parser.add_argument(
        "--acc-events",
        default=False,
        action="store_true",
        help="Accumulate events in the profiler",
    )
    parser.add_argument(
        "--train-iters",
        "--train_iters",
        type=int,
        default=500,
        help="Number of training iterations",
    )
    parser.add_argument(
        "--log-freq",
        "--log_freq",
        type=int,
        default=1,
        help="Logging frequency",
    )
    parser.add_argument(
        "--print-freq",
        type=int,
        default=25,
        help="Printing frequency",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=256,
        help="Batch size",
    )
    parser.add_argument(
        "--input-size",
        type=int,
        default=2048,
        help="Input size",
    )
    parser.add_argument(
        "--output-size",
        type=int,
        default=2048,
        help="Output size",
    )
    parser.add_argument(
        "--layer-sizes",
        help="Comma-separated list of layer sizes",
        type=lambda s: [int(item) for item in s.split(",")],
        default=[4096, 8192, 16384, 8192, 4096],
        # default=[1024, 512, 256, 128],
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="bfloat16",
        help="Data type (fp16, float16, bfloat16, bf16, float32, etc.)",
    )

    args = parser.parse_args()
    if args.backend.lower() in {"ds", "deepspeed"}:
        try:
            import deepspeed  # type:ignore

            args.deepspeed = True
        except (ImportError, ModuleNotFoundError) as e:
            logger.error(
                "Deepspeed not available. "
                "Install via `python3 -m pip install deepspeed`"
            )
            raise e
    return args

train(config, profiler=None) ΒΆ

Instantiate the model/optimiser and run the training loop.

Source code in src/ezpz/test_dist.py
@ezpz.timeitlogit(rank=ezpz.get_rank())
def train(
    config: TrainConfig, profiler: Optional[torch.profiler.profile] = None
) -> Trainer:
    """Instantiate the model/optimiser and run the training loop."""
    from ezpz.models.minimal import SequentialLinearNet
    from ezpz.utils import model_summary

    # logger.info(f"Setting up torch with {config.backend=}...")
    timings = {}
    t0m = time.perf_counter()
    model = SequentialLinearNet(
        input_dim=config.input_size,
        output_dim=config.output_size,
        sizes=config.layer_sizes,
    )
    logger.info(f"Model size: {sum(p.numel() for p in model.parameters())} parameters")
    try:
        logger.info(f"\n{model_summary(model)}")
    except Exception as e:
        logger.warning(f"Failed to summarize model: {e}, using default summary")
        logger.info(model)
    t1m = time.perf_counter()
    dt_model = t1m - t0m
    logger.info(f"Took: {dt_model} seconds to build model")
    model, optimizer = build_model_and_optimizer(model, backend=config.backend)
    t2m = time.perf_counter()
    dt_optimizer = time.perf_counter() - t1m
    logger.info(f"Took: {dt_optimizer:.2f} seconds to build optimizer")
    trainer = Trainer(config=config, model=model, optimizer=optimizer)
    t1tr = time.perf_counter()
    logger.info(f"Took: {(dt_trainer := t1tr - t2m):.2f} seconds to build trainer")
    jstr = json.dumps(asdict(config), indent=2, sort_keys=True)
    logger.info(f"config:\n{jstr}")
    t1s = time.perf_counter()
    logger.info(f"Took: {(dt_train_start := t1s - START_TIME):.2f} to get here.")

    # -------------------------------------------
    # Main training loop
    t0t = time.perf_counter()
    _ = trainer.train(profiler=profiler)
    t1t = time.perf_counter()
    # -------------------------------------------

    # Record timings and return trainer
    logger.info(
        f"Took: {(dt_train_duration := t1t - t0t):.2f} seconds to finish training"
    )
    timings = {
        "timings/model": dt_model,
        "timings/optimizer": dt_optimizer,
        "timings/trainer": dt_trainer,
        "timings/training_start": dt_train_start,
        "timings/train_duration": dt_train_duration,
    }
    try:
        wandb.log(timings)  # type:ignore
    except Exception:
        pass
    # if not WANDB_DISABLED:
    #     try:
    # if wandb is not None and getattr(wandb, "run", None) is not None:
    #     wandb.log(timings)

    return trainer