-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[FSDP] Enable async collectives in FSDP with MPI backend for compute/comm and comm/comm overlap #153215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153215
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 85aee04 with merge base 5683965 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
It is not obvious to me that this change is correct. FSDP1 is written such that waiting on the all-gather (or reduce-scatter respectively) is done by having the current/default/compute stream wait for the separate stream from which the all-gather is issued. I am not sure that waiting on an all-gather at the point of the next all-gather is doing the same thing as that -- in fact, I would be somewhat surprised if this is the same synchronization behavior. |
As you pointed out, the PR did not correctly enforce the ordering between communication and computation. This is because the MPI backend does not support CUDA streams, making wait_stream() ineffective for scheduling execution. To address this, I have added to simple rule in this commit to clarify and enforce synchronization:
This rule enables comm and comp to overlap under the MPI backend, while still preserving the execution order previously enforced via CUDA streams.
|
In FSDP1, all collectives are currently invoked with async_op=False. Additionally, when using the MPI backend, CUDA stream-based scheduling is not supported, causing collectives to block both computation and other collectives. As a result, there is no overlap between communication and computation.
This PR enables async_op=True for all_gather and reduce_scatter, and explicitly delays the corresponding .wait() calls. This allows FSDP with MPI-backend to benefit from overlapping comp/comm and comm/comm.
This change only affects MPI-backend and does not change behavior in other backends.
Changes
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k