Skip to content

ezpz.generate

generate.py

parse_args()

Parse command line arguments.

Source code in src/ezpz/generate.py
def parse_args():
    """
    Parse command line arguments.
    """
    import argparse

    parser = argparse.ArgumentParser(description="Generate text using a model.")
    parser.add_argument(
        "--model_name",
        type=str,
        default="meta-llama/Llama-3.2-1B",
        help="Name of the model to use.",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="bfloat16",
        choices=["float16", "bfloat16", "float32"],
        help="Data type to use for the model.",
    )
    return parser.parse_args()

prompt_model(model, tokenizer, prompt, max_length=64, **kwargs)

Generate text using a model and tokenizer.

Parameters:

Name Type Description Default
model

The model to use for generation.

required
tokenizer

The tokenizer to use for encoding and decoding.

required
prompt

The input prompt to generate text from.

required
max_length int

The maximum length of the generated text.

64
**kwargs

Additional arguments to pass to the model's generate method.

{}

Example:

1
2
3
4
5
6
7
8
9
>>> import ezpz
>>> from transformers import AutoModelForCausalLM, AutoTokenizer #, Trainer, TrainingArguments
>>> import torch
>>> model_name = "argonne-private/AuroraGPT-7B"
>>> tokenizer = AutoTokenizer.from_pretrained(model_name)
>>> model = AutoModelForCausalLM.from_pretrained(model_name)
>>> model.to(ezpz.get_torch_device_type())
>>> model.to(torch.bfloat16)
>>> result = tokenizer.batch_decode(model.generate(**tokenizer("Who are you?", return_tensors="pt").to(ezpz.get_torch_device_type()), max_length=128))
Source code in src/ezpz/generate.py
def prompt_model(
    model, tokenizer, prompt, max_length: int = 64, **kwargs
) -> str:
    """
    Generate text using a model and tokenizer.

    Args:
        model: The model to use for generation.
        tokenizer: The tokenizer to use for encoding and decoding.
        prompt: The input prompt to generate text from.
        max_length: The maximum length of the generated text.
        **kwargs: Additional arguments to pass to the model's generate method.

    Example:

        >>> import ezpz
        >>> from transformers import AutoModelForCausalLM, AutoTokenizer #, Trainer, TrainingArguments
        >>> import torch
        >>> model_name = "argonne-private/AuroraGPT-7B"
        >>> tokenizer = AutoTokenizer.from_pretrained(model_name)
        >>> model = AutoModelForCausalLM.from_pretrained(model_name)
        >>> model.to(ezpz.get_torch_device_type())
        >>> model.to(torch.bfloat16)
        >>> result = tokenizer.batch_decode(model.generate(**tokenizer("Who are you?", return_tensors="pt").to(ezpz.get_torch_device_type()), max_length=128))
    """
    return tokenizer.batch_decode(
        model.generate(
            **tokenizer(prompt, return_tensors="pt").to(
                ezpz.get_torch_device_type()
            ),
            max_length=max_length,
            **kwargs,
        )
    )