Skip to content

ezpz.log.styleΒΆ

src/ezpz/log/style.py

make_layout(ratio=4, visible=True) ΒΆ

Define the layout.

Source code in src/ezpz/log/style.py
def make_layout(ratio: int = 4, visible: bool = True) -> Layout:
    """Define the layout."""
    layout = Layout(name="root", visible=visible)
    layout.split_row(
        Layout(name="main", ratio=ratio, visible=visible),
        Layout(name="footer", visible=visible),
    )
    return layout

print_config(config, resolve=True, print_order=None) ΒΆ

Prints content of DictConfig using Rich library and its tree structure.

Parameters:

Name Type Description Default
config DictConfig

Configuration composed by Hydra.

required
print_order Sequence[str]

Determines in what order config components are printed.

None
resolve bool

Whether to resolve reference fields of DictConfig.

True
Source code in src/ezpz/log/style.py
def print_config(
    config: DictConfig | dict | Any,
    resolve: bool = True,
    print_order: Sequence[str] | None = None,
) -> None:
    """Prints content of DictConfig using Rich library and its tree structure.

    Args:
        config (DictConfig): Configuration composed by Hydra.
        print_order (Sequence[str], optional): Determines in what order config
            components are printed.
        resolve (bool, optional): Whether to resolve reference fields of
            DictConfig.
    """
    import pandas as pd

    tree = rich.tree.Tree("CONFIG")  # , style=style, guide_style=style)
    quee = []
    if print_order:
        quee.extend([f for f in print_order if f not in quee])
    for f in config:
        if f not in quee:
            quee.append(f)
    dconfig = {}
    for f in quee:
        branch = tree.add(f)  # , style=style, guide_style=style)
        config_group = config[f]
        if isinstance(config_group, DictConfig):
            branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
            cfg = OmegaConf.to_container(config_group, resolve=resolve)
        else:
            branch_content = str(config_group)
            cfg = str(config_group)
        dconfig[f] = cfg
        branch.add(rich.syntax.Syntax(branch_content, "yaml"))
    outfile = Path(os.getcwd()).joinpath("config_tree.log")
    from rich.console import Console

    with outfile.open("wt") as f:
        console = Console(file=f)
        console.print(tree)
    with open("config.json", "w") as f:
        f.write(json.dumps(dconfig))
    cfgfile = Path("config.yaml")
    OmegaConf.save(config, cfgfile, resolve=True)
    cfgdict = OmegaConf.to_object(config)
    logdir = Path(os.getcwd()).resolve().as_posix()
    if not config.get("debug_mode", False):
        dbfpath = Path(os.getcwd()).joinpath("logdirs.csv")
    else:
        dbfpath = Path(os.getcwd()).joinpath("logdirs-debug.csv")
    if dbfpath.is_file():
        mode = "a"
        header = False
    else:
        mode = "w"
        header = True
    df = pd.DataFrame({logdir: cfgdict})
    df.T.to_csv(dbfpath.resolve().as_posix(), mode=mode, header=header)
    os.environ["LOGDIR"] = logdir

printarr(*arrs, float_width=6) ΒΆ

Print a pretty table giving name, shape, dtype, type, and content information for input tensors or scalars.

Call like: printarr(my_arr, some_other_arr, maybe_a_scalar). Accepts a variable number of arguments.

Inputs can be
  • Numpy tensor arrays
  • Pytorch tensor arrays
  • Jax tensor arrays
  • Python ints / floats
  • None

It may also work with other array-like types, but they have not been tested

Use the float_width option specify the precision to which floating point types are printed.

Author: Nicholas Sharp (nmwsharp.com) Canonical source: https://gist.github.com/nmwsharp/54d04af87872a4988809f128e1a1d233 License: This snippet may be used under an MIT license, and it is also released into the public domain. Please retain this docstring as a reference.

Source code in src/ezpz/log/style.py
def printarr(*arrs, float_width=6):
    """
    Print a pretty table giving name, shape, dtype, type, and content
    information for input tensors or scalars.

    Call like: printarr(my_arr, some_other_arr, maybe_a_scalar). Accepts a
    variable number of arguments.

    Inputs can be:
        - Numpy tensor arrays
        - Pytorch tensor arrays
        - Jax tensor arrays
        - Python ints / floats
        - None

    It may also work with other array-like types, but they have not been tested

    Use the `float_width` option specify the precision to which floating point
    types are printed.

    Author: Nicholas Sharp (nmwsharp.com)
    Canonical source:
        https://gist.github.com/nmwsharp/54d04af87872a4988809f128e1a1d233
    License: This snippet may be used under an MIT license, and it is also
    released into the public domain. Please retain this docstring as a
    reference.
    """
    import inspect

    frame_ = inspect.currentframe()
    assert frame_ is not None
    frame = frame_.f_back
    # if frame_ is not None:
    #     frame = frame_.f_back
    # else:
    #     frame = inspect.getouterframes()
    default_name = "[temporary]"

    # helpers to gather data about each array

    def name_from_outer_scope(a):
        if a is None:
            return "[None]"
        name = default_name
        if frame_ is not None:
            for k, v in frame_.f_locals.items():
                if v is a:
                    name = k
                    break
        return name

    def dtype_str(a):
        if a is None:
            return "None"
        if isinstance(a, int):
            return "int"
        if isinstance(a, float):
            return "float"
        return str(a.dtype)

    def shape_str(a):
        if a is None:
            return "N/A"
        if isinstance(a, int):
            return "scalar"
        if isinstance(a, float):
            return "scalar"
        return str(list(a.shape))

    def type_str(a):
        # TODO this is is weird... what's the better way?
        return str(type(a))[8:-2]

    def device_str(a):
        if hasattr(a, "device"):
            device_str = str(a.device)
            if len(device_str) < 10:
                # heuristic: jax returns some goofy long string we don't want,
                # ignore it
                return device_str
        return ""

    def format_float(x):
        return f"{x:{float_width}g}"

    def minmaxmean_str(a):
        if a is None:
            return ("N/A", "N/A", "N/A")
        if isinstance(a, int) or isinstance(a, float):
            return (format_float(a), format_float(a), format_float(a))

        # compute min/max/mean. if anything goes wrong, just print 'N/A'
        min_str = "N/A"
        try:
            min_str = format_float(a.min())
        except Exception:
            pass
        max_str = "N/A"
        try:
            max_str = format_float(a.max())
        except Exception:
            pass
        mean_str = "N/A"
        try:
            mean_str = format_float(a.mean())
        except Exception:
            pass

        return (min_str, max_str, mean_str)

    try:
        props = [
            "name",
            "dtype",
            "shape",
            "type",
            "device",
            "min",
            "max",
            "mean",
        ]

        # precompute all of the properties for each input
        str_props = []
        for a in arrs:
            minmaxmean = minmaxmean_str(a)
            str_props.append(
                {
                    "name": name_from_outer_scope(a),
                    "dtype": dtype_str(a),
                    "shape": shape_str(a),
                    "type": type_str(a),
                    "device": device_str(a),
                    "min": minmaxmean[0],
                    "max": minmaxmean[1],
                    "mean": minmaxmean[2],
                }
            )

        # for each property, compute its length
        maxlen = {}
        for p in props:
            maxlen[p] = 0
        for sp in str_props:
            for p in props:
                maxlen[p] = max(maxlen[p], len(sp[p]))

        # if any property got all empty strings,
        # don't bother printing it, remove if from the list
        props = [p for p in props if maxlen[p] > 0]

        # print a header
        header_str = ""
        for p in props:
            prefix = "" if p == "name" else " | "
            fmt_key = ">" if p == "name" else "<"
            header_str += f"{prefix}{p:{fmt_key}{maxlen[p]}}"
        print(header_str)
        print("-" * len(header_str))
        # now print the acual arrays
        for strp in str_props:
            for p in props:
                prefix = "" if p == "name" else " | "
                fmt_key = ">" if p == "name" else "<"
                print(f"{prefix}{strp[p]:{fmt_key}{maxlen[p]}}", end="")
            print("")

    finally:
        del frame