10000 [Reland] Launch kernel on current stream & remove `record_stream` entirely by kwen2501 · Pull Request #150398 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Reland] Launch kernel on current stream & remove record_stream entirely #150398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

kwen2501
Copy link
Contributor
@kwen2501 kwen2501 commented Apr 1, 2025

Stack from ghstack (oldest at bottom):

Relanding #148590 due to merge conflict.

This PR has multiple changes to ProcessGroupNCCL (which unfortunately are related):

  1. When async_op=False, we directly launch the collective on "current" stream, instead of a trampoline stream and join back.
  1. Entirely remove record_stream and use CPU-side stashing for managing tensor lifetime against recycling.
  1. Remove tensor life management when async_op=False; only use it when async_op=True.
  2. To guard against user not calling work.wait(), we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion here.
  3. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels.

Joint work with @cenzhaometa who wants to remove the event sync overhead.

Squashed contents:

  • PTD creates its own dedicated ncclStream for comm operation
  • it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective
    such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us).
    This diff:
  • async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead
  • async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready
  • pass down async from c10d down to NCCL-PG
    this helps shave off 50% CPU overhead (70us -> 35us), which reduce total CPU/GPU from 230us to 195us by 15%

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

Differential Revision: D72224314

…irely

Relanding #148590 due to merge conflict.

This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately are related):
1. When async_op=False, we directly launch the collective on "current" stream, instead of a trampoline stream and join back.
- Resolves #147729
- Resolves #146881
- Also saves two event syncs (which have overhead in case of HIP) and one pybind when we call `work.wait()` in distributed_c10d.py on behalf of user.
2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling.
- Resolves #147168
3. Remove tensor life management when async_op=False; only use it when async_op=True.
4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](#147168 (comment)).
5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels.

Joint work with @cenzhaometa who wants to remove the event sync overhead.

Squashed contents:

* [ptd][nccl] use current-stream as nccl-stream under async=False mode (#147820)
PTD current workflow:
- PTD creates its own dedicated `ncclStream` for comm operation
- it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective
such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us).
This diff:
- async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead
- async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready
- pass down async from c10d down to NCCL-PG
this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%**
Differential Revision: D70135605

* [PGNCCL] Make avoid-record-stream default

* [c10d] Add asyncOp argument to Ops

* Change python side wait

* Pass asyncOp at ProcessGroup level

* Watchdog unstashing tensors as a safety net

* Stash tensors for reduce_scatter_v and all_gather_v
Pull Request approved: #149753

* [c10d] Move unstashing from watchdog to main thread
Pull Request approved: #150079

* [PGNCCL][BE] Merge mutex into TensorShelf for encapsulation
Pull Request approved: #150130

[ghstack-poisoned]
@kwen2501 kwen2501 requested a review from larryliu0820 as a code owner April 1, 2025 06:58
@pytorch-bot pytorch-bot bot added ci-no-td Do not run TD on this PR oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Apr 1, 2025
Copy link
pytorch-bot bot commented Apr 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150398

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit acf5139 with merge base 6470b37 (image):

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

kwen2501 added a commit that referenced this pull request Apr 1, 2025
…irely

Relanding #148590 due to merge conflict.

This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately are related):
1. When async_op=False, we directly launch the collective on "current" stream, instead of a trampoline stream and join back.
- Resolves #147729
- Resolves #146881
- Also saves two event syncs (which have overhead in case of HIP) and one pybind when we call `work.wait()` in distributed_c10d.py on behalf of user.
2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling.
- Resolves #147168
3. Remove tensor life management when async_op=False; only use it when async_op=True.
4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](#147168 (comment)).
5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels.

Joint work with cenzhaometa who wants to remove the event sync overhead.

Squashed contents:

* [ptd][nccl] use current-stream as nccl-stream under async=False mode (#147820)
PTD current workflow:
- PTD creates its own dedicated `ncclStream` for comm operation
- it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective
such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us).
This diff:
- async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead
- async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready
- pass down async from c10d down to NCCL-PG
this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%**
Differential Revision: D70135605

* [PGNCCL] Make avoid-record-stream default

* [c10d] Add asyncOp argument to Ops

* Change python side wait

* Pass asyncOp at ProcessGroup level

* Watchdog unstashing tensors as a safety net

* Stash tensors for reduce_scatter_v and all_gather_v
Pull Request approved: #149753

* [c10d] Move unstashing from watchdog to main thread
Pull Request approved: #150079

* [PGNCCL][BE] Merge mutex into TensorShelf for encapsulation
Pull Request approved: #150130

ghstack-source-id: ce103fc
Pull Request resolved: #150398
@kwen2501 kwen2501 requested a review from atalman April 1, 2025 07:00
@kwen2501
Copy link
Contributor Author
kwen2501 commented Apr 1, 2025

@kwen2501 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 1, 2025
Copy link
Contributor
@atalman atalman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@atalman
Copy link
Contributor
atalman commented Apr 1, 2025

@pytorchmergebot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR!

Details for Dev Infra team Raised by workflow job

@kwen2501
Copy link
Contributor Author
kwen2501 commented Apr 1, 2025

@kwen2501 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@atalman
Copy link
Contributor
atalman commented Apr 1, 2025

@pytorchmergebot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR!

Details for Dev Infra team Raised by workflow job

@atalman
Copy link
Contributor

@pytorchmergebot merge -f "lint is green, diff will be landed with internal changes"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR!

Details for Dev Infra team Raised by workflow job

@atalman
Copy link
Contributor
atalman commented Apr 1, 2025

@pytorchmergebot merge -f "lint is green, diff will be landed with internal changes"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@kwen2501
Copy link
Contributor Author
kwen2501 commented Apr 1, 2025

Internal diff is D72224314 (contains internal changes)

pytorchmergebot pushed a commit that referenced this pull request Apr 2, 2025
Update the torch-xpu-ops commit to [98c808dea6de7330c415aa777d6921944cf79887](intel/torch-xpu-ops@98c808d), include

- Fixes #150001 by removing pre-CXX11 ABI logic from build script for XPU
- Fixes #150430
- Fixes XCCL build issue caused by PR #150398

Pull Request resolved: #150554
Approved by: https://github.com/EikanWang, https://github.com/malfet
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
…irely (pytorch#150398)

Relanding pytorch#148590 due to merge conflict.

This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately are related):
1. When async_op=False, we directly launch the collective on "current" stream, instead of a trampoline stream and join back.
- Resolves pytorch#147729
- Resolves pytorch#146881
- Also saves two event syncs (which have overhead in case of HIP) and one pybind when we call `work.wait()` in distributed_c10d.py on behalf of user.
2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling.
- Resolves pytorch#147168
3. Remove tensor life management when async_op=False; only use it when async_op=True.
4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](pytorch#147168 (comment)).
5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels.

Joint work with @cenzhaometa who wants to remove the event sync overhead.

Squashed contents:

* [ptd][nccl] use current-stream as nccl-stream under async=False mode (pytorch#147820)
PTD current workflow:
- PTD creates its own dedicated `ncclStream` for comm operation
- it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective
such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us).
This diff:
- async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead
- async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready
- pass down async from c10d down to NCCL-PG
this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%**

* [PGNCCL] Make avoid-record-stream default

* [c10d] Add asyncOp argument to Ops

* Change python side wait

* Pass asyncOp at ProcessGroup level

* Watchdog unstashing tensors as a safety net

* Stash tensors for reduce_scatter_v and all_gather_v
Pull Request approved: pytorch#149753

* [c10d] Move unstashing from watchdog to main thread
Pull Request approved: pytorch#150079

* [PGNCCL][BE] Merge mutex into TensorShelf for encapsulation
Pull Request approved: pytorch#150130

Pull Request resolved: pytorch#150398
Approved by: https://github.com/atalman
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
Update the torch-xpu-ops commit to [98c808dea6de7330c415aa777d6921944cf79887](intel/torch-xpu-ops@98c808d), include

- Fixes pytorch#150001 by removing pre-CXX11 ABI logic from build script for XPU
- Fixes pytorch#150430
- Fixes XCCL build issue caused by PR pytorch#150398

Pull Request resolved: pytorch#150554
Approved by: https://github.com/EikanWang, https://github.com/malfet
@github-actions github-actions bot deleted the gh/kwen2501/139/head branch May 2, 2025 02:19
github-merge-queue bot pushed a commit to intel/torch-xpu-ops that referenced this pull request May 16, 2025
Refer pytorch/pytorch#147820
pytorch/pytorch#150398
To launch kernels on the current stream and reduce the CPU overhead
introduced by `recordStream`, an `async` option is introduced.

For example, in an `allreduce` operation between two ranks:

- `rank0` corresponds to `device0`, using the current device's `stream0`
to create the communicator and preserving `stream0`.

When `async = true`:
- Both `rank0` and `rank1` perform the collective using `stream0`, which
is associated with the communicator.
- To prevent potential reads by `stream0` from unready tensors (e.g.,
from `rank1`), synchronization with the current stream is required.
- After the collective completes, to prevent premature release of the
input tensors, `recordStream` must be used for stream tracking, or the
tensors need to be temporarily stored (e.g., in `reduce_scatter` or
`all2all`).

When `async = false`:
- Both `rank0` and `rank1` use their respective **current streams** for
collectives (i.e., `rank0` uses `stream0`, `rank1` uses `stream1`).
- In this case, the collective op handles synchronization implicitly.

Previously, we defaulted to `async = true`. Now, the `async` option is
explicitly introduced and set to `false` by default, leveraging the
current stream to avoid the overhead of stream synchronization.

---------

Co-authored-by: mengfei25 <mengfei.li@Intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0