8000 [FSDP] Enable async collectives in FSDP with MPI backend for compute/comm and comm/comm overlap by nariaki3551 · Pull Request #153215 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[FSDP] Enable async collectives in FSDP with MPI backend for compute/comm and comm/comm overlap #153215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

nariaki3551
Copy link
Contributor
@nariaki3551 nariaki3551 commented May 8, 2025

In FSDP1, all collectives are currently invoked with async_op=False. Additionally, when using the MPI backend, CUDA stream-based scheduling is not supported, causing collectives to block both computation and other collectives. As a result, there is no overlap between communication and computation.

This PR enables async_op=True for all_gather and reduce_scatter, and explicitly delays the corresponding .wait() calls. This allows FSDP with MPI-backend to benefit from overlapping comp/comm and comm/comm.

This change only affects MPI-backend and does not change behavior in other backends.

Changes

  • Use async_op=True in:
    • all_gather during FlatParamHandle.unshard
    • reduce_scatter for gradient reduction
  • Add _all_gather_work and _reduce_scatter_work to FlatParamHandle to store in-flight work objects
  • Track in-flight collectives in _FSDPState to ensure:
    • At most one all_gather and one reduce_scatter is active
    • A new collective is not issued until the previous one has completed via .wait()

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

Copy link
pytorch-bot bot commented May 8, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153215

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 85aee04 with merge base 5683965 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels May 8, 2025
@nariaki3551 nariaki3551 changed the title Enable async collectives in FSDP with MPI backend for compute/comm and comm/comm overlap [FSDP] Enable async collectives in FSDP with MPI backend for compute/comm and comm/comm overlap May 8, 2025
@nariaki3551 nariaki3551 marked this pull request as ready for review May 8, 2025 23:25
@Skylion007 Skylion007 requested a review from awgu May 9, 2025 13:06
@awgu
Copy link
Collaborator
awgu commented May 9, 2025

It is not obvious to me that this change is correct. FSDP1 is written such that waiting on the all-gather (or reduce-scatter respectively) is done by having the current/default/compute stream wait for the separate stream from which the all-gather is issued. I am not sure that waiting on an all-gather at the point of the next all-gather is doing the same thing as that -- in fact, I would be somewhat surprised if this is the same synchronization behavior.

@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 13, 2025
@nariaki3551
Copy link
Contributor Author

@awgu

As you pointed out, the PR did not correctly enforce the ordering between communication and computation. This is because the MPI backend does not support CUDA streams, making wait_stream() ineffective for scheduling execution.

To address this, I have added to simple rule in this commit to clarify and enforce synchronization:

  • If stream_A.wait_stream(_unshard_stream) is called, we wait on the issued allgather.
  • If stream_A.wait_stream(_post_backward_stream) is called, we wait on the issued reduce_scatter.
  • Otherw 7FF6 ise, we additionally call stream_B.synchronize() when stream_A.wait_stream(stream_B) is used.

This rule enables comm and comp to overlap under the MPI backend, while still preserving the execution order previously enforced via CUDA streams.

  1. unshard (allgather) is issued only after pre_unshard is complete.
  2. reduce_grad (reduce_scatter) is issued only after backward computation is complete.
  3. Forward/backward computation begins only after unshard (allgather) has completed.
  4. FSDP1 waits for the last reduce_scatter to complete in _post_callback_final_callback function

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (fsdp) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0