Skip to content

ezpz.profileΒΆ

profile.py

Sam Foreman [2024-06-21]

Contains implementation of:

  • get_context_manager
  • PyInstrumentProfiler

which can be used as a context manager to profile a block of code, e.g.

# test.py


def main():
    print("Hello!")
    from ezpz.profile import get_context_manager

    # NOTE:
    # 1. if `rank` is passed to `get_context_manager`:
    #        - it will ONLY be instantiated if rank == 0,
    #          otherwise, it will return a contextlib.nullcontext() instance.
    # 2. if `strict=True`:
    #        - only run if "PYINSTRUMENT_PROFILER=1" in environment
    cm = get_context_manager(rank=RANK, strict=False)
    with cm:
        main()


if __name__ == "__main__":
    main()

get_context_manager(rank=None, outdir=None, strict=True, *, profiler_type='pyinstrument', rank_zero_only=True, **profile_kwargs) ΒΆ

Returns a context manager for profiling code blocks using PyInstrument.

Parameters:

Name Type Description Default
rank Optional[int]

The rank of the process (default: None). If provided, the profiler will only run if rank is 0.

None
outdir Optional[str]

The output directory for saving profiles. Defaults to ezpz.OUTPUTS_DIR.

None
strict Optional[bool]

If True, the profiler will only run if "PYINSTRUMENT_PROFILER" is set in the environment. Defaults to True.

True

Returns:

Name Type Description
AbstractContextManager AbstractContextManager

A context manager that starts and stops the PyInstrument profiler.

Source code in src/ezpz/profile.py
def get_context_manager(
    rank: Optional[int] = None,
    outdir: Optional[str] = None,
    strict: Optional[bool] = True,
    *,
    profiler_type: str = "pyinstrument",
    rank_zero_only: bool = True,
    **profile_kwargs: Any,
) -> AbstractContextManager:
    """
    Returns a context manager for profiling code blocks using PyInstrument.

    Args:
        rank (Optional[int]): The rank of the process (default: None).
            If provided, the profiler will only run if rank is 0.
        outdir (Optional[str]): The output directory for saving profiles.
            Defaults to `ezpz.OUTPUTS_DIR`.
        strict (Optional[bool]): If True, the profiler will only run if
            "PYINSTRUMENT_PROFILER" is set in the environment.
            Defaults to True.

    Returns:
        AbstractContextManager: A context manager that starts and stops
            the PyInstrument profiler.
    """
    if profiler_type != "pyinstrument":
        return get_profiling_context(
            profiler_type=profiler_type,
            wait=profile_kwargs.get("wait", 0),
            warmup=profile_kwargs.get("warmup", 0),
            active=profile_kwargs.get("active", 1),
            repeat=profile_kwargs.get("repeat", 1),
            rank_zero_only=rank_zero_only,
            record_shapes=profile_kwargs.get("record_shapes", True),
            with_stack=profile_kwargs.get("with_stack", True),
            with_flops=profile_kwargs.get("with_flops", True),
            with_modules=profile_kwargs.get("with_modules", True),
            acc_events=profile_kwargs.get("acc_events", False),
            profile_memory=profile_kwargs.get("profile_memory", False),
            outdir=outdir,
            strict=strict,
        )

    if rank_zero_only and rank not in (None, 0):
        return nullcontext()

    d = ezpz.OUTPUTS_DIR if outdir is None else outdir
    fp = Path(d)
    fp = fp.joinpath("ezpz", "pyinstrument_profiles")

    if strict and os.environ.get("PYINSTRUMENT_PROFILER", None) is None:
        return nullcontext()

    return PyInstrumentProfiler(
        rank=rank,
        outdir=fp.as_posix(),
        rank_zero_only=rank_zero_only,
    )

get_profiling_context(profiler_type, wait, warmup, active, repeat, rank_zero_only, record_shapes=True, with_stack=True, with_flops=True, with_modules=True, acc_events=False, profile_memory=False, outdir=None, strict=True) ΒΆ

Returns a context manager for profiling code blocks using either PyTorch Profiler or PyInstrument.

Parameters:

Name Type Description Default
profiler_type str

The type of profiler to use. Must be one of ['torch', 'pyinstrument'].

required
wait int

The number of steps to wait before starting profiling.

required
warmup int

The number of warmup steps before profiling starts.

required
active int

The number of active profiling steps.

required
repeat int

The number of times to repeat the profiling schedule.

required
rank_zero_only bool

If True, the profiler will only run on rank 0. Defaults to True.

required
record_shapes bool

If True, shapes of tensors are recorded. Defaults to True.

True
with_stack bool

If True, stack traces are recorded. Defaults to True.

True
with_flops bool

If True, FLOPs are recorded. Defaults to True.

True
with_modules bool

If True, module information is recorded. Defaults to True.

True
acc_events bool

If True, accumulated events are recorded. Defaults to False.

False
profile_memory bool

If True, memory profiling is enabled. Defaults to False.

False
outdir Optional[str | Path | PathLike]

The output directory for saving profiles. Defaults to ezpz.OUTPUTS_DIR.

None
strict Optional[bool]

If True, the profiler will only run if "PYINSTRUMENT_PROFILER" is set in the environment. Defaults to True.

True

Returns: AbstractContextManager: A context manager that starts and stops the profiler.

Source code in src/ezpz/profile.py
def get_profiling_context(
    profiler_type: str,
    wait: int,
    warmup: int,
    active: int,
    repeat: int,
    rank_zero_only: bool,
    record_shapes: bool = True,
    with_stack: bool = True,
    with_flops: bool = True,
    with_modules: bool = True,
    acc_events: bool = False,
    profile_memory: bool = False,
    outdir: Optional[str | Path | os.PathLike] = None,
    strict: Optional[bool] = True,
) -> AbstractContextManager:
    """
    Returns a context manager for profiling code blocks using either
    PyTorch Profiler or PyInstrument.

    Args:
        profiler_type (str): The type of profiler to use.
            Must be one of ['torch', 'pyinstrument'].
        wait (int): The number of steps to wait before starting profiling.
        warmup (int): The number of warmup steps before profiling starts.
        active (int): The number of active profiling steps.
        repeat (int): The number of times to repeat the profiling schedule.
        rank_zero_only (bool): If True, the profiler will only run on rank 0.
            Defaults to True.
        record_shapes (bool): If True, shapes of tensors are recorded.
            Defaults to True.
        with_stack (bool): If True, stack traces are recorded.
            Defaults to True.
        with_flops (bool): If True, FLOPs are recorded.
            Defaults to True.
        with_modules (bool): If True, module information is recorded.
            Defaults to True.
        acc_events (bool): If True, accumulated events are recorded.
            Defaults to False.
        profile_memory (bool): If True, memory profiling is enabled.
            Defaults to False.
        outdir (Optional[str | Path | os.PathLike]): The output directory
            for saving profiles. Defaults to `ezpz.OUTPUTS_DIR`.
        strict (Optional[bool]): If True, the profiler will only run if
            "PYINSTRUMENT_PROFILER" is set in the environment. Defaults to True.
    Returns:
        AbstractContextManager: A context manager that starts and stops
            the profiler.
    """
    if profiler_type not in {"pt", "pytorch", "torch", "pyinstrument"}:
        raise ValueError(
            f"Invalid profiling type: {profiler_type}. "
            "Must be one of ['torch', 'pyinstrument']"
        )
    outdir_fallback = Path(os.getcwd()).joinpath("ezpz", "torch_profiles")
    outdir = outdir_fallback if outdir is None else outdir
    _ = Path(outdir).mkdir(parents=True, exist_ok=True)
    if profiler_type in {"torch", "pytorch", "pt"}:

        def trace_handler(p: torch.profiler.profile):
            """
            Callback function to handle the trace when it is ready.
            """
            logger.info(
                "\n"
                + p.key_averages().table(
                    sort_by=(f"self_{ezpz.get_torch_device_type()}_time_total"),
                    row_limit=-1,
                )
            )
            fname: str = "-".join(
                [
                    "torch-profiler",
                    f"rank{ezpz.get_rank()}",
                    f"step{p.step_num}",
                    f"{ezpz.get_timestamp()}",
                ]
            )
            trace_output = Path(outdir).joinpath(f"{fname}.json")
            logger.info(f"Saving torch profiler trace to: {trace_output.as_posix()}")
            p.export_chrome_trace(trace_output.as_posix())

        schedule = torch.profiler.schedule(
            wait=wait,
            warmup=warmup,
            active=active,
            repeat=repeat,
        )

        return get_torch_profiler(
            rank=ezpz.get_rank(),
            schedule=schedule,
            on_trace_ready=trace_handler,
            rank_zero_only=rank_zero_only,
            profile_memory=profile_memory,
            record_shapes=record_shapes,
            with_stack=with_stack,
            with_flops=with_flops,
            with_modules=with_modules,
            acc_events=acc_events,
        )

    if profiler_type == "pyinstrument":
        return get_context_manager(rank=ezpz.get_rank(), strict=strict)

    raise ValueError(
        f"Invalid profiling type: {profiler_type}. "
        "Must be one of ['torch', 'pyinstrument']"
    )

get_torch_profiler(rank=None, schedule=None, on_trace_ready=None, rank_zero_only=True, profile_memory=False, record_shapes=True, with_stack=True, with_flops=True, with_modules=True, acc_events=False) ΒΆ

A thin wrapper around torch.profiler.profile that:

  1. Supports automatic device detection {CPU, CUDA, XPU}
  2. Runs on rank 0 only (by default)
  3. To run from all ranks, set rank_zero_only=False

Parameters:

Name Type Description Default
rank Optional[int]

The rank of the process (default: None). If provided, the profiler will only run if rank is 0.

None
schedule Optional[Callable[[int], ProfilerAction]]

A callable that returns a ProfilerAction for the profiler schedule.

None
on_trace_ready Optional[Callable]

A callback function that is called when the trace is ready.

None
rank_zero_only bool

If True, the profiler will only run on rank 0. Defaults to True.

True
profile_memory bool

If True, memory profiling is enabled. Defaults to False.

False
record_shapes bool

If True, shapes of tensors are recorded. Defaults to True.

True
with_stack bool

If True, stack traces are recorded. Defaults to True.

True
with_flops bool

If True, FLOPs are recorded. Defaults to True.

True
with_modules bool

If True, module information is recorded. Defaults to True.

True
acc_events bool

If True, accumulated events are recorded. Defaults to False.

False

Returns: torch.profiler.profile: A profiler context manager that can be used to profile code blocks.

Source code in src/ezpz/profile.py
def get_torch_profiler(
    rank: Optional[int] = None,
    schedule: Optional[Callable[[int], ProfilerAction]] = None,
    on_trace_ready: Optional[Callable] = None,
    rank_zero_only: bool = True,
    profile_memory: bool = False,
    record_shapes: bool = True,
    with_stack: bool = True,
    with_flops: bool = True,
    with_modules: bool = True,
    acc_events: bool = False,
):
    """
    A thin wrapper around `torch.profiler.profile` that:

    1. Supports automatic device detection {CPU, CUDA, XPU}
    2. Runs on rank 0 only (by default)
       - To run from _all_ ranks, set `rank_zero_only=False`

    Args:
        rank (Optional[int]): The rank of the process (default: None).
            If provided, the profiler will only run if rank is 0.
        schedule (Optional[Callable[[int], ProfilerAction]]): A callable
            that returns a `ProfilerAction` for the profiler schedule.
        on_trace_ready (Optional[Callable]): A callback function that is
            called when the trace is ready.
        rank_zero_only (bool): If True, the profiler will only run on rank 0.
            Defaults to True.
        profile_memory (bool): If True, memory profiling is enabled.
            Defaults to False.
        record_shapes (bool): If True, shapes of tensors are recorded.
            Defaults to True.
        with_stack (bool): If True, stack traces are recorded.
            Defaults to True.
        with_flops (bool): If True, FLOPs are recorded.
            Defaults to True.
        with_modules (bool): If True, module information is recorded.
            Defaults to True.
        acc_events (bool): If True, accumulated events are recorded.
            Defaults to False.
    Returns:
        torch.profiler.profile: A profiler context manager that can be used
            to profile code blocks.
    """
    if rank_zero_only and (rank is None or rank != 0):
        return nullcontext()

    activities = [ProfilerActivity.CPU]
    if torch.cuda.is_available():
        activities.append(ProfilerActivity.CUDA)
    if hasattr(torch, "xpu") and torch.xpu.is_available():
        activities.append(ProfilerActivity.XPU)
    return profile(
        activities=activities,
        schedule=schedule,
        on_trace_ready=on_trace_ready,
        record_shapes=record_shapes,
        profile_memory=profile_memory,
        with_stack=with_stack,
        with_flops=with_flops,
        with_modules=with_modules,
        # use_cuda=(torch.cuda.is_available()),
        # acc_events=acc_events,
    )