Skip to content

ezpz.configs

ezpz/configs.py

HfDataTrainingArguments dataclass

Arguments pertaining to what data we are going to input our model for training and eval.

Source code in src/ezpz/configs.py
@dataclass
class HfDataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    data_path: Optional[str] = field(
        default=None,
        metadata={"help": "Path to the training data."},
    )
    dataset_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "The name of the dataset to use (via the datasets library)."
        },
    )
    dataset_config_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "The configuration name of the dataset to use (via the datasets library)."
        },
    )
    train_split_str: Optional[str] = field(
        default=None,
        metadata={
            "help": "The split string to use for the train split (via the datasets library)."
        },
    )
    train_split_name: Optional[str] = field(
        default="train",
        metadata={
            "help": "The name of the train split to use (via the datasets library)."
        },
    )
    validation_split_name: Optional[str] = field(
        default="validation",
        metadata={
            "help": "The name of the validation split to use (via the datasets library)."
        },
    )
    validation_split_str: Optional[str] = field(
        default=None,
        metadata={
            "help": "The split string to use for the validation split (via the datasets library)."
        },
    )
    test_split_name: Optional[str] = field(
        default="test",
        metadata={
            "help": "The name of the test split to use (via the datasets library)."
        },
    )
    test_split_str: Optional[str] = field(
        default=None,
        metadata={
            "help": "The split string to use for the test split (via the datasets library)."
        },
    )
    train_file: Optional[str] = field(
        default=None,
        metadata={"help": "The input training data file (a text file)."},
    )
    validation_file: Optional[str] = field(
        default=None,
        metadata={
            "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
        },
    )
    streaming: bool = field(
        default=False, metadata={"help": "Enable streaming mode"}
    )
    block_size: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "Optional input sequence length after tokenization. "
                "The training dataset will be truncated in block of this size for training. "
                "Default to the model max input length for single sentence inputs (take into account special tokens)."
            )
        },
    )
    overwrite_cache: bool = field(
        default=False,
        metadata={"help": "Overwrite the cached training and evaluation sets"},
    )
    validation_split_percentage: Optional[int] = field(
        default=5,
        metadata={
            "help": "The percentage of the train set used as validation set in case there's no validation split"
        },
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={
            "help": "The number of processes to use for the preprocessing."
        },
    )
    keep_linebreaks: bool = field(
        default=True,
        metadata={
            "help": "Whether to keep line breaks when using TXT files or not."
        },
    )

    def __post_init__(self):
        if self.streaming:
            require_version(
                "datasets>=2.0.0",
                "The streaming feature requires `datasets>=2.0.0`",
            )

        if (
            self.dataset_name is None
            and self.data_path is None
            and self.train_file is None
            and self.validation_file is None
        ):
            raise ValueError(
                "You must specify at least one of the following: "
                "a dataset name, a data path, a training file, or a validation file."
            )

        if self.train_file is not None:
            extension = self.train_file.split(".")[-1]
            assert extension in [
                "csv",
                "json",
                "txt",
            ], "`train_file` should be a csv, a json or a txt file."
        if self.validation_file is not None:
            extension = self.validation_file.split(".")[-1]
            assert extension in [
                "csv",
                "json",
                "txt",
            ], "`validation_file` should be a csv, a json or a txt file."

HfModelArguments dataclass

Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.

Source code in src/ezpz/configs.py
@dataclass
class HfModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    wandb_project_name: Optional[str] = field(  # type:ignore
        default=None,
        metadata={
            "help": (
                "The name of the wandb project to use. If not specified, will use the model name."
            )
        },
    )

    model_name_or_path: Optional[str] = field(  # type:ignore
        default=None,
        metadata={
            "help": (
                "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
            )
        },
    )
    model_type: Optional[str | None] = field(
        default=None,
        metadata={
            "help": "If training from scratch, pass a model type from the list: "
            + ", ".join(CAUSAL_LM_MODEL_TYPES)
        },
    )
    config_overrides: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override some existing default config settings when a model is trained from scratch. Example: "
                "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
            )
        },
    )
    config_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained config name or path if not the same as model_name"
        },
    )
    tokenizer_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained tokenizer name or path if not the same as model_name"
        },
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={
            "help": "Where do you want to store the pretrained models downloaded from huggingface.co"
        },
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={
            "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
        },
    )
    model_revision: str = field(
        default="main",
        metadata={
            "help": "The specific model version to use (can be a branch name, tag name or commit id)."
        },
    )
    token: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
            )
        },
    )
    trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether to trust the execution of code from datasets/models defined on the Hub."
                " This option should only be set to `True` for repositories you trust and in which you have read the"
                " code, as it will execute code present on the Hub on your local machine."
            )
        },
    )
    torch_dtype: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
                "dtype will be automatically derived from the model's weights."
            ),
            "choices": ["auto", "bfloat16", "float16", "float32"],
        },
    )
    low_cpu_mem_usage: bool = field(
        default=False,
        metadata={
            "help": (
                "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
                "set True will benefit LLM loading time and RAM consumption."
            )
        },
    )

    def __post_init__(self):
        if self.config_overrides is not None and (
            self.config_name is not None or self.model_name_or_path is not None
        ):
            raise ValueError(
                "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
            )

cmd_exists(cmd)

Check whether command exists.

cmd_exists("ls") True cmd_exists("hostname") True

Source code in src/ezpz/configs.py
def cmd_exists(cmd: str) -> bool:
    """Check whether command exists.

    >>> cmd_exists("ls")
    True
    >>> cmd_exists("hostname")
    True
    """
    return shutil.which(cmd) is not None

print_config_tree(cfg, resolve=True, save_to_file=True, verbose=True, style='tree', print_order=None, highlight=True, outfile=None)

Prints the contents of a DictConfig as a tree structure using the Rich library.

  • cfg: A DictConfig composed by Hydra.
  • print_order: Determines in what order config components are printed.
  • resolve: Whether to resolve reference fields of DictConfig.
  • save_to_file: Whether to export config to the hydra output folder.
Source code in src/ezpz/configs.py
def print_config_tree(
    cfg: DictConfig,
    resolve: bool = True,
    save_to_file: bool = True,
    verbose: bool = True,
    style: str = "tree",
    print_order: Optional[Sequence[str]] = None,
    highlight: bool = True,
    outfile: Optional[Union[str, os.PathLike, Path]] = None,
) -> Tree:
    """Prints the contents of a DictConfig as a tree structure using the Rich
    library.

    - cfg: A DictConfig composed by Hydra.
    - print_order: Determines in what order config components are printed.
    - resolve: Whether to resolve reference fields of DictConfig.
    - save_to_file: Whether to export config to the hydra output folder.
    """
    from rich.console import Console
    from ezpz.log.config import STYLES
    from rich.theme import Theme

    name = cfg.get("_target_", "cfg")
    console = Console(record=True, theme=Theme(STYLES))
    tree = Tree(label=name, highlight=highlight)
    queue = []
    # add fields from `print_order` to queue
    if print_order is not None:
        for field in print_order:
            (
                queue.append(field)
                if field in cfg
                else log.warning(
                    f"Field '{field}' not found in config. "
                    f"Skipping '{field}' config printing..."
                )
            )
    # add all the other fields to queue (not specified in `print_order`)
    for field in cfg:
        if field not in queue:
            queue.append(field)
    # generate config tree from queue
    for field in queue:
        branch = tree.add(field, highlight=highlight)  # , guide_style=style)
        config_group = cfg[field]
        if isinstance(config_group, DictConfig):
            branch_content = str(
                OmegaConf.to_yaml(config_group, resolve=resolve)
            )
            branch.add(Text(branch_content, style="red"))
        else:
            branch_content = str(config_group)
            branch.add(Text(branch_content, style="blue"))
    if verbose or save_to_file:
        console.print(tree)
        if save_to_file:
            outfpath = (
                Path(os.getcwd()).joinpath("config_tree.log")
                if outfile is None
                else Path(outfile)
            )
            console.save_text(outfpath.as_posix())
    return tree

print_json(json_str=None, console=None, *, data=None, indent=2, highlight=True, skip_keys=False, ensure_ascii=False, check_circular=True, allow_nan=True, default=None, sort_keys=False)

Pretty prints JSON. Output will be valid JSON.

Parameters:

Name Type Description Default
json_str Optional[str]

A string containing JSON.

None
data Any

If json is not supplied, then encode this data.

None
indent Union[None, int, str]

Number of spaces to indent. Defaults to 2.

2
highlight bool

Enable highlighting of output: Defaults to True.

True
skip_keys bool

Skip keys not of a basic type. Defaults to False.

False
ensure_ascii bool

Escape all non-ascii characters. Defaults to False.

False
check_circular bool

Check for circular references. Defaults to True.

True
allow_nan bool

Allow NaN and Infinity values. Defaults to True.

True
default Callable

A callable that converts values that can not be encoded in to something that can be JSON encoded. Defaults to None.

None
sort_keys bool

Sort dictionary keys. Defaults to False.

False
Source code in src/ezpz/configs.py
def print_json(
    json_str: Optional[str] = None,
    console: Optional[Console] = None,
    *,
    data: Any = None,
    indent: Union[None, int, str] = 2,
    highlight: bool = True,
    skip_keys: bool = False,
    ensure_ascii: bool = False,
    check_circular: bool = True,
    allow_nan: bool = True,
    default: Optional[Callable[[Any], Any]] = None,
    sort_keys: bool = False,
) -> None:
    """Pretty prints JSON. Output will be valid JSON.

    Args:
        json_str (Optional[str]): A string containing JSON.
        data (Any): If json is not supplied, then encode this data.
        indent (Union[None, int, str], optional): Number of spaces to indent.
            Defaults to 2.
        highlight (bool, optional): Enable highlighting of output:
            Defaults to True.
        skip_keys (bool, optional): Skip keys not of a basic type.
            Defaults to False.
        ensure_ascii (bool, optional): Escape all non-ascii characters.
            Defaults to False.
        check_circular (bool, optional): Check for circular references.
            Defaults to True.
        allow_nan (bool, optional): Allow NaN and Infinity values.
            Defaults to True.
        default (Callable, optional): A callable that converts values
            that can not be encoded in to something that can be JSON
            encoded.
            Defaults to None.
        sort_keys (bool, optional): Sort dictionary keys. Defaults to False.
    """
    if json_str is None and data is None:
        raise ValueError(
            "Either `json_str` or `data` must be provided. "
            "Did you mean print_json(data={data!r}) ?"
        )
    if json_str is not None and data is not None:
        raise ValueError(
            " ".join(
                [
                    "Only one of `json_str` or `data` should be provided.",
                    "Did you mean print_json(json_str={json_str!r}) ?",
                    "Or print_json(data={data!r}) ?",
                    "Received both:",
                    f"json_str={json_str!r}",
                    f"data={data!r}",
                ]
            )
        )
    from ezpz.log.console import get_console
    from rich.json import JSON

    console = get_console() if console is None else console
    if json_str is None:
        json_renderable = JSON.from_data(
            data,
            indent=indent,
            highlight=highlight,
            skip_keys=skip_keys,
            ensure_ascii=ensure_ascii,
            check_circular=check_circular,
            allow_nan=allow_nan,
            default=default,
            sort_keys=sort_keys,
        )
    else:
        json_renderable = JSON(
            json_str,
            indent=indent,
            highlight=highlight,
            skip_keys=skip_keys,
            ensure_ascii=ensure_ascii,
            check_circular=check_circular,
            allow_nan=allow_nan,
            default=default,
            sort_keys=sort_keys,
        )
    assert console is not None and isinstance(console, Console)
    log.info(Text(str(json_renderable)).render(console=console))