Skip to content

ezpz.examples.diffusionΒΆ

Tiny diffusion example for short text generation.

This script trains a tiny denoising diffusion model on a handful of toy sentences and then samples new sentences by running the reverse process. The goal is to keep the code minimal while showcasing the full flow:

1
2
3
1. Build a small vocabulary from a list of prompts.
2. Train a denoising network to predict noise on token embeddings.
3. Sample text by iterating the reverse diffusion process.

Typical usage (customize with args as needed):

1
2
3
ezpz-launch -m ezpz.examples.diffusion --timesteps 64 --train-steps 500 --batch-size 16
# with FSDP and a HF dataset slice:
WORLD_SIZE=2 ezpz-launch -m ezpz.examples.diffusion --hf-dataset ag_news --fsdp

Launch with:

1
ezpz launch -m ezpz.examples.diffusion --timesteps 64 --train-steps 500

Help output (python3 -m ezpz.examples.diffusion --help):

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
usage: diffusion.py [-h] [--batch-size BATCH_SIZE] [--dtype DTYPE]
                    [--extra-text [EXTRA_TEXT ...]] [--fsdp]
                    [--fsdp-mixed-precision] [--hidden HIDDEN]
                    [--hf-dataset HF_DATASET] [--hf-split HF_SPLIT]
                    [--hf-text-column HF_TEXT_COLUMN] [--hf-limit HF_LIMIT]
                    [--log_freq LOG_FREQ] [--outdir OUTDIR]
                    [--samples SAMPLES] [--seed SEED] [--seq-len SEQ_LEN]
                    [--timesteps TIMESTEPS] [--train-steps TRAIN_STEPS]
                    [--lr LR]

Tiny diffusion example for text generation.

options:
  -h, --help            show this help message and exit
  --batch-size BATCH_SIZE
  --dtype DTYPE
  --extra-text [EXTRA_TEXT ...]
                        Additional sentences to add to the tiny corpus.
  --fsdp                Enable FSDP wrapping (requires WORLD_SIZE>1 and torch.distributed init).
  --fsdp-mixed-precision
                        Use bfloat16 parameters with FSDP for speed (defaults to float32).
  --hidden HIDDEN
  --hf-dataset HF_DATASET
                        Optional Hugging Face dataset name (e.g., 'ag_news'). When set, replaces the toy corpus.
  --hf-split HF_SPLIT   Dataset split to load.
  --hf-text-column HF_TEXT_COLUMN
                        Column containing raw text in the dataset.
  --hf-limit HF_LIMIT   Number of rows to sample from the HF dataset for quick experiments.
  --log_freq LOG_FREQ
  --outdir OUTDIR
  --samples SAMPLES
  --seed SEED
  --seq-len SEQ_LEN
  --timesteps TIMESTEPS
  --train-steps TRAIN_STEPS
  --lr LR

DiffusionSchedule dataclass ΒΆ

Precompute alpha/beta schedule values for DDPM style updates.

Source code in src/ezpz/examples/diffusion.py
@dataclass
class DiffusionSchedule:
    """Precompute alpha/beta schedule values for DDPM style updates."""

    timesteps: int = 64
    beta_start: float = 1e-4
    beta_end: float = 0.02

    def __post_init__(self) -> None:
        """Precompute alpha and alpha_bar schedules for diffusion steps."""
        self.betas = torch.linspace(
            self.beta_start, self.beta_end, self.timesteps
        )
        self.alphas = 1.0 - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)

__post_init__() ΒΆ

Precompute alpha and alpha_bar schedules for diffusion steps.

Source code in src/ezpz/examples/diffusion.py
def __post_init__(self) -> None:
    """Precompute alpha and alpha_bar schedules for diffusion steps."""
    self.betas = torch.linspace(
        self.beta_start, self.beta_end, self.timesteps
    )
    self.alphas = 1.0 - self.betas
    self.alpha_bars = torch.cumprod(self.alphas, dim=0)

DiffusionTextModel ΒΆ

Bases: Module

Simple transformer that predicts noise on token embeddings.

Source code in src/ezpz/examples/diffusion.py
class DiffusionTextModel(nn.Module):
    """Simple transformer that predicts noise on token embeddings."""

    def __init__(
        self,
        vocab_size: int,
        hidden_size: int,
        max_seq_len: int,
        timesteps: int,
        n_layers: int = 2,
        n_heads: int = 4,
    ) -> None:
        """Initialize embeddings and transformer encoder.

        Args:
            vocab_size: Size of the token vocabulary.
            hidden_size: Dimensionality of embeddings and model width.
            max_seq_len: Maximum sequence length.
            timesteps: Number of diffusion steps.
            n_layers: Number of transformer encoder layers.
            n_heads: Attention heads per layer.
        """
        super().__init__()
        self.hidden_size = hidden_size  # type:ignore
        self.token_emb = nn.Embedding(vocab_size, hidden_size)
        self.pos_emb = nn.Embedding(max_seq_len, hidden_size)
        self.time_emb = nn.Embedding(timesteps, hidden_size)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_size,
            nhead=n_heads,
            dim_feedforward=4 * hidden_size,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer, num_layers=n_layers
        )
        self.proj = nn.Linear(hidden_size, hidden_size)

    def embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
        """Embed token ids and scale them for transformer input."""
        # Clone avoids autograd complaints about views when using sharded params.
        return self.token_emb(tokens).clone() * math.sqrt(self.hidden_size)

    def forward(
        self, noisy_embs: torch.Tensor, t: torch.Tensor
    ) -> torch.Tensor:
        """Predict noise residuals given noisy embeddings and timestep."""
        _, seq_len, _ = noisy_embs.shape
        pos = self.pos_emb(torch.arange(seq_len, device=noisy_embs.device))
        temb = self.time_emb(t).unsqueeze(1)
        h = noisy_embs + pos.unsqueeze(0) + temb
        h = self.encoder(h)
        return self.proj(h)

    def decode_tokens(self, embs: torch.Tensor) -> torch.Tensor:
        """Project embeddings back to token ids via tied embeddings."""
        weights = self.token_emb.weight  # (vocab, hidden)
        logits = torch.einsum("bld,vd->blv", embs, weights)
        return logits.argmax(dim=-1)

__init__(vocab_size, hidden_size, max_seq_len, timesteps, n_layers=2, n_heads=4) ΒΆ

Initialize embeddings and transformer encoder.

Parameters:

Name Type Description Default
vocab_size int

Size of the token vocabulary.

required
hidden_size int

Dimensionality of embeddings and model width.

required
max_seq_len int

Maximum sequence length.

required
timesteps int

Number of diffusion steps.

required
n_layers int

Number of transformer encoder layers.

2
n_heads int

Attention heads per layer.

4
Source code in src/ezpz/examples/diffusion.py
def __init__(
    self,
    vocab_size: int,
    hidden_size: int,
    max_seq_len: int,
    timesteps: int,
    n_layers: int = 2,
    n_heads: int = 4,
) -> None:
    """Initialize embeddings and transformer encoder.

    Args:
        vocab_size: Size of the token vocabulary.
        hidden_size: Dimensionality of embeddings and model width.
        max_seq_len: Maximum sequence length.
        timesteps: Number of diffusion steps.
        n_layers: Number of transformer encoder layers.
        n_heads: Attention heads per layer.
    """
    super().__init__()
    self.hidden_size = hidden_size  # type:ignore
    self.token_emb = nn.Embedding(vocab_size, hidden_size)
    self.pos_emb = nn.Embedding(max_seq_len, hidden_size)
    self.time_emb = nn.Embedding(timesteps, hidden_size)
    encoder_layer = nn.TransformerEncoderLayer(
        d_model=hidden_size,
        nhead=n_heads,
        dim_feedforward=4 * hidden_size,
        batch_first=True,
    )
    self.encoder = nn.TransformerEncoder(
        encoder_layer, num_layers=n_layers
    )
    self.proj = nn.Linear(hidden_size, hidden_size)

decode_tokens(embs) ΒΆ

Project embeddings back to token ids via tied embeddings.

Source code in src/ezpz/examples/diffusion.py
def decode_tokens(self, embs: torch.Tensor) -> torch.Tensor:
    """Project embeddings back to token ids via tied embeddings."""
    weights = self.token_emb.weight  # (vocab, hidden)
    logits = torch.einsum("bld,vd->blv", embs, weights)
    return logits.argmax(dim=-1)

embed_tokens(tokens) ΒΆ

Embed token ids and scale them for transformer input.

Source code in src/ezpz/examples/diffusion.py
def embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
    """Embed token ids and scale them for transformer input."""
    # Clone avoids autograd complaints about views when using sharded params.
    return self.token_emb(tokens).clone() * math.sqrt(self.hidden_size)

forward(noisy_embs, t) ΒΆ

Predict noise residuals given noisy embeddings and timestep.

Source code in src/ezpz/examples/diffusion.py
def forward(
    self, noisy_embs: torch.Tensor, t: torch.Tensor
) -> torch.Tensor:
    """Predict noise residuals given noisy embeddings and timestep."""
    _, seq_len, _ = noisy_embs.shape
    pos = self.pos_emb(torch.arange(seq_len, device=noisy_embs.device))
    temb = self.time_emb(t).unsqueeze(1)
    h = noisy_embs + pos.unsqueeze(0) + temb
    h = self.encoder(h)
    return self.proj(h)

ToyTextDataset ΒΆ

Bases: Dataset

Pads or truncates sentences to a fixed length.

Source code in src/ezpz/examples/diffusion.py
class ToyTextDataset(Dataset):
    """Pads or truncates sentences to a fixed length."""

    def __init__(
        self, texts: List[str], vocab: Dict[str, int], seq_len: int = 12
    ):
        """Store corpus and vocabulary for encoding.

        Args:
            texts: Raw sentences.
            vocab: Token-to-id mapping.
            seq_len: Target sequence length for padding/truncation.
        """
        self.texts = texts
        self.vocab = vocab
        self.seq_len = seq_len
        self.pad_id = vocab["<pad>"]
        self.unk_id = vocab["<unk>"]

    def __len__(self) -> int:
        """Return number of sentences in the corpus."""
        return len(self.texts)

    def _encode(self, text: str) -> torch.Tensor:
        """Convert a sentence to a fixed-length tensor of token ids."""
        tokens = [
            self.vocab.get(tok, self.unk_id) for tok in text.lower().split()
        ]
        tokens = tokens[: self.seq_len]
        tokens += [self.pad_id] * (self.seq_len - len(tokens))
        return torch.tensor(tokens, dtype=torch.long)

    def __getitem__(self, idx: int) -> torch.Tensor:  # type:ignore
        """Return encoded tokens for the indexed sentence."""
        return self._encode(self.texts[idx])

__getitem__(idx) ΒΆ

Return encoded tokens for the indexed sentence.

Source code in src/ezpz/examples/diffusion.py
def __getitem__(self, idx: int) -> torch.Tensor:  # type:ignore
    """Return encoded tokens for the indexed sentence."""
    return self._encode(self.texts[idx])

__init__(texts, vocab, seq_len=12) ΒΆ

Store corpus and vocabulary for encoding.

Parameters:

Name Type Description Default
texts List[str]

Raw sentences.

required
vocab Dict[str, int]

Token-to-id mapping.

required
seq_len int

Target sequence length for padding/truncation.

12
Source code in src/ezpz/examples/diffusion.py
def __init__(
    self, texts: List[str], vocab: Dict[str, int], seq_len: int = 12
):
    """Store corpus and vocabulary for encoding.

    Args:
        texts: Raw sentences.
        vocab: Token-to-id mapping.
        seq_len: Target sequence length for padding/truncation.
    """
    self.texts = texts
    self.vocab = vocab
    self.seq_len = seq_len
    self.pad_id = vocab["<pad>"]
    self.unk_id = vocab["<unk>"]

__len__() ΒΆ

Return number of sentences in the corpus.

Source code in src/ezpz/examples/diffusion.py
def __len__(self) -> int:
    """Return number of sentences in the corpus."""
    return len(self.texts)

add_noise(x0, t, schedule) ΒΆ

Apply forward diffusion noise to clean embeddings.

Source code in src/ezpz/examples/diffusion.py
def add_noise(
    x0: torch.Tensor, t: torch.Tensor, schedule: DiffusionSchedule
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Apply forward diffusion noise to clean embeddings."""
    noise = torch.randn_like(x0)
    alpha_bar = schedule.alpha_bars.to(x0.device)[t].view(-1, 1, 1)
    noisy = torch.sqrt(alpha_bar) * x0 + torch.sqrt(1 - alpha_bar) * noise
    return noisy, noise

build_vocab(texts) ΒΆ

Create a tiny vocabulary from a list of strings.

Source code in src/ezpz/examples/diffusion.py
def build_vocab(texts: Iterable[str]) -> Tuple[Dict[str, int], Dict[int, str]]:
    """Create a tiny vocabulary from a list of strings."""
    specials = ["<pad>", "<unk>"]
    words = sorted({word for text in texts for word in text.lower().split()})
    vocab = {tok: idx for idx, tok in enumerate(specials + words)}
    inv_vocab = {idx: tok for tok, idx in vocab.items()}
    return vocab, inv_vocab

generate_text(model, schedule, inv_vocab, seq_len, num_samples, skip_tokens=('<pad>', '<unk>')) ΒΆ

Sample sequences from the trained diffusion model.

Parameters:

Name Type Description Default
model DiffusionTextModel

Trained diffusion network (possibly FSDP wrapped).

required
schedule DiffusionSchedule

Noise schedule with precomputed alphas.

required
inv_vocab Dict[int, str]

Mapping from token ids back to string tokens.

required
seq_len int

Maximum sequence length.

required
num_samples int

Number of sentences to generate.

required
skip_tokens Tuple[str, ...]

Tokens to drop from decoded outputs.

('<pad>', '<unk>')

Returns:

Type Description
List[str]

List of generated text strings.

Source code in src/ezpz/examples/diffusion.py
def generate_text(
    model: DiffusionTextModel,
    schedule: DiffusionSchedule,
    inv_vocab: Dict[int, str],
    seq_len: int,
    num_samples: int,
    skip_tokens: Tuple[str, ...] = ("<pad>", "<unk>"),
) -> List[str]:
    """Sample sequences from the trained diffusion model.

    Args:
        model: Trained diffusion network (possibly FSDP wrapped).
        schedule: Noise schedule with precomputed alphas.
        inv_vocab: Mapping from token ids back to string tokens.
        seq_len: Maximum sequence length.
        num_samples: Number of sentences to generate.
        skip_tokens: Tokens to drop from decoded outputs.

    Returns:
        List of generated text strings.
    """
    model.eval()
    samples: List[str] = []
    do_sample = ezpz.get_rank() == 0
    is_fsdp = isinstance(model, FSDP)
    base_model = model.module if hasattr(model, "module") else model
    full_param_ctx = (
        FSDP.summon_full_params(model)  # , recursive=True)
        if is_fsdp
        else nullcontext()
    )

    with torch.no_grad():
        with full_param_ctx:
            if not do_sample:
                return samples
            token_emb_weight = base_model.token_emb.weight  # type:ignore
            for _ in range(num_samples):
                xt = torch.randn(
                    (1, seq_len, base_model.hidden_size),
                    device=token_emb_weight.device,
                )
                for t in reversed(range(schedule.timesteps)):
                    xt = p_sample(base_model, xt, t, schedule)
                logits = torch.einsum("bld,vd->blv", xt, token_emb_weight)
                token_ids = logits.argmax(dim=-1)[0].tolist()
                words = [
                    inv_vocab.get(tok_id, "<unk>") for tok_id in token_ids
                ]
                words = [w for w in words if w not in skip_tokens]
                samples.append(" ".join(words))
    return samples

get_default_texts() ΒΆ

Return a small corpus of seed sentences for toy training.

Source code in src/ezpz/examples/diffusion.py
def get_default_texts() -> List[str]:
    """Return a small corpus of seed sentences for toy training."""
    return [
        "the product team ships updates every week",
        "customers ask for faster onboarding",
        "the service autoscaling keeps latency steady",
        "data pipelines need reliable monitoring",
        "large language models assist with code reviews",
        "cloud costs drop when workloads are right sized",
        "edge devices sync logs during quiet hours",
        "dashboards show live metrics for incidents",
    ]

load_hf_texts(dataset_name, split, text_column, limit) ΒΆ

Pull a small slice of text from a Hugging Face dataset for quick experiments.

This uses only a limited number of rows (limit) to keep the example light.

Source code in src/ezpz/examples/diffusion.py
def load_hf_texts(
    dataset_name: str,
    split: str,
    text_column: str,
    limit: int,
) -> List[str]:
    """
    Pull a small slice of text from a Hugging Face dataset for quick experiments.

    This uses only a limited number of rows (`limit`) to keep the example light.
    """
    try:
        from datasets import load_dataset  # type: ignore
    except Exception as exc:  # pragma: no cover - best-effort import
        raise RuntimeError(
            "datasets package is required for --hf-dataset usage"
        ) from exc

    logger.info(
        "Loading HF dataset %s split=%s column=%s limit=%s",
        dataset_name,
        split,
        text_column,
        limit,
    )
    dataset = load_dataset(dataset_name, split=split)
    if text_column not in list(dataset.column_names):
        raise ValueError(
            f"text_column '{text_column}' not in dataset columns {dataset.column_names}"
        )
    texts = [str(row[text_column]) for row in dataset.select(range(limit))]
    if not texts:
        raise ValueError("No text rows found from HF dataset.")
    return texts

main(args) ΒΆ

Set up distributed training, fit the model, and log samples.

Source code in src/ezpz/examples/diffusion.py
def main(args: argparse.Namespace) -> None:
    """Set up distributed training, fit the model, and log samples."""
    rank = ezpz.setup_torch(seed=args.seed)
    if rank == 0:
        outdir = args.outdir if args.outdir is not None else OUTDIR
    else:
        outdir = None
    outdir = ezpz.dist.broadcast(outdir, root=0)
    logger.info(f"Using {outdir=}")
    # self._created_at = ezpz.dist.broadcast(self._created_at, root=0)
    if ezpz.get_rank() == 0:
        run = ezpz.dist.setup_wandb(
            project_name=WBPROJ_NAME,
            # outdir=outdir,
        )
        assert run is not None and run is wandb.run
        # wandb.config.update(ezpz.dist.get_dist_info())
        wandb.config.update({"outdir": outdir, "args": {**vars(args)}})
        # wandb.config.update({"args": {**vars(args)}})

    base_texts: List[str]
    if args.hf_dataset:
        base_texts = load_hf_texts(
            dataset_name=args.hf_dataset,
            split=args.hf_split,
            text_column=args.hf_text_column,
            limit=args.hf_limit,
        )
    else:
        base_texts = get_default_texts()
        if args.extra_text:
            base_texts = base_texts + args.extra_text

    vocab, inv_vocab = build_vocab(base_texts)
    dataset = ToyTextDataset(base_texts, vocab, seq_len=args.seq_len)
    sampler = (
        DistributedSampler(dataset) if ezpz.get_world_size() > 1 else None
    )
    loader = DataLoader(
        dataset,
        sampler=sampler,
        batch_size=args.batch_size,
        shuffle=(sampler is None),
        drop_last=False,
    )

    schedule = DiffusionSchedule(timesteps=args.timesteps)
    model = DiffusionTextModel(
        vocab_size=len(vocab),
        hidden_size=args.hidden,
        max_seq_len=args.seq_len,
        timesteps=args.timesteps,
    )
    device = ezpz.get_torch_device(as_torch_device=True)
    model.to(device)

    history, wrapped_model = train(
        model=model,
        loader=loader,
        schedule=schedule,
        args=args,
        steps=args.train_steps,
        lr=args.lr,
        outdir=outdir,
    )

    if ezpz.get_rank() == 0:
        dataset = history.finalize(
            run_name=WBRUN_NAME,
            dataset_fname="train",
            warmup=0.1,
        )
    samples = generate_text(
        wrapped_model,
        schedule,
        inv_vocab,
        seq_len=args.seq_len,
        num_samples=args.samples,
    )
    if ezpz.get_rank() == 0:
        for idx, text in enumerate(samples):
            logger.info("sample %s: %s", idx, text)

p_sample(model, xt, t, schedule) ΒΆ

Predict one reverse-diffusion step at timestep t.

Source code in src/ezpz/examples/diffusion.py
def p_sample(
    model: DiffusionTextModel,
    xt: torch.Tensor,
    t: int,
    schedule: DiffusionSchedule,
) -> torch.Tensor:
    """Predict one reverse-diffusion step at timestep ``t``."""
    t_batch = torch.full((xt.size(0),), t, device=xt.device, dtype=torch.long)
    beta = schedule.betas.to(xt.device)[t]
    alpha = schedule.alphas.to(xt.device)[t]
    alpha_bar = schedule.alpha_bars.to(xt.device)[t]
    eps = model(xt, t_batch)
    mean = (xt - (beta / torch.sqrt(1 - alpha_bar)) * eps) / torch.sqrt(alpha)
    if t == 0:
        return mean
    noise = torch.randn_like(xt)
    return mean + torch.sqrt(beta) * noise

parse_args() ΒΆ

Parse CLI arguments for the diffusion text example.

Source code in src/ezpz/examples/diffusion.py
def parse_args() -> argparse.Namespace:
    """Parse CLI arguments for the diffusion text example."""
    parser = argparse.ArgumentParser(
        description="Tiny diffusion example for text generation."
    )
    parser.add_argument(
        "--batch-size", type=int, default=int(os.environ.get("BATCH_SIZE", 8))
    )
    parser.add_argument(
        "--dtype", type=str, default=os.environ.get("DTYPE", "float32")
    )
    parser.add_argument(
        "--extra-text",
        type=str,
        nargs="*",
        default=None,
        help="Additional sentences to add to the tiny corpus.",
    )
    parser.add_argument(
        "--fsdp",
        action="store_true",
        help="Enable FSDP wrapping (requires WORLD_SIZE>1 and torch.distributed init).",
    )
    parser.add_argument(
        "--fsdp-mixed-precision",
        action="store_true",
        help="Use bfloat16 parameters with FSDP for speed (defaults to float32).",
    )
    parser.add_argument(
        "--hidden", type=int, default=int(os.environ.get("HIDDEN", 128))
    )
    parser.add_argument(
        "--hf-dataset",
        type=str,
        default=None,
        help="Optional Hugging Face dataset name (e.g., 'ag_news'). When set, replaces the toy corpus.",
    )
    parser.add_argument(
        "--hf-split",
        type=str,
        default="train",
        help="Dataset split to load.",
    )
    parser.add_argument(
        "--hf-text-column",
        type=str,
        default="text",
        help="Column containing raw text in the dataset.",
    )
    parser.add_argument(
        "--hf-limit",
        type=int,
        default=512,
        help="Number of rows to sample from the HF dataset for quick experiments.",
    )
    parser.add_argument(
        "--log_freq", type=int, default=int(os.environ.get("LOG_FREQ", 1))
    )
    parser.add_argument("--outdir", type=str, default=None)
    parser.add_argument(
        "--samples", type=int, default=int(os.environ.get("SAMPLES", 3))
    )
    parser.add_argument(
        "--seed", type=int, default=int(os.environ.get("SEED", 0))
    )
    parser.add_argument(
        "--seq-len", type=int, default=int(os.environ.get("SEQ_LEN", 12))
    )
    parser.add_argument(
        "--timesteps", type=int, default=int(os.environ.get("TIMESTEPS", 64))
    )
    parser.add_argument(
        "--train-steps",
        type=int,
        default=int(os.environ.get("TRAIN_STEPS", 400)),
    )
    parser.add_argument(
        "--lr", type=float, default=float(os.environ.get("LR", 3e-3))
    )
    # parser.add_argument(
    #     "--ddp",
    #     action="store_true",
    #     help="Enable DDP wrapping (requires WORLD_SIZE>1 and torch.distributed init).",
    # )
    return parser.parse_args()

sample_timesteps(batch_size, schedule, device) ΒΆ

Uniformly sample diffusion steps for a batch.

Source code in src/ezpz/examples/diffusion.py
def sample_timesteps(
    batch_size: int, schedule: DiffusionSchedule, device: torch.device
) -> torch.Tensor:
    """Uniformly sample diffusion steps for a batch."""
    return torch.randint(0, schedule.timesteps, (batch_size,), device=device)

test(model, test_loader) ΒΆ

Evaluate the classifier outputs on a held-out loader.

Source code in src/ezpz/examples/diffusion.py
@ezpz.timeitlogit(rank=ezpz.get_rank())
def test(model, test_loader):
    """Evaluate the classifier outputs on a held-out loader."""
    DEVICE = ezpz.get_torch_device()
    DEVICE_ID = f"{DEVICE}:{ezpz.get_local_rank()}"
    model.eval()
    # correct = 0
    ddp_loss = torch.zeros(3).to(DEVICE_ID)
    with torch.no_grad():
        for batch, target in test_loader:
            batch, target = batch.to(DEVICE_ID), target.to(DEVICE_ID)
            output = model(batch)
            ddp_loss[0] += F.nll_loss(
                output, target, reduction="sum"
            ).item()  # sum up batch loss
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item()
            ddp_loss[2] += len(batch)

    dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)  # type:ignore

    test_loss = ddp_loss[0] / ddp_loss[2]

    return {
        "test_loss": test_loss,
        "test_acc": 100.0 * ddp_loss[1] / ddp_loss[2],
    }

train(model, loader, schedule, args, steps, outdir, lr=0.001) ΒΆ

Train the diffusion text model for a fixed number of steps.

Source code in src/ezpz/examples/diffusion.py
@ezpz.timeitlogit(rank=ezpz.get_rank())
def train(
    model: DiffusionTextModel,
    loader: DataLoader,
    schedule: DiffusionSchedule,
    args: argparse.Namespace,
    steps: int,
    outdir: Path | os.PathLike | str,
    lr: float = 1e-3,
) -> tuple[ezpz.History, torch.nn.Module]:
    """Train the diffusion text model for a fixed number of steps."""
    device = ezpz.get_torch_device(as_torch_device=True)
    # if not isinstance(model, (DistributeFSDP):
    model.to(device)
    model.train()
    wrapped_model = ezpz.dist.wrap_model(
        model, use_fsdp=args.fsdp, dtype=args.dtype
    )
    optim = torch.optim.AdamW(wrapped_model.parameters(), lr=lr)
    mstr = ezpz.models.summarize_model(
        wrapped_model,
        verbose=False,
        depth=2,
        # input_size=(
        #     torch.tensor((int(args.batch_size), int(args.seq_length))).to(
        #         torch.long
        #     )
        # ).shape,
    )
    logger.info("Model summary:\n%s", mstr)

    # outdir = Path(os.getcwd()) if outdir is None else outdir
    # outdir_parent = Path(outdir).joinpath(ezpz.utils.get_timestamp())
    # outdir = Path(outdir).as_posix()
    metrics_path = Path(outdir).joinpath(f"metrics-{ezpz.get_rank()}.jsonl")
    history = ezpz.history.History(
        report_dir=outdir,
        report_enabled=True,
        jsonl_path=metrics_path,
        jsonl_overwrite=True,
        distributed_history=(
            1 < ezpz.get_world_size() <= 384  # and not config.pytorch_profiler
        ),
    )

    # log_freq = max(1, steps // 100)
    assert isinstance(
        wrapped_model, (nn.Module, FSDP, DistributedDataParallel)
    ), "Model should be wrapped for training."
    base_model = (
        wrapped_model.module
        if hasattr(wrapped_model, "module")
        else wrapped_model
    )
    assert callable(getattr(base_model, "embed_tokens", None)), (
        "Model should have embed_tokens method."
    )
    is_fsdp = isinstance(wrapped_model, FSDP)
    loader_iter = iter(loader)
    for step in range(steps):
        t0 = time.perf_counter()
        try:
            tokens = next(loader_iter)
        except StopIteration:
            loader_iter = iter(loader)
            tokens = next(loader_iter)
        tokens = tokens.to(device)
        t1 = time.perf_counter()
        ezpz.dist.synchronize()
        full_param_ctx = (
            FSDP.summon_full_params(wrapped_model)
            if is_fsdp
            else nullcontext()
        )
        with full_param_ctx:
            x0 = base_model.embed_tokens(tokens)
        t = sample_timesteps(tokens.size(0), schedule, device=device)
        xt, noise = add_noise(x0, t, schedule)
        pred_noise = wrapped_model(xt, t)
        loss = torch.mean((pred_noise - noise) ** 2)
        t2 = time.perf_counter()
        ezpz.dist.synchronize()

        loss.backward()
        optim.step()
        optim.zero_grad(set_to_none=True)
        t3 = time.perf_counter()
        ezpz.dist.synchronize()

        if step % args.log_freq == 0 or step == steps - 1:
            logger.info(
                history.update(
                    {
                        "train/step": step,
                        "train/loss": loss.item(),
                        "train/dt": t3 - t0,
                        "train/dtd": t1 - t0,
                        "train/dtf": t2 - t1,
                        "train/dtb": t3 - t2,
                    }
                ).replace("train/", "")
            )

    # loader_iter = iter(loader)
    # for step in range(steps):
    #     try:
    #         tokens = next(loader_iter)
    #     except StopIteration:
    #         loader_iter = iter(loader)
    #         tokens = next(loader_iter)
    #     tokens = tokens.to(device)
    #     x0 = model.embed_tokens(tokens)
    #     t = sample_timesteps(tokens.size(0), schedule, device=device)
    #     xt, noise = add_noise(x0, t, schedule)
    #     pred_noise = model(xt, t)
    #     loss = torch.mean((pred_noise - noise) ** 2)
    #
    #     loss.backward()
    #     optim.step()
    #     optim.zero_grad(set_to_none=True)
    #
    #     if step % log_freq == 0 or step == steps - 1:
    #         summary = history.update({"step": step, "loss": loss.item()})
    #         logger.info(summary)
    return history, wrapped_model