Skip to content

ezpz.examples.generateΒΆ

Interactive text generation loop for Hugging Face causal language models.

Launch with:

1
ezpz launch -m ezpz.examples.generate --model_name <repo/model>

Help output (python3 -m ezpz.examples.generate --help):

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
usage: generate.py [-h] [--model_name MODEL_NAME]
                   [--dtype {float16,bfloat16,float32}]

Generate text using a model.

options:
  -h, --help            show this help message and exit
  --model_name MODEL_NAME
                        Name of the model to use.
  --dtype {float16,bfloat16,float32}
                        Data type to use for the model.

main() ΒΆ

Load a model and enter an interactive text generation REPL.

Source code in src/ezpz/examples/generate.py
def main():
    """Load a model and enter an interactive text generation REPL."""
    args = parse_args()
    model_name = args.model_name
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    # model.resize_token_embeddings(len(tokenizer))
    model.to(ezpz.get_torch_device_type())
    if args.dtype in {"bfloat16", "bf16", "b16"}:
        model.to(torch.bfloat16)
    elif args.dtype in {"float16", "fp16", "f16"}:
        model.to(torch.float16)
    elif args.dtype in {"float32", "fp32", "f32"}:
        model.to(torch.float32)
    else:
        raise ValueError(f"Unsupported dtype: {args.dtype}")
    print(f"Model loaded: {model_name}")
    print("Enter a prompt. Type 'exit' to quit.")
    while True:
        try:
            prompt = str(input("Enter a prompt: "))
            if str(prompt.lower().strip("").strip("")) == "exit":
                print("Exiting!")
                break
            max_length = int(input("Enter max length: "))
            print(prompt_model(model, tokenizer, prompt, max_length))
        except ValueError:
            print("Invalid input. Please enter a number.")
        except KeyboardInterrupt:
            print("\nExiting...")
            break

parse_args() ΒΆ

Parse CLI arguments for interactive generation.

Source code in src/ezpz/examples/generate.py
def parse_args():
    """Parse CLI arguments for interactive generation."""
    parser = build_generate_parser()
    return parser.parse_args()

prompt_model(model, tokenizer, prompt, max_length=64, **kwargs) ΒΆ

Generate text using a model and tokenizer.

Parameters:

Name Type Description Default
model AutoModelForCausalLM

Causal LM used for generation.

required
tokenizer AutoTokenizer

Tokenizer that encodes/decodes text.

required
prompt str

Input prompt to seed generation.

required
max_length int

Maximum number of tokens to generate.

64
**kwargs object

Extra parameters forwarded to model.generate.

{}

Returns:

Type Description
str

Decoded text returned by the model.

Examples:

1
2
3
4
>>> model_name = "argonne-private/AuroraGPT-7B"
>>> tokenizer = AutoTokenizer.from_pretrained(model_name)
>>> model = AutoModelForCausalLM.from_pretrained(model_name)
>>> _ = prompt_model(model, tokenizer, "Who are you?", max_length=32)
Source code in src/ezpz/examples/generate.py
def prompt_model(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompt: str,
    max_length: int = 64,
    **kwargs: object,
) -> str:
    """Generate text using a model and tokenizer.

    Args:
        model: Causal LM used for generation.
        tokenizer: Tokenizer that encodes/decodes text.
        prompt: Input prompt to seed generation.
        max_length: Maximum number of tokens to generate.
        **kwargs: Extra parameters forwarded to ``model.generate``.

    Returns:
        Decoded text returned by the model.

    Examples:
        >>> model_name = \"argonne-private/AuroraGPT-7B\"
        >>> tokenizer = AutoTokenizer.from_pretrained(model_name)
        >>> model = AutoModelForCausalLM.from_pretrained(model_name)
        >>> _ = prompt_model(model, tokenizer, \"Who are you?\", max_length=32)
    """
    return tokenizer.batch_decode(
        model.generate(
            **tokenizer(prompt, return_tensors="pt").to(ezpz.get_torch_device_type()),
            max_length=max_length,
            **kwargs,
        )
    )