Skip to content

ezpz.examples.deepspeed.tp.utilsΒΆ

Utility helpers for DeepSpeed tensor-parallel scripts (OpenAI completions, IO).

This module is imported by other examples; it is not a standalone CLI.

OpenAIDecodingArguments dataclass ΒΆ

Configurable decoding parameters for OpenAI API completions.

Source code in src/ezpz/examples/deepspeed/tp/utils.py
@dataclasses.dataclass
class OpenAIDecodingArguments:
    """Configurable decoding parameters for OpenAI API completions."""

    max_tokens: int = 1800
    temperature: float = 0.2
    top_p: float = 1.0
    n: int = 1
    stream: bool = False
    stop: Optional[Sequence[str]] = None
    presence_penalty: float = 0.0
    frequency_penalty: float = 0.0
    suffix: Optional[str] = None
    logprobs: Optional[int] = None
    echo: bool = False

jdump(obj, f, mode='w', indent=4, default=None) ΒΆ

Dump a str or dictionary to a file in json format.

Parameters:

Name Type Description Default
obj Any

An object to be written.

required
f str | PathLike

A string path to the location on disk.

required
mode str

Mode for opening the file.

'w'
indent int

Indent for storing json dictionaries.

4
default Any | None

A function to handle non-serializable entries; defaults to str.

None
Source code in src/ezpz/examples/deepspeed/tp/utils.py
def jdump(
    obj: Any,
    f: str | os.PathLike,
    mode: str = "w",
    indent: int = 4,
    default: Any | None = None,
):
    """Dump a str or dictionary to a file in json format.

    Args:
        obj: An object to be written.
        f: A string path to the location on disk.
        mode: Mode for opening the file.
        indent: Indent for storing json dictionaries.
        default: A function to handle non-serializable entries; defaults to `str`.
    """
    fout = _make_w_io_base(f, mode)
    if isinstance(obj, (dict, list)):
        json.dump(obj, fout, indent=indent, default=default)
    elif isinstance(obj, str):
        fout.write(obj)
    else:
        raise ValueError(f"Unexpected type: {type(obj)}")
    fout.close()

jload(f, mode='r') ΒΆ

Load a .json file into a dictionary.

Source code in src/ezpz/examples/deepspeed/tp/utils.py
def jload(f, mode="r"):
    """Load a .json file into a dictionary."""
    f = _make_r_io_base(f, mode)
    jdict = json.load(f)
    f.close()
    return jdict

openai_completion(prompts, decoding_args, model_name='text-davinci-003', sleep_time=2, batch_size=1, max_instances=sys.maxsize, max_batches=sys.maxsize, return_text=False, **decoding_kwargs) ΒΆ

Decode with OpenAI API.

Parameters:

Name Type Description Default
prompts str | Sequence[str] | Sequence[dict[str, str]] | dict[str, str]

A string or a list of strings to complete. If it is a chat model the strings should be formatted as explained here: https://github.com/openai/openai-python/blob/main/chatml.md. If it is a chat model it can also be a dictionary (or list thereof) as explained here: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb

required
decoding_args OpenAIDecodingArguments

Decoding arguments.

required
model_name str

Model name. Can be either in the format of "org/model" or just "model".

'text-davinci-003'
sleep_time int

Time to sleep once the rate-limit is hit.

2
batch_size int

Number of prompts to send in a single request. Only for non chat model.

1
max_instances int

Maximum number of prompts to decode.

maxsize
max_batches int

Maximum number of batches to decode. This argument will be deprecated in the future.

maxsize
return_text bool

If True, return text instead of full completion object (which contains things like logprob).

False
decoding_kwargs dict[Any, Any]

Additional decoding arguments. Pass in best_of and logit_bias if you need them.

{}

Returns:

Type Description
Union[Union[str, Any], Sequence[StrOrOpenAIObject], Sequence[Sequence[StrOrOpenAIObject]]]

A completion or a list of completions.

Union[Union[str, Any], Sequence[StrOrOpenAIObject], Sequence[Sequence[StrOrOpenAIObject]]]

Depending on return_text, return_openai_object, and decoding_args.n, the completion type can be one of - a string (if return_text is True) - an openai_object.OpenAIObject object (if return_text is False) - a list of objects of the above types (if decoding_args.n > 1)

Source code in src/ezpz/examples/deepspeed/tp/utils.py
def openai_completion(
    prompts: str | Sequence[str] | Sequence[dict[str, str]] | dict[str, str],
    decoding_args: OpenAIDecodingArguments,
    model_name: str = "text-davinci-003",
    sleep_time: int = 2,
    batch_size: int = 1,
    max_instances: int = sys.maxsize,
    max_batches: int = sys.maxsize,
    return_text: bool = False,
    **decoding_kwargs: dict[Any, Any],
) -> Union[
    Union[str, Any],
    Sequence[StrOrOpenAIObject],
    Sequence[Sequence[StrOrOpenAIObject]],
]:
    """Decode with OpenAI API.

    Args:
        prompts: A string or a list of strings to complete. If it is a chat model the strings should be formatted
            as explained here: https://github.com/openai/openai-python/blob/main/chatml.md. If it is a chat model
            it can also be a dictionary (or list thereof) as explained here:
            https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
        decoding_args: Decoding arguments.
        model_name: Model name. Can be either in the format of "org/model" or just "model".
        sleep_time: Time to sleep once the rate-limit is hit.
        batch_size: Number of prompts to send in a single request. Only for non chat model.
        max_instances: Maximum number of prompts to decode.
        max_batches: Maximum number of batches to decode. This argument will be deprecated in the future.
        return_text: If True, return text instead of full completion object (which contains things like logprob).
        decoding_kwargs: Additional decoding arguments. Pass in `best_of` and `logit_bias` if you need them.

    Returns:
        A completion or a list of completions.
        Depending on return_text, return_openai_object, and decoding_args.n, the completion type can be one of
            - a string (if return_text is True)
            - an openai_object.OpenAIObject object (if return_text is False)
            - a list of objects of the above types (if decoding_args.n > 1)
    """
    is_single_prompt = isinstance(prompts, (str, dict))
    if is_single_prompt:
        prompts = [prompts]  # type:ignore

    assert isinstance(prompts, list)
    if max_batches < sys.maxsize:
        logging.warning(
            "`max_batches` will be deprecated in the future, please use `max_instances` instead."
            "Setting `max_instances` to `max_batches * batch_size` for now."
        )
        max_instances = max_batches * batch_size

    prompts = prompts[:max_instances]
    num_prompts = len(prompts)
    prompt_batches = [
        prompts[batch_id * batch_size : (batch_id + 1) * batch_size]
        for batch_id in range(int(math.ceil(num_prompts / batch_size)))
    ]

    completions = []
    for batch_id, prompt_batch in tqdm.tqdm(
        enumerate(prompt_batches),
        desc="prompt_batches",
        total=len(prompt_batches),
    ):
        batch_decoding_args = copy.deepcopy(
            decoding_args
        )  # cloning the decoding_args

        while True:
            try:
                shared_kwargs = dict(
                    model=model_name,
                    **batch_decoding_args.__dict__,
                    **decoding_kwargs,
                )
                completion_batch = openai.Completion.create(
                    prompt=prompt_batch, **shared_kwargs
                )
                choices = completion_batch.choices

                for choice in choices:
                    choice["total_tokens"] = (
                        completion_batch.usage.total_tokens
                    )
                completions.extend(choices)
                break
            except openai.error.OpenAIError as e:
                logging.warning(f"OpenAIError: {e}.")
                if "Please reduce your prompt" in str(e):
                    batch_decoding_args.max_tokens = int(
                        batch_decoding_args.max_tokens * 0.8
                    )
                    logging.warning(
                        f"Reducing target length to {batch_decoding_args.max_tokens}, Retrying..."
                    )
                else:
                    logging.warning("Hit request rate limit; retrying...")
                    time.sleep(sleep_time)  # Annoying rate limit on requests.

    if return_text:
        completions = [completion.text for completion in completions]
    if decoding_args.n > 1:
        # make completions a nested list, where each entry is a consecutive decoding_args.n of original entries.
        completions = [
            completions[i : i + decoding_args.n]
            for i in range(0, len(completions), decoding_args.n)
        ]
    if is_single_prompt:
        # Return non-tuple if only 1 input and 1 generation.
        (completions,) = completions
    return completions