Skip to content

ezpz.__init__

ezpz/init.py

History

A class to track and log metrics during training or evaluation.

Source code in src/ezpz/history.py
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
class History:
    """
    A class to track and log metrics during training or evaluation.
    """

    def __init__(self, keys: Optional[list[str]] = None) -> None:
        """
        Initialize the History object.

        Args:
            keys (Optional[list[str]]): List of keys to initialize the history with.
                If None, initializes with an empty list.
        """
        self.keys = [] if keys is None else keys
        self.history = {}

    @timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
    def _update(
        self,
        key: str,
        val: Union[Any, ScalarLike, list, torch.Tensor, np.ndarray],
    ):
        """
        Update the history with a new key-value pair.

        Args:
            key (str): The key to update in the history.
            val (Union[Any, ScalarLike, list, torch.Tensor, np.ndarray]): The value
                to associate with the key.
        """
        try:
            self.history[key].append(val)
        except KeyError:
            self.history[key] = [val]
        return val

    @timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
    def update(
        self,
        metrics: dict,
        precision: int = 6,
        use_wandb: Optional[bool] = True,
        commit: Optional[bool] = True,
        summarize: Optional[bool] = True,
    ) -> str:
        """
        Update the history with a dictionary of metrics.

        Args:
            metrics (dict): Dictionary of metrics to update the history with.
            precision (int): Precision for summarizing the metrics.
            use_wandb (Optional[bool]): Whether to log the metrics to Weights & Biases.
            commit (Optional[bool]): Whether to commit the log to Weights & Biases.
            summarize (Optional[bool]): Whether to summarize the metrics.
        """
        for key, val in metrics.items():
            # if isinstance(val, (list, np.ndarray, torch.Tensor)):
            #     val = grab_tensor(val)
            try:
                self.history[key].append(val)
            except KeyError:
                self.history[key] = [val]
        if (
            wandb is not None
            and use_wandb
            # and not WANDB_DISABLED
            and getattr(wandb, "run", None) is not None
        ):
            wandb.log(metrics, commit=commit)
        if summarize:
            return summarize_dict(metrics, precision=precision)
        return ""

    def _tplot(
        self,
        y: np.ndarray,
        x: Optional[np.ndarray] = None,
        xlabel: Optional[str] = None,
        ylabel: Optional[str] = None,
        append: bool = True,
        title: Optional[str] = None,
        verbose: bool = False,
        outfile: Optional[str] = None,
        logfreq: Optional[int] = None,
        plot_type: Optional[str] = None,
    ):
        """
        Create a text plot of the given data.

        Args:
            y (np.ndarray): The data to plot.
            x (Optional[np.ndarray]): The x-axis data.
            xlabel (Optional[str]): The x-axis label.
            ylabel (Optional[str]): The y-axis label.
            append (bool): Whether to append to an existing plot.
            title (Optional[str]): The title of the plot.
            verbose (bool): Whether to print the plot.
            outfile (Optional[str]): The path to save the plot to.
            logfreq (Optional[int]): The log frequency of the plot.
            plot_type (Optional[str]): The type of plot to create.
        """
        if xlabel is not None and ylabel == xlabel:
            return
        if len(y) > 1:
            x = x if x is not None else np.arange(len(y))
            assert x is not None
            eztplot(
                y=y,
                x=x,
                xlabel=xlabel,
                ylabel=ylabel,
                logfreq=(1 if logfreq is None else logfreq),
                append=append,
                verbose=verbose,
                outfile=outfile,
                plot_type=plot_type,
                title=title,
                # plot_type=('scatter' if 'dt' in ylabel else None),
            )
        if ylabel is not None and "dt" in ylabel:
            of = Path(outfile) if outfile is not None else None
            if of is not None:
                of = Path(of.parent).joinpath(f"{of.stem}-hist{of.suffix}")
            eztplot(
                y=y,
                xlabel=ylabel,
                title=title,
                ylabel="freq",
                append=append,
                verbose=verbose,
                outfile=(of if of is not None else None),
                plot_type="hist",
            )

    @timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
    def plot(
        self,
        val: np.ndarray,
        key: Optional[str] = None,
        warmup: Optional[float] = 0.0,
        num_chains: Optional[int] = 128,
        title: Optional[str] = None,
        outdir: Optional[os.PathLike] = None,
        subplots_kwargs: Optional[dict[str, Any]] = None,
        plot_kwargs: Optional[dict[str, Any]] = None,
    ):
        """
        Plot a single variable from the history.

        NOTE: The `warmup` argument can be used to drop the first `warmup`
        iterations (as a percent of the total number of iterations) from the
        plot.

        Args:
            val (np.ndarray): The data to plot.
            key (Optional[str]): The key for the data.
            warmup (Optional[float]): The percentage of iterations to drop from the
                beginning of the plot.
            num_chains (Optional[int]): The number of chains to plot.
            title (Optional[str]): The title of the plot.
            outdir (Optional[os.PathLike]): The directory to save the plot to.
            subplots_kwargs (Optional[dict[str, Any]]): Additional arguments for
                subplots.
            plot_kwargs (Optional[dict[str, Any]]): Additional arguments for plotting.
        """
        import matplotlib.pyplot as plt

        LW = plt.rcParams.get("axes.linewidth", 1.75)
        plot_kwargs = {} if plot_kwargs is None else plot_kwargs
        subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs
        figsize = subplots_kwargs.get("figsize", ezplot.set_size())
        subplots_kwargs.update({"figsize": figsize})
        num_chains = 16 if num_chains is None else num_chains

        # tmp = val[0]
        arr = np.array(val)

        subfigs = None
        steps = np.arange(arr.shape[0])
        if warmup is not None and warmup > 0:
            drop = int(warmup * arr.shape[0])
            arr = arr[drop:]
            steps = steps[drop:]

        if len(arr.shape) == 2:
            import seaborn as sns

            _ = subplots_kwargs.pop("constrained_layout", True)
            figsize = (3 * figsize[0], 1.5 * figsize[1])

            fig = plt.figure(figsize=figsize, constrained_layout=True)
            subfigs = fig.subfigures(1, 2)

            gs_kw = {"width_ratios": [1.33, 0.33]}
            (ax, ax1) = subfigs[1].subplots(
                1, 2, sharey=True, gridspec_kw=gs_kw
            )
            ax.grid(alpha=0.2)
            ax1.grid(False)
            color = plot_kwargs.get("color", None)
            label = r"$\langle$" + f" {key} " + r"$\rangle$"
            ax.plot(
                steps, arr.mean(-1), lw=1.5 * LW, label=label, **plot_kwargs
            )
            sns.kdeplot(y=arr.flatten(), ax=ax1, color=color, shade=True)
            ax1.set_xticks([])
            ax1.set_xticklabels([])
            # ax1.set_yticks([])
            # ax1.set_yticklabels([])
            sns.despine(ax=ax, top=True, right=True)
            sns.despine(ax=ax1, top=True, right=True, left=True, bottom=True)
            # ax.legend(loc='best', frameon=False)
            ax1.set_xlabel("")
            # ax1.set_ylabel('')
            # ax.set_yticks(ax.get_yticks())
            # ax.set_yticklabels(ax.get_yticklabels())
            # ax.set_ylabel(key)
            # _ = subfigs[1].subplots_adjust(wspace=-0.75)
            axes = (ax, ax1)
        else:
            if len(arr.shape) == 1:
                fig, ax = plt.subplots(**subplots_kwargs)
                # assert isinstance(ax, plt.Axes)
                ax.plot(steps, arr, **plot_kwargs)
                axes = ax
            elif len(arr.shape) == 3:
                fig, ax = plt.subplots(**subplots_kwargs)
                # assert isinstance(ax, plt.Axes)
                cmap = plt.get_cmap("viridis")
                nlf = arr.shape[1]
                for idx in range(nlf):
                    # y = arr[:, idx, :].mean(-1)
                    # pkwargs = {
                    #     'color': cmap(idx / nlf),
                    #     'label': f'{idx}',
                    # }
                    # ax.plot(steps, y, **pkwargs)
                    label = plot_kwargs.pop("label", None)
                    if label is not None:
                        label = f"{label}-{idx}"
                    y = arr[:, idx, :]
                    color = cmap(idx / y.shape[1])
                    plot_kwargs["color"] = cmap(idx / y.shape[1])
                    if len(y.shape) == 2:
                        # TOO: Plot chains
                        if num_chains > 0:
                            for idx in range(min((num_chains, y.shape[1]))):
                                _ = ax.plot(
                                    steps,
                                    y[:, idx],  # color,
                                    lw=LW / 2.0,
                                    alpha=0.8,
                                    **plot_kwargs,
                                )

                        _ = ax.plot(
                            steps,
                            y.mean(-1),  # color=color,
                            label=label,
                            **plot_kwargs,
                        )
                    else:
                        _ = ax.plot(
                            steps,
                            y,  # color=color,
                            label=label,
                            **plot_kwargs,
                        )
                axes = ax
            else:
                raise ValueError("Unexpected shape encountered")

            ax.set_ylabel(key)

        if num_chains > 0 and len(arr.shape) > 1:
            # lw = LW / 2.
            for idx in range(min(num_chains, arr.shape[1])):
                # ax = subfigs[0].subplots(1, 1)
                # plot values of invidual chains, arr[:, idx]
                # where arr[:, idx].shape = [ndraws, 1]
                ax.plot(
                    steps, arr[:, idx], alpha=0.5, lw=LW / 2.0, **plot_kwargs
                )

        ax.set_xlabel("draw")
        if title is not None:
            fig.suptitle(title)

        if outdir is not None:
            # plt.savefig(Path(outdir).joinpath(f'{key}.svg'),
            #             dpi=400, bbox_inches='tight')
            outfile = Path(outdir).joinpath(f"{key}.svg")
            if outfile.is_file():
                tstamp = ezpz.get_timestamp()
                pngdir = Path(outdir).joinpath("pngs")
                pngdir.mkdir(exist_ok=True, parents=True)
                pngfile = pngdir.joinpath(f"{key}-{tstamp}.png")
                svgfile = Path(outdir).joinpath(f"{key}-{tstamp}.svg")
                plt.savefig(pngfile, dpi=400, bbox_inches="tight")
                plt.savefig(svgfile, dpi=400, bbox_inches="tight")

        return fig, subfigs, axes

    @timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
    def plot_dataArray(
        self,
        val: xr.DataArray,
        key: Optional[str] = None,
        warmup: Optional[float] = 0.0,
        num_chains: Optional[int] = 0,
        title: Optional[str] = None,
        outdir: Optional[str] = None,
        subplots_kwargs: Optional[dict[str, Any]] = None,
        plot_kwargs: Optional[dict[str, Any]] = None,
        verbose: bool = False,
        line_labels: bool = False,
        logfreq: Optional[int] = None,
    ):
        """
        Plot a single variable from the history as an xarray DataArray.

        Args:
            val (xr.DataArray): The data to plot.
            key (Optional[str]): The key for the data.
            warmup (Optional[float]): The percentage of iterations to drop from the
                beginning of the plot.
            num_chains (Optional[int]): The number of chains to plot.
            title (Optional[str]): The title of the plot.
            outdir (Optional[str]): The directory to save the plot to.
            subplots_kwargs (Optional[dict[str, Any]]): Additional arguments for
                subplots.
            plot_kwargs (Optional[dict[str, Any]]): Additional arguments for plotting.
            verbose (bool): Whether to print the plot.
            line_labels (bool): Whether to label lines in the plot.
            logfreq (Optional[int]): The log frequency of the plot.
        """
        import matplotlib.pyplot as plt

        plot_kwargs = {} if plot_kwargs is None else plot_kwargs
        subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs
        ezplot.set_plot_style()
        plt.rcParams["axes.labelcolor"] = "#bdbdbd"
        figsize = subplots_kwargs.get("figsize", ezplot.set_size())
        subplots_kwargs.update({"figsize": figsize})
        subfigs = None
        # if key == 'dt':
        #     warmup = 0.2
        arr = val.values  # shape: [nchains, ndraws]
        # steps = np.arange(len(val.coords['draw']))
        steps = val.coords["draw"]
        if warmup is not None and warmup > 0.0:
            drop = int(warmup * arr.shape[0])
            arr = arr[drop:]
            steps = steps[drop:]
        if len(arr.shape) == 2:
            fig, axes = ezplot.plot_combined(
                val,
                key=key,
                num_chains=num_chains,
                plot_kwargs=plot_kwargs,
                subplots_kwargs=subplots_kwargs,
            )
        else:
            if len(arr.shape) == 1:
                fig, ax = ezplot.subplots(**subplots_kwargs)
                try:
                    ax.plot(steps, arr, **plot_kwargs)
                except ValueError:
                    try:
                        ax.plot(steps, arr[~np.isnan(arr)], **plot_kwargs)
                    except Exception:
                        logger.error(f"Unable to plot {key}! Continuing")
                _ = ax.grid(True, alpha=0.2)
                axes = ax
            elif len(arr.shape) == 3:
                fig, ax = ezplot.subplots(**subplots_kwargs)
                cmap = plt.get_cmap("viridis")
                y = val.mean("chain")
                for idx in range(len(val.coords["leapfrog"])):
                    pkwargs = {
                        "color": cmap(idx / len(val.coords["leapfrog"])),
                        "label": f"{idx}",
                    }
                    ax.plot(steps, y[idx], **pkwargs)
                axes = ax
            else:
                raise ValueError("Unexpected shape encountered")
            ax = plt.gca()
            # assert isinstance(ax, plt.Axes)
            assert key is not None
            _ = ax.set_ylabel(key)
            _ = ax.set_xlabel("draw")
            # if num_chains > 0 and len(arr.shape) > 1:
            #     lw = LW / 2.
            #     #for idx in range(min(num_chains, arr.shape[1])):
            #     nchains = len(val.coords['chains'])
            #     for idx in range(min(nchains, num_chains)):
            #         # ax = subfigs[0].subplots(1, 1)
            #         # plot values of invidual chains, arr[:, idx]
            #         # where arr[:, idx].shape = [ndraws, 1]
            #         ax.plot(steps, val
            #                 alpha=0.5, lw=lw/2., **plot_kwargs)
        if title is not None:
            fig = plt.gcf()
            _ = fig.suptitle(title)
        if logfreq is not None:
            ax = plt.gca()
            xticks = ax.get_xticks()  # type: ignore
            _ = ax.set_xticklabels(  # type: ignore
                [f"{logfreq * int(i)}" for i in xticks]  # type: ignore
            )
        if outdir is not None:
            dirs = {
                "png": Path(outdir).joinpath("pngs/"),
                "svg": Path(outdir).joinpath("svgs/"),
            }
            _ = [i.mkdir(exist_ok=True, parents=True) for i in dirs.values()]
            # from l2hmc.configs import PROJECT_DIR
            # from ezpz
            if verbose:
                logger.info(f"Saving {key} plot to: {Path(outdir).resolve()}")
            for ext, d in dirs.items():
                outfile = d.joinpath(f"{key}.{ext}")
                plt.savefig(outfile, dpi=400, bbox_inches="tight")
        return (fig, subfigs, axes)

    @timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
    def plot_dataset(
        self,
        title: Optional[str] = None,
        nchains: Optional[int] = None,
        outdir: Optional[os.PathLike] = None,
        dataset: Optional[xr.Dataset] = None,
        data: Optional[dict] = None,
        warmup: Optional[int | float] = None,
        # subplots_kwargs: Optional[dict[str, Any]] = None,
        # plot_kwargs: Optional[dict[str, Any]] = None,
    ):
        dataset = (
            dataset
            if dataset is not None
            else (
                self.get_dataset(
                    data=(data if data is not None else self.history),
                    warmup=warmup,
                )
            )
        )
        return ezplot.plot_dataset(
            dataset=dataset,
            nchains=nchains,
            title=title,
            outdir=outdir,
        )

    def plot_2d_xarr(
        self,
        xarr: xr.DataArray,
        label: Optional[str] = None,
        num_chains: Optional[int] = None,
        title: Optional[str] = None,
        outdir: Optional[os.PathLike] = None,
        subplots_kwargs: Optional[dict[str, Any]] = None,
        plot_kwargs: Optional[dict[str, Any]] = None,
    ):
        import matplotlib.pyplot as plt
        import seaborn as sns

        LW = plt.rcParams.get("axes.linewidth", 1.75)
        plot_kwargs = {} if plot_kwargs is None else plot_kwargs
        subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs
        assert len(xarr.shape) == 2
        assert "draw" in xarr.coords and "chain" in xarr.coords
        num_chains = len(xarr.chain) if num_chains is None else num_chains
        # _ = subplots_kwargs.pop('constrained_layout', True)
        figsize = plt.rcParams.get("figure.figsize", (8, 6))
        figsize = (3 * figsize[0], 1.5 * figsize[1])
        fig = plt.figure(figsize=figsize, constrained_layout=True)
        subfigs = fig.subfigures(1, 2)
        gs_kw = {"width_ratios": [1.33, 0.33]}
        (ax, ax1) = subfigs[1].subplots(1, 2, sharey=True, gridspec_kw=gs_kw)
        ax.grid(alpha=0.2)
        ax1.grid(False)
        color = plot_kwargs.get("color", f"C{np.random.randint(6)}")
        label = r"$\langle$" + f" {label} " + r"$\rangle$"
        ax.plot(
            xarr.draw.values,
            xarr.mean("chain"),
            color=color,
            lw=1.5 * LW,
            label=label,
            **plot_kwargs,
        )
        for idx in range(num_chains):
            # ax = subfigs[0].subplots(1, 1)
            # plot values of invidual chains, arr[:, idx]
            # where arr[:, idx].shape = [ndraws, 1]
            # ax0.plot(
            #     xarr.draw.values,
            #     xarr[xarr.chain == idx][0],
            #     lw=1.,
            #     alpha=0.7,
            #     color=color
            # )
            ax.plot(
                xarr.draw.values,
                xarr[xarr.chain == idx][0],
                color=color,
                alpha=0.5,
                lw=LW / 2.0,
                **plot_kwargs,
            )

        axes = (ax, ax1)
        sns.kdeplot(y=xarr.values.flatten(), ax=ax1, color=color, shade=True)
        ax1.set_xticks([])
        ax1.set_xticklabels([])
        # ax1.set_yticks([])
        # ax1.set_yticklabels([])
        sns.despine(ax=ax, top=True, right=True)
        sns.despine(ax=ax1, top=True, right=True, left=True, bottom=True)
        # ax.legend(loc='best', frameon=False)
        ax1.set_xlabel("")
        # ax1.set_ylabel('')
        # ax.set_yticks(ax.get_yticks())
        # ax.set_yticklabels(ax.get_yticklabels())
        # ax.set_ylabel(key)
        # _ = subfigs[1].subplots_adjust(wspace=-0.75)
        # if num_chains > 0 and len(arr.shape) > 1:
        # lw = LW / 2.
        # num_chains = np.min([
        #     16,
        #     len(xarr.coords['chain']),
        # ])
        sns.despine(subfigs[0])
        ax0 = subfigs[0].subplots(1, 1)
        im = xarr.plot(ax=ax0)  # type:ignore
        im.colorbar.set_label(label)  # type:ignore
        # ax0.plot(
        #     xarr.draw.values,
        #     xarr.mean('chain'),
        #     lw=2.,
        #     color=color
        # )
        # for idx in range(min(num_chains, i.shape[1])):
        ax.set_xlabel("draw")
        if title is not None:
            fig.suptitle(title)

        if outdir is not None:
            assert label is not None
            # plt.savefig(Path(outdir).joinpath(f'{key}.svg'),
            #             dpi=400, bbox_inches='tight')
            outfile = Path(outdir).joinpath(f"{label}.svg")
            if outfile.is_file():
                tstamp = get_timestamp("%Y-%m-%d-%H%M%S")
                pngdir = Path(outdir).joinpath("pngs")
                pngdir.mkdir(exist_ok=True, parents=True)
                pngfile = pngdir.joinpath(f"{label}-{tstamp}.png")
                svgfile = Path(outdir).joinpath(f"{label}-{tstamp}.svg")
                plt.savefig(pngfile, dpi=400, bbox_inches="tight")
                plt.savefig(svgfile, dpi=400, bbox_inches="tight")

    @timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
    def tplot_all(
        self,
        outdir: Optional[os.PathLike] = None,
        warmup: Optional[float] = 0.0,
        append: bool = True,
        xkey: Optional[str] = None,
        dataset: Optional[xr.Dataset] = None,
        data: Optional[dict] = None,
        logfreq: Optional[int] = None,
        plot_type: Optional[str] = None,
        verbose: bool = False,
    ):
        dataset = (
            dataset
            if dataset is not None
            else (
                self.get_dataset(
                    data=(data if data is not None else self.history),
                    warmup=warmup,
                )
            )
        )

        outdir = Path(os.getcwd()) if outdir is None else Path(outdir)
        logger.info(f"Saving tplots to {outdir.as_posix()}")
        for _, (key, val) in enumerate(dataset.items()):
            if (xkey is not None and key == xkey) or xkey in ["iter", "draw"]:
                continue
            if len(val.values) > 0:
                self._tplot(
                    y=val.values,
                    x=None,
                    xlabel="iter",
                    plot_type=plot_type,
                    ylabel=str(key),
                    append=append,
                    title=f"{key} [{get_timestamp()}]",
                    verbose=verbose,
                    outfile=outdir.joinpath(f"{key}.txt").as_posix(),
                    logfreq=logfreq,
                )
            else:
                logger.warning(
                    f"No data found in {key=}: {len(val.values)=} <= 0"
                )

    @timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
    def plot_all(
        self,
        num_chains: int = 128,
        warmup: Optional[float | int] = 0.0,
        title: Optional[str] = None,
        verbose: bool = False,
        outdir: Optional[os.PathLike] = None,
        subplots_kwargs: Optional[dict[str, Any]] = None,
        plot_kwargs: Optional[dict[str, Any]] = None,
        dataset: Optional[xr.Dataset] = None,
        data: Optional[dict] = None,
    ):
        import matplotlib.pyplot as plt
        import seaborn as sns

        plot_kwargs = {} if plot_kwargs is None else plot_kwargs
        subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs

        dataset = (
            dataset
            if dataset is not None
            else (
                self.get_dataset(
                    data=(data if data is not None else self.history),
                    warmup=warmup,
                )
            )
        )

        _ = ezplot.make_ridgeplots(
            dataset,
            outdir=outdir,
            drop_nans=True,
            drop_zeros=False,
            num_chains=num_chains,
            cmap="viridis",
            save_plot=(outdir is not None),
        )

        for idx, (key, val) in enumerate(dataset.data_vars.items()):
            color = f"C{idx % 9}"
            plot_kwargs["color"] = color

            fig, subfigs, ax = self.plot(
                val=val.values.T.real,
                key=str(key),
                title=title,
                outdir=outdir,
                warmup=warmup,
                num_chains=num_chains,
                plot_kwargs=plot_kwargs,
                subplots_kwargs=subplots_kwargs,
            )
            if fig is not None:
                _ = sns.despine(
                    fig, top=True, right=True, bottom=True, left=True
                )

            # _ = plt.grid(True, alpha=0.4)
            if subfigs is not None:
                # edgecolor = plt.rcParams['axes.edgecolor']
                plt.rcParams["axes.edgecolor"] = plt.rcParams["axes.facecolor"]
                ax = subfigs[0].subplots(1, 1)
                # ax = fig[1].subplots(constrained_layout=True)
                _ = xplt.pcolormesh(
                    val, "draw", "chain", ax=ax, robust=True, add_colorbar=True
                )
                # im = val.plot(ax=ax, cbar_kwargs=cbar_kwargs)
                # im.colorbar.set_label(f'{key}')  # , labelpad=1.25)
                sns.despine(
                    subfigs[0], top=True, right=True, left=True, bottom=True
                )
            if outdir is not None:
                dirs = {
                    "png": Path(outdir).joinpath("pngs/"),
                    "svg": Path(outdir).joinpath("svgs/"),
                }
                _ = [
                    i.mkdir(exist_ok=True, parents=True) for i in dirs.values()
                ]
                # if verbose:
                logger.info(f"Saving {key} plot to: {Path(outdir).resolve()}")
                for ext, d in dirs.items():
                    outfile = d.joinpath(f"{key}.{ext}")
                    if outfile.is_file():
                        outfile = d.joinpath(f"{key}-subfig.{ext}")
                    # logger.info(f"Saving {key}.ext to: {outfile}")
                    if verbose:
                        logger.info(
                            f"Saving {key} plot to: {outfile.resolve()}"
                        )
                    plt.savefig(outfile, dpi=400, bbox_inches="tight")
            if is_interactive():
                plt.show()

        return dataset

    def history_to_dict(self) -> dict:
        # return {k: np.stack(v).squeeze() for k, v in self.history.items()}
        return {
            k: torch.Tensor(v).numpy(force=True)
            for k, v in self.history.items()
        }

    def to_DataArray(
        self,
        x: Union[list, np.ndarray, torch.Tensor],
        warmup: Optional[float] = 0.0,
    ) -> xr.DataArray:
        if isinstance(x, list) and isinstance(x[0], torch.Tensor):
            x = torch.Tensor(x).numpy(force=True)
        try:
            arr = grab_tensor(x)
        except ValueError:
            arr = np.array(x).real
            # arr = np.array(x)
            logger.info(f"len(x): {len(x)}")
            logger.info(f"x[0].shape: {x[0].shape}")
            logger.info(f"arr.shape: {arr.shape}")
        assert isinstance(arr, np.ndarray)
        if warmup is not None and warmup > 0 and len(arr) > 0:
            if isinstance(warmup, int):
                warmup = warmup / len(arr)
            # drop = int(warmup * arr.shape[0])
            drop = int(warmup * len(arr))
            arr = arr[drop:]
        # steps = np.arange(len(arr))
        if len(arr.shape) == 1:  # [ndraws]
            ndraws = arr.shape[0]
            dims = ["draw"]
            coords = [np.arange(len(arr))]
            return xr.DataArray(arr, dims=dims, coords=coords)  # type:ignore

        if len(arr.shape) == 2:  # [nchains, ndraws]
            arr = arr.T
            nchains, ndraws = arr.shape
            dims = ("chain", "draw")
            coords = [np.arange(nchains), np.arange(ndraws)]
            return xr.DataArray(arr, dims=dims, coords=coords)  # type:ignore

        if len(arr.shape) == 3:  # [nchains, nlf, ndraws]
            arr = arr.T
            nchains, nlf, ndraws = arr.shape
            dims = ("chain", "leapfrog", "draw")
            coords = [np.arange(nchains), np.arange(nlf), np.arange(ndraws)]
            return xr.DataArray(arr, dims=dims, coords=coords)  # type:ignore

        else:
            print(f"arr.shape: {arr.shape}")
            raise ValueError("Invalid shape encountered")

    def get_dataset(
        self,
        data: Optional[
            dict[str, Union[list, np.ndarray, torch.Tensor]]
        ] = None,
        warmup: Optional[float] = 0.0,
    ):
        data = self.history_to_dict() if data is None else data
        data_vars = {}
        for key, val in data.items():
            name = key.replace("/", "_")
            try:
                data_vars[name] = self.to_DataArray(val, warmup)
            except ValueError:
                logger.error(
                    f"Unable to create DataArray for {key}! Skipping!"
                )
                logger.error(f"{key}.shape= {np.stack(val).shape}")  # type:ignore
        return xr.Dataset(data_vars)

    @timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
    def save_dataset(
        self,
        outdir: PathLike,
        fname: str = "dataset",
        use_hdf5: bool = True,
        data: Optional[
            dict[str, Union[list, np.ndarray, torch.Tensor]]
        ] = None,
        dataset: Optional[xr.Dataset] = None,
        warmup: Optional[int | float] = None,
        **kwargs,
    ) -> Path:
        dataset = (
            dataset
            if dataset is not None
            else (
                self.get_dataset(
                    data=(data if data is not None else self.history),
                    warmup=warmup,
                )
            )
        )
        return save_dataset(
            dataset,
            outdir=outdir,
            fname=fname,
            use_hdf5=use_hdf5,
            **kwargs,
        )

    @timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
    def finalize(
        self,
        outdir: Optional[PathLike] = None,
        run_name: Optional[str] = None,
        dataset_fname: Optional[str] = None,
        num_chains: int = 128,
        warmup: Optional[int | float] = 0.0,
        verbose: bool = False,
        save: bool = True,
        plot: bool = True,
        append_tplot: bool = True,
        title: Optional[str] = None,
        data: Optional[
            dict[str, Union[list, np.ndarray, torch.Tensor]]
        ] = None,
        dataset: Optional[xr.Dataset] = None,
        xkey: Optional[str] = None,
        plot_kwargs: Optional[dict[str, Any]] = None,
        subplots_kwargs: Optional[dict[str, Any]] = None,
        tplot_type: Optional[str] = None,
    ) -> xr.Dataset:
        dataset = (
            dataset
            if dataset is not None
            else (
                self.get_dataset(
                    data=(data if data is not None else self.history),
                    warmup=warmup,
                )
            )
        )
        run_name = (
            f"History-{get_timestamp()}" if run_name is None else run_name
        )
        fallback_outdir = Path(os.getcwd()).joinpath("outputs")
        if run_name is not None:
            fallback_outdir = fallback_outdir.joinpath(
                run_name, get_timestamp()
            )
        outdir = (
            # Path(os.getcwd()).joinpath('outputs')
            fallback_outdir if outdir is None else Path(outdir)
        )
        outdir = outdir.joinpath(run_name)
        if plot:
            plotdir = outdir.joinpath("plots")
            tplotdir = plotdir.joinpath("tplot")
            mplotdir = plotdir.joinpath("mplot")
            tplotdir.mkdir(exist_ok=True, parents=True)
            mplotdir.mkdir(exist_ok=True, parents=True)
            _ = self.plot_all(
                dataset=dataset,
                outdir=mplotdir,
                verbose=verbose,
                num_chains=num_chains,
                warmup=warmup,
                title=title,
                plot_kwargs=plot_kwargs,
                subplots_kwargs=subplots_kwargs,
            )
            _ = self.tplot_all(
                dataset=dataset,
                outdir=tplotdir,
                warmup=warmup,
                append=append_tplot,
                plot_type=tplot_type,
                xkey=xkey,
                verbose=verbose,
            )
        if save:
            try:
                import h5py

                use_hdf5 = True
            except ImportError:
                logger.warning(
                    "h5py not found! Saving dataset as netCDF instead."
                )
                use_hdf5 = False

            fname = "dataset" if dataset_fname is None else dataset_fname
            _ = self.save_dataset(
                dataset=dataset, outdir=outdir, fname=fname, use_hdf5=use_hdf5
            )
        return dataset

__init__(keys=None)

Initialize the History object.

Parameters:

Name Type Description Default
keys Optional[list[str]]

List of keys to initialize the history with. If None, initializes with an empty list.

None
Source code in src/ezpz/history.py
def __init__(self, keys: Optional[list[str]] = None) -> None:
    """
    Initialize the History object.

    Args:
        keys (Optional[list[str]]): List of keys to initialize the history with.
            If None, initializes with an empty list.
    """
    self.keys = [] if keys is None else keys
    self.history = {}

plot(val, key=None, warmup=0.0, num_chains=128, title=None, outdir=None, subplots_kwargs=None, plot_kwargs=None)

Plot a single variable from the history.

NOTE: The warmup argument can be used to drop the first warmup iterations (as a percent of the total number of iterations) from the plot.

Parameters:

Name Type Description Default
val ndarray

The data to plot.

required
key Optional[str]

The key for the data.

None
warmup Optional[float]

The percentage of iterations to drop from the beginning of the plot.

0.0
num_chains Optional[int]

The number of chains to plot.

128
title Optional[str]

The title of the plot.

None
outdir Optional[PathLike]

The directory to save the plot to.

None
subplots_kwargs Optional[dict[str, Any]]

Additional arguments for subplots.

None
plot_kwargs Optional[dict[str, Any]]

Additional arguments for plotting.

None
Source code in src/ezpz/history.py
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def plot(
    self,
    val: np.ndarray,
    key: Optional[str] = None,
    warmup: Optional[float] = 0.0,
    num_chains: Optional[int] = 128,
    title: Optional[str] = None,
    outdir: Optional[os.PathLike] = None,
    subplots_kwargs: Optional[dict[str, Any]] = None,
    plot_kwargs: Optional[dict[str, Any]] = None,
):
    """
    Plot a single variable from the history.

    NOTE: The `warmup` argument can be used to drop the first `warmup`
    iterations (as a percent of the total number of iterations) from the
    plot.

    Args:
        val (np.ndarray): The data to plot.
        key (Optional[str]): The key for the data.
        warmup (Optional[float]): The percentage of iterations to drop from the
            beginning of the plot.
        num_chains (Optional[int]): The number of chains to plot.
        title (Optional[str]): The title of the plot.
        outdir (Optional[os.PathLike]): The directory to save the plot to.
        subplots_kwargs (Optional[dict[str, Any]]): Additional arguments for
            subplots.
        plot_kwargs (Optional[dict[str, Any]]): Additional arguments for plotting.
    """
    import matplotlib.pyplot as plt

    LW = plt.rcParams.get("axes.linewidth", 1.75)
    plot_kwargs = {} if plot_kwargs is None else plot_kwargs
    subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs
    figsize = subplots_kwargs.get("figsize", ezplot.set_size())
    subplots_kwargs.update({"figsize": figsize})
    num_chains = 16 if num_chains is None else num_chains

    # tmp = val[0]
    arr = np.array(val)

    subfigs = None
    steps = np.arange(arr.shape[0])
    if warmup is not None and warmup > 0:
        drop = int(warmup * arr.shape[0])
        arr = arr[drop:]
        steps = steps[drop:]

    if len(arr.shape) == 2:
        import seaborn as sns

        _ = subplots_kwargs.pop("constrained_layout", True)
        figsize = (3 * figsize[0], 1.5 * figsize[1])

        fig = plt.figure(figsize=figsize, constrained_layout=True)
        subfigs = fig.subfigures(1, 2)

        gs_kw = {"width_ratios": [1.33, 0.33]}
        (ax, ax1) = subfigs[1].subplots(
            1, 2, sharey=True, gridspec_kw=gs_kw
        )
        ax.grid(alpha=0.2)
        ax1.grid(False)
        color = plot_kwargs.get("color", None)
        label = r"$\langle$" + f" {key} " + r"$\rangle$"
        ax.plot(
            steps, arr.mean(-1), lw=1.5 * LW, label=label, **plot_kwargs
        )
        sns.kdeplot(y=arr.flatten(), ax=ax1, color=color, shade=True)
        ax1.set_xticks([])
        ax1.set_xticklabels([])
        # ax1.set_yticks([])
        # ax1.set_yticklabels([])
        sns.despine(ax=ax, top=True, right=True)
        sns.despine(ax=ax1, top=True, right=True, left=True, bottom=True)
        # ax.legend(loc='best', frameon=False)
        ax1.set_xlabel("")
        # ax1.set_ylabel('')
        # ax.set_yticks(ax.get_yticks())
        # ax.set_yticklabels(ax.get_yticklabels())
        # ax.set_ylabel(key)
        # _ = subfigs[1].subplots_adjust(wspace=-0.75)
        axes = (ax, ax1)
    else:
        if len(arr.shape) == 1:
            fig, ax = plt.subplots(**subplots_kwargs)
            # assert isinstance(ax, plt.Axes)
            ax.plot(steps, arr, **plot_kwargs)
            axes = ax
        elif len(arr.shape) == 3:
            fig, ax = plt.subplots(**subplots_kwargs)
            # assert isinstance(ax, plt.Axes)
            cmap = plt.get_cmap("viridis")
            nlf = arr.shape[1]
            for idx in range(nlf):
                # y = arr[:, idx, :].mean(-1)
                # pkwargs = {
                #     'color': cmap(idx / nlf),
                #     'label': f'{idx}',
                # }
                # ax.plot(steps, y, **pkwargs)
                label = plot_kwargs.pop("label", None)
                if label is not None:
                    label = f"{label}-{idx}"
                y = arr[:, idx, :]
                color = cmap(idx / y.shape[1])
                plot_kwargs["color"] = cmap(idx / y.shape[1])
                if len(y.shape) == 2:
                    # TOO: Plot chains
                    if num_chains > 0:
                        for idx in range(min((num_chains, y.shape[1]))):
                            _ = ax.plot(
                                steps,
                                y[:, idx],  # color,
                                lw=LW / 2.0,
                                alpha=0.8,
                                **plot_kwargs,
                            )

                    _ = ax.plot(
                        steps,
                        y.mean(-1),  # color=color,
                        label=label,
                        **plot_kwargs,
                    )
                else:
                    _ = ax.plot(
                        steps,
                        y,  # color=color,
                        label=label,
                        **plot_kwargs,
                    )
            axes = ax
        else:
            raise ValueError("Unexpected shape encountered")

        ax.set_ylabel(key)

    if num_chains > 0 and len(arr.shape) > 1:
        # lw = LW / 2.
        for idx in range(min(num_chains, arr.shape[1])):
            # ax = subfigs[0].subplots(1, 1)
            # plot values of invidual chains, arr[:, idx]
            # where arr[:, idx].shape = [ndraws, 1]
            ax.plot(
                steps, arr[:, idx], alpha=0.5, lw=LW / 2.0, **plot_kwargs
            )

    ax.set_xlabel("draw")
    if title is not None:
        fig.suptitle(title)

    if outdir is not None:
        # plt.savefig(Path(outdir).joinpath(f'{key}.svg'),
        #             dpi=400, bbox_inches='tight')
        outfile = Path(outdir).joinpath(f"{key}.svg")
        if outfile.is_file():
            tstamp = ezpz.get_timestamp()
            pngdir = Path(outdir).joinpath("pngs")
            pngdir.mkdir(exist_ok=True, parents=True)
            pngfile = pngdir.joinpath(f"{key}-{tstamp}.png")
            svgfile = Path(outdir).joinpath(f"{key}-{tstamp}.svg")
            plt.savefig(pngfile, dpi=400, bbox_inches="tight")
            plt.savefig(svgfile, dpi=400, bbox_inches="tight")

    return fig, subfigs, axes

plot_dataArray(val, key=None, warmup=0.0, num_chains=0, title=None, outdir=None, subplots_kwargs=None, plot_kwargs=None, verbose=False, line_labels=False, logfreq=None)

Plot a single variable from the history as an xarray DataArray.

Parameters:

Name Type Description Default
val DataArray

The data to plot.

required
key Optional[str]

The key for the data.

None
warmup Optional[float]

The percentage of iterations to drop from the beginning of the plot.

0.0
num_chains Optional[int]

The number of chains to plot.

0
title Optional[str]

The title of the plot.

None
outdir Optional[str]

The directory to save the plot to.

None
subplots_kwargs Optional[dict[str, Any]]

Additional arguments for subplots.

None
plot_kwargs Optional[dict[str, Any]]

Additional arguments for plotting.

None
verbose bool

Whether to print the plot.

False
line_labels bool

Whether to label lines in the plot.

False
logfreq Optional[int]

The log frequency of the plot.

None
Source code in src/ezpz/history.py
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def plot_dataArray(
    self,
    val: xr.DataArray,
    key: Optional[str] = None,
    warmup: Optional[float] = 0.0,
    num_chains: Optional[int] = 0,
    title: Optional[str] = None,
    outdir: Optional[str] = None,
    subplots_kwargs: Optional[dict[str, Any]] = None,
    plot_kwargs: Optional[dict[str, Any]] = None,
    verbose: bool = False,
    line_labels: bool = False,
    logfreq: Optional[int] = None,
):
    """
    Plot a single variable from the history as an xarray DataArray.

    Args:
        val (xr.DataArray): The data to plot.
        key (Optional[str]): The key for the data.
        warmup (Optional[float]): The percentage of iterations to drop from the
            beginning of the plot.
        num_chains (Optional[int]): The number of chains to plot.
        title (Optional[str]): The title of the plot.
        outdir (Optional[str]): The directory to save the plot to.
        subplots_kwargs (Optional[dict[str, Any]]): Additional arguments for
            subplots.
        plot_kwargs (Optional[dict[str, Any]]): Additional arguments for plotting.
        verbose (bool): Whether to print the plot.
        line_labels (bool): Whether to label lines in the plot.
        logfreq (Optional[int]): The log frequency of the plot.
    """
    import matplotlib.pyplot as plt

    plot_kwargs = {} if plot_kwargs is None else plot_kwargs
    subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs
    ezplot.set_plot_style()
    plt.rcParams["axes.labelcolor"] = "#bdbdbd"
    figsize = subplots_kwargs.get("figsize", ezplot.set_size())
    subplots_kwargs.update({"figsize": figsize})
    subfigs = None
    # if key == 'dt':
    #     warmup = 0.2
    arr = val.values  # shape: [nchains, ndraws]
    # steps = np.arange(len(val.coords['draw']))
    steps = val.coords["draw"]
    if warmup is not None and warmup > 0.0:
        drop = int(warmup * arr.shape[0])
        arr = arr[drop:]
        steps = steps[drop:]
    if len(arr.shape) == 2:
        fig, axes = ezplot.plot_combined(
            val,
            key=key,
            num_chains=num_chains,
            plot_kwargs=plot_kwargs,
            subplots_kwargs=subplots_kwargs,
        )
    else:
        if len(arr.shape) == 1:
            fig, ax = ezplot.subplots(**subplots_kwargs)
            try:
                ax.plot(steps, arr, **plot_kwargs)
            except ValueError:
                try:
                    ax.plot(steps, arr[~np.isnan(arr)], **plot_kwargs)
                except Exception:
                    logger.error(f"Unable to plot {key}! Continuing")
            _ = ax.grid(True, alpha=0.2)
            axes = ax
        elif len(arr.shape) == 3:
            fig, ax = ezplot.subplots(**subplots_kwargs)
            cmap = plt.get_cmap("viridis")
            y = val.mean("chain")
            for idx in range(len(val.coords["leapfrog"])):
                pkwargs = {
                    "color": cmap(idx / len(val.coords["leapfrog"])),
                    "label": f"{idx}",
                }
                ax.plot(steps, y[idx], **pkwargs)
            axes = ax
        else:
            raise ValueError("Unexpected shape encountered")
        ax = plt.gca()
        # assert isinstance(ax, plt.Axes)
        assert key is not None
        _ = ax.set_ylabel(key)
        _ = ax.set_xlabel("draw")
        # if num_chains > 0 and len(arr.shape) > 1:
        #     lw = LW / 2.
        #     #for idx in range(min(num_chains, arr.shape[1])):
        #     nchains = len(val.coords['chains'])
        #     for idx in range(min(nchains, num_chains)):
        #         # ax = subfigs[0].subplots(1, 1)
        #         # plot values of invidual chains, arr[:, idx]
        #         # where arr[:, idx].shape = [ndraws, 1]
        #         ax.plot(steps, val
        #                 alpha=0.5, lw=lw/2., **plot_kwargs)
    if title is not None:
        fig = plt.gcf()
        _ = fig.suptitle(title)
    if logfreq is not None:
        ax = plt.gca()
        xticks = ax.get_xticks()  # type: ignore
        _ = ax.set_xticklabels(  # type: ignore
            [f"{logfreq * int(i)}" for i in xticks]  # type: ignore
        )
    if outdir is not None:
        dirs = {
            "png": Path(outdir).joinpath("pngs/"),
            "svg": Path(outdir).joinpath("svgs/"),
        }
        _ = [i.mkdir(exist_ok=True, parents=True) for i in dirs.values()]
        # from l2hmc.configs import PROJECT_DIR
        # from ezpz
        if verbose:
            logger.info(f"Saving {key} plot to: {Path(outdir).resolve()}")
        for ext, d in dirs.items():
            outfile = d.joinpath(f"{key}.{ext}")
            plt.savefig(outfile, dpi=400, bbox_inches="tight")
    return (fig, subfigs, axes)

update(metrics, precision=6, use_wandb=True, commit=True, summarize=True)

Update the history with a dictionary of metrics.

Parameters:

Name Type Description Default
metrics dict

Dictionary of metrics to update the history with.

required
precision int

Precision for summarizing the metrics.

6
use_wandb Optional[bool]

Whether to log the metrics to Weights & Biases.

True
commit Optional[bool]

Whether to commit the log to Weights & Biases.

True
summarize Optional[bool]

Whether to summarize the metrics.

True
Source code in src/ezpz/history.py
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def update(
    self,
    metrics: dict,
    precision: int = 6,
    use_wandb: Optional[bool] = True,
    commit: Optional[bool] = True,
    summarize: Optional[bool] = True,
) -> str:
    """
    Update the history with a dictionary of metrics.

    Args:
        metrics (dict): Dictionary of metrics to update the history with.
        precision (int): Precision for summarizing the metrics.
        use_wandb (Optional[bool]): Whether to log the metrics to Weights & Biases.
        commit (Optional[bool]): Whether to commit the log to Weights & Biases.
        summarize (Optional[bool]): Whether to summarize the metrics.
    """
    for key, val in metrics.items():
        # if isinstance(val, (list, np.ndarray, torch.Tensor)):
        #     val = grab_tensor(val)
        try:
            self.history[key].append(val)
        except KeyError:
            self.history[key] = [val]
    if (
        wandb is not None
        and use_wandb
        # and not WANDB_DISABLED
        and getattr(wandb, "run", None) is not None
    ):
        wandb.log(metrics, commit=commit)
    if summarize:
        return summarize_dict(metrics, precision=precision)
    return ""

StopWatch

Bases: ContextDecorator

A simple stopwatch context manager for measuring time taken by a block of code.

Source code in src/ezpz/history.py
class StopWatch(ContextDecorator):
    """
    A simple stopwatch context manager for measuring time taken by a block of code.
    """

    def __init__(
        self,
        msg: str,
        wbtag: Optional[str] = None,
        iter: Optional[int] = None,
        commit: Optional[bool] = False,
        prefix: str = "StopWatch/",
        log_output: bool = True,
    ) -> None:
        """
        Initialize the StopWatch.

        Args:
            msg (str): Message to log when the stopwatch is started.
            wbtag (Optional[str]): Optional tag for logging to Weights & Biases.
            iter (Optional[int]): Optional iteration number to log.
            commit (Optional[bool]): Whether to commit the log to Weights & Biases.
            prefix (str): Prefix for the log data.
            log_output (bool): Whether to log the output message.
        """
        self.msg = msg
        self.data = {}
        self.iter = iter if iter is not None else None
        self.prefix = prefix
        self.wbtag = wbtag if wbtag is not None else None
        self.log_output = log_output
        self.commit = commit
        if wbtag is not None:
            self.data = {
                f"{self.wbtag}/dt": None,
            }
            if iter is not None:
                self.data |= {
                    f"{self.wbtag}/iter": self.iter,
                }

    def __enter__(self):
        """Start the stopwatch."""
        self.time = time.perf_counter()
        return self

    def __exit__(self, t, v, traceback):
        """Stop the stopwatch and log the time taken."""
        dt = time.perf_counter() - self.time
        # if self.wbtag is not None and wandb.run is not None:
        # if len(self.data) > 0 and wandb.run is not None:
        try:
            if (
                len(self.data) > 0
                and wandb is not None
                and (wbrun := getattr(wandb, "run", None)) is not None
            ):
                self.data |= {f"{self.wbtag}/dt": dt}
                wbrun.log({self.prefix: self.data}, commit=self.commit)
        except Exception as e:
            logger.error(f"Unable to log to wandb: {e}")
        if self.log_output:
            logger.info(f"{self.msg} took {dt:.3f} seconds")

__enter__()

Start the stopwatch.

Source code in src/ezpz/history.py
def __enter__(self):
    """Start the stopwatch."""
    self.time = time.perf_counter()
    return self

__exit__(t, v, traceback)

Stop the stopwatch and log the time taken.

Source code in src/ezpz/history.py
def __exit__(self, t, v, traceback):
    """Stop the stopwatch and log the time taken."""
    dt = time.perf_counter() - self.time
    # if self.wbtag is not None and wandb.run is not None:
    # if len(self.data) > 0 and wandb.run is not None:
    try:
        if (
            len(self.data) > 0
            and wandb is not None
            and (wbrun := getattr(wandb, "run", None)) is not None
        ):
            self.data |= {f"{self.wbtag}/dt": dt}
            wbrun.log({self.prefix: self.data}, commit=self.commit)
    except Exception as e:
        logger.error(f"Unable to log to wandb: {e}")
    if self.log_output:
        logger.info(f"{self.msg} took {dt:.3f} seconds")

__init__(msg, wbtag=None, iter=None, commit=False, prefix='StopWatch/', log_output=True)

Initialize the StopWatch.

Parameters:

Name Type Description Default
msg str

Message to log when the stopwatch is started.

required
wbtag Optional[str]

Optional tag for logging to Weights & Biases.

None
iter Optional[int]

Optional iteration number to log.

None
commit Optional[bool]

Whether to commit the log to Weights & Biases.

False
prefix str

Prefix for the log data.

'StopWatch/'
log_output bool

Whether to log the output message.

True
Source code in src/ezpz/history.py
def __init__(
    self,
    msg: str,
    wbtag: Optional[str] = None,
    iter: Optional[int] = None,
    commit: Optional[bool] = False,
    prefix: str = "StopWatch/",
    log_output: bool = True,
) -> None:
    """
    Initialize the StopWatch.

    Args:
        msg (str): Message to log when the stopwatch is started.
        wbtag (Optional[str]): Optional tag for logging to Weights & Biases.
        iter (Optional[int]): Optional iteration number to log.
        commit (Optional[bool]): Whether to commit the log to Weights & Biases.
        prefix (str): Prefix for the log data.
        log_output (bool): Whether to log the output message.
    """
    self.msg = msg
    self.data = {}
    self.iter = iter if iter is not None else None
    self.prefix = prefix
    self.wbtag = wbtag if wbtag is not None else None
    self.log_output = log_output
    self.commit = commit
    if wbtag is not None:
        self.data = {
            f"{self.wbtag}/dt": None,
        }
        if iter is not None:
            self.data |= {
                f"{self.wbtag}/iter": self.iter,
            }

breakpoint(rank=0)

Set a breakpoint, but only on a single rank. All other ranks will wait for you to be done with the breakpoint before continuing.

Parameters:

Name Type Description Default
rank int

Which rank to break on. Default: 0

0
Source code in src/ezpz/utils.py
def breakpoint(rank: int = 0):
    """
    Set a breakpoint, but only on a single rank.  All other ranks will wait for you to be
    done with the breakpoint before continuing.

    Args:
        rank (int): Which rank to break on.  Default: ``0``
    """
    if get_rank() == rank:
        pdb = DistributedPdb()
        pdb.message(
            "\n!!! ATTENTION !!!\n\n"
            f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n"
        )
        pdb.set_trace()
    tdist.barrier()

check(framework='pytorch', backend='deepspeed', port='5432')

Check if the framework is installed and working

Source code in src/ezpz/dist.py
def check(
    framework: str = "pytorch",
    backend: str = "deepspeed",
    port: int | str = "5432",
):
    """Check if the framework is installed and working"""
    from ezpz.configs import FRAMEWORKS

    if framework in FRAMEWORKS["pytorch"]:
        _ = setup_torch(
            backend=backend,
            port=str(port),
        )
    elif framework in FRAMEWORKS["tensorflow"]:
        _ = setup_tensorflow()
    else:
        raise ValueError(f"Unable to parse framework: {framework}")

cleanup()

Cleanup the distributed environment. This function destroys the process group if it is initialized.

Example

cleanup()

Source code in src/ezpz/dist.py
def cleanup() -> None:
    """
    Cleanup the distributed environment.
    This function destroys the process group if it is initialized.

    Example:
        >>> cleanup()
    """
    if torch.distributed.is_initialized():  # type:ignore
        torch.distributed.destroy_process_group()  # type:ignore

destroy_tensor_parallel()

Set the groups to none.

Source code in src/ezpz/tp/__init__.py
def destroy_tensor_parallel() -> None:
    """Set the groups to none."""
    global _TENSOR_PARALLEL_GROUP
    _TENSOR_PARALLEL_GROUP = None
    global _TENSOR_PARALLEL_RANKS
    _TENSOR_PARALLEL_RANKS = None

    global _DATA_PARALLEL_GROUP
    _DATA_PARALLEL_GROUP = None
    global _DATA_PARALLEL_RANKS
    _DATA_PARALLEL_RANKS = None

    global _PIPELINE_PARALLEL_GROUP
    _PIPELINE_PARALLEL_GROUP = None
    global _PIPELINE_PARALLEL_RANKS
    _PIPELINE_PARALLEL_RANKS = None

    global _CONTEXT_PARALLEL_GROUP
    _CONTEXT_PARALLEL_GROUP = None
    global _CONTEXT_PARALLEL_GROUP_RANKS
    _CONTEXT_PARALLEL_GROUP_RANKS = None

ensure_divisibility(numerator, denominator)

Ensure that numerator is divisible by the denominator.

Source code in src/ezpz/tp/utils.py
4
5
6
7
8
def ensure_divisibility(numerator: int, denominator: int) -> None:
    """Ensure that numerator is divisible by the denominator."""
    assert numerator % denominator == 0, '{} is not divisible by {}'.format(
        numerator, denominator
    )

get_context_manager(rank=None, outdir=None, strict=True)

Returns a context manager for profiling code blocks using PyInstrument.

Parameters:

Name Type Description Default
rank Optional[int]

The rank of the process (default: None). If provided, the profiler will only run if rank is 0.

None
outdir Optional[str]

The output directory for saving profiles. Defaults to ezpz.OUTPUTS_DIR.

None
strict Optional[bool]

If True, the profiler will only run if "PYINSTRUMENT_PROFILER" is set in the environment. Defaults to True.

True

Returns:

Name Type Description
AbstractContextManager AbstractContextManager

A context manager that starts and stops the PyInstrument profiler.

Source code in src/ezpz/profile.py
def get_context_manager(
    rank: Optional[int] = None,
    outdir: Optional[str] = None,
    strict: Optional[bool] = True,
) -> AbstractContextManager:
    """
    Returns a context manager for profiling code blocks using PyInstrument.

    Args:
        rank (Optional[int]): The rank of the process (default: None).
            If provided, the profiler will only run if rank is 0.
        outdir (Optional[str]): The output directory for saving profiles.
            Defaults to `ezpz.OUTPUTS_DIR`.
        strict (Optional[bool]): If True, the profiler will only run if
            "PYINSTRUMENT_PROFILER" is set in the environment.
            Defaults to True.

    Returns:
        AbstractContextManager: A context manager that starts and stops
            the PyInstrument profiler.
    """
    d = ezpz.OUTPUTS_DIR if outdir is None else outdir
    fp = Path(d)
    fp = fp.joinpath("ezpz", "pyinstrument_profiles")

    if strict:
        if os.environ.get("PYINSTRUMENT_PROFILER", None) is not None:
            return PyInstrumentProfiler(rank=rank, outdir=fp.as_posix())
        return nullcontext()
    if rank is None or rank == 0:
        return PyInstrumentProfiler(rank=rank, outdir=fp.as_posix())
    # if rank == 0:
    #     return PyInstrumentProfiler(rank=rank, outdir=outdir)
    return nullcontext()

get_context_parallel_group()

Get the context parallel group the caller rank belongs to.

Source code in src/ezpz/tp/__init__.py
def get_context_parallel_group() -> tdist.ProcessGroup:
    """Get the context parallel group the caller rank belongs to."""
    assert (
        _CONTEXT_PARALLEL_GROUP is not None
    ), "context parallel group is not initialized"
    return _CONTEXT_PARALLEL_GROUP

get_context_parallel_rank()

Return my rank for the context parallel group.

Source code in src/ezpz/tp/__init__.py
def get_context_parallel_rank() -> int:
    """Return my rank for the context parallel group."""
    return tdist.get_rank(group=get_context_parallel_group())

get_context_parallel_ranks()

Return context parallel ranks for the context parallel group.

Source code in src/ezpz/tp/__init__.py
def get_context_parallel_ranks() -> List[int]:
    """Return context parallel ranks for the context parallel group."""
    assert (
        _CONTEXT_PARALLEL_GROUP_RANKS is not None
    ), "context parallel group is not initialized"
    return _CONTEXT_PARALLEL_GROUP_RANKS

get_context_parallel_world_size()

Return world size for the context parallel group.

Source code in src/ezpz/tp/__init__.py
def get_context_parallel_world_size() -> int:
    """Return world size for the context parallel group."""
    return tdist.get_world_size(group=get_context_parallel_group())

get_cpus_per_node()

Get the number of CPUs per node

Source code in src/ezpz/dist.py
def get_cpus_per_node() -> int:
    """Get the number of CPUs per node"""
    from sh import getconf as sh_getconf  # type:ignore noqa

    return int(sh_getconf("_NPROCESSORS_ONLN").rstrip("\n"))

get_data_parallel_group()

Get the data parallel group the caller rank belongs to.

Source code in src/ezpz/tp/__init__.py
def get_data_parallel_group() -> tdist.ProcessGroup:
    """Get the data parallel group the caller rank belongs to."""
    assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
    return _DATA_PARALLEL_GROUP

get_data_parallel_rank()

Return my rank for the data parallel group.

Source code in src/ezpz/tp/__init__.py
def get_data_parallel_rank() -> int:
    """Return my rank for the data parallel group."""
    return tdist.get_rank(group=get_data_parallel_group())

get_data_parallel_world_size()

Return world size for the data parallel group.

Source code in src/ezpz/tp/__init__.py
def get_data_parallel_world_size() -> int:
    """Return world size for the data parallel group."""
    return tdist.get_world_size(group=get_data_parallel_group())

get_device(type=None, as_torch_device=None)

Alias for get_torch_device.

Source code in src/ezpz/dist.py
def get_device(type: Optional[str] = None, as_torch_device: Optional[bool] = None) -> str | torch.device:
    """Alias for `get_torch_device`."""
    return get_torch_device(device_type=type, as_torch_device=as_torch_device)

get_dist_info(framework=None, verbose=None, hostfile=None)

Get distributed info.

Parameters:

Name Type Description Default
framework str

Framework to use. Defaults to None.

None
verbose bool

Whether to print the info. Defaults to None.

None
hostfile PathLike

Path to the hostfile. Defaults to None.

None

Returns:

Name Type Description
dict dict[str, str | int | list]

Dictionary containing the distributed info.

Source code in src/ezpz/dist.py
def get_dist_info(
    framework: Optional[str] = None,
    verbose: Optional[bool] = None,
    hostfile: Optional[PathLike] = None,
) -> dict[str, str | int | list]:
    """Get distributed info.

    Args:
        framework (str, optional): Framework to use. Defaults to None.
        verbose (bool, optional): Whether to print the info. Defaults to None.
        hostfile (PathLike, optional): Path to the hostfile. Defaults to None.

    Returns:
        dict: Dictionary containing the distributed info.
    """
    dist_info = _get_dist_info(
        hostfile=hostfile,
        framework=framework,
    )
    if verbose:
        import json

        logger.info(
            f"DistInfo={json.dumps(dist_info, indent=4, sort_keys=True)}"
        )
    return dist_info

get_gpus_per_node()

Get the number of GPUs per node

Source code in src/ezpz/dist.py
def get_gpus_per_node() -> int:
    """Get the number of GPUs per node"""
    # return torch.cuda.device_count() if torch.cuda.is_available() else (
    #     (
    #         ipex.xpu.device_count() if ipex is not None else (
    #             get_cpus_per_node()
    #         )
    #     )
    # )
    # if _assert:
    #     raise RuntimeError(
    #         'No {X, G}pus found; but _assert specified. Returning !!'
    #     )
    # logger.warning('No {x,g}-pus found, returning' + f'{cpus_per_node}')
    ngpu_per_host = os.environ.get("NGPU_PER_HOST", None)
    if ngpu_per_host is not None:
        return int(ngpu_per_host)
    if torch.cuda.is_available():
        return torch.cuda.device_count()
    if torch.xpu.is_available():
        return torch.xpu.device_count()
    if ipex is not None and torch.xpu.is_available():
        return ipex.xpu.device_count()
    if torch.backends.mps.is_available():
        # XXX: Maybe we're running MPI with multiple MPS devices?
        return get_world_size_in_use()
    return 0

get_hostname()

Get the hostname of the current machine.

Returns:

Name Type Description
str str

The hostname of the current machine.

Source code in src/ezpz/dist.py
def get_hostname() -> str:
    """Get the hostname of the current machine.

    Returns:
        str: The hostname of the current machine.
    """
    import socket

    try:
        hostname = socket.gethostbyaddr(socket.gethostname())[0].lower()
    # except socket.herror as exc:
    except Exception:
        from sh import hostname as sh_hostname  # type:ignore[missingTypeStubs]

        hostname = sh_hostname()
        # if get_rank() == 0:
        #     logger.debug('Unable to determine hostname with `socket`.')
        #     logger.debug(f'hostname from`sh`: {hostname}')
        #     # logger.exception(exc)
    return hostname.rstrip("\n")

get_hosts_from_hostfile(hostfile=None)

Get hosts from the hostfile or environment variables.

Parameters:

Name Type Description Default
hostfile str | Path

Path to the hostfile. Defaults to None.

None

Returns:

Type Description
tuple[str, list[str]]

tuple[str, list[str]]: Tuple containing the hostfile path and a list of hosts.

Example

get_hosts_from_hostfile("/path/to/hostfile") ('/path/to/hostfile', ['host1', 'host2', ...])

Source code in src/ezpz/dist.py
def get_hosts_from_hostfile(
    hostfile: Optional[str | Path] = None,  # type:ignore[reportDeprecated]
) -> tuple[str, list[str]]:
    """
    Get hosts from the hostfile or environment variables.

    Args:
        hostfile (str | Path, optional): Path to the hostfile. Defaults to None.

    Returns:
        tuple[str, list[str]]: Tuple containing the hostfile path and a list of hosts.

    Example:
        >>> get_hosts_from_hostfile("/path/to/hostfile")
        ('/path/to/hostfile', ['host1', 'host2', ...])
    """
    # hostfile = '' if hostfile is None else hostfile
    hostname = get_hostname()
    hostfile = os.environ.get(
        "HOSTFILE",
        os.environ.get(
            "PBS_NODEFILE",
            os.environ.get(
                "COBALT_NODEFILE",
                None,
            ),
        ),
    )
    hosts: list[str] = []
    assert hostfile is not None
    if Path(hostfile).is_file():
        if get_rank() == 0:
            logger.debug(f"Reading hosts from {hostfile}")
        hpath = Path(hostfile).resolve().absolute()
        with hpath.open("r") as f:
            hosts.extend([h.rstrip("\n") for h in f.readlines()])
    else:
        hosts.append(hostname)
    return Path(hostfile).as_posix(), hosts

get_local_rank()

Return get_rank() % get_gpus_per_node()

Returns:

Name Type Description
int int

The local rank of the current process within its node. This is calculated as the current rank modulo the number of GPUs per node.

Example

local_rank = get_local_rank() print(f"Local rank of the current process: {local_rank}")

Source code in src/ezpz/dist.py
def get_local_rank() -> int:
    """Return `get_rank() % get_gpus_per_node()`

    Returns:
        int: The local rank of the current process within its node.
            This is calculated as the current rank modulo the number of GPUs per node.

    Example:
        >>> local_rank = get_local_rank()
        >>> print(f"Local rank of the current process: {local_rank}")
    """
    return int(get_rank() % get_gpus_per_node()) if get_world_size() > 1 else 0

get_machine(hostname=None)

Get the machine name from the hostname.

Parameters:

Name Type Description Default
hostname str

The hostname to check. Defaults to None.

None

Returns:

Name Type Description
str str

The machine name.

Example

get_machine("frontier") "Frontier"

Source code in src/ezpz/dist.py
def get_machine(hostname: Optional[str] = None) -> str:
    """Get the machine name from the hostname.

    Args:
        hostname (str, optional): The hostname to check. Defaults to None.

    Returns:
        str: The machine name.

    Example:
        >>> get_machine("frontier")
        "Frontier"
    """

    if hostname is None:
        try:
            hostname = socket.gethostbyaddr(socket.gethostname())[0]
        except Exception:
            try:
                hostname = socket.gethostname()
            except Exception:
                logger.warning("Unable to determine hostname!")
                hostname = "unknown"
    if hostname.startswith("frontier"):
        return "Frontier"
    if hostname.startswith("sophia"):
        return "Sophia"
    if hostname.startswith("theta"):
        return "ThetaGPU"
    if hostname.startswith("x1"):
        return "SunSpot"
    if hostname.startswith("x3"):
        if (pbs_host := os.environ.get("PBS_O_HOST", None)) is not None:
            if str(pbs_host).startswith("sirius"):
                return "Sirius"
            return "Polaris"
        return "Polaris"
    if hostname.startswith("x4"):
        return "Aurora"
    if hostname.startswith("login"):
        return "Perlmutter"
    if hostname.startswith("nid"):
        return "Perlmutter"
    return f"{hostname}"

get_max_memory_allocated(device)

Get the maximum memory allocated on the specified device.

Parameters:

Name Type Description Default
device device

The device to check memory allocation for.

required
Source code in src/ezpz/utils.py
def get_max_memory_allocated(device: torch.device) -> float:
    """
    Get the maximum memory allocated on the specified device.

    Args:
        device (torch.device): The device to check memory allocation for.
    """
    if torch.cuda.is_available():
        return torch.cuda.max_memory_allocated(device)
    elif torch.xpu.is_available():  #  and ipex is not None:
        try:
            import intel_extension_for_pytorch as ipex

            return ipex.xpu.max_memory_allocated(device)
        except ImportError:
            return -1.0
    raise RuntimeError(f"Memory allocation not available for {device=}")

get_node_index()

Get the index of the current node in the hostfile

Source code in src/ezpz/dist.py
def get_node_index() -> int:
    """Get the index of the current node in the hostfile"""
    return get_rank() % get_num_nodes()

get_nodes_from_hostfile(hostfile)

Get the nodes from the hostfile.

Parameters:

Name Type Description Default
hostfile PathLike

The path to the hostfile.

required

Returns:

Type Description
list[str]

list[str]: A list of nodes from the hostfile.

Source code in src/ezpz/dist.py
def get_nodes_from_hostfile(
    hostfile: PathLike,
) -> list[str]:
    """Get the nodes from the hostfile.

    Args:
        hostfile (PathLike): The path to the hostfile.

    Returns:
        list[str]: A list of nodes from the hostfile.
    """
    # cobalt_nodefile = get_cobalt_nodefile()
    fpath = Path(hostfile)
    assert fpath.is_file()
    with fpath.open("r") as f:
        nodes = [i.rstrip("\n") for i in f.readlines()]
    return nodes

get_num_nodes(hostfile=None)

Get the number of nodes from the hostfile

Source code in src/ezpz/dist.py
def get_num_nodes(hostfile: Optional[PathLike] = None) -> int:
    """Get the number of nodes from the hostfile"""
    num_nodes = os.environ.get("SLURM_NNODES", None)
    if num_nodes is not None:
        return int(num_nodes)
    hfp = get_hostfile_with_fallback(hostfile)
    hosts = [h.split(".")[0] for h in get_nodes_from_hostfile(hfp)]
    return len(hosts)

get_pipeline_parallel_group()

Get the pipeline parallel group the caller rank belongs to.

Source code in src/ezpz/tp/__init__.py
def get_pipeline_parallel_group() -> tdist.ProcessGroup:
    """Get the pipeline parallel group the caller rank belongs to."""
    assert (
        _PIPELINE_PARALLEL_GROUP is not None
    ), "pipeline parallel group is not initialized"
    return _PIPELINE_PARALLEL_GROUP

get_pipeline_parallel_ranks()

Get the pipeline parallel group the caller rank belongs to.

Source code in src/ezpz/tp/__init__.py
def get_pipeline_parallel_ranks() -> List[int]:
    """Get the pipeline parallel group the caller rank belongs to."""
    assert (
        _PIPELINE_PARALLEL_RANKS is not None
    ), "pipeline parallel group is not initialized"
    return _PIPELINE_PARALLEL_RANKS

get_rank()

Get current MPI rank.

Returns:

Name Type Description
int int

The rank of the current process in the MPI world.

Example

rank = get_rank() print(f"Current MPI rank: {rank}")

Source code in src/ezpz/dist.py
def get_rank() -> int:
    """Get current MPI rank.

    Returns:
        int: The rank of the current process in the MPI world.

    Example:
        >>> rank = get_rank()
        >>> print(f"Current MPI rank: {rank}")
    """
    return int(MPI.COMM_WORLD.Get_rank())

get_tensor_parallel_group()

Get the tensor parallel group the caller rank belongs to.

Source code in src/ezpz/tp/__init__.py
def get_tensor_parallel_group() -> tdist.ProcessGroup:
    """Get the tensor parallel group the caller rank belongs to."""
    assert (
        _TENSOR_PARALLEL_GROUP is not None
    ), "tensor parallel group is not initialized"
    return _TENSOR_PARALLEL_GROUP

get_tensor_parallel_rank()

Return my rank for the tensor parallel group.

Source code in src/ezpz/tp/__init__.py
def get_tensor_parallel_rank() -> int:
    """Return my rank for the tensor parallel group."""
    return tdist.get_rank(group=get_tensor_parallel_group())

get_tensor_parallel_src_rank()

Calculate the global rank corresponding to local rank 0 in the TP group.

Source code in src/ezpz/tp/__init__.py
def get_tensor_parallel_src_rank() -> int:
    """
    Calculate the global rank corresponding to local rank 0 in the TP group.
    """
    global_rank = tdist.get_rank()
    local_world_size = get_tensor_parallel_world_size()
    return (global_rank // local_world_size) * local_world_size

get_tensor_parallel_world_size()

Return world size for the tensor parallel group.

Source code in src/ezpz/tp/__init__.py
def get_tensor_parallel_world_size() -> int:
    """Return world size for the tensor parallel group."""
    return tdist.get_world_size(group=get_tensor_parallel_group())

get_timestamp(fstr=None)

Get formatted timestamp.

Source code in src/ezpz/utils.py
def get_timestamp(fstr: Optional[str] = None) -> str:
    """Get formatted timestamp."""
    import datetime

    now = datetime.datetime.now()
    return (
        now.strftime("%Y-%m-%d-%H%M%S") if fstr is None else now.strftime(fstr)
    )

get_torch_backend()

Get the current PyTorch backend.

Returns:

Name Type Description
str str

The current PyTorch backend.

Source code in src/ezpz/dist.py
def get_torch_backend() -> str:
    """
    Get the current PyTorch backend.

    Returns:
        str: The current PyTorch backend.
    """
    backend_from_env = os.environ.get("TORCH_BACKEND", None)
    if backend_from_env is not None:
        return backend_from_env
    return (
        "nccl"
        if torch.cuda.is_available()
        else (
            get_torch_backend_on_xpu() if torch.xpu.is_available() else "gloo"
        )
    )

get_torch_device(*, device_type=None, as_torch_device=None)

Get the current PyTorch device.

Parameters:

Name Type Description Default
device_type str

The type of device to return. If None, it will be determined automatically. Defaults to None.

None
as_torch_device bool

If True, return a torch.device object. If False, return a string representing the device type. Defaults to False.

None

Returns:

Type Description
str | device

str | torch.device: The current PyTorch device. If as_torch_device is True, returns a torch.device object. If as_torch_device is False, returns a string representing the device type.

Example

get_torch_device() # Returns the current device type as a string

Source code in src/ezpz/dist.py
def get_torch_device(
    *,
    device_type: Optional[str] = None,
    as_torch_device: Optional[bool] = None,
) -> str | torch.device:
    """Get the current PyTorch device.

    Args:
        device_type (str, optional): The type of device to return.
            If None, it will be determined automatically.
            Defaults to None.
        as_torch_device (bool, optional): If True, return a torch.device object.
            If False, return a string representing the device type.
            Defaults to False.

    Returns:
        str | torch.device: The current PyTorch device.
            If as_torch_device is True, returns a torch.device object.
            If as_torch_device is False, returns a string representing the device type.

    Example:
        >>> get_torch_device()  # Returns the current device type as a string
    """
    device_type = get_torch_device_type(device_type)
    return torch.device(device_type) if as_torch_device else device_type

get_torch_device_type(device_type=None)

Get the current PyTorch device type.

Parameters:

Name Type Description Default
device_type str

The type of device to return. If None, it will be determined automatically. Defaults to None.

None

Returns:

Name Type Description
str str

The current PyTorch device type. Possible values are "cpu", "mps", "xpu", or "cuda".

Example

get_torch_device_type() # Returns the current device type as a string

Source code in src/ezpz/dist.py
def get_torch_device_type(device_type: Optional[str] = None) -> str:
    """Get the current PyTorch device type.

    Args:
        device_type (str, optional): The type of device to return.
            If None, it will be determined automatically.
            Defaults to None.

    Returns:
        str: The current PyTorch device type.
            Possible values are "cpu", "mps", "xpu", or "cuda".

    Example:
        >>> get_torch_device_type()  # Returns the current device type as a string
    """
    if device_type is not None:
        assert device_type in (
            "cpu",
            "mps",
            "xpu",
            "cuda",
        )
        logger.warning(
            " ".join(
                [
                    f"device_type: {device_type} passed to",
                    "ezpz.dist.get_torch_device_type",
                ]
            )
        )
        return device_type
    if (tdevice := os.environ.get("TORCH_DEVICE")) is not None:
        if get_rank() == 0:
            logger.warning(f"Caught TORCH_DEVICE={tdevice} from environment!")
        tdevice = tdevice.lower()
        assert tdevice is not None and tdevice in (
            "cpu",
            "mps",
            "xpu",
            "cuda",
        )
        return tdevice.lower()
    return (
        "xpu"
        if torch.xpu.is_available()
        else (
            "cuda"
            if torch.cuda.is_available()
            else (
                "mps"
                if (
                    torch.backends.mps.is_available()
                    and torch.get_default_dtype() != torch.float64
                )
                else "cpu"
            )
        )
    )

get_world_size(total=None, in_use=None)

Get the total number of *PUs available or currently in use. Args: total (bool, optional): If True, return the total number of *PUs available. Defaults to None. in_use (bool, optional): If True, return the number of *PUs currently in use. Defaults to None.

Returns:

Name Type Description
int int

The total number of *PUs available or currently in use. If both total and in_use are None, it returns the number of *PUs currently in use by the MPI communicator.

Example

world_size = get_world_size(total=True) print(f"Total number of *PUs available: {world_size}") world_size_in_use = get_world_size(in_use=True) print(f"Number of *PUs currently in use: {world_size_in_use}")

Source code in src/ezpz/dist.py
def get_world_size(
    total: Optional[bool] = None,
    in_use: Optional[bool] = None,
) -> int:
    """
    Get the total number of *PUs available or currently in use.
    Args:
        total (bool, optional): If True, return the total number of *PUs available.
            Defaults to None.
        in_use (bool, optional): If True, return the number of *PUs currently in use.
            Defaults to None.

    Returns:
        int: The total number of *PUs available or currently in use.
            If both `total` and `in_use` are None, it returns the number of *PUs
            currently in use by the MPI communicator.

    Example:
        >>> world_size = get_world_size(total=True)
        >>> print(f"Total number of *PUs available: {world_size}")
        >>> world_size_in_use = get_world_size(in_use=True)
        >>> print(f"Number of *PUs currently in use: {world_size_in_use}")
    """
    if total:
        return get_world_size_total()
    if in_use:
        return get_world_size_in_use()
    # TODO: Deal with subtlety between:
    # 1. 'world_size' == total AVAILABLE gpus (for record keeping)
    # 2. 'world_size' == number of gpus CURRENTLY IN USE (from {`mpi`, ...})
    # Β―\_(ツ)_/Β―
    try:
        world_size = int(MPI.COMM_WORLD.Get_size())
    except Exception:
        num_nodes = get_num_nodes()
        gpus_per_node = get_gpus_per_node()
        world_size = num_nodes * gpus_per_node
        logger.warning(
            "MPI not initialized !!"
            "Calculating (and using!! ??) "
            "[world_size]=[(num_nodes) x (num_*pus_per_node)]=[num_*pus_total]"
            f"[{world_size}]=[({num_nodes}) x ({gpus_per_node})]"
        )
    # if world_size == 1:
    #     gpus_per_node = get_gpus_per_node()
    #     num_nodes = get_num_nodes()
    #     world_size = num_nodes * gpus_per_node
    return world_size

include_file(f)

Check if a file should be included based on its extension.

Parameters:

Name Type Description Default
f PathLike

The file path to check.

required

Returns:

Name Type Description
bool

True if the file should be included, False otherwise.

Source code in src/ezpz/dist.py
def include_file(f: PathLike):
    """
    Check if a file should be included based on its extension.

    Args:
        f (PathLike): The file path to check.

    Returns:
        bool: True if the file should be included, False otherwise.
    """
    fpath = Path(f)
    return fpath.suffix in {
        ".py",
        ".yaml",
        ".sh",
        ".md",
        ".qmd",
        ".yml",
        ".toml",
    }

init_deepspeed(dist_backend=None, auto_mpi_discovery=True, distributed_port=29500, verbose=True, timeout=None, init_method=None, dist_init_required=None, config=None, rank=None, world_size=None)

Initialize DeepSpeed distributed environment.

Parameters:

Name Type Description Default
dist_backend str

The distributed backend to use. Defaults to None.

None
auto_mpi_discovery bool

Whether to automatically discover MPI. Defaults to True.

True
distributed_port int | str

The port for distributed communication. Defaults to 29500.

29500
verbose bool

Whether to print verbose logs. Defaults to True.

True
timeout int | None

Timeout in seconds for distributed initialization. Defaults to None.

None
init_method str

Initialization method for distributed training. Defaults to None.

None
dist_init_required bool

Whether distributed initialization is required. Defaults to None.

None
config dict

DeepSpeed configuration dictionary. Defaults to None.

None
rank int | None

Rank of the current process. Defaults to None.

None
world_size int | None

Total number of processes. Defaults to None.

None

Raises:

Type Description
ImportError

If DeepSpeed is not installed.

Exception

If there is an error during DeepSpeed initialization.

Example

init_deepspeed( ... dist_backend="nccl", ... distributed_port=29500, ... verbose=True, ... timeout=3600, ... rank=0, ... world_size=4, ... )

Source code in src/ezpz/dist.py
def init_deepspeed(
    dist_backend: Optional[str] = None,
    auto_mpi_discovery: bool = True,
    distributed_port: int | str = 29500,
    verbose: bool = True,
    timeout: Optional[int] = None,
    init_method: Optional[str] = None,
    dist_init_required: Optional[bool] = None,
    config: Optional[dict] = None,
    rank: Optional[int] = None,
    world_size: Optional[int] = None,
):
    """
    Initialize DeepSpeed distributed environment.

    Args:
        dist_backend (str, optional): The distributed backend to use.
            Defaults to None.
        auto_mpi_discovery (bool, optional): Whether to automatically discover MPI.
            Defaults to True.
        distributed_port (int | str, optional): The port for distributed communication.
            Defaults to 29500.
        verbose (bool, optional): Whether to print verbose logs. Defaults to True.
        timeout (int | None, optional): Timeout in seconds for distributed initialization.
            Defaults to None.
        init_method (str, optional): Initialization method for distributed training.
            Defaults to None.
        dist_init_required (bool, optional): Whether distributed initialization is required.
            Defaults to None.
        config (dict, optional): DeepSpeed configuration dictionary. Defaults to None.
        rank (int | None, optional): Rank of the current process. Defaults to None.
        world_size (int | None, optional): Total number of processes. Defaults to None.

    Raises:
        ImportError: If DeepSpeed is not installed.
        Exception: If there is an error during DeepSpeed initialization.

    Example:
        >>> init_deepspeed(
        ...     dist_backend="nccl",
        ...     distributed_port=29500,
        ...     verbose=True,
        ...     timeout=3600,
        ...     rank=0,
        ...     world_size=4,
        ... )
    """
    try:
        import deepspeed  # noqa type:ignore

        os.environ["DEEPSPEED_VERSION"] = deepspeed.__version__
    except ImportError as e:
        logger.warning(
            "Unable to import deepspeed. Please install it to use DeepSpeed features."
        )
        raise ImportError(
            "DeepSpeed is not installed. Install with 'pip install deepspeed'"
        ) from e

    rank = get_rank() if rank is None else rank
    world_size = get_world_size() if world_size is None else world_size
    os.environ["WORLD_SIZE"] = str(world_size)
    try:
        import deepspeed  # type:ignore

        # logger.warning(f'Setting {timeout=}')
        dt = 3600 if timeout is None else timeout
        deepspeed.init_distributed(
            dist_backend=dist_backend,
            auto_mpi_discovery=auto_mpi_discovery,
            distributed_port=int(distributed_port),
            verbose=verbose,
            timeout=datetime.timedelta(seconds=dt),
            init_method=init_method,
            dist_init_required=dist_init_required,
            config=config,
            rank=rank,
            world_size=world_size,
        )
    except Exception as exc:
        logger.warning("Unable to `import deepspeed`. Exiting!")
        logger.exception(exc)
        raise exc

init_process_group(rank, world_size, timeout, backend=None)

Initialize the PyTorch distributed process group.

Parameters:

Name Type Description Default
rank int | str

The rank of the current process.

required
world_size int | str

The total number of processes.

required
timeout str | int | timedelta

Timeout for the process group initialization.

required
backend str

The backend to use for distributed training. Defaults to None, which will use the default backend based on the device.

None
Source code in src/ezpz/dist.py
def init_process_group(
    rank: int | str,
    world_size: int | str,
    timeout: str | int | timedelta,
    backend: Optional[str] = None,
) -> None:
    """
    Initialize the PyTorch distributed process group.

    Args:
        rank (int | str): The rank of the current process.
        world_size (int | str): The total number of processes.
        timeout (str | int | timedelta): Timeout for the process group initialization.
        backend (str, optional): The backend to use for distributed training.
            Defaults to None, which will use the default backend based on the device.
    """
    backend = get_torch_backend() if backend is None else backend
    if get_rank() == 0:
        logger.info(
            " ".join(
                [
                    "Calling torch.distributed.init_process_group_with:",
                    f"rank={rank}",
                    f"world_size={world_size}",
                    f"backend={backend}",
                ]
            )
        )
    if not isinstance(timeout, timedelta):
        env_timeout = os.environ.get("TORCH_DDP_TIMEOUT", timeout)
        timeout = timedelta(
            seconds=int(env_timeout),
        )
    if not torch.distributed.is_initialized():  # type:ignore
        torch.distributed.init_process_group(  # type:ignore
            backend=backend,
            timeout=timeout,
            rank=int(rank),
            world_size=int(world_size),
            init_method="env://",
        )

initialize_tensor_parallel(tensor_parallel_size=1, pipeline_parallel_size=1, context_parallel_size=1, tensor_parallel_backend=None, pipeline_parallel_backend=None, context_parallel_backend=None, data_parallel_backend=None, timeout=None)

Initialize tensor data parallel groups.

Parameters:

Name Type Description Default
tensor_parallel_size int

number of GPUs used to parallelize model.

1

Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we use 2 GPUs to parallelize the model. The present function will create 4 tensor parallel groups and 2 data parallel groups as:

  • 4 tensor parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
  • 2 data parallel groups:

    [g0, g2, g4, g6], [g1, g3, g5, g7]
    

Note that for efficiency, the caller should make sure adjacent ranks are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box.

process groups initialized in the order of TP, CP, PP, DP.

Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we use 2 GPUs to parallelize the tensor tensor, 2 GPUs to parallelize context(seq len), and 2 GPUs to parallelize the tensor pipeline. The present function will create 8 tensor model-parallel groups, 8 context-parallel group, 8 pipeline model-parallel groups and 8 data-parallel groups as: when alternate_pp_config = False,

  • 8 data_parallel groups: [g0, g4], [g1, g5], [g2, g6], [g3, g7], [g8, g12], [g9, g13], [g10, g14], [g11, g15]
  • 8 tensor model-parallel groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
  • 8 context-parallel groups: [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
  • 8 pipeline model-parallel groups: [g0, g8], [g1, g9], [g2, g10], [g3, g11], [g4, g12], [g5, g13], [g6, g16], [g7, g15]
Source code in src/ezpz/tp/__init__.py
def initialize_tensor_parallel(
    tensor_parallel_size: int = 1,
    pipeline_parallel_size: int = 1,
    context_parallel_size: int = 1,
    tensor_parallel_backend: Optional[str] = None,
    pipeline_parallel_backend: Optional[str] = None,
    context_parallel_backend: Optional[str] = None,
    data_parallel_backend: Optional[str] = None,
    timeout: Optional[timedelta] = None,
) -> None:
    """
    Initialize tensor data parallel groups.

    Arguments:
        tensor_parallel_size: number of GPUs used to parallelize model.

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
    use 2 GPUs to parallelize the model. The present function will
    create 4 tensor parallel groups and 2 data parallel groups as:

    - 4 tensor parallel groups:

      ```
      [g0, g1], [g2, g3], [g4, g5], [g6, g7]
      ```

    - 2 data parallel groups:

        ```
        [g0, g2, g4, g6], [g1, g3, g5, g7]
        ```

    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.

    process groups initialized in the order of TP, CP, PP, DP.

    Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
    use 2 GPUs to parallelize the tensor tensor, 2 GPUs to parallelize context(seq len), and 2 GPUs to parallelize
    the tensor pipeline. The present function will
    create 8 tensor model-parallel groups, 8 context-parallel group, 8 pipeline model-parallel groups
    and 8 data-parallel groups as:
    when alternate_pp_config = False,

    - 8 data_parallel groups:
        [g0, g4], [g1, g5], [g2, g6], [g3, g7], [g8, g12], [g9, g13], [g10, g14], [g11, g15]
    - 8 tensor model-parallel groups:
        [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
    - 8 context-parallel groups:
        [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
    - 8 pipeline model-parallel groups:
        [g0, g8], [g1, g9], [g2, g10], [g3, g11], [g4, g12], [g5, g13], [g6, g16], [g7, g15]
    """
    # Get world size and rank. Ensure some consistencies.
    assert tdist.is_initialized()
    world_size = tdist.get_world_size()
    tensor_parallel_size = int(min(tensor_parallel_size, world_size))
    ensure_divisibility(world_size, tensor_parallel_size)
    ensure_divisibility(world_size, context_parallel_size)
    ensure_divisibility(
        world_size,
        tensor_parallel_size * pipeline_parallel_size * context_parallel_size,
    )
    rank = tdist.get_rank()

    dpsize = int(
        world_size
        / (tensor_parallel_size * pipeline_parallel_size * context_parallel_size)
    )

    if tdist.get_rank() == 0:
        pstr = ", ".join(
            [
                f"TP: {tensor_parallel_size}",
                f"PP: {pipeline_parallel_size}",
                f"CP: {context_parallel_size}",
                f"DP: {dpsize}",
            ]
        )
        logger.info(pstr)
        # pstr = f'TP: {tensor_parallel_size}, PP: {pipeline_parallel_size}, CP: {context_parallel_size}, DP: {dpsize}'
        # logger.info(
        #     '> initializing tensor parallel with size {}'.format(
        #         tensor_parallel_size
        #     )
        # )
        # logger.info(
        #     '> initializing context parallel with size {}'.format(
        #         context_parallel_size
        #     )
        # )
        # logger.info(
        #     '> initializing pipeline with size {}'.format(
        #         pipeline_parallel_size
        #     )
        # )

    groups = torch.LongTensor(range(world_size)).reshape(
        dpsize,
        pipeline_parallel_size,
        context_parallel_size,
        tensor_parallel_size,
    )

    found = torch.where(groups == rank)
    assert all(len(x) == 1 for x in found)
    found = [x[0] for x in found]

    # Build the data parallel groups.
    global _DATA_PARALLEL_GROUP
    global _DATA_PARALLEL_RANKS
    assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
    assert _DATA_PARALLEL_RANKS is None, "data parallel ranks are already initialized"
    for i in range(pipeline_parallel_size):
        for j in range(context_parallel_size):
            for k in range(tensor_parallel_size):
                ranks = groups[:, i, j, k].tolist()
                group = tdist.new_group(
                    groups[:, i, j, k].tolist(),
                    backend=data_parallel_backend,
                    timeout=timeout,
                )
                if i == found[1] and j == found[2] and k == found[3]:
                    _DATA_PARALLEL_GROUP = group
                    _DATA_PARALLEL_RANKS = ranks

    # Build the tensor parallel groups.
    global _TENSOR_PARALLEL_GROUP
    global _TENSOR_PARALLEL_RANKS
    assert (
        _TENSOR_PARALLEL_GROUP is None
    ), "tensor parallel group is already initialized"
    assert (
        _TENSOR_PARALLEL_RANKS is None
    ), "tensor parallel ranks are already initialized"
    for i in range(dpsize):
        for j in range(pipeline_parallel_size):
            for k in range(context_parallel_size):
                ranks = groups[i, j, k, :].tolist()
                group = tdist.new_group(
                    groups[i, j, k, :].tolist(),
                    backend=tensor_parallel_backend,
                    timeout=timeout,
                )
                if i == found[0] and j == found[1] and k == found[2]:
                    _TENSOR_PARALLEL_GROUP = group
                    _TENSOR_PARALLEL_RANKS = ranks

    # Build the pipeline parallel groups.
    global _PIPELINE_PARALLEL_GROUP
    global _PIPELINE_PARALLEL_RANKS
    assert (
        _PIPELINE_PARALLEL_GROUP is None
    ), "Pipeline parallel group is already initialized"
    for i in range(dpsize):
        for j in range(context_parallel_size):
            for k in range(tensor_parallel_size):
                ranks = groups[i, :, j, k].tolist()
                group = tdist.new_group(
                    ranks, backend=pipeline_parallel_backend, timeout=timeout
                )
                if i == found[0] and j == found[2] and k == found[3]:
                    _PIPELINE_PARALLEL_GROUP = group
                    _PIPELINE_PARALLEL_RANKS = ranks

    # Build the context parallel groups.
    global _CONTEXT_PARALLEL_GROUP
    global _CONTEXT_PARALLEL_GROUP_RANKS

    assert (
        _CONTEXT_PARALLEL_GROUP is None
    ), "Context parallelism is already initialized."
    for i in range(dpsize):
        for j in range(pipeline_parallel_size):
            for k in range(tensor_parallel_size):
                ranks = groups[i, j, :, k].tolist()
                group = tdist.new_group(
                    ranks, backend=context_parallel_backend, timeout=timeout
                )
                if i == found[0] and j == found[1] and k == found[3]:
                    _CONTEXT_PARALLEL_GROUP = group
                    _CONTEXT_PARALLEL_GROUP_RANKS = ranks

print_config_tree(cfg, resolve=True, save_to_file=True, verbose=True, style='tree', print_order=None, highlight=True, outfile=None)

Prints the contents of a DictConfig as a tree structure using the Rich library.

  • cfg: A DictConfig composed by Hydra.
  • print_order: Determines in what order config components are printed.
  • resolve: Whether to resolve reference fields of DictConfig.
  • save_to_file: Whether to export config to the hydra output folder.
Source code in src/ezpz/configs.py
def print_config_tree(
    cfg: DictConfig,
    resolve: bool = True,
    save_to_file: bool = True,
    verbose: bool = True,
    style: str = "tree",
    print_order: Optional[Sequence[str]] = None,
    highlight: bool = True,
    outfile: Optional[Union[str, os.PathLike, Path]] = None,
) -> Tree:
    """Prints the contents of a DictConfig as a tree structure using the Rich
    library.

    - cfg: A DictConfig composed by Hydra.
    - print_order: Determines in what order config components are printed.
    - resolve: Whether to resolve reference fields of DictConfig.
    - save_to_file: Whether to export config to the hydra output folder.
    """
    from rich.console import Console
    from ezpz.log.config import STYLES
    from rich.theme import Theme

    name = cfg.get("_target_", "cfg")
    console = Console(record=True, theme=Theme(STYLES))
    tree = Tree(label=name, highlight=highlight)
    queue = []
    # add fields from `print_order` to queue
    if print_order is not None:
        for field in print_order:
            (
                queue.append(field)
                if field in cfg
                else log.warning(
                    f"Field '{field}' not found in config. "
                    f"Skipping '{field}' config printing..."
                )
            )
    # add all the other fields to queue (not specified in `print_order`)
    for field in cfg:
        if field not in queue:
            queue.append(field)
    # generate config tree from queue
    for field in queue:
        branch = tree.add(field, highlight=highlight)  # , guide_style=style)
        config_group = cfg[field]
        if isinstance(config_group, DictConfig):
            branch_content = str(
                OmegaConf.to_yaml(config_group, resolve=resolve)
            )
            branch.add(Text(branch_content, style="red"))
        else:
            branch_content = str(config_group)
            branch.add(Text(branch_content, style="blue"))
    if verbose or save_to_file:
        console.print(tree)
        if save_to_file:
            outfpath = (
                Path(os.getcwd()).joinpath("config_tree.log")
                if outfile is None
                else Path(outfile)
            )
            console.save_text(outfpath.as_posix())
    return tree

print_dist_setup(framework=None, hostfile=None)

Print distributed setup.

Parameters:

Name Type Description Default
framework str

Framework to use. Defaults to None.

None
hostfile PathLike

Path to the hostfile. Defaults to None.

None

Returns:

Name Type Description
str str

String containing the distributed setup.

Source code in src/ezpz/dist.py
def print_dist_setup(
    framework: Optional[str] = None,
    hostfile: Optional[PathLike] = None,
) -> str:
    """Print distributed setup.

    Args:
        framework (str, optional): Framework to use. Defaults to None.
        hostfile (PathLike, optional): Path to the hostfile. Defaults to None.

    Returns:
        str: String containing the distributed setup.
    """
    rank = get_rank()
    wst = get_world_size(total=True)
    wsa = get_world_size(in_use=True)
    # world_size = get_world_size()
    local_rank = get_local_rank()
    gpus_per_node = get_gpus_per_node()
    hostfile = get_hostfile_with_fallback(hostfile)
    # NOTE:
    # We ensure that num_nodes is AT LEAST 1
    # since if gpus_per_node > wsa, wsa // gpus_per_node = 0
    # if gpus_per_node > wsa, wsa // gpus_per_node = 0
    num_nodes = max((wsa // gpus_per_node, 1))
    num_nodes_from_hostfile = get_num_nodes()
    # assert num_nodes_from_hostfile == num_nodes
    # if num_nodes != num_nodes_from_hostfile:
    #     logger.critical(f'{num_nodes=} vs. {num_nodes_from_hostfile=} ??')
    node = get_node_index()
    device = None
    # if framework.lower() in {'pt', 'torch', 'pytorch'}:
    device = get_torch_device_type()
    rank_len = len(str(rank))
    ws_len = len(str(wsa))
    lr_len = len(str(local_rank))
    gpn_len = len(str(gpus_per_node))
    node_len = len(str(node))
    num_nodes_len = len(str(num_nodes))
    dist_list = [
        f"[{device=}]",
        f"[{rank=:>{rank_len}}/{(wsa - 1):<{ws_len}}]",
        f"[{local_rank=:>{lr_len}}/{gpus_per_node - 1:<{gpn_len}}]",
        f"[{node=:>{node_len}}/{(num_nodes - 1):<{num_nodes_len}}]",
    ]
    if framework is not None:
        dist_list.append(f"[{framework=}]")
    dist_str = "".join(dist_list)
    logger.info(f"{dist_str}")
    if rank == 0:
        if wsa > 1000:
            logger.warning(
                f"WORLD_SIZE={wsa} > 1000, only printing on RANK={rank}"
            )
        logger.warning(
            f'Using [{wsa} / {wst}] available "{device}" devices !!'
        )
        if num_nodes_from_hostfile != num_nodes:
            logger.critical(
                f"num_nodes_from_hostfile = [{num_nodes_from_hostfile=}]"
                f"vs."
                f"[{wsa=} // {gpus_per_node=}] = {num_nodes}"
                r"Β―\_(ツ)_/Β― ??"
            )
    return dist_str

query_environment()

Query environment variables for info about distributed setup

Returns:

Type Description
dict[str, int]

dict[str, int]: A dictionary containing the distributed setup information. Includes keys like 'world_size', 'rank', and 'local_rank'. If the environment variables are not set, it falls back to using get_world_size(), get_rank(), and get_local_rank() functions.

Example

env_info = query_environment() print(env_info)

Source code in src/ezpz/dist.py
def query_environment() -> dict[str, int]:
    """Query environment variables for info about distributed setup

    Returns:
        dict[str, int]: A dictionary containing the distributed setup information.
            Includes keys like 'world_size', 'rank', and 'local_rank'.
            If the environment variables are not set, it falls back to using
            `get_world_size()`, `get_rank()`, and `get_local_rank()` functions.

    Example:
        >>> env_info = query_environment()
        >>> print(env_info)
        {'world_size': 4, 'rank': 0, 'local_rank': 0}
    """
    ws = os.environ.get("WORLD_SIZE", None)
    r = os.environ.get("RANK", None)
    lr = os.environ.get("LOCAL_RANK", None)
    if ws is not None and r is not None and lr is not None:
        return {
            "world_size": int(ws),
            "rank": int(r),
            "local_rank": int(lr),
            # 'machine': machine,
        }
    return {
        "world_size": int(get_world_size()),
        "rank": int(get_rank()),
        "local_rank": int(get_local_rank()),
    }

run_bash_command(cmd)

Run a bash command and return the output. Args: cmd (str): The command to run.

Returns:

Name Type Description
Any Any

The output of the command.

Source code in src/ezpz/dist.py
def run_bash_command(cmd: str) -> Any:
    """
    Run a bash command and return the output.
    Args:
        cmd (str): The command to run.

    Returns:
        Any: The output of the command.
    """
    import subprocess
    import shlex

    process = subprocess.Popen(
        shlex.split(cmd, posix=True),
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
    )
    output, error = process.communicate()
    if process.returncode != 0:
        raise Exception(
            f"Command failed with return code {process.returncode}.\n"
            f"stdout: {output.decode().strip()}\n"
            f"stderr: {error.decode().strip()}"
        )
    if error:
        raise Exception(error.decode())
    else:
        return output

seed_everything(seed)

Set random seed for reproducibility.

Parameters:

Name Type Description Default
seed int

Random seed to set.

required
Source code in src/ezpz/dist.py
def seed_everything(seed: int) -> None:
    """Set random seed for reproducibility.

    Args:
        seed (int): Random seed to set.
    """
    import torch
    import numpy as np
    import random

    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    _ = torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    if torch.xpu.is_available():
        torch.xpu.manual_seed(seed)

setup(framework='pytorch', backend='DDP', port='5432', seed=None, precision=None, ngpus=None)

Setup distributed environment for the specified framework.

Parameters:

Name Type Description Default
framework str

The framework to use for distributed training. Defaults to "pytorch".

'pytorch'
backend str

The backend to use for distributed training. Defaults to "DDP".

'DDP'
port str

The port to use for distributed communication. Defaults to "5432".

'5432'
seed int

Random seed for reproducibility. Defaults to None.

None
precision str

Precision to use for training. Defaults to None.

None
ngpus int

Number of GPUs to use. Defaults to None.

None

Returns:

Type Description

None

Source code in src/ezpz/dist.py
def setup(
    framework: str = "pytorch",
    backend: str = "DDP",
    port: str = "5432",
    seed: Optional[int] = None,
    precision: Optional[str] = None,
    ngpus: Optional[int] = None,
):
    """
    Setup distributed environment for the specified framework.

    Args:
        framework (str): The framework to use for distributed training.
            Defaults to "pytorch".
        backend (str): The backend to use for distributed training.
            Defaults to "DDP".
        port (str): The port to use for distributed communication.
            Defaults to "5432".
        seed (int, optional): Random seed for reproducibility. Defaults to None.
        precision (str, optional): Precision to use for training. Defaults to None.
        ngpus (int, optional): Number of GPUs to use. Defaults to None.

    Returns:
        None
    """
    return (
        setup_tensorflow(precision=precision, ngpus=ngpus)
        if framework in {"tensorflow", "tf", "t"}
        else setup_torch(backend=backend, port=port, seed=seed)
    )

setup_torch(framework=None, backend=None, port=None, seed=None, timeout=None, verbose=False, tensor_parallel_size=1, pipeline_parallel_size=1, context_parallel_size=1, tensor_parallel_backend=None, pipeline_parallel_backend=None, context_parallel_backend=None, data_parallel_backend=None)

Setup torch.

Parameters:

Name Type Description Default
backend str

Backend to use. Defaults to None.

None
port str | int

Port to use. Defaults to None.

None
seed int

Seed to use. Defaults to None.

None
timeout str | int

Timeout to use. Defaults to None.

None
verbose bool

Whether to print the info. Defaults to False.

False
tensor_parallel_size int

Tensor parallel size. Defaults to 1.

1
pipeline_parallel_size int

Pipeline parallel size. Defaults to 1.

1
context_parallel_size int

Context parallel size. Defaults to 1.

1
tensor_parallel_backend str

Tensor parallel backend. Defaults to None.

None
pipeline_parallel_backend str

Pipeline parallel backend. Defaults to None.

None
context_parallel_backend str

Context parallel backend. Defaults to None.

None
data_parallel_backend str

Data parallel backend. Defaults to None.

None

Returns:

Name Type Description
int int

Rank of the process.

Source code in src/ezpz/dist.py
def setup_torch(
    framework: Optional[str] = None,
    backend: Optional[str] = None,
    port: Optional[str | int] = None,
    seed: Optional[int] = None,
    timeout: Optional[str | int] = None,
    verbose: Optional[bool] = False,
    tensor_parallel_size: int = 1,
    pipeline_parallel_size: int = 1,
    context_parallel_size: int = 1,
    tensor_parallel_backend: Optional[str] = None,
    pipeline_parallel_backend: Optional[str] = None,
    context_parallel_backend: Optional[str] = None,
    data_parallel_backend: Optional[str] = None,
) -> int:
    """Setup torch.

    Args:
        backend (str, optional): Backend to use. Defaults to None.
        port (str | int, optional): Port to use. Defaults to None.
        seed (int, optional): Seed to use. Defaults to None.
        timeout (str | int, optional): Timeout to use. Defaults to None.
        verbose (bool, optional): Whether to print the info. Defaults to False.
        tensor_parallel_size (int, optional): Tensor parallel size. Defaults to 1.
        pipeline_parallel_size (int, optional): Pipeline parallel size. Defaults to 1.
        context_parallel_size (int, optional): Context parallel size. Defaults to 1.
        tensor_parallel_backend (str, optional): Tensor parallel backend. Defaults to None.
        pipeline_parallel_backend (str, optional): Pipeline parallel backend. Defaults to None.
        context_parallel_backend (str, optional): Context parallel backend. Defaults to None.
        data_parallel_backend (str, optional): Data parallel backend. Defaults to None.

    Returns:
        int: Rank of the process.
    """
    device = get_torch_device()
    # if ACCELERATOR_TYPE == 'NvidiaGPU' and device == 'cuda':
    #     os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    #     torch.backends.cudnn.deterministic = True     # type:ignore
    #     torch.backends.cudnn.benchmark = True         # type:ignore
    #     torch.backends.cudnn.allow_tf32 = True        # type:ignore
    #     torch.backends.cuda.matmul.allow_tf32 = True  # type:ignore
    # torch.use_deterministic_algorithms(True)
    ws_from_env = os.environ.get("WORLD_SIZE", None)
    framework = "DDP" if framework is None else framework
    framework = framework.lower()
    backend = str(get_torch_backend()).lower()
    if ws_from_env is not None and ws_from_env == "1":
        logger.info(
            f"Running on a single {device}, not initializing torch.distributed!"
        )
        rank = 0
        world_size = 1
        local_rank = 0
        local_size = 1
        num_nodes = 1
    else:
        dsetup = setup_torch_distributed(
            framework=framework,
            backend=backend,
            port=port,
            timeout=timeout,
            tensor_parallel_size=int(tensor_parallel_size),
            pipeline_parallel_size=int(pipeline_parallel_size),
            context_parallel_size=int(context_parallel_size),
            tensor_parallel_backend=tensor_parallel_backend,
            pipeline_parallel_backend=pipeline_parallel_backend,
            context_parallel_backend=context_parallel_backend,
            data_parallel_backend=data_parallel_backend,
        )
        rank = dsetup["rank"]
        world_size = dsetup["world_size"]
        local_rank = dsetup["local_rank"]
        try:
            local_size = get_gpus_per_node()
        except Exception:
            local_size = 1

        try:
            num_nodes = get_num_nodes()
        except Exception:
            num_nodes = 1
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(local_rank)
    os.environ["NUM_NODES"] = str(num_nodes)
    os.environ["LOCAL_SIZE"] = str(local_size)
    os.environ["WORLD_SIZE"] = str(world_size)
    # nthreads = os.environ.get('OMP_NUM_THREADS', None)
    # if ACCELERATOR_TYPE == "IntelGPU" and device == "xpu":
    if torch.xpu.is_available():
        torch.xpu.set_device(local_rank)
        # try:
        #     import intel_extension_for_pytorch as ipex  # type:ignore[missingTypeStubs]
        # except Exception:
        #     ipex = None
        # if ipex is not None:
        #     logger.debug(f"Using ipex from: {ipex.__file__}")
        #
        # try:
        #     import oneccl_bindings_for_pytorch as oneccl_bpt  # type:ignore[missingTypeStubs]
        # except Exception:
        #     oneccl_bpt = None
        # if oneccl_bpt is not None:
        #     logger.debug(f"Using oneccl_bindings from: {oneccl_bpt.__file__}")
        #
        #     # logger.warning(f'Using {get_torch_device()}:{get_local_rank()}')
        #     # os.environ['CCL_LOCAL_RANK'] = str(local_rank)
        #     # os.environ['CCL_LOCAL_SIZE'] = str(local_size)
    if seed is not None:
        seed_everything(seed * (rank + 1) * (local_rank + 1))
    if rank == 0:
        if backend in {"ds", "deepspeed", "dspeed"}:
            from ezpz.configs import git_ds_info

            git_ds_info()
        _ = get_dist_info(verbose=verbose)
        if verbose:
            _ = print_dist_setup()
    if world_size > 1:
        barrier()

    if rank == 0:
        logger.info(
            f"Using {device=} with {backend=} "
            f"+ '{get_torch_backend()}' "
            "for distributed training."
        )
    lrank = len(str(world_size - 1))
    # nz = lrank - len(str(rank))
    hn = socket.gethostname()
    psizes = [f"['{hn}']" + f"[{rank:>{lrank}}/{world_size - 1:<{lrank}}] "]
    if (
        tensor_parallel_size > 1
        or context_parallel_size > 1
        or pipeline_parallel_size > 1
    ):
        import ezpz.tp

        tpsize = ezpz.tp.get_tensor_parallel_world_size()
        cpsize = ezpz.tp.get_context_parallel_world_size()
        ppsize = ezpz.tp.get_pipeline_parallel_world_size()
        dpsize = ezpz.tp.get_data_parallel_world_size()
        if cpsize > 1 or ppsize > 1 or tpsize > 1:
            if cpsize > 1:
                lcp = len(str(cpsize - 1))
                cprank = ezpz.tp.get_context_parallel_rank()
                # cpranks = ezpz.tp.get_context_parallel_ranks()
                psizes.append(f"[cp:{cprank:>{lcp}}/{cpsize - 1:<{lcp}}]")
                barrier(group=ezpz.tp.get_context_parallel_group())
            if ppsize > 1:
                pprank = ezpz.tp.get_pipeline_parallel_rank()
                # ppranks = ezpz.tp.get_pipeline_parallel_ranks()
                lpp = len(str(ppsize - 1))
                psizes.append(f"[pp:{pprank:>{lpp}}/{ppsize - 1:<{lpp}}]")
                barrier(group=ezpz.tp.get_pipeline_parallel_group())
            if tpsize > 1:
                ltp = len(str(tpsize - 1))
                tprank = ezpz.tp.get_tensor_parallel_rank()
                # tpranks = ezpz.tp.get_tensor_parallel_ranks()
                psizes.append(f"[tp:{tprank:>{ltp}}/{tpsize - 1:<{ltp}}]")
                barrier(group=ezpz.tp.get_tensor_parallel_group())
            if dpsize > 1:
                ldp = len(str(dpsize - 1))
                dprank = ezpz.tp.get_data_parallel_rank()
                # dpranks = ezpz.tp.get_data_parallel_ranks()
                psizes.append(f"[dp:{dprank:>{ldp}}/{dpsize - 1:<{ldp}}]")
                barrier(group=ezpz.tp.get_data_parallel_group())
    logger.info("".join(psizes))
    barrier()
    return rank

setup_torch_DDP(port='2345', timeout=3600, backend=None)

Setup PyTorch Distributed Data Parallel (DDP) environment. Args: port (str, optional): The port to use for distributed communication. Defaults to "2345". timeout (int | str | timedelta, optional): Timeout for the process group initialization. Defaults to 3600 seconds. backend (str, optional): The backend to use for distributed training. Defaults to None, which will use the default backend based on the device.

Returns:

Type Description
dict[str, int]

dict[str, int]: A dictionary containing the distributed setup information. Includes keys like 'world_size', 'rank', and 'local_rank'.

Source code in src/ezpz/dist.py
def setup_torch_DDP(
    port: str = "2345",
    timeout: int | str | timedelta = 3600,
    backend: Optional[str] = None,
) -> dict[str, int]:
    """
    Setup PyTorch Distributed Data Parallel (DDP) environment.
    Args:
        port (str, optional): The port to use for distributed communication.
            Defaults to "2345".
        timeout (int | str | timedelta, optional): Timeout for the process group initialization.
            Defaults to 3600 seconds.
        backend (str, optional): The backend to use for distributed training.
            Defaults to None, which will use the default backend based on the device.

    Returns:
        dict[str, int]: A dictionary containing the distributed setup information.
            Includes keys like 'world_size', 'rank', and 'local_rank'.
    """
    if not isinstance(timeout, timedelta):
        timeout = timedelta(seconds=int(timeout))
    os_rank = os.environ.get("RANK", None)
    os_world_size = os.environ.get("WORLD_SIZE", None)
    os_local_rank = os.environ.get("LOCAL_RANK", None)
    world_size = int(get_world_size())
    rank = int(get_rank())
    local_rank = int(get_local_rank())
    # ensure there is no funny business going on
    if os_rank and int(os_rank) != int(rank):
        logger.warning(f"Mismatch between {os_rank=} and {rank=}")
    if os_world_size and int(os_world_size) != int(world_size):
        logger.warning(f"Mismatch between {os_world_size=} and {world_size=}")
    if os_local_rank and int(os_local_rank) != int(local_rank):
        logger.warning(f"Mismatch between {os_local_rank=} and {local_rank=}")
    # now, set these variables explicitly in the process' environment
    os.environ["LOCAL_RANK"] = str(local_rank)
    os.environ["RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    # get `hostname` ONLY from rank 0
    master_addr = socket.gethostname() if rank == 0 else None
    if (mn := ezpz.get_machine().lower()) in {
        "aurora",
        "polaris",
        "sirius",
    }:
        master_addr = f"{master_addr}.hsn.cm.{mn}.alcf.anl.gov"
    elif mn == "sophia":
        master_addr = f"{master_addr}.lab.alcf.anl.gov"
    # check if we have specified a 'MASTER_PORT' explicitly, if so, use this
    free_port = str(get_free_port()) if rank == 0 else None
    eport = os.environ.get("MASTER_PORT", free_port)
    if eport is not None:
        _ = (
            logger.info(f"Caught MASTER_PORT={eport} from environment!")
            if rank == 0
            else None
        )
    else:
        eport = port
    # grab it from rank 0
    master_port = eport if rank == 0 else None
    # broadcast it to make sure everyones tapped in
    master_port = MPI.COMM_WORLD.bcast(master_port, root=0)
    master_addr = MPI.COMM_WORLD.bcast(master_addr, root=0)
    # set it explicitly in each process' environment
    os.environ["MASTER_ADDR"] = master_addr
    os.environ["MASTER_PORT"] = master_port
    # now, torch is ready for us
    if rank == 0:
        logger.info(
            "\n".join(
                [
                    "Using torch.distributed.init_process_group with",
                    f"- {master_addr=}",
                    f"- {master_port=}",
                    f"- {world_size=}",
                    f"- {rank=}",
                    f"- {local_rank=}",
                    f"- {timeout=}",
                    f"- {backend=}",
                ]
            )
        )
    init_process_group(
        rank=rank,
        world_size=world_size,
        timeout=timeout,
        backend=backend,
    )
    return {"world_size": world_size, "rank": rank, "local_rank": local_rank}

setup_torch_distributed(framework=None, backend=None, tensor_parallel_size=1, pipeline_parallel_size=1, context_parallel_size=1, tensor_parallel_backend=None, pipeline_parallel_backend=None, context_parallel_backend=None, data_parallel_backend=None, port=None, timeout=None)

Setup distributed environment for PyTorch.

Parameters:

Name Type Description Default
framework str

The framework to use for distributed training. Defaults to None, which will use "ddp".

None
backend str

The backend to use for distributed training. Defaults to None, which will use the default backend based on the device.

None
tensor_parallel_size int

Size of tensor parallelism. Defaults to 1.

1
pipeline_parallel_size int

Size of pipeline parallelism. Defaults to 1.

1
context_parallel_size int

Size of context parallelism. Defaults to 1.

1
tensor_parallel_backend str

Backend for tensor parallelism. Defaults to None.

None
pipeline_parallel_backend str

Backend for pipeline parallelism. Defaults to None.

None
context_parallel_backend str

Backend for context parallelism. Defaults to None.

None
data_parallel_backend str

Backend for data parallelism. Defaults to None.

None
port str | int

Port for distributed communication. Defaults to "1234".

None
timeout str | int

Timeout for distributed initialization. Defaults to 3600 seconds.

None

Returns:

Type Description
dict[str, int]

dict[str, int]: A dictionary containing the distributed setup information. Includes keys like 'world_size', 'rank', and 'local_rank'.

Raises:

Type Description
AssertionError

If the framework is not one of the supported frameworks. Supported frameworks are "ddp", "ds", "deepspeed", "horovod", and "hvd".

ValueError

If the backend is not one of the supported backends. Supported backends are "ddp", "ds", "deepspeed", "horovod", and "hvd".

Example

setup_torch_distributed( ... framework="ddp", ... backend="nccl", ... tensor_parallel_size=2, ... pipeline_parallel_size=1, ... context_parallel_size=1, ... port=1234, ... timeout=3600, ... )

Source code in src/ezpz/dist.py
def setup_torch_distributed(
    framework: Optional[str] = None,
    backend: Optional[str] = None,
    tensor_parallel_size: int = 1,
    pipeline_parallel_size: int = 1,
    context_parallel_size: int = 1,
    tensor_parallel_backend: Optional[str] = None,
    pipeline_parallel_backend: Optional[str] = None,
    context_parallel_backend: Optional[str] = None,
    data_parallel_backend: Optional[str] = None,
    port: Optional[str | int] = None,
    timeout: Optional[str | int] = None,
) -> dict[str, int]:
    """
    Setup distributed environment for PyTorch.

    Args:
        framework (str, optional): The framework to use for distributed training.
            Defaults to None, which will use "ddp".
        backend (str, optional): The backend to use for distributed training.
            Defaults to None, which will use the default backend based on the device.
        tensor_parallel_size (int, optional): Size of tensor parallelism. Defaults to 1.
        pipeline_parallel_size (int, optional): Size of pipeline parallelism. Defaults to 1.
        context_parallel_size (int, optional): Size of context parallelism. Defaults to 1.
        tensor_parallel_backend (str, optional): Backend for tensor parallelism. Defaults to None.
        pipeline_parallel_backend (str, optional): Backend for pipeline parallelism. Defaults to None.
        context_parallel_backend (str, optional): Backend for context parallelism. Defaults to None.
        data_parallel_backend (str, optional): Backend for data parallelism. Defaults to None.
        port (str | int, optional): Port for distributed communication. Defaults to "1234".
        timeout (str | int, optional): Timeout for distributed initialization. Defaults to 3600 seconds.

    Returns:
        dict[str, int]: A dictionary containing the distributed setup information.
            Includes keys like 'world_size', 'rank', and 'local_rank'.

    Raises:
        AssertionError: If the framework is not one of the supported frameworks.
            Supported frameworks are "ddp", "ds", "deepspeed", "horovod", and "hvd".
        ValueError: If the backend is not one of the supported backends.
            Supported backends are "ddp", "ds", "deepspeed", "horovod", and "hvd".

    Example:
        >>> setup_torch_distributed(
        ...     framework="ddp",
        ...     backend="nccl",
        ...     tensor_parallel_size=2,
        ...     pipeline_parallel_size=1,
        ...     context_parallel_size=1,
        ...     port=1234,
        ...     timeout=3600,
        ... )
    """
    framework = "ddp" if framework is None else framework
    # if str(framework).lower() not in {"ddp", "ds", "deepspeed", "horovod", "hvd"}:
    assert str(framework).lower() in {
        "ddp",
        "ds",
        "deepspeed",
        "horovod",
        "hvd",
    }, (
        f"Invalid framework: {framework=}, expected one of "
        f"{'ddp', 'ds', 'deepspeed', 'horovod', 'hvd'}"
    )

    DEFAULT_TIMEOUT = os.environ.get("TORCH_DDP_TIMEOUT", 3600)
    timeout = (
        DEFAULT_TIMEOUT if timeout is None else (
            int(timeout) if isinstance(timeout, str) else timeout
        )
    )
    port = (
        "1234"
        if port is None
        else str(port)
        if isinstance(port, int)
        else port
    )
    rank = get_rank()
    world_size = get_world_size()
    local_rank = get_local_rank()
    fw = str(framework).lower()
    be = (
        str(get_torch_backend()).lower()
        if backend is None
        else str(backend).lower()
    )
    # be = str(framework).lower()
    # assert fw in {"ds", "deepspeed", "ddp", "horovod", "hvd"}, (
    #     f"Invalid backend: {framework=}, expected one of "
    #     f"{'ds', 'deepspeed', 'ddp', 'horovod', 'hvd'}"
    # )
    # assert be in BACKENDS['pytorch']
    if rank == 0:
        logger.info(
            " ".join(
                [
                    f"Using {fw=} with",
                    "torch_{device,backend}=",
                    "{" + f"{get_torch_device_type()}, {be}" + "}",
                ]
            )
        )
    if fw == "ddp":
        dsetup = setup_torch_DDP(port, timeout, backend=be)
        world_size = dsetup["world_size"]
        rank = dsetup["rank"]
        local_rank = dsetup["local_rank"]
        if torch.cuda.is_available():
            torch.cuda.set_device(local_rank)
    elif fw in {"deepspeed", "ds"}:
        init_deepspeed(timeout=timeout)
        world_size = get_world_size()
        rank = get_rank()
        local_rank = get_local_rank()
    elif fw in {"horovod", "hvd"}:
        import horovod.torch as hvd  # type:ignore noqa

        _ = None if hvd.is_initialized() else hvd.init()  # type:ignore
        # hvd.init() if not hvd.is_initialized() else None
        rank = hvd.rank()  # type:ignore
        world_size = hvd.size()  # type:ignore
        local_rank = hvd.local_rank()  # type:ignore
        if torch.cuda.is_available():
            torch.cuda.set_device(hvd.local_rank())  # type:ignore
    else:
        raise ValueError(f"Unable to parse backend: {be=}")

    if (
        tensor_parallel_size > 1
        or context_parallel_size > 1
        or pipeline_parallel_size > 1
    ):
        ezpz.tp.initialize_tensor_parallel(
            tensor_parallel_size=tensor_parallel_size,
            pipeline_parallel_size=pipeline_parallel_size,
            context_parallel_size=context_parallel_size,
            tensor_parallel_backend=tensor_parallel_backend,
            pipeline_parallel_backend=pipeline_parallel_backend,
            context_parallel_backend=context_parallel_backend,
            data_parallel_backend=data_parallel_backend,
            timeout=timedelta(seconds=timeout),
        )

    os.environ["world_size"] = str(world_size)
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(local_rank)

    return {"world_size": world_size, "rank": rank, "local_rank": local_rank}

setup_wandb(project_name=None, entity=None, config=None, start_method='thread', outdir=None, init_timeout=300)

Setup wandb for logging.

Parameters:

Name Type Description Default
project_name str

The name of the project. Defaults to None.

None
entity str

The entity name. Defaults to None.

None
config dict | DictConfig

The configuration dictionary. Defaults to None.

None
start_method str

The start method for wandb. Defaults to "thread".

'thread'
outdir str | Path | PathLike

The output directory. Defaults to None.

None
init_timeout int

The timeout for wandb initialization. Defaults to 300.

300
Example

setup_wandb(project_name="my_project", entity="my_entity")

Source code in src/ezpz/dist.py
def setup_wandb(
    project_name: Optional[str] = None,
    entity: Optional[str] = None,
    config: Optional[dict | DictConfig] = None,
    start_method: str = "thread",
    outdir: Optional[str | Path | os.PathLike] = None,
    init_timeout: int = 300,
):
    """Setup wandb for logging.

    Args:
        project_name (str, optional): The name of the project. Defaults to None.
        entity (str, optional): The entity name. Defaults to None.
        config (dict | DictConfig, optional): The configuration dictionary. Defaults to None.
        start_method (str, optional): The start method for wandb. Defaults to "thread".
        outdir (str | Path | os.PathLike, optional): The output directory. Defaults to None.
        init_timeout (int, optional): The timeout for wandb initialization. Defaults to 300.

    Example:
        >>> setup_wandb(project_name="my_project", entity="my_entity")
    """
    WANDB_DISABLED = os.environ.get("WANDB_DISABLED", False)
    WANDB_MODE = os.environ.get("WANDB_MODE", "").lower()
    if WANDB_DISABLED or WANDB_MODE == "disabled":
        logger.warning(
            f"Logging with W&B is disabled!, caught: {WANDB_DISABLED=}"
        )
        return None

    try:
        import wandb
    except (ImportError, ModuleNotFoundError) as e:
        logger.warning(
            "Unable to import `wandb`. Install with `pip install wandb`"
        )
        raise e

    outdir = (
        Path(os.getcwd()).as_posix()
        if outdir is None
        else Path(outdir).as_posix()
    )
    rank = get_rank()
    project_name = (
        project_name
        if project_name is not None
        else os.environ.get(
            "WB_PROJECT",
            os.environ.get(
                "WANDB_PROJECT",
                os.environ.get("WB_PROJECT_NAME", None),
            ),
        )
    )
    if project_name is None:
        import sys

        frame = sys._getframe().f_back
        assert frame is not None
        calling_module = frame.f_code.co_filename
        fp = Path(calling_module)
        project_name = f"{fp.parent.stem}.{fp.stem}"

    logger.info(f"Setting up wandb from {rank=}")
    logger.info(f"Using WB_PROJECT={project_name}")
    tensorboard_dir = (
        os.environ.get("TENSORBOARD_DIR", None)
        if config is None
        else config.get("tensorboard_dir", None)
    )
    if tensorboard_dir is not None:
        logger.info(f"Patching tensorboard from {tensorboard_dir}")
        try:
            wandb.tensorboard.patch(root_logdir=tensorboard_dir)  # type:ignore
        except Exception as exc:
            logger.exception(exc)
    # wbrun_id = wandb.util.generate_id()
    now = datetime.datetime.now()
    dstr = now.strftime("%Y-%m-%d-%H%M%S")
    run = wandb.init(
        entity=entity,
        # resume='allow',
        dir=outdir,
        sync_tensorboard=(tensorboard_dir is not None),  # True,
        project=(project_name if project_name is not None else None),
        # dir=(tensorboard_dir if tensorboard_dir is not None else None),
        settings=wandb.Settings(
            start_method=start_method, init_timeout=init_timeout
        ),
    )
    assert run is not None and run is wandb.run
    # run.log_code(HERE.as_posix(), include_fn=include_file)
    logger.info(f"wandb.run=[{run.name}]({run.url})")
    if (
        wandb is not None
        and wandb.run is not None
        and "DIST_INFO" not in wandb.run.config
    ):
        wandb.run.config.update({"DIST_INFO": get_dist_info()})
    torch_version = torch.__version__
    torch_file = torch.__file__
    run.config.update(
        {
            "created_at": dstr,
            "day": ezpz.get_timestamp("%d"),
            "ezpz_file": ezpz.__file__,
            "ezpz_version": ezpz.__version__,
            "hostname": get_hostname(),
            "month": ezpz.get_timestamp("%m"),
            "outdir": os.getcwd(),
            "pytorch_backend": str(get_torch_backend()).lower(),
            "torch_version": torch_version,
            "torch_version_as_float": get_torch_version_as_float(),
            "torch_file": torch_file,
            "world_size": get_world_size(),
            "year": ezpz.get_timestamp("%Y"),
            "working_directory": os.getcwd(),
        }
    )
    if config is not None:
        if isinstance(config, DictConfig):
            cfg = OmegaConf.to_container(
                config, resolve=True, throw_on_missing=True
            )
            run.config.update({"config": cfg})
        else:
            run.config.update({"config": config})
    env = {
        k: v
        for k, v in dict(os.environ).items()
        if not k.startswith("_ModuleTable")
    }
    _ = env.pop("LS_COLORS", None)
    _ = env.pop("PS1", None)
    run.config.update({"env": env})
    machine = get_machine()
    logger.info(f"Running on {machine=}")
    run.config.update({"machine": machine})
    model_size = os.environ.get("MODEL_SIZE", None)
    if model_size is not None:
        run.config.update({"MODEL_SIZE": model_size})
    return wandb.run

summarize_dict(d, precision=6)

Summarize a dictionary into a string with formatted key-value pairs.

Parameters:

Name Type Description Default
d dict

The dictionary to summarize.

required
precision int

The precision for floating point values. Default: 6.

6

Returns:

Name Type Description
str str

A string representation of the dictionary with formatted key-value pairs.

Source code in src/ezpz/utils.py
def summarize_dict(d: dict, precision: int = 6) -> str:
    """
    Summarize a dictionary into a string with formatted key-value pairs.

    Args:
        d (dict): The dictionary to summarize.
        precision (int): The precision for floating point values. Default: ``6``.

    Returns:
        str: A string representation of the dictionary with formatted key-value pairs.
    """
    return " ".join(
        [format_pair(k, v, precision=precision) for k, v in d.items()]
    )

synchronize(device=None)

Synchronize the given device.

Parameters:

Name Type Description Default
device device | int | str

The device to synchronize. If None, the default device will be used. Defaults to None.

None

Returns:

Type Description

None

Source code in src/ezpz/dist.py
def synchronize(device: Optional[torch.device | int | str] = None):
    """
    Synchronize the given device.

    Args:
        device (torch.device | int | str, optional): The device to synchronize.
            If None, the default device will be used. Defaults to None.

    Returns:
        None
    """
    return (
        torch.cuda.synchronize(device)
        if torch.cuda.is_available()
        else (
            torch.xpu.synchronize(device)
            if torch.xpu.is_available()
            else (
                torch.mps.synchronize()
                if torch.backends.mps.is_available()
                else torch.cpu.synchronize(device)
            )
        )
    )

tensor_parallel_is_initialized()

Check if tensor and data parallel groups are initialized.

Source code in src/ezpz/tp/__init__.py
def tensor_parallel_is_initialized() -> bool:
    """Check if tensor and data parallel groups are initialized."""
    if (
        _TENSOR_PARALLEL_GROUP is None
        or _DATA_PARALLEL_GROUP is None
        or _PIPELINE_PARALLEL_GROUP is None
        or _CONTEXT_PARALLEL_GROUP is None
    ):
        return False
    return True

timeit(func)

Decorator to time a function and log the time taken.

Parameters:

Name Type Description Default
func Callable

Function to be timed.

required
Example

@timeit def my_function(arg1, arg2): # Function implementation pass

Source code in src/ezpz/dist.py
def timeit(func: Callable):
    """
    Decorator to time a function and log the time taken.

    Args:
        func (Callable): Function to be timed.

    Example:
        @timeit
        def my_function(arg1, arg2):
            # Function implementation
            pass
    """
    try:
        import wandb
    except Exception:
        wandb = None  # type:ignore

    @wraps(func)
    def wrapper(*args, **kwargs):
        t0 = time.perf_counter()
        result = func(*args, **kwargs)
        dt = time.perf_counter() - t0
        fname = getattr(
            func, "__qualname__", getattr(func, "__name__", "unknown")
        )
        logger.info(f"{fname}({args}, {kwargs}) took: {dt=:.4f}s")
        if wandb is not None and wandb.run is not None:
            wandb.log({f"timeit/{fname}": dt})
        return result

    return wrapper

timeitlogit(rank=None, record=True, verbose=False, prefix=None)

Decorator to time a function and log the time taken.

Parameters:

Name Type Description Default
rank int

Rank of the process. Defaults to None.

None
verbose bool

Whether to log the time taken. Defaults to True.

False
Example

@timeitlogit(rank=0, verbose=True) def my_function(arg1, arg2): # Function implementation pass

Source code in src/ezpz/dist.py
def timeitlogit(
    rank: Optional[int] = None,
    record: bool = True,
    verbose: bool = False,
    prefix: str | None = None,
):
    """Decorator to time a function and log the time taken.

    Args:
        rank (int, optional): Rank of the process. Defaults to None.
        verbose (bool, optional): Whether to log the time taken. Defaults to True.

    Example:
        @timeitlogit(rank=0, verbose=True)
        def my_function(arg1, arg2):
            # Function implementation
            pass
    """
    rank = get_rank() if rank is None else rank
    prefix = "timeit" if prefix is None else prefix
    try:
        import wandb
    except Exception:
        wandb = None  # type:ignore

    def decorator(func: Callable):
        """Decorator to time a function and log the time taken.

        Args:
            func (Callable): Function to be timed.
        """

        @wraps(func)
        def wrapper(*args, **kwargs):
            t0 = time.perf_counter()
            assert isinstance(rank, int)
            result = func(*args, **kwargs)
            dt = time.perf_counter() - t0
            fname = getattr(
                func, "__qualname__", getattr(func, "__name__", "unknown")
            )
            if record and wandb is not None and wandb.run is not None:
                wandb.log({f"{prefix}/{fname}": dt}, commit=False)
            if verbose and rank == 0:
                arg_str = ", ".join(map(str, args))
                kw_str = ", ".join(f"{k}={v}" for k, v in kwargs.items())
                inner = ", ".join(filter(None, [arg_str, kw_str]))
                logger.info(f"{fname}({inner}) took {dt:.4f} s")
                # if wandb is not None and wandb.run is not None:
                #     wandb.log({f"timeit/{fname}": dt}, commit=False)
            # if verbose:
            #     if rank == 0:
            #         astr = []
            #         if len(args) > 0:
            #             astr.append(f"({args}")
            #         _ = (
            #             astr.append(f", {kwargs})")
            #             if len(kwargs) > 0
            #             else (astr.append(")") if len(args) > 0 else "")
            #         )
            #         zstr = [f"Called: '{fname}' with arguments:"]
            #         if len(astr) > 0:
            #             zstr.append(f"{''.join(astr)}")
            #         zstr.append(f"'{fname}' took: {dt=:.4f} s")
            #         logger.info("\n".join(zstr))
            return result

        return wrapper

    return decorator