8000 GitHub - gpauloski/kfac-pytorch: Distributed K-FAC Preconditioner for PyTorch
[go: up one dir, main page]

Skip to content


Repository files navigation

Distributed K-FAC Preconditioner for PyTorch

DOI pre-commit.ci status Tests Integration

K-FAC, Kronecker-factored Approximate Curvature, is a second-order optimization method based on an efficient approximation of the Fisher information matrix (see the original paper). This repository provides a PyTorch implementation of K-FAC as a preconditioner to standard PyTorch optimizers with support for single-device or distributed training. The distributed strategy is implemented using KAISA, a K-FAC-enabled, Adaptable, Improved, and Scalable second-order optimizer framework, where the placement of the second-order computations and gradient preconditioning is controlled by the gradient worker fraction parameter (see the paper for more details). KAISA has been shown to reduce time-to-convergence in PyTorch distributed training applications such as ResNet-50, Mask R-CNN, and BERT.


Table of Contents



K-FAC only requires PyTorch 1.8 or later. The example scripts have additional requirements defined in examples/requirements.txt.


$ git clone https://github.com/gpauloski/kfac_pytorch.git
$ cd kfac_pytorch
$ pip install .  # Use -e to install in development mode

If NVIDIA Apex is installed with C extensions, the optimized flatten and unflatten operations will be used during collective communication operations.


K-FAC requires minimal code to incorporate with existing training scripts. See the K-FAC docstring for a detailed list of K-FAC parameters.

from kfac.preconditioner import KFACPreconditioner


model = torch.nn.parallel.DistributedDataParallel(...)
optimizer = optim.SGD(model.parameters(), ...)
preconditioner = KFACPreconditioner(model, ...)


for data, target in train_loader:
    output = model(data)
    loss = criterion(output, target)


See the wiki for more details on K-FAC's features.


Example scripts for training ResNet models on Cifar10 and ImageNet-1k are provided in examples/.


tox and pre-commit are used for development. Pre-commit enforces the code formatting, linting, and type-checking in this repository.

To get started with local development (note: Python 3.11 is supported but some testing dependencies are not available):

$ tox --devenv venv -e py310
$ . venv/bin/activate
$ pre-commit install

Note that the tox recipes install CPU-only PyTorch as GPUs are not available in CI.

To verify code passes pre-commit, run:

$ pre-commit run --all-files

Tox can also be used to run the test suite:

$ tox -e py39  # run all tests in Python 3.9

Citations and References

The K-FAC code is based on Chaoqi Wang's KFAC-PyTorch. The ResNet models for Cifar10 are from Yerlan Idelbayev's pytorch_resnet_cifar10. The CIFAR-10 and ImageNet-1k training scripts are modeled after Horovod's example PyTorch training scripts.

The code used in "Convolutional Neural Network Training with Distributed K-FAC" is frozen in the kfac-lw and kfac-opt branches. The code used in "KAISA: An Adaptive Second-order Optimizer Framework for Deep Neural Networks" is frozen in the hybrid-opt branch.

If you use this code in your work, please cite the SC '20 and '21 papers.

    author = {Pauloski, J. Gregory and Zhang, Zhao and Huang, Lei and Xu, Weijia and Foster, Ian T.},
    title = {Convolutional {N}eural {N}etwork {T}raining with {D}istributed {K}-{FAC}},
    year = {2020},
    isbn = {9781728199986},
    publisher = {IEEE Press},
    booktitle = {Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis},
    articleno = {94},
    numpages = {14},
    location = {Atlanta, Georgia},
    series = {SC '20},
    doi = {10.5555/3433701.3433826}

    author = {Pauloski, J. Gregory and Huang, Qi and Huang, Lei and Venkataraman, Shivaram and Chard, Kyle and Foster, Ian and Zhang, Zhao},
    title = {KAISA: {A}n {A}daptive {S}econd-{O}rder {O}ptimizer {F}ramework for {D}eep {N}eural {N}etworks},
    year = {2021},
    isbn = {9781450384421},
    publisher = {Association for Computing Machinery},
    address = {New York, NY, USA},
    url = {https://doi.org/10.1145/3458817.3476152},
    doi = {10.1145/3458817.3476152},
    booktitle = {Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis},
    articleno = {13},
    numpages = {14},
    location = {St. Louis, Missouri},
    series = {SC '21}