Skip to content

ezpz.data.hfΒΆ

ezpz/datasets/hf.py

HuggingFace Datasets loading and tokenization.

ToyTextDataset ΒΆ

Bases: Dataset

Pads or truncates sentences to a fixed length.

Source code in src/ezpz/data/hf.py
class ToyTextDataset(Dataset):
    """Pads or truncates sentences to a fixed length."""

    def __init__(
        self, texts: List[str], vocab: Dict[str, int], seq_len: int = 12
    ):
        self.texts = texts
        self.vocab = vocab
        self.seq_len = seq_len
        self.pad_id = vocab["<pad>"]
        self.unk_id = vocab["<unk>"]

    def __len__(self) -> int:
        return len(self.texts)

    def _encode(self, text: str) -> torch.Tensor:
        tokens = [
            self.vocab.get(tok, self.unk_id) for tok in text.lower().split()
        ]
        tokens = tokens[: self.seq_len]
        tokens += [self.pad_id] * (self.seq_len - len(tokens))
        return torch.tensor(tokens, dtype=torch.long)

    def __getitem__(self, idx: int) -> torch.Tensor:  # type:ignore
        return self._encode(self.texts[idx])

build_vocab(texts) ΒΆ

Create a tiny vocabulary from a list of strings.

Source code in src/ezpz/data/hf.py
def build_vocab(texts: Iterable[str]) -> Tuple[Dict[str, int], Dict[int, str]]:
    """Create a tiny vocabulary from a list of strings."""
    specials = ["<pad>", "<unk>"]
    words = sorted({word for text in texts for word in text.lower().split()})
    vocab = {tok: idx for idx, tok in enumerate(specials + words)}
    inv_vocab = {idx: tok for tok, idx in vocab.items()}
    return vocab, inv_vocab

get_hf_text_dataset(*, dataset_name, split, text_column, tokenizer_name, seq_len, limit, seed) ΒΆ

Build a tokenized HF dataset with input_ids + attention_mask.

Returns:

Type Description
tuple[Dataset, AutoTokenizer]

tokenized dataset (torch formatted) and tokenizer.

Source code in src/ezpz/data/hf.py
def get_hf_text_dataset(
    *,
    dataset_name: str,
    split: str,
    text_column: str,
    tokenizer_name: str,
    seq_len: int,
    limit: int,
    seed: int,
) -> tuple[datasets.Dataset, AutoTokenizer]:
    """
    Build a tokenized HF dataset with input_ids + attention_mask.

    Returns:
        tokenized dataset (torch formatted) and tokenizer.
    """
    if seq_len <= 0:
        raise ValueError("seq_len must be > 0 for HF dataset tokenization.")
    logger.info(
        "Tokenizing HF dataset %s split=%s column=%s limit=%s seq_len=%s",
        dataset_name,
        split,
        text_column,
        limit,
        seq_len,
    )
    dataset = datasets.load_dataset(dataset_name, split=split)
    if (
        cnames := getattr(dataset, "column_names")
    ) and text_column not in list(cnames):
        raise ValueError(
            f"text_column '{text_column}' not in dataset columns {dataset.column_names}"
        )

    if limit > 0 and limit < len(dataset):
        dataset = dataset.shuffle(seed=seed)
        dataset = dataset.select(range(limit))

    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    max_length = seq_len + 1

    def tokenize_function(examples):
        return tokenizer(
            examples[text_column],
            padding="max_length",
            truncation=True,
            max_length=max_length,
            return_attention_mask=True,
        )

    tokenized = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=dataset.column_names,
        desc="Tokenizing HF dataset",
    )
    tokenized.set_format(
        type="torch", columns=["input_ids", "attention_mask"]
    )
    tokenized.pad_id = tokenizer.pad_token_id  # type: ignore[attr-defined]
    tokenized.vocab_size = tokenizer.vocab_size  # type: ignore[attr-defined]
    return tokenized, tokenizer

load_hf_texts(dataset_name, split, text_column, limit) ΒΆ

Pull a small slice of text from a Hugging Face dataset for quick experiments.

This uses only a limited number of rows (limit) to keep the example light.

Source code in src/ezpz/data/hf.py
def load_hf_texts(
    dataset_name: str,
    split: str,
    text_column: str,
    limit: int,
) -> list[str]:
    """
    Pull a small slice of text from a Hugging Face dataset for quick experiments.

    This uses only a limited number of rows (`limit`) to keep the example light.
    """
    try:
        from datasets import load_dataset  # type: ignore
    except Exception as exc:  # pragma: no cover - best-effort import
        raise RuntimeError(
            "datasets package is required for --hf-dataset usage"
        ) from exc

    logger.info(
        "Loading HF dataset %s split=%s column=%s limit=%s",
        dataset_name,
        split,
        text_column,
        limit,
    )
    dataset = load_dataset(dataset_name, split=split)
    # assert isinstance(dataset, datasets.Data)
    # if text_column not in list(dataset.column_names):
    if (
        cnames := getattr(dataset, "column_names")
    ) and text_column not in list(cnames):
        raise ValueError(
            f"text_column '{text_column}' not in dataset columns {dataset.column_names}"
        )
    else:
        assert callable(getattr(dataset, "select"))
        total = len(dataset)
        if limit <= 0:
            raise ValueError("limit must be > 0 for HF dataset sampling.")
        if limit >= total:
            indices = list(range(total))
        else:
            seed = int(os.environ.get("EZPZ_HF_SAMPLE_SEED", "1337"))
            try:
                dataset = dataset.shuffle(seed=seed)
                indices = list(range(limit))
            except Exception:
                rng = torch.Generator().manual_seed(seed)
                indices = torch.randperm(total, generator=rng)[:limit].tolist()
        texts = [
            str(row[text_column]) for row in dataset.select(indices)
            if str(row.get(text_column, "")).strip()
        ]
        if not texts:
            raise ValueError("No text rows found from HF dataset.")
    return texts

split_dataset(data_args, train_split_name='train', validation_split_name=None, cache_dir=None, token=None, trust_remote_code=False) ΒΆ

Splits the dataset into training and validation sets based on the provided split names.

Args:

Source code in src/ezpz/data/hf.py
def split_dataset(
    data_args: HfDataTrainingArguments,
    train_split_name: str = "train",
    validation_split_name: Optional[str] = None,
    cache_dir: Optional[str | os.PathLike | Path] = None,
    token: Optional[str] = None,
    trust_remote_code: bool = False,
    # model_args: HfModelArguments,
) -> datasets.IterableDatasetDict | datasets.DatasetDict:
    """
    Splits the dataset into training and validation sets based on the provided split names.

    Args:
    """
    dsets = {}
    # if (
    #     validation_split_name not in raw_datasets.keys() and training_args.do_eval
    # ):  # type:ignore
    # assert data_args.dataset_name is not None, (
    #     "dataset_name must be provided to split the dataset."
    # )
    dataset_name = data_args.dataset_name
    assert dataset_name is not None, (
        "dataset_name must be provided to split the dataset."
    )
    cache_dir = (
        Path("./.cache/hf/datasets") if cache_dir is None else cache_dir
    )
    assert cache_dir is not None and isinstance(cache_dir, (str, os.PathLike))
    cache_dir = Path(cache_dir).as_posix()
    if validation_split_name is not None:
        try:
            dsets[validation_split_name] = datasets.load_dataset(  # type:ignore
                dataset_name,
                data_args.dataset_config_name,
                split=f"{train_split_name}[:{data_args.validation_split_percentage}%]",
                cache_dir=cache_dir,
                token=token,
                streaming=data_args.streaming,
                trust_remote_code=trust_remote_code,
            )
            dsets[train_split_name] = datasets.load_dataset(  # type: ignore
                dataset_name,
                data_args.dataset_config_name,
                split=f"{train_split_name}[{data_args.validation_split_percentage}%:]",
                cache_dir=cache_dir,
                token=token,
                streaming=data_args.streaming,
                trust_remote_code=trust_remote_code,
            )
        except ValueError:
            # In some cases, the dataset doesn't support slicing.
            # In this case, we just use the full training set as validation set.
            dsets[validation_split_name] = datasets.load_dataset(  # type:ignore
                dataset_name,
                data_args.dataset_config_name,
                split=train_split_name,
                cache_dir=cache_dir,
                token=token,
                streaming=data_args.streaming,
                trust_remote_code=trust_remote_code,
            )
            try:
                dsets[train_split_name] = datasets.load_dataset(  # type:ignore
                    dataset_name,
                    data_args.dataset_config_name,
                    split=f"{train_split_name}[:{data_args.validation_split_percentage}%]",
                    cache_dir=cache_dir,
                    token=token,
                    streaming=data_args.streaming,
                    trust_remote_code=trust_remote_code,
                )
            except Exception:
                # In some cases, the dataset doesn't support slicing.
                # In this case, we just use the full training set as validation set.
                dsets[train_split_name] = datasets.load_dataset(  # type:ignore
                    dataset_name,
                    data_args.dataset_config_name,
                    split=train_split_name,
                    cache_dir=cache_dir,
                    token=token,
                    streaming=data_args.streaming,
                    trust_remote_code=trust_remote_code,
                )

    if data_args.streaming:
        return datasets.IterableDatasetDict(dsets)
    return datasets.DatasetDict(dsets)