Skip to content

dist module

dist.py

Contains methods for initializing distributed communication.

barrier(device=None, group=tdist.GroupMember.WORLD, async_op=False, device_ids=None)

Barrier for all processes in the group This collective blocks processes until the whole group enters this function, if async_op is False, or if async work handle is called on wait().

Args: group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. async_op (bool, optional): Whether this op should be an async op device_ids ([int], optional): List of device/GPU ids.

Returns: Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

  • [group: ProcessGroup | None = GroupMember.WORLD] group (ProcessGroup, optional): The process group to work on. If None,
  • [async_op: bool = False]
  • [device_ids: Unknown | None = None]
Source code in src/ezpz/dist.py
def barrier(
    device: Optional[torch.device | int | str] = None,
    group: tdist.ProcessGroup | None = tdist.GroupMember.WORLD,
    async_op: bool = False,
    device_ids: str | Iterable | None = None,
) -> tdist.Work | None:
    """Barrier for all processes in the group
    This collective blocks processes until the whole group enters this function,
    if async_op is False, or if async work handle is called on wait().

    Args:
    group (ProcessGroup, optional): The process group to work on. If None,
    the default process group will be used.
    async_op (bool, optional): Whether this op should be an async op
    device_ids ([int], optional): List of device/GPU ids.

    Returns:
    Async work handle, if async_op is set to True.
    None, if not async_op or if not part of the group

    - `[group: ProcessGroup | None = GroupMember.WORLD]` group (ProcessGroup, optional): The process group to work on. If None,
    - `[async_op: bool = False]`
    - `[device_ids: Unknown | None = None]`
    """
    try:
        tdist.barrier(group=group, async_op=async_op, device_ids=device_ids)
    except Exception:
        logger.warning(
            "Unable to use `torch.distributed.barrier` "
            "for this process group. "
            "Falling back to `mpi4py` barrier."
        )
        MPI.COMM_WORLD.barrier()

check(framework='pytorch', backend='deepspeed', port='5432')

Check if the framework is installed and working

Source code in src/ezpz/dist.py
def check(
    framework: str = "pytorch",
    backend: str = "deepspeed",
    port: int | str = "5432",
):
    """Check if the framework is installed and working"""
    from ezpz.configs import FRAMEWORKS

    if framework in FRAMEWORKS["pytorch"]:
        _ = setup_torch(
            backend=backend,
            port=str(port),
        )
    elif framework in FRAMEWORKS["tensorflow"]:
        _ = setup_tensorflow()
    else:
        raise ValueError(f"Unable to parse framework: {framework}")

get_cpus_per_node()

Get the number of CPUs per node

Source code in src/ezpz/dist.py
def get_cpus_per_node() -> int:
    """Get the number of CPUs per node"""
    from sh import getconf as sh_getconf  # type:ignore noqa

    return int(sh_getconf("_NPROCESSORS_ONLN").rstrip("\n"))

get_dist_info(framework=None, verbose=None, hostfile=None)

Get distributed info.

Parameters:

Name Type Description Default
framework str

Framework to use. Defaults to None.

None
verbose bool

Whether to print the info. Defaults to None.

None
hostfile PathLike

Path to the hostfile. Defaults to None.

None

Returns:

Name Type Description
dict dict[str, str | int | list]

Dictionary containing the distributed info.

Source code in src/ezpz/dist.py
def get_dist_info(
    framework: Optional[str] = None,
    verbose: Optional[bool] = None,
    hostfile: Optional[PathLike] = None,
) -> dict[str, str | int | list]:
    """Get distributed info.

    Args:
        framework (str, optional): Framework to use. Defaults to None.
        verbose (bool, optional): Whether to print the info. Defaults to None.
        hostfile (PathLike, optional): Path to the hostfile. Defaults to None.

    Returns:
        dict: Dictionary containing the distributed info.
    """
    dist_info = _get_dist_info(
        hostfile=hostfile,
        framework=framework,
    )
    if verbose:
        import json

        logger.info(
            f"DistInfo={json.dumps(dist_info, indent=4, sort_keys=True)}"
        )
    return dist_info

get_gpus_per_node()

Get the number of GPUs per node

Source code in src/ezpz/dist.py
def get_gpus_per_node() -> int:
    """Get the number of GPUs per node"""
    # return torch.cuda.device_count() if torch.cuda.is_available() else (
    #     (
    #         ipex.xpu.device_count() if ipex is not None else (
    #             get_cpus_per_node()
    #         )
    #     )
    # )
    # if _assert:
    #     raise RuntimeError(
    #         'No {X, G}pus found; but _assert specified. Returning !!'
    #     )
    # logger.warning('No {x,g}-pus found, returning' + f'{cpus_per_node}')
    ngpu_per_host = os.environ.get("NGPU_PER_HOST", None)
    if ngpu_per_host is not None:
        return int(ngpu_per_host)
    if torch.cuda.is_available():
        return torch.cuda.device_count()
    if torch.xpu.is_available():
        return torch.xpu.device_count()
    if ipex is not None:
        return ipex.xpu.device_count()  # type:ignore
    return get_cpus_per_node()

get_hostfile_with_fallback(hostfile=None)

Get the hostfile from the environment or create one if it doesn't exist

Source code in src/ezpz/dist.py
def get_hostfile_with_fallback(hostfile: Optional[PathLike] = None) -> Path:
    """Get the hostfile from the environment or create one if it doesn't exist"""
    from ezpz.configs import get_scheduler

    scheduler = get_scheduler()
    if scheduler.lower() == "unknown":
        logger.debug("Unknown scheduler")
        hostfile = Path(os.getcwd()).joinpath("hostfile")
    if scheduler.lower() == "slurm":
        hostfile = make_hostfile_from_slurm_env()
        assert Path(hostfile).is_file()
    if hostfile is None:
        hfp = os.environ.get(
            "PBS_NODEFILE",
            os.environ.get(
                "HOSTFILE",
                None,  # fallback_hostfile.as_posix()
            ),
        )
        if (
            hfp is None or not Path(hfp).is_file()
            # and scheduler == 'PBS'
        ):
            if scheduler == "PBS":
                hfp = Path(get_pbs_nodefile_from_qstat())
            else:
                # create makeshift hostfile containing 'localhost'
                hfp = Path(os.getcwd()).joinpath("hostfile")
                hfp.touch(exist_ok=True)
                write_localhost_to_hostfile(hfp)
    else:
        hfp = Path(hostfile)
    assert hfp is not None and Path(hfp).is_file()
    assert Path(hfp).is_file()
    hostfile = Path(hfp).as_posix()
    # if hfp is not None:
    # hostfile, hosts = get_hosts_from_hostfile(hostfile)
    # hosts = [h.split('.')[0] for h in hosts]
    # if scheduler == 'PBS':
    #     os.environ['PBS_NODEFILE'] = hostfile  # hfp.as_posix()
    hfname = f"{scheduler.upper()}_NODEFILE"
    if hfname not in os.environ:
        os.environ |= {hfname: hostfile}
    # os.environ[f'{scheduler.upper()}_NODEFILE'] = hostfile
    return Path(hfp)

get_local_rank()

Return get_rank() % get_gpus_per_node()

Source code in src/ezpz/dist.py
def get_local_rank() -> int:
    """Return `get_rank() % get_gpus_per_node()`"""
    return int(get_rank() % get_gpus_per_node())

get_machine(hostname=None)

Get the machine name from the hostname.

Parameters:

Name Type Description Default
hostname str

The hostname to check. Defaults to None.

None

Returns:

Name Type Description
str str

The machine name.

Example

get_machine("frontier") "Frontier"

Source code in src/ezpz/dist.py
def get_machine(hostname: Optional[str] = None) -> str:
    """Get the machine name from the hostname.

    Args:
        hostname (str, optional): The hostname to check. Defaults to None.

    Returns:
        str: The machine name.

    Example:
        >>> get_machine("frontier")
        "Frontier"
    """

    if hostname is None:
        try:
            hostname = socket.gethostbyaddr(socket.gethostname())[0]
        except Exception:
            try:
                hostname = socket.gethostname()
            except Exception:
                logger.warning("Unable to determine hostname!")
                hostname = "unknown"
    if hostname.startswith("frontier"):
        return "Frontier"
    if hostname.startswith("sophia"):
        return "Sophia"
    if hostname.startswith("theta"):
        return "ThetaGPU"
    if hostname.startswith("x1"):
        return "SunSpot"
    if hostname.startswith("x3"):
        if (pbs_host := os.environ.get("PBS_O_HOST", None)) is not None:
            if str(pbs_host).startswith("sirius"):
                return "Sirius"
            return "Polaris"
        return "Polaris"
    if hostname.startswith("x4"):
        return "Aurora"
    if hostname.startswith("login"):
        return "Perlmutter"
    if hostname.startswith("nid"):
        return "Perlmutter"
    return f"{hostname}"

get_node_index()

Get the index of the current node in the hostfile

Source code in src/ezpz/dist.py
def get_node_index() -> int:
    """Get the index of the current node in the hostfile"""
    return get_rank() % get_num_nodes()

get_num_nodes(hostfile=None)

Get the number of nodes from the hostfile

Source code in src/ezpz/dist.py
def get_num_nodes(hostfile: Optional[PathLike] = None) -> int:
    """Get the number of nodes from the hostfile"""
    num_nodes = os.environ.get("SLURM_NNODES", None)
    if num_nodes is not None:
        return int(num_nodes)
    hfp = get_hostfile_with_fallback(hostfile)
    hosts = [h.split(".")[0] for h in get_nodes_from_hostfile(hfp)]
    return len(hosts)

get_pbs_env(hostfile=None, verbose=None)

Get the PBS environment variables

Source code in src/ezpz/dist.py
def get_pbs_env(
    hostfile: Optional[Union[str, Path]] = None,
    verbose: Optional[bool] = None,
) -> dict[str, str]:
    """Get the PBS environment variables"""
    from ezpz.configs import get_scheduler

    assert get_scheduler() == "PBS"
    pbsenv = {k: v for k, v in dict(os.environ).items() if "PBS" in k}
    if hostfile is None:
        hostfile = pbsenv.get("PBS_NODEFILE", get_pbs_nodefile_from_qstat())
    if (hfp := Path(hostfile)).is_file():
        pbsenv |= {
            f"{k.upper()}": f"{v}" for k, v in get_pbs_launch_info(hfp).items()
        }
        pbsenv |= {"LAUNCH_CMD": get_pbs_launch_cmd(hostfile=hostfile)}
    os.environ |= pbsenv
    if verbose and get_rank() == 0:
        # logger.debug(f'pbsenv={json.dumps(pbsenv, indent=4, sort_keys=True)}')
        log_dict_as_bulleted_list(pbsenv, name="pbsenv")
    return pbsenv

get_pbs_jobid_from_qstat()

Get the PBS job ID from qstat

Source code in src/ezpz/dist.py
def get_pbs_jobid_from_qstat() -> int:
    """Get the PBS job ID from qstat"""
    from ezpz.configs import get_scheduler

    assert get_scheduler() == "PBS"
    try:
        from sh import qstat as sh_qstat  # pyright:ignore
    except Exception as exc:
        raise exc
    qstat_out = sh_qstat("-u", os.environ.get("USER")).split("\n")[2:-1]
    return int(qstat_out[-1].split(".")[0])

get_pbs_launch_cmd(ngpus=None, nhosts=None, ngpu_per_host=None, hostfile=None)

Get the PBS launch command

Source code in src/ezpz/dist.py
def get_pbs_launch_cmd(
    ngpus: Optional[int] = None,
    nhosts: Optional[int] = None,
    ngpu_per_host: Optional[int] = None,
    hostfile: Optional[PathLike] = None,
) -> str:
    """Get the PBS launch command"""
    nhosts = get_num_nodes(hostfile=hostfile) if nhosts is None else nhosts
    ngpu_per_host = (
        get_gpus_per_node() if ngpu_per_host is None else ngpu_per_host
    )
    ngpus_available = get_world_size_total() if ngpus is None else ngpus
    ngpus_in_use = nhosts * ngpu_per_host
    hfp = Path(
        get_hostfile_with_fallback(hostfile) if hostfile is None else hostfile
    )
    if ngpus_available != (ngpus_in_use):
        logger.warning(
            "Mismatch in `ngpus_in_use` and `ngpus_available` "
            f"{ngpus_in_use=} vs. {ngpus_available=}"
        )
    # ncpus_per_host = get_cpus_per_node()
    return " ".join(
        [
            "mpiexec",
            "--verbose",
            "--envall",
            # f'-n {ngpus}',
            f"--np={ngpus_in_use}",
            f"--ppn={ngpu_per_host}",
            f"--hostfile={hfp.as_posix()}",
            "--cpu-bind=depth",
            "--depth=8",
        ]
    )

get_pbs_launch_info(hostfile=None)

Get the PBS launch info

Source code in src/ezpz/dist.py
def get_pbs_launch_info(
    hostfile: Optional[str | Path] = None,  # type:ignore[reportDeprecated]
) -> dict[str, str]:
    """Get the PBS launch info"""
    from ezpz.configs import get_scheduler

    assert get_scheduler() == "PBS"
    if hostfile is None:
        hostfile = os.environ.get("PBS_NODEFILE", get_pbs_nodefile_from_qstat())
    assert hostfile is not None
    hfp = Path(hostfile)
    # hostfile = os.environ.get("PBS_NODEFILE", None)
    # if hostfile is None:
    #     hostfile = (
    #             get_pbs_nodefile_from_qstat() if hostfile is None else
    #             Path(hostfile)
    #     )
    # assert hostfile is not None
    # hf = Path(hostfile)
    # assert hostfile is not None and hf.is_file()
    # hfp = Path(hostfile)
    hosts = get_nodes_from_hostfile(hfp)
    hosts = [h.split(".")[0] for h in hosts]
    nhosts = len(hosts)
    ngpu_per_host = get_gpus_per_node()
    # ngpus = nhosts * ngpu_per_host
    ngpus_available = get_world_size(total=True)
    ngpus = nhosts * ngpu_per_host
    world_size_total = get_world_size_total()
    # if ngpus != world_size_total:
    #     logger.warning('Disagreement in total world size!!')
    #     logger.warning(' '.join([
    #         f'{get_world_size(total=True)=}',
    #         f' vs. {get_world_size_total()=}'
    #     ]))
    #     logger.warning(' '.join([
    #         'Mismatch in: ',
    #         f'{ngpus=} vs. {ngpu_per_host=} * {nhosts=}'
    #     ]))
    launch_cmd = get_pbs_launch_cmd(hostfile=hostfile)
    return {
        "HOSTFILE": hfp.as_posix(),
        "HOSTS": (
            f"[{', '.join(hosts)}]"
            if nhosts < 1000
            else "[truncated (>1000 nodes)]"
        ),
        "NHOSTS": f"{nhosts}",
        "NGPU_PER_HOST": f"{ngpu_per_host}",
        "NGPUS": f"{ngpus}",
        "NGPUS_AVAILABLE": f"{ngpus_available}",
        "MACHINE": get_machine(),
        "DEVICE": get_torch_device_type(),
        "BACKEND": get_torch_backend(),
        "LAUNCH_CMD": launch_cmd,
        "world_size_total": f"{world_size_total}",
    }

get_pbs_nodefile_from_qstat()

Get the PBS nodefile from qstat

Source code in src/ezpz/dist.py
def get_pbs_nodefile_from_qstat() -> Path:
    """Get the PBS nodefile from qstat"""
    from ezpz.configs import get_scheduler

    assert get_scheduler() == "PBS"
    nodefile = os.environ.get("PBS_NODEFILE", None)
    if nodefile is not None and (nf := Path(nodefile)).is_file():
        return nf
    pbs_jobid = get_pbs_jobid_from_qstat()
    matches = [
        i
        for i in Path("/var/spool/pbs/aux/").rglob(f"*{pbs_jobid}*")
        if i.is_file()
    ]
    assert len(matches) == 1
    return matches[0]

get_rank()

Get current MPI rank

Source code in src/ezpz/dist.py
def get_rank() -> int:
    """Get current MPI rank"""
    return int(MPI.COMM_WORLD.Get_rank())

get_running_jobs_from_qstat()

Get the running jobs from qstat

Source code in src/ezpz/dist.py
def get_running_jobs_from_qstat() -> list[int]:
    """Get the running jobs from qstat"""
    try:
        from sh import qstat as shqstat  # type: ignore
    except Exception as e:
        raise e
    return [
        int(i.split(".")[0])
        for i in shqstat("-u", os.environ.get("USER")).split("\n")[2:-1]
        if " R " in i
    ]

get_torch_backend_on_xpu()

Deal with breaking change introduced in torch 2.6:

See: https://github.com/pytorch/pytorch/pull/141856

Example:

1
2
3
4
5
6
7
```python
>>> torch_version = float('.'join(torch.__version__.split('.')[:2]))
>>> if torch_version > 2.5:
>>>     backend = 'xccl'
>>> else:
>>>     backend = 'ccl'
```
Source code in src/ezpz/dist.py
def get_torch_backend_on_xpu() -> str:
    """Deal with breaking change introduced in torch 2.6:

    See: https://github.com/pytorch/pytorch/pull/141856

    Example:

        ```python
        >>> torch_version = float('.'join(torch.__version__.split('.')[:2]))
        >>> if torch_version > 2.5:
        >>>     backend = 'xccl'
        >>> else:
        >>>     backend = 'ccl'
        ```
    """
    torch_version = get_torch_version_as_float()
    assert torch.xpu.is_available()
    return "xccl" if torch_version > 2.5 else "ccl"

get_world_size_in_use()

Get number of currently in use MPI ranks

Source code in src/ezpz/dist.py
def get_world_size_in_use() -> int:
    """Get number of currently in use MPI ranks"""
    return int(MPI.COMM_WORLD.Get_size())

get_world_size_total()

Calculate total AVAILABLE *PUs as:

total = [num_hosts] * [num_*pu_per_host]

Source code in src/ezpz/dist.py
def get_world_size_total() -> int:
    """Calculate total AVAILABLE *PUs as:

    total = [num_hosts] * [num_*pu_per_host]
    """
    # nhosts = get_num_nodes()
    # ngpu_per_host = get_gpus_per_node()
    # return ngpu_per_host * nhosts
    return get_gpus_per_node() * get_num_nodes()

log_dict_as_bulleted_list(d, name=None)

Print dictionary as list

Source code in src/ezpz/dist.py
def log_dict_as_bulleted_list(d: dict, name: Optional[str] = None):
    """Print dictionary as list"""
    tag = name if name is not None else d.__qualname__
    logger.info(
        "\n".join(
            ["\n", f"[{tag}]:"]
            + [f"  β€’ {k}={v}" for k, v in d.items()]
            + ["\n"]
        )
    )

make_hostfile_from_slurm_env(outfile=None)

Make a hostfile from the SLURM_NODELIST environment variable

Source code in src/ezpz/dist.py
def make_hostfile_from_slurm_env(outfile: Optional[PathLike] = None) -> Path:
    """Make a hostfile from the SLURM_NODELIST environment variable"""
    nodes = os.environ.get("SLURM_NODELIST", None)
    # if nodes is not None:
    assert nodes is not None
    # machine = get_machine()
    prefix, idxs = nodes.split("[")
    idxs = idxs.rstrip("]")
    idxs = "-".join(idxs.split(",")).split("-")
    nodelist = [f"{prefix}{i}" for i in idxs]
    # idxs = (
    #     nodes.split
    # )
    # idxs = (
    #     nodes.lstrip('frontier').replace('[', '').replace(']', '').split('-')
    # )
    # nodelist = [f'frontier{i}' for i in idxs]
    if outfile is None:
        outfile = Path(os.getcwd()).joinpath("hostfile")
    else:
        outfile = Path(outfile)
    with outfile.open("w") as f:
        for node in nodelist:
            f.write(f"{node}\n")
    return outfile

print_dist_setup(framework=None, hostfile=None)

Print distributed setup.

Parameters:

Name Type Description Default
framework str

Framework to use. Defaults to None.

None
hostfile PathLike

Path to the hostfile. Defaults to None.

None

Returns:

Name Type Description
str str

String containing the distributed setup.

Source code in src/ezpz/dist.py
def print_dist_setup(
    framework: Optional[str] = None,
    hostfile: Optional[PathLike] = None,
) -> str:
    """Print distributed setup.

    Args:
        framework (str, optional): Framework to use. Defaults to None.
        hostfile (PathLike, optional): Path to the hostfile. Defaults to None.

    Returns:
        str: String containing the distributed setup.
    """
    rank = get_rank()
    wst = get_world_size(total=True)
    wsa = get_world_size(in_use=True)
    # world_size = get_world_size()
    local_rank = get_local_rank()
    gpus_per_node = get_gpus_per_node()
    hostfile = get_hostfile_with_fallback(hostfile)
    # NOTE:
    # We ensure that num_nodes is AT LEAST 1
    # since if gpus_per_node > wsa, wsa // gpus_per_node = 0
    # if gpus_per_node > wsa, wsa // gpus_per_node = 0
    num_nodes = max((wsa // gpus_per_node, 1))
    num_nodes_from_hostfile = get_num_nodes()
    # assert num_nodes_from_hostfile == num_nodes
    # if num_nodes != num_nodes_from_hostfile:
    #     logger.critical(f'{num_nodes=} vs. {num_nodes_from_hostfile=} ??')
    node = get_node_index()
    device = None
    # if framework.lower() in {'pt', 'torch', 'pytorch'}:
    device = get_torch_device_type()
    rank_len = len(str(rank))
    ws_len = len(str(wsa))
    lr_len = len(str(local_rank))
    gpn_len = len(str(gpus_per_node))
    node_len = len(str(node))
    num_nodes_len = len(str(num_nodes))
    dist_list = [
        f"[{device=}]",
        f"[{rank=:>{rank_len}}/{(wsa - 1):<{ws_len}}]",
        f"[{local_rank=:>{lr_len}}/{gpus_per_node - 1:<{gpn_len}}]",
        f"[{node=:>{node_len}}/{(num_nodes - 1):<{num_nodes_len}}]",
    ]
    if framework is not None:
        dist_list.append(f"[{framework=}]")
    dist_str = "".join(dist_list)
    logger.info(f"{dist_str}")
    if rank == 0:
        if wsa > 1000:
            logger.warning(
                f"WORLD_SIZE={wsa} > 1000, only printing on RANK={rank}"
            )
        logger.warning(f'Using [{wsa} / {wst}] available "{device}" devices !!')
        if num_nodes_from_hostfile != num_nodes:
            logger.critical(
                f"num_nodes_from_hostfile = [{num_nodes_from_hostfile=}]"
                f"vs."
                f"[{wsa=} // {gpus_per_node=}] = {num_nodes}"
                r"Β―\_(ツ)_/Β― ??"
            )
    return dist_str

query_environment()

Query environment variables for info about distributed setup

Source code in src/ezpz/dist.py
def query_environment() -> dict[str, int]:
    """Query environment variables for info about distributed setup"""
    ws = os.environ.get("WORLD_SIZE", None)
    r = os.environ.get("RANK", None)
    lr = os.environ.get("LOCAL_RANK", None)
    if ws is not None and r is not None and lr is not None:
        return {
            "world_size": int(ws),
            "rank": int(r),
            "local_rank": int(lr),
            # 'machine': machine,
        }
    return {
        "world_size": int(get_world_size()),
        "rank": int(get_rank()),
        "local_rank": int(get_local_rank()),
    }

seed_everything(seed)

Set random seed for reproducibility.

Parameters:

Name Type Description Default
seed int

Random seed to set.

required
Source code in src/ezpz/dist.py
def seed_everything(seed: int) -> None:
    """Set random seed for reproducibility.

    Args:
        seed (int): Random seed to set.
    """
    import torch
    import numpy as np
    import random

    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    _ = torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    if torch.xpu.is_available():
        torch.xpu.manual_seed(seed)

setup_tensorflow(precision=None, ngpus=None)

Initialize TensorFlow + Horovod for Distributed Training

Source code in src/ezpz/dist.py
def setup_tensorflow(
    precision: Optional[str] = None,
    ngpus: Optional[int] = None,
) -> int:
    """Initialize TensorFlow + Horovod for Distributed Training"""
    try:
        import tensorflow as tf  # type:ignore noqa

        os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        import horovod.tensorflow as hvd  # type:ignore noqa
    except Exception:
        logger.warning(
            "Unable to import `tensorflow` or `horovod.tensorflow`. "
            "Install with `pip install tensorflow horovod`"
        )
        raise

    _ = None if hvd.is_initialized() else hvd.init()
    # hvd.init() if not hvd.is_initialized() else None
    if precision in [
        "fp16",
        "float16",
        "half",
        "16",
        "mixed_float16",
        # 'mixed_bfloat16'
    ]:
        tf.keras.mixed_precision.set_global_policy(  # pyright:ignore
            "mixed_float16"
        )
    TF_FLOAT = tf.keras.backend.floatx()  # pyright:ignore
    eager_mode = os.environ.get("TF_EAGER", None)
    if eager_mode is not None:
        logger.info("Detected `TF_EAGER` from env. Running eagerly.")
        tf.config.run_functions_eagerly(True)

    gpus = tf.config.experimental.list_physical_devices("GPU")
    cpus = tf.config.experimental.list_physical_devices("CPU")
    if gpus:
        try:
            # Currently memory growth needs to be the same across GPUs
            if ngpus is not None:
                gpus = gpus[-ngpus:]

            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            tf.config.experimental.set_visible_devices(
                gpus[hvd.local_rank()],
                "GPU",
            )
            _ = (  # pyright:ignore
                tf.config.experimental.list_logical_devices("GPU")
            )
        except RuntimeError as e:
            logger.info(e)
    elif cpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            logical_cpus = tf.config.experimental.list_logical_devices("CPU")
            logger.info(
                f"{len(cpus)}, Physical CPUs and "
                f"{len(logical_cpus)} Logical CPUs"
            )
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            logger.info(e)
    RANK = hvd.rank()
    WORLD_SIZE = hvd.size()
    LOCAL_RANK = hvd.local_rank()
    # LOCAL_SIZE = hvd.local_size()
    os.environ["RANK"] = str(RANK)
    os.environ["WORLD_SIZE"] = str(WORLD_SIZE)
    os.environ["LOCAL_RANK"] = str(LOCAL_RANK)
    # logger.info(f'RANK: {RANK} / {WORLD_SIZE-1}')
    if RANK == 0:
        logger.info(f"Using {TF_FLOAT} precision")
    return RANK

setup_torch(backend=None, port=None, seed=None, timeout=None, verbose=False, tensor_parallel_size=1, pipeline_parallel_size=1, context_parallel_size=1, tensor_parallel_backend=None, pipeline_parallel_backend=None, context_parallel_backend=None, data_parallel_backend=None)

Setup torch.

Parameters:

Name Type Description Default
backend str

Backend to use. Defaults to None.

None
port str | int

Port to use. Defaults to None.

None
seed int

Seed to use. Defaults to None.

None
timeout str | int

Timeout to use. Defaults to None.

None
verbose bool

Whether to print the info. Defaults to False.

False
tensor_parallel_size int

Tensor parallel size. Defaults to 1.

1
pipeline_parallel_size int

Pipeline parallel size. Defaults to 1.

1
context_parallel_size int

Context parallel size. Defaults to 1.

1
tensor_parallel_backend str

Tensor parallel backend. Defaults to None.

None
pipeline_parallel_backend str

Pipeline parallel backend. Defaults to None.

None
context_parallel_backend str

Context parallel backend. Defaults to None.

None
data_parallel_backend str

Data parallel backend. Defaults to None.

None

Returns:

Name Type Description
int int

Rank of the process.

Source code in src/ezpz/dist.py
def setup_torch(
    backend: Optional[str] = None,
    port: Optional[str | int] = None,
    seed: Optional[int] = None,
    timeout: Optional[str | int] = None,
    verbose: Optional[bool] = False,
    tensor_parallel_size: int = 1,
    pipeline_parallel_size: int = 1,
    context_parallel_size: int = 1,
    tensor_parallel_backend: Optional[str] = None,
    pipeline_parallel_backend: Optional[str] = None,
    context_parallel_backend: Optional[str] = None,
    data_parallel_backend: Optional[str] = None,
) -> int:
    """Setup torch.

    Args:
        backend (str, optional): Backend to use. Defaults to None.
        port (str | int, optional): Port to use. Defaults to None.
        seed (int, optional): Seed to use. Defaults to None.
        timeout (str | int, optional): Timeout to use. Defaults to None.
        verbose (bool, optional): Whether to print the info. Defaults to False.
        tensor_parallel_size (int, optional): Tensor parallel size. Defaults to 1.
        pipeline_parallel_size (int, optional): Pipeline parallel size. Defaults to 1.
        context_parallel_size (int, optional): Context parallel size. Defaults to 1.
        tensor_parallel_backend (str, optional): Tensor parallel backend. Defaults to None.
        pipeline_parallel_backend (str, optional): Pipeline parallel backend. Defaults to None.
        context_parallel_backend (str, optional): Context parallel backend. Defaults to None.
        data_parallel_backend (str, optional): Data parallel backend. Defaults to None.

    Returns:
        int: Rank of the process.
    """
    device = get_torch_device()
    # if ACCELERATOR_TYPE == 'NvidiaGPU' and device == 'cuda':
    #     os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    #     torch.backends.cudnn.deterministic = True     # type:ignore
    #     torch.backends.cudnn.benchmark = True         # type:ignore
    #     torch.backends.cudnn.allow_tf32 = True        # type:ignore
    #     torch.backends.cuda.matmul.allow_tf32 = True  # type:ignore
    # torch.use_deterministic_algorithms(True)
    ws_from_env = os.environ.get("WORLD_SIZE", None)
    backend = "DDP" if backend is None else backend
    backend = backend.lower()
    if ws_from_env is not None and ws_from_env == "1":
        logger.info(
            f"Running on a single {device}, not initializing torch.distributed!"
        )
        rank = 0
        world_size = 1
        local_rank = 0
        local_size = 1
        num_nodes = 1
    else:
        dsetup = setup_torch_distributed(
            backend=backend,
            port=port,
            timeout=timeout,
            tensor_parallel_size=int(tensor_parallel_size),
            pipeline_parallel_size=int(pipeline_parallel_size),
            context_parallel_size=int(context_parallel_size),
            tensor_parallel_backend=tensor_parallel_backend,
            pipeline_parallel_backend=pipeline_parallel_backend,
            context_parallel_backend=context_parallel_backend,
            data_parallel_backend=data_parallel_backend,
        )
        rank = dsetup["rank"]
        world_size = dsetup["world_size"]
        local_rank = dsetup["local_rank"]
        local_size = get_gpus_per_node()
        num_nodes = get_num_nodes()
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(local_rank)
    os.environ["NUM_NODES"] = str(num_nodes)
    os.environ["LOCAL_SIZE"] = str(local_size)
    os.environ["WORLD_SIZE"] = str(world_size)
    # nthreads = os.environ.get('OMP_NUM_THREADS', None)
    # if ACCELERATOR_TYPE == "IntelGPU" and device == "xpu":
    if torch.xpu.is_available():
        torch.xpu.set_device(local_rank)
        # try:
        #     import intel_extension_for_pytorch as ipex  # type:ignore[missingTypeStubs]
        # except Exception:
        #     ipex = None
        # if ipex is not None:
        #     logger.debug(f"Using ipex from: {ipex.__file__}")
        #
        # try:
        #     import oneccl_bindings_for_pytorch as oneccl_bpt  # type:ignore[missingTypeStubs]
        # except Exception:
        #     oneccl_bpt = None
        # if oneccl_bpt is not None:
        #     logger.debug(f"Using oneccl_bindings from: {oneccl_bpt.__file__}")
        #
        #     # logger.warning(f'Using {get_torch_device()}:{get_local_rank()}')
        #     # os.environ['CCL_LOCAL_RANK'] = str(local_rank)
        #     # os.environ['CCL_LOCAL_SIZE'] = str(local_size)
    if seed is not None:
        seed_everything(seed * (rank + 1) * (local_rank + 1))
    if rank == 0:
        if backend in {"ds", "deepspeed", "dspeed"}:
            from ezpz.configs import git_ds_info

            git_ds_info()
        _ = get_dist_info(verbose=verbose)
        if verbose:
            _ = print_dist_setup()
    # if world_size > 1:
    #     tdist.barrier()

    if rank == 0:
        logger.info(
            f"Using {device=} with {backend=} "
            f"+ '{get_torch_backend()}' "
            "for distributed training."
        )
    lrank = len(str(world_size - 1))
    # nz = lrank - len(str(rank))
    hn = socket.gethostname()
    psizes = [f"['{hn}']" + f"[{rank:>{lrank}}/{world_size - 1:<{lrank}}] "]
    if (
        tensor_parallel_size > 1
        or context_parallel_size > 1
        or pipeline_parallel_size > 1
    ):
        import ezpz.tp

        tpsize = ezpz.tp.get_tensor_parallel_world_size()
        cpsize = ezpz.tp.get_context_parallel_world_size()
        ppsize = ezpz.tp.get_pipeline_parallel_world_size()
        dpsize = ezpz.tp.get_data_parallel_world_size()
        if cpsize > 1 or ppsize > 1 or tpsize > 1:
            if cpsize > 1:
                lcp = len(str(cpsize - 1))
                cprank = ezpz.tp.get_context_parallel_rank()
                # cpranks = ezpz.tp.get_context_parallel_ranks()
                psizes.append(f"[cp:{cprank:>{lcp}}/{cpsize - 1:<{lcp}}]")
                barrier(group=ezpz.tp.get_context_parallel_group())
            if ppsize > 1:
                pprank = ezpz.tp.get_pipeline_parallel_rank()
                # ppranks = ezpz.tp.get_pipeline_parallel_ranks()
                lpp = len(str(ppsize - 1))
                psizes.append(f"[pp:{pprank:>{lpp}}/{ppsize - 1:<{lpp}}]")
                barrier(group=ezpz.tp.get_pipeline_parallel_group())
                # tdist.barrier(group=ezpz.tp.get_pipeline_parallel_group())
            if tpsize > 1:
                ltp = len(str(tpsize - 1))
                tprank = ezpz.tp.get_tensor_parallel_rank()
                # tpranks = ezpz.tp.get_tensor_parallel_ranks()
                psizes.append(f"[tp:{tprank:>{ltp}}/{tpsize - 1:<{ltp}}]")
                barrier(group=ezpz.tp.get_tensor_parallel_group())
            if dpsize > 1:
                ldp = len(str(dpsize - 1))
                dprank = ezpz.tp.get_data_parallel_rank()
                # dpranks = ezpz.tp.get_data_parallel_ranks()
                psizes.append(f"[dp:{dprank:>{ldp}}/{dpsize - 1:<{ldp}}]")
                barrier(group=ezpz.tp.get_data_parallel_group())
    # tdist.all_gather(psizes)
    logger.info("".join(psizes))
    barrier()
    return rank

setup_torch_distributed(backend=None, tensor_parallel_size=1, pipeline_parallel_size=1, context_parallel_size=1, tensor_parallel_backend=None, pipeline_parallel_backend=None, context_parallel_backend=None, data_parallel_backend=None, port=None, timeout=None)

Returns {'world_size': int, 'rank': int, 'local_rank': int}

Source code in src/ezpz/dist.py
def setup_torch_distributed(
    backend: Optional[str] = None,
    tensor_parallel_size: int = 1,
    pipeline_parallel_size: int = 1,
    context_parallel_size: int = 1,
    tensor_parallel_backend: Optional[str] = None,
    pipeline_parallel_backend: Optional[str] = None,
    context_parallel_backend: Optional[str] = None,
    data_parallel_backend: Optional[str] = None,
    port: Optional[str | int] = None,
    timeout: Optional[str | int] = None,
) -> dict[str, int]:
    """Returns {'world_size': int, 'rank': int, 'local_rank': int}"""
    backend = "DDP" if backend is None else backend
    assert backend.lower() in {
        "ddp",
        "ds",
        "deepspeed",
        "horovod",
        "hvd",
    }
    timeout = (
        3600
        if timeout is None
        else int(timeout)
        if isinstance(timeout, str)
        else timeout
    )
    port = (
        "1234" if port is None else str(port) if isinstance(port, int) else port
    )
    rank = get_rank()
    world_size = get_world_size()
    local_rank = get_local_rank()
    be = backend.lower()
    # assert be in BACKENDS['pytorch']
    if be == "ddp":
        dsetup = setup_torch_DDP(port, timeout)
        world_size = dsetup["world_size"]
        rank = dsetup["rank"]
        local_rank = dsetup["local_rank"]
        if torch.cuda.is_available():
            torch.cuda.set_device(local_rank)
    elif be in {"deepspeed", "ds"}:
        init_deepspeed(timeout=timeout)
        world_size = get_world_size()
        rank = get_rank()
        local_rank = get_local_rank()
    elif be in {"horovod", "hvd"}:
        import horovod.torch as hvd  # type:ignore noqa

        _ = None if hvd.is_initialized() else hvd.init()
        # hvd.init() if not hvd.is_initialized() else None
        rank = hvd.rank()
        world_size = hvd.size()
        local_rank = hvd.local_rank()
        if torch.cuda.is_available():
            torch.cuda.set_device(hvd.local_rank())
    else:
        raise ValueError(f"Unable to parse backend: {be=}")

    if (
        tensor_parallel_size > 1
        or context_parallel_size > 1
        or pipeline_parallel_size > 1
    ):
        ezpz.tp.initialize_tensor_parallel(
            tensor_parallel_size=tensor_parallel_size,
            pipeline_parallel_size=pipeline_parallel_size,
            context_parallel_size=context_parallel_size,
            tensor_parallel_backend=tensor_parallel_backend,
            pipeline_parallel_backend=pipeline_parallel_backend,
            context_parallel_backend=context_parallel_backend,
            data_parallel_backend=data_parallel_backend,
            timeout=timedelta(seconds=timeout),
        )

    os.environ["world_size"] = str(world_size)
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(local_rank)

    return {"world_size": world_size, "rank": rank, "local_rank": local_rank}

setup_wandb(project_name=None, entity=None, config=None, start_method='thread', outdir=None, init_timeout=300)

Setup wandb for logging.

Parameters:

Name Type Description Default
project_name str

The name of the project. Defaults to None.

None
entity str

The entity name. Defaults to None.

None
config dict | DictConfig

The configuration dictionary. Defaults to None.

None
start_method str

The start method for wandb. Defaults to "thread".

'thread'
outdir str | Path | PathLike

The output directory. Defaults to None.

None
init_timeout int

The timeout for wandb initialization. Defaults to 300.

300
Example

setup_wandb(project_name="my_project", entity="my_entity")

Source code in src/ezpz/dist.py
def setup_wandb(
    project_name: Optional[str] = None,
    entity: Optional[str] = None,
    config: Optional[dict | DictConfig] = None,
    start_method: str = "thread",
    outdir: Optional[str | Path | os.PathLike] = None,
    init_timeout: int = 300,
):
    """Setup wandb for logging.

    Args:
        project_name (str, optional): The name of the project. Defaults to None.
        entity (str, optional): The entity name. Defaults to None.
        config (dict | DictConfig, optional): The configuration dictionary. Defaults to None.
        start_method (str, optional): The start method for wandb. Defaults to "thread".
        outdir (str | Path | os.PathLike, optional): The output directory. Defaults to None.
        init_timeout (int, optional): The timeout for wandb initialization. Defaults to 300.

    Example:
        >>> setup_wandb(project_name="my_project", entity="my_entity")
    """
    # try:
    #     import wandb
    # except Exception:
    #     wandb = None
    # try:
    #     import wandb
    #
    WANDB_DISABLED = os.environ.get("WANDB_DISABLED", False)
    WANDB_MODE = os.environ.get("WANDB_MODE", "").lower()
    # except Exception:
    #     wandb = None
    #     WANDB_DISABLED = True

    if WANDB_DISABLED or WANDB_MODE == "disabled":
        logger.warning(
            f"Logging with W&B is disabled!, caught: {WANDB_DISABLED=}"
        )
        return None

    try:
        import wandb
    except (ImportError, ModuleNotFoundError) as e:
        logger.warning(
            "Unable to import `wandb`. Install with `pip install wandb`"
        )
        raise e

    outdir = (
        Path(os.getcwd()).as_posix()
        if outdir is None
        else Path(outdir).as_posix()
    )
    rank = get_rank()
    project_name = (
        project_name
        if project_name is not None
        else os.environ.get(
            "WB_PROJECT",
            os.environ.get(
                "WANDB_PROJECT",
                os.environ.get("WB_PROJECT_NAME", None),
            ),
        )
    )
    if project_name is None:
        import sys

        frame = sys._getframe().f_back
        assert frame is not None
        calling_module = frame.f_code.co_filename
        fp = Path(calling_module)
        project_name = f"{fp.parent.stem}.{fp.stem}"

    logger.info(f"Setting up wandb from {rank=}")
    logger.info(f"Using WB_PROJECT={project_name}")
    tensorboard_dir = (
        os.environ.get("TENSORBOARD_DIR", None)
        if config is None
        else config.get("tensorboard_dir", None)
    )
    if tensorboard_dir is not None:
        logger.info(f"Patching tensorboard from {tensorboard_dir}")
        try:
            wandb.tensorboard.patch(root_logdir=tensorboard_dir)  # type:ignore
        except Exception as exc:
            logger.exception(exc)
    # wbrun_id = wandb.util.generate_id()
    now = datetime.datetime.now()
    dstr = now.strftime("%Y-%m-%d-%H%M%S")
    run = wandb.init(
        entity=entity,
        # resume='allow',
        dir=outdir,
        sync_tensorboard=(tensorboard_dir is not None),  # True,
        project=(project_name if project_name is not None else None),
        # dir=(tensorboard_dir if tensorboard_dir is not None else None),
        settings=wandb.Settings(
            start_method=start_method, init_timeout=init_timeout
        ),
    )
    assert run is not None and run is wandb.run
    # run.log_code(HERE.as_posix(), include_fn=include_file)
    logger.info(f"wandb.run=[{run.name}]({run.url})")
    if (
        wandb is not None
        and wandb.run is not None
        and "DIST_INFO" not in wandb.run.config
    ):
        wandb.run.config.update({"DIST_INFO": get_dist_info()})
    torch_version = torch.__version__
    torch_file = torch.__file__
    run.config.update(
        {
            "created_at": dstr,
            "day": ezpz.get_timestamp("%d"),
            "month": ezpz.get_timestamp("%m"),
            "outdir": os.getcwd(),
            "torch_version": torch_version,
            "torch_file": torch_file,
            "world_size": get_world_size(),
            "year": ezpz.get_timestamp("%Y"),
            "ezpz_version": ezpz.__version__,
            "ezpz_file": ezpz.__file__,
        }
    )
    if config is not None:
        if isinstance(config, DictConfig):
            cfg = OmegaConf.to_container(
                config, resolve=True, throw_on_missing=True
            )
            run.config.update({"config": cfg})
        else:
            run.config.update({"config": config})
    env = {
        k: v
        for k, v in dict(os.environ).items()
        if not k.startswith("_ModuleTable")
    }
    _ = env.pop("LS_COLORS", None)
    _ = env.pop("PS1", None)
    run.config.update({"env": env})
    machine = get_machine()
    logger.info(f"Running on {machine=}")
    run.config.update({"machine": machine})
    model_size = os.environ.get("MODEL_SIZE", None)
    if model_size is not None:
        run.config.update({"MODEL_SIZE": model_size})
    return wandb.run

synchronize(device=None)

Synchronize the given device.

Parameters:

Name Type Description Default
device device | int | str

The device to synchronize. If None, the default device will be used. Defaults to None.

None

Returns:

Type Description

None

Source code in src/ezpz/dist.py
def synchronize(device: Optional[torch.device | int | str] = None):
    """
    Synchronize the given device.

    Args:
        device (torch.device | int | str, optional): The device to synchronize.
            If None, the default device will be used. Defaults to None.

    Returns:
        None
    """
    return (
        torch.cuda.synchronize(device)
        if torch.cuda.is_available()
        else (
            torch.xpu.synchronize(device)
            if torch.xpu.is_available()
            else torch.mps.synchronize()
            if torch.backends.mps.is_available()
            else torch.cpu.synchronize(device)
        )
    )

timeitlogit(rank=None, verbose=True)

Decorator to time a function and log the time taken.

Parameters:

Name Type Description Default
rank int

Rank of the process. Defaults to None.

None
verbose bool

Whether to log the time taken. Defaults to True.

True
Source code in src/ezpz/dist.py
def timeitlogit(rank: Optional[int] = None, verbose: bool = True):
    """Decorator to time a function and log the time taken.

    Args:
        rank (int, optional): Rank of the process. Defaults to None.
        verbose (bool, optional): Whether to log the time taken. Defaults to True.
    """
    rank = get_rank() if rank is None else rank
    try:
        import wandb
    except Exception:
        wandb = None

    def decorator(func: Callable):
        """Decorator to time a function and log the time taken.

        Args:
            func (Callable): Function to be timed.
        """

        @wraps(func)
        def wrapper(*args, **kwargs):
            t0 = time.perf_counter()
            assert isinstance(rank, int)
            result = func(*args, **kwargs)
            dt = time.perf_counter() - t0
            if verbose:
                if rank == 0:
                    tstr = [f"`{func.__name__}`"]
                    if len(args) > 0:
                        tstr.append(f"({args}")
                    # _ = tstr.append(f"({args}") if len(args) > 0 else None
                    _ = (
                        tstr.append(f", {kwargs})")
                        if len(kwargs) > 0
                        else (tstr.append(")") if len(args) > 0 else "")
                    )
                    _ = tstr.append(f" took: {dt=:.4f}s")
                    logger.info("".join(tstr))
                if wandb is not None and wandb.run is not None:
                    # logger.info(
                    #     f'Logging timeit/{func.__name__}/{dt=:.4f} to W&B'
                    # )
                    wandb.run.log({f"timeit/{func.__name__}": dt}, commit=False)
            return result

        return wrapper

    return decorator

write_hostfile_from_list_of_hosts(hosts, hostfile=None, rank_zero_only=True)

Write a list of hosts to the hostfile

Source code in src/ezpz/dist.py
def write_hostfile_from_list_of_hosts(
    hosts: list[str],
    hostfile: Optional[PathLike] = None,
    rank_zero_only: bool = True,
):
    """Write a list of hosts to the hostfile"""
    hostfile = (
        Path(hostfile).as_posix()
        if hostfile is not None
        else Path(os.getcwd()).joinpath("hostfile").as_posix()
    )
    if (rank_zero_only and get_rank() == 0) or not rank_zero_only:
        logger.info(f"Writing to {hostfile}")
        with Path(hostfile).open("w") as f:
            for host in hosts:
                f.write(f"{host}\n")

write_localhost_to_hostfile(hostfile)

Write 'localhost' to the hostfile

Source code in src/ezpz/dist.py
def write_localhost_to_hostfile(hostfile: PathLike):
    """Write 'localhost' to the hostfile"""
    if get_rank() == 0:
        logger.debug(
            f"Writing {(hostname := get_hostname())} "
            f"to {Path(hostfile).as_posix()}"
        )
        hostname = get_hostname()
        with Path(hostfile).open("w") as f:
            f.write(f"{hostname}")