-
Notifications
You must be signed in to change notification settings - Fork 1.1k
NCCL Fast Init - CPU Optimizations for NCCL Initialization Large Scale #1789
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NCCL Fast Init - CPU Optimizations for NCCL Initialization Large Scale #1789
Conversation
At large scale 32K+ GPUs we start to see a significant initialization time coming from `ncclBuildRing` and `initTransportsRank` - often several dozens of seconds at 100K scale. This changes optimizes the two functions to remove the overhead. And enables NCCL to initialize fast at 100K scale.
| sprintf(prefix, "[%d] Channel %d Next : ", rank, r); | ||
| dumpLine(next+r*nranks, nranks, prefix);*/ | ||
|
|
||
| std::vector<bool> rankBitSet(nranks, false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm using std::vector here but open to other choices. My rationale for using vector is that
- nccl is using C++ std library so using vector doesn't add a new dependency
- Memory is automatically managed with std container
- Most notably
std::vector<bool>has a space efficient implementation to use 1 bit per entry instead of 1 byte, thus being space as well as page cache efficient. https://en.cppreference.com/w/cpp/container/vector_bool/reference
|
Hi @sjeaugey and @marksantesson - I've sent out a new PR for optimizing init times at large scale that surprised us at Meta. We believe it'll be of help to broader community. I'm happy to iterate with you to improve the PR based on your feedback. |
|
Hi @saifhhasan. Thanks for your contribution. I have a question about the |
|
Hi @saifhhasan . We will have a similar change in |
|
Thank you @stephenmsachs for helping take a look at the PR. Regarding init.cc changes, what you present could be the case edge case. |
GPU-Initiated Networking (GIN):
* Provides device-side API for integrating GPU-Initiated Networking
capability into application kernels.
* New transport layer called DOCA GPUNetIO.
* New ncclGin construct to create, destroy and manipulate GIN contexts.
* New ncclGinBarrierSession to provide synchronization functionality.
* New put, signal, counter operations for data movement and signaling.
* GIN API signatures and functionalities are subject to change.
* GIN Support Requirements
* CUDA 12.2 or later when compiling the GPU code
* NVIDIA GPUs: Volta or newer. NVIDIA GPU drivers >= 510.40.3
* NVIDIA NICs: CX4 or newer. rdma-core >= 44.0
* Requires nvidia-peermem or DMABUF support. When using DMABUF, linux
kernel >= 6.1 is required.
New ncclCommRevoke API for fault tolerance:
* Introduces ncclCommRevoke to quiesce ongoing NCCL work on a
communicator without freeing resources.
* This answers the need for a lightweight way to cancel in-flight
collectives and bring a communicator to a safe state before
split/shrink/finalize/destroy.
* Includes optional cross-rank coordination (global barrier) and
supports blocking/non-blocking usage.
New NCCL Environment Plugin:
* The env plugin allows users to set NCCL environment variables, for
example, after loading them from a centralized database.
* The NCCL_ENV_PLUGIN variable can be used to let NCCL load an external
environment plugin.
New NCCL Examples on GitHub:
* The NCCL examples directory provides users and developers with
practical code samples that highlight NCCL’s core features.
* It covers basic operations like communicator initialization,
point-to-point communication, and collective operations, as well as
advanced features such as user buffer registration, symmetric memory,
and the device API.
Device API improvements:
* Adds ncclFindWindow API.
* Adds new ncclBarrierSession to provide hybrid synchronization
functionality.
* Makes multimem available with as few as two ranks.
* Removes distance (NCCL_P2P_LEVEL) considerations from determining the
availability of symmetric memory.
Enhanced NCCL RAS output:
* Extends RAS subsystem with JSON format to support machine-parsable
metrics collection.
* Enables structured data export for monitoring tools, dashboards, and
automated analysis systems.
Github Pull Requests resolved:
* Fast Init - CPU Optimizations for NCCL Initialization Large Scale.
(PR #1789)
* Fast Init - Improve Bootstrap AllGather by 2x at large scale by
sending bootstrap information bidirectionally. (PR #1791)
* Fixes spurious failures when PyTorch is statically linked with
NCCL-2.28.3 because error is not drained, but rather gets propagated
into the next CUDA kernel invocation. (PR #1864)
Other notable improvements:
* Fixes multicast object leaks in case of failed NVLS user buffer
registrations, which could lead to crashes. Avoids such registration
attempts in case of the use of incompatible memory allocators.
* Fixes potential data corruption with built-in symmetric kernels for
small messages with size granularity under 8 bytes or when multiple
symmetric operations were aggregated in a group.
* Generalizes the existing point-to-point scheduling to the case of
un-even GPU count per node.
* Fixes a crash when network plugin assignment fails.
* Fixes a large performance issue with NCCL_CROSS_NIC=0 and certain
split mask settings, where NCCL cannot find a viable ring.
* Fixes crash when NCCL is compiled with recent CUDA versions but
running on hosts with certain specific older CUDA drivers.
|
This has been accepted and released in the latest 2.28.7 release. @saifhhasan |
Problem
At large scale 32K+ GPUs we start to see a significant initialization time coming from
ncclBuildRingandinitTransportsRank- often several dozens of seconds at 100K scale. This occurs because both of these functions performs nested loops of complexity O(N*N), which at 100K scale translates to 10B loop iterations, each loop iteration with multiple instructions.We used CPU Profiling tools that helped us spot the loops that were taking excessive amount of time during the initialization phase.
Observations
This changes optimizes the two functions to remove the overhead. And enables NCCL to initialize fast at 100K scale. We tested this at Meta at Scale and were able to observe following savings in numbers.
ncclBuildRingsaves 26s of busy CPU cycles.initTranportsRankoptimization saves of 11+ secondsMost notably this patch helps reduce CPU Utilization during job startup phase which is also crucial for other ongoing operations at a job level (Checkpoint initialization, Model Loading, Data Fetching etc.).
Testing
At meta we've tested these fixes on our large scale clusters. On the top I'm also adding a simple stanalone benchmark to demonstrate the performance improvement of
ncclBuildRingsbefore and after. The numbers are as below for varying number of ranks from 1024 to 96K. With 16 rings, we save about 13s, for 32 rings it'll be about 52s of duration.Test Binary Code
Can be compiled standalone by linking with gtest and google-benchmark
nccl-build-ring-benchmark.cpp.txt