Skip to content

ezpz.tp.__init__

ezpz/tp/init.py

modified from: https://github.com/facebookresearch/fairscale/blob/5f484b3545f27eddb19d970fbe1d361b9c5f2b07/fairscale/nn/tensor_parallel/initialize.py

destroy_tensor_parallel()

Set the groups to none.

Source code in src/ezpz/tp/__init__.py
def destroy_tensor_parallel() -> None:
    """Set the groups to none."""
    global _TENSOR_PARALLEL_GROUP
    _TENSOR_PARALLEL_GROUP = None
    global _TENSOR_PARALLEL_RANKS
    _TENSOR_PARALLEL_RANKS = None

    global _DATA_PARALLEL_GROUP
    _DATA_PARALLEL_GROUP = None
    global _DATA_PARALLEL_RANKS
    _DATA_PARALLEL_RANKS = None

    global _PIPELINE_PARALLEL_GROUP
    _PIPELINE_PARALLEL_GROUP = None
    global _PIPELINE_PARALLEL_RANKS
    _PIPELINE_PARALLEL_RANKS = None

    global _CONTEXT_PARALLEL_GROUP
    _CONTEXT_PARALLEL_GROUP = None
    global _CONTEXT_PARALLEL_GROUP_RANKS
    _CONTEXT_PARALLEL_GROUP_RANKS = None

divide_and_check_no_remainder(numerator, denominator)

Divide the numerator by the denominator and check that there is no remainder.

Source code in src/ezpz/tp/utils.py
def divide_and_check_no_remainder(numerator: int, denominator: int) -> int:
    """Divide the numerator by the denominator and check that there is no remainder."""
    ensure_divisibility(numerator, denominator)
    return numerator // denominator

ensure_divisibility(numerator, denominator)

Ensure that numerator is divisible by the denominator.

Source code in src/ezpz/tp/utils.py
4
5
6
7
8
def ensure_divisibility(numerator: int, denominator: int) -> None:
    """Ensure that numerator is divisible by the denominator."""
    assert numerator % denominator == 0, '{} is not divisible by {}'.format(
        numerator, denominator
    )

get_context_parallel_group()

Get the context parallel group the caller rank belongs to.

Source code in src/ezpz/tp/__init__.py
def get_context_parallel_group() -> tdist.ProcessGroup:
    """Get the context parallel group the caller rank belongs to."""
    assert (
        _CONTEXT_PARALLEL_GROUP is not None
    ), "context parallel group is not initialized"
    return _CONTEXT_PARALLEL_GROUP

get_context_parallel_rank()

Return my rank for the context parallel group.

Source code in src/ezpz/tp/__init__.py
def get_context_parallel_rank() -> int:
    """Return my rank for the context parallel group."""
    return tdist.get_rank(group=get_context_parallel_group())

get_context_parallel_ranks()

Return context parallel ranks for the context parallel group.

Source code in src/ezpz/tp/__init__.py
def get_context_parallel_ranks() -> List[int]:
    """Return context parallel ranks for the context parallel group."""
    assert (
        _CONTEXT_PARALLEL_GROUP_RANKS is not None
    ), "context parallel group is not initialized"
    return _CONTEXT_PARALLEL_GROUP_RANKS

get_context_parallel_world_size()

Return world size for the context parallel group.

Source code in src/ezpz/tp/__init__.py
def get_context_parallel_world_size() -> int:
    """Return world size for the context parallel group."""
    return tdist.get_world_size(group=get_context_parallel_group())

get_data_parallel_group()

Get the data parallel group the caller rank belongs to.

Source code in src/ezpz/tp/__init__.py
def get_data_parallel_group() -> tdist.ProcessGroup:
    """Get the data parallel group the caller rank belongs to."""
    assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
    return _DATA_PARALLEL_GROUP

get_data_parallel_rank()

Return my rank for the data parallel group.

Source code in src/ezpz/tp/__init__.py
def get_data_parallel_rank() -> int:
    """Return my rank for the data parallel group."""
    return tdist.get_rank(group=get_data_parallel_group())

get_data_parallel_ranks()

Get the data parallel group the caller rank belongs to.

Source code in src/ezpz/tp/__init__.py
def get_data_parallel_ranks() -> List[int]:
    """Get the data parallel group the caller rank belongs to."""
    assert _DATA_PARALLEL_RANKS is not None, "data parallel group is not initialized"
    return _DATA_PARALLEL_RANKS

get_data_parallel_world_size()

Return world size for the data parallel group.

Source code in src/ezpz/tp/__init__.py
def get_data_parallel_world_size() -> int:
    """Return world size for the data parallel group."""
    return tdist.get_world_size(group=get_data_parallel_group())

get_pipeline_parallel_group()

Get the pipeline parallel group the caller rank belongs to.

Source code in src/ezpz/tp/__init__.py
def get_pipeline_parallel_group() -> tdist.ProcessGroup:
    """Get the pipeline parallel group the caller rank belongs to."""
    assert (
        _PIPELINE_PARALLEL_GROUP is not None
    ), "pipeline parallel group is not initialized"
    return _PIPELINE_PARALLEL_GROUP

get_pipeline_parallel_rank()

Return my rank for the pipeline parallel group.

Source code in src/ezpz/tp/__init__.py
def get_pipeline_parallel_rank() -> int:
    """Return my rank for the pipeline parallel group."""
    return tdist.get_rank(group=get_pipeline_parallel_group())

get_pipeline_parallel_ranks()

Get the pipeline parallel group the caller rank belongs to.

Source code in src/ezpz/tp/__init__.py
def get_pipeline_parallel_ranks() -> List[int]:
    """Get the pipeline parallel group the caller rank belongs to."""
    assert (
        _PIPELINE_PARALLEL_RANKS is not None
    ), "pipeline parallel group is not initialized"
    return _PIPELINE_PARALLEL_RANKS

get_pipeline_parallel_world_size()

Return world size for the context parallel group.

Source code in src/ezpz/tp/__init__.py
def get_pipeline_parallel_world_size() -> int:
    """Return world size for the context parallel group."""
    return tdist.get_world_size(group=get_pipeline_parallel_group())

get_tensor_parallel_group()

Get the tensor parallel group the caller rank belongs to.

Source code in src/ezpz/tp/__init__.py
def get_tensor_parallel_group() -> tdist.ProcessGroup:
    """Get the tensor parallel group the caller rank belongs to."""
    assert (
        _TENSOR_PARALLEL_GROUP is not None
    ), "tensor parallel group is not initialized"
    return _TENSOR_PARALLEL_GROUP

get_tensor_parallel_rank()

Return my rank for the tensor parallel group.

Source code in src/ezpz/tp/__init__.py
def get_tensor_parallel_rank() -> int:
    """Return my rank for the tensor parallel group."""
    return tdist.get_rank(group=get_tensor_parallel_group())

get_tensor_parallel_ranks()

Get the tensor parallel group the caller rank belongs to.

Source code in src/ezpz/tp/__init__.py
def get_tensor_parallel_ranks() -> List[int]:
    """Get the tensor parallel group the caller rank belongs to."""
    assert (
        _TENSOR_PARALLEL_RANKS is not None
    ), "tensor parallel group is not initialized"
    return _TENSOR_PARALLEL_RANKS

get_tensor_parallel_src_rank()

Calculate the global rank corresponding to local rank 0 in the TP group.

Source code in src/ezpz/tp/__init__.py
def get_tensor_parallel_src_rank() -> int:
    """
    Calculate the global rank corresponding to local rank 0 in the TP group.
    """
    global_rank = tdist.get_rank()
    local_world_size = get_tensor_parallel_world_size()
    return (global_rank // local_world_size) * local_world_size

get_tensor_parallel_world_size()

Return world size for the tensor parallel group.

Source code in src/ezpz/tp/__init__.py
def get_tensor_parallel_world_size() -> int:
    """Return world size for the tensor parallel group."""
    return tdist.get_world_size(group=get_tensor_parallel_group())

initialize_tensor_parallel(tensor_parallel_size=1, pipeline_parallel_size=1, context_parallel_size=1, tensor_parallel_backend=None, pipeline_parallel_backend=None, context_parallel_backend=None, data_parallel_backend=None, timeout=None)

Initialize tensor data parallel groups.

Parameters:

Name Type Description Default
tensor_parallel_size int

number of GPUs used to parallelize model.

1

Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we use 2 GPUs to parallelize the model. The present function will create 4 tensor parallel groups and 2 data parallel groups as:

  • 4 tensor parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
  • 2 data parallel groups:

    [g0, g2, g4, g6], [g1, g3, g5, g7]
    

Note that for efficiency, the caller should make sure adjacent ranks are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box.

process groups initialized in the order of TP, CP, PP, DP.

Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we use 2 GPUs to parallelize the tensor tensor, 2 GPUs to parallelize context(seq len), and 2 GPUs to parallelize the tensor pipeline. The present function will create 8 tensor model-parallel groups, 8 context-parallel group, 8 pipeline model-parallel groups and 8 data-parallel groups as: when alternate_pp_config = False,

  • 8 data_parallel groups: [g0, g4], [g1, g5], [g2, g6], [g3, g7], [g8, g12], [g9, g13], [g10, g14], [g11, g15]
  • 8 tensor model-parallel groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
  • 8 context-parallel groups: [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
  • 8 pipeline model-parallel groups: [g0, g8], [g1, g9], [g2, g10], [g3, g11], [g4, g12], [g5, g13], [g6, g16], [g7, g15]
Source code in src/ezpz/tp/__init__.py
def initialize_tensor_parallel(
    tensor_parallel_size: int = 1,
    pipeline_parallel_size: int = 1,
    context_parallel_size: int = 1,
    tensor_parallel_backend: Optional[str] = None,
    pipeline_parallel_backend: Optional[str] = None,
    context_parallel_backend: Optional[str] = None,
    data_parallel_backend: Optional[str] = None,
    timeout: Optional[timedelta] = None,
) -> None:
    """
    Initialize tensor data parallel groups.

    Arguments:
        tensor_parallel_size: number of GPUs used to parallelize model.

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
    use 2 GPUs to parallelize the model. The present function will
    create 4 tensor parallel groups and 2 data parallel groups as:

    - 4 tensor parallel groups:

      ```
      [g0, g1], [g2, g3], [g4, g5], [g6, g7]
      ```

    - 2 data parallel groups:

        ```
        [g0, g2, g4, g6], [g1, g3, g5, g7]
        ```

    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.

    process groups initialized in the order of TP, CP, PP, DP.

    Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
    use 2 GPUs to parallelize the tensor tensor, 2 GPUs to parallelize context(seq len), and 2 GPUs to parallelize
    the tensor pipeline. The present function will
    create 8 tensor model-parallel groups, 8 context-parallel group, 8 pipeline model-parallel groups
    and 8 data-parallel groups as:
    when alternate_pp_config = False,

    - 8 data_parallel groups:
        [g0, g4], [g1, g5], [g2, g6], [g3, g7], [g8, g12], [g9, g13], [g10, g14], [g11, g15]
    - 8 tensor model-parallel groups:
        [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
    - 8 context-parallel groups:
        [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
    - 8 pipeline model-parallel groups:
        [g0, g8], [g1, g9], [g2, g10], [g3, g11], [g4, g12], [g5, g13], [g6, g16], [g7, g15]
    """
    # Get world size and rank. Ensure some consistencies.
    assert tdist.is_initialized()
    world_size = tdist.get_world_size()
    tensor_parallel_size = int(min(tensor_parallel_size, world_size))
    ensure_divisibility(world_size, tensor_parallel_size)
    ensure_divisibility(world_size, context_parallel_size)
    ensure_divisibility(
        world_size,
        tensor_parallel_size * pipeline_parallel_size * context_parallel_size,
    )
    rank = tdist.get_rank()

    dpsize = int(
        world_size
        / (tensor_parallel_size * pipeline_parallel_size * context_parallel_size)
    )

    if tdist.get_rank() == 0:
        pstr = ", ".join(
            [
                f"TP: {tensor_parallel_size}",
                f"PP: {pipeline_parallel_size}",
                f"CP: {context_parallel_size}",
                f"DP: {dpsize}",
            ]
        )
        logger.info(pstr)
        # pstr = f'TP: {tensor_parallel_size}, PP: {pipeline_parallel_size}, CP: {context_parallel_size}, DP: {dpsize}'
        # logger.info(
        #     '> initializing tensor parallel with size {}'.format(
        #         tensor_parallel_size
        #     )
        # )
        # logger.info(
        #     '> initializing context parallel with size {}'.format(
        #         context_parallel_size
        #     )
        # )
        # logger.info(
        #     '> initializing pipeline with size {}'.format(
        #         pipeline_parallel_size
        #     )
        # )

    groups = torch.LongTensor(range(world_size)).reshape(
        dpsize,
        pipeline_parallel_size,
        context_parallel_size,
        tensor_parallel_size,
    )

    found = torch.where(groups == rank)
    assert all(len(x) == 1 for x in found)
    found = [x[0] for x in found]

    # Build the data parallel groups.
    global _DATA_PARALLEL_GROUP
    global _DATA_PARALLEL_RANKS
    assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
    assert _DATA_PARALLEL_RANKS is None, "data parallel ranks are already initialized"
    for i in range(pipeline_parallel_size):
        for j in range(context_parallel_size):
            for k in range(tensor_parallel_size):
                ranks = groups[:, i, j, k].tolist()
                group = tdist.new_group(
                    groups[:, i, j, k].tolist(),
                    backend=data_parallel_backend,
                    timeout=timeout,
                )
                if i == found[1] and j == found[2] and k == found[3]:
                    _DATA_PARALLEL_GROUP = group
                    _DATA_PARALLEL_RANKS = ranks

    # Build the tensor parallel groups.
    global _TENSOR_PARALLEL_GROUP
    global _TENSOR_PARALLEL_RANKS
    assert (
        _TENSOR_PARALLEL_GROUP is None
    ), "tensor parallel group is already initialized"
    assert (
        _TENSOR_PARALLEL_RANKS is None
    ), "tensor parallel ranks are already initialized"
    for i in range(dpsize):
        for j in range(pipeline_parallel_size):
            for k in range(context_parallel_size):
                ranks = groups[i, j, k, :].tolist()
                group = tdist.new_group(
                    groups[i, j, k, :].tolist(),
                    backend=tensor_parallel_backend,
                    timeout=timeout,
                )
                if i == found[0] and j == found[1] and k == found[2]:
                    _TENSOR_PARALLEL_GROUP = group
                    _TENSOR_PARALLEL_RANKS = ranks

    # Build the pipeline parallel groups.
    global _PIPELINE_PARALLEL_GROUP
    global _PIPELINE_PARALLEL_RANKS
    assert (
        _PIPELINE_PARALLEL_GROUP is None
    ), "Pipeline parallel group is already initialized"
    for i in range(dpsize):
        for j in range(context_parallel_size):
            for k in range(tensor_parallel_size):
                ranks = groups[i, :, j, k].tolist()
                group = tdist.new_group(
                    ranks, backend=pipeline_parallel_backend, timeout=timeout
                )
                if i == found[0] and j == found[2] and k == found[3]:
                    _PIPELINE_PARALLEL_GROUP = group
                    _PIPELINE_PARALLEL_RANKS = ranks

    # Build the context parallel groups.
    global _CONTEXT_PARALLEL_GROUP
    global _CONTEXT_PARALLEL_GROUP_RANKS

    assert (
        _CONTEXT_PARALLEL_GROUP is None
    ), "Context parallelism is already initialized."
    for i in range(dpsize):
        for j in range(pipeline_parallel_size):
            for k in range(tensor_parallel_size):
                ranks = groups[i, j, :, k].tolist()
                group = tdist.new_group(
                    ranks, backend=context_parallel_backend, timeout=timeout
                )
                if i == found[0] and j == found[1] and k == found[3]:
                    _CONTEXT_PARALLEL_GROUP = group
                    _CONTEXT_PARALLEL_GROUP_RANKS = ranks

split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False)

Split a tensor along its last dimension.

Parameters:

Name Type Description Default
tensor Tensor

The tensor to split.

required
num_partitions int

The number of partitions to split the tensor into.

required
contiguous_split_chunks bool

Whether to return contiguous split chunks.

False
Source code in src/ezpz/tp/utils.py
def split_tensor_along_last_dim(
    tensor: torch.Tensor,
    num_partitions: int,
    contiguous_split_chunks: bool = False,
) -> tuple[torch.Tensor, ...]:
    """Split a tensor along its last dimension.

    Arguments:
        tensor: The tensor to split.
        num_partitions: The number of partitions to split the tensor into.
        contiguous_split_chunks: Whether to return contiguous split chunks.
    """
    last_dim = tensor.dim() - 1
    last_dim_size = divide_and_check_no_remainder(
        tensor.size()[last_dim], num_partitions
    )
    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
    if contiguous_split_chunks:
        return tuple(chunk.contiguous() for chunk in tensor_list)
    return tensor_list

tensor_parallel_is_initialized()

Check if tensor and data parallel groups are initialized.

Source code in src/ezpz/tp/__init__.py
def tensor_parallel_is_initialized() -> bool:
    """Check if tensor and data parallel groups are initialized."""
    if (
        _TENSOR_PARALLEL_GROUP is None
        or _DATA_PARALLEL_GROUP is None
        or _PIPELINE_PARALLEL_GROUP is None
        or _CONTEXT_PARALLEL_GROUP is None
    ):
        return False
    return True