8000 [PGNCCL] Launch kernel on current stream & remove `record_stream` entirely by kwen2501 · Pull Request #148590 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[PGNCCL] Launch kernel on current stream & remove record_stream entirely #148590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from

Conversation

kwen2501
Copy link
Contributor
@kwen2501 kwen2501 commented Mar 5, 2025

Stack from ghstack (oldest at bottom):

This PR has multiple changes to ProcessGroupNCCL (which unfortunately are related):

  1. When async_op=False, we directly launch the collective on "current" stream, instead of a trampoline stream and join back.
  1. Entirely remove record_stream and use CPU-side stashing for managing tensor lifetime against recycling.
  1. Remove tensor life management when async_op=False; only use it when async_op=True.
  2. To guard against user not calling work.wait(), we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion here.
  3. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels.

Joint work with @cenzhaometa who wants to remove the event sync overhead.

Cc: @ngimel @awgu @Aidyn-A @skyw @wconstab @leonardo0lyj

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@diff-train-skip-merge

Differential Revision: D71652868

…147820)

Summary:

PTD current workflow:
- PTD creates its own dedicated `ncclStream` for comm operation
- it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective

such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us).

This diff:
- async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead
- async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready
- pass down async from c10d down to NCCL-PG

this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%**

Differential Revision: D70135605

[PGNCCL] Make avoid-record-stream default

[c10d] Add asyncOp argument to Ops

Change python side wait

Pass asyncOp at ProcessGroup level

Watchdog unstashing tensors as a safety net

lint

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Mar 5, 2025
Copy link
pytorch-bot bot commented Mar 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148590

Note: Links to docs will display an error until the docs builds have been completed.

❌ 7 New Failures, 2 Unrelated Failures

As of commit e933dfb with merge base 666508e (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

kwen2501 added a commit that referenced this pull request Mar 5, 2025
…147820)

Summary:

PTD current workflow:
- PTD creates its own dedicated `ncclStream` for comm operation
- it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective

such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us).

This diff:
- async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead
- async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready
- pass down async from c10d down to NCCL-PG

this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%**

Differential Revision: D70135605

[PGNCCL] Make avoid-record-stream default

[c10d] Add asyncOp argument to Ops

Change python side wait

Pass asyncOp at ProcessGroup level

Watchdog unstashing tensors as a safety net

lint

ghstack-source-id: 4793680
Pull Request resolved: #148590
@kwen2501 kwen2501 changed the title [ptd][nccl] use current-stream as nccl-stream under async=False mode (#147820) [PGNCCL] Launch kernel on current stream & remove record_stream entirely Mar 5, 2025
@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 5, 2025
…stream` entirely"


This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately have to be atomic):
1. When async_op=False, we directly launch the collective on "current" stream instead of a trampoline stream and join back. 
- Resolves #147729
- Resolves #146881
- Also saves an event sync and one pybind during the unnecessary `work.wait()` called by distributed_c10d.py.
2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling.
- Resolves #147168
3. Remove tensor life management when async_op=False; only use it when async_op=True.
4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](#147168 (comment)).
5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels.

Joint work with cenzhaometa who wants to remove the event sync overhead.

Cc: ngimel awgu Aidyn-A skyw wconstab leonardo0lyj

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Mar 6, 2025
…147820)

Summary:

PTD current workflow:
- PTD creates its own dedicated `ncclStream` for comm operation
- it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective

such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us).

This diff:
- async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead
- async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready
- pass down async from c10d down to NCCL-PG

this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%**

Differential Revision: D70135605

[PGNCCL] Make avoid-record-stream default

[c10d] Add asyncOp argument to Ops

Change python side wait

Pass asyncOp at ProcessGroup level

Watchdog unstashing tensors as a safety net

lint

ghstack-source-id: ac5295d
Pull Request resolved: #148590
Copy link
linux-foundation-easycla bot commented Mar 6, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

@taozhiwei
Copy link
Contributor

#148553
I mentioned a similar one

…stream` entirely"


This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately have to be atomic):
1. When async_op=False, we directly launch the collective on "current" stream instead of a trampoline stream and join back. 
- Resolves #147729
- Resolves #146881
- Also saves an event sync and one pybind during the unnecessary `work.wait()` called by distributed_c10d.py.
2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling.
- Resolves #147168
3. Remove tensor life management when async_op=False; only use it when async_op=True.
4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](#147168 (comment)).
5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels.

Joint work with cenzhaometa who wants to remove the event sync overhead.

Cc: ngimel awgu Aidyn-A skyw wconstab leonardo0lyj

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Mar 6, 2025
…147820)

Summary:

PTD current workflow:
- PTD creates its own dedicated `ncclStream` for comm operation
- it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective

such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us).

This diff:
- async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead
- async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready
- pass down async from c10d down to NCCL-PG

this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%**

Differential Revision: D70135605

[PGNCCL] Make avoid-record-stream default

[c10d] Add asyncOp argument to Ops

Change python side wait

Pass asyncOp at ProcessGroup level

Watchdog unstashing tensors as a safety net

lint

ghstack-source-id: 0f222d3
Pull Request resolved: #148590
@kwen2501 kwen2501 added the keep-going Don't stop on first failure, keep running tests until the end label Mar 6, 2025
…stream` entirely"


This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately have to be atomic):
1. When async_op=False, we directly launch the collective on "current" stream instead of a trampoline stream and join back. 
- Resolves #147729
- Resolves #146881
- Also saves an event sync and one pybind during the unnecessary `work.wait()` called by distributed_c10d.py.
2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling.
- Resolves #147168
3. Remove tensor life management when async_op=False; only use it when async_op=True.
4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](#147168 (comment)).
5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels.

Joint work with cenzhaometa who wants to remove the event sync overhead.

Cc: ngimel awgu Aidyn-A skyw wconstab leonardo0lyj

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Mar 6, 2025
…147820)

Summary:

PTD current workflow:
- PTD creates its own dedicated `ncclStream` for comm operation
- it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective

such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us).

This diff:
- async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead
- async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready
- pass down async from c10d down to NCCL-PG

this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%**

Differential Revision: D70135605

[PGNCCL] Make avoid-record-stream default

[c10d] Add asyncOp argument to Ops

Change python side wait

Pass asyncOp at ProcessGroup level

Watchdog unstashing tensors as a safety net

lint

ghstack-source-id: 8306cce
Pull Request resolved: #148590
@kwen2501
Copy link
Contributor Author
kwen2501 commented Mar 6, 2025

@albanD Would appreciate your help:

The PR is introducing backward incompatible changes to the operator library. Please contact PyTorch team to confirm whether this change is wanted or not

We are adding asyncOp argument to the ops.
Thanks!

@taozhiwei
Copy link
Contributor
taozhiwei commented Mar 7, 2025

Stack from ghstack (oldest at bottom):

This PR has multiple changes to ProcessGroupNCCL (which unfortunately have to be atomic):

  1. When async_op=False, we directly launch the collective on "current" stream instead of a trampoline stream and join back.
  1. Entirely remove record_stream and use CPU-side stashing for managing tensor lifetime against recycling.
  1. Remove tensor life management when async_op=False; only use it when async_op=True.
  2. To guard against user not calling work.wait(), we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion here.
  3. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels.

Joint work with @cenzhaometa who wants to remove the event sync overhead.

Cc: @ngimel @awgu @Aidyn-A @skyw @wconstab @leonardo0lyj

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

  1. In cuda graph mode,watchdogHandler is disable. when async_op=True,and user not calling work.wait(), Can't input be released?
  2. when async_op=False, users code will ensure input is not released prematurely;There shouldn't be any need to make current stream as nccl stream? make current stream as nccl stream will break the user's habit of looking at Profile.

@kwen2501
Copy link
Contributor Author
kwen2501 commented Mar 7, 2025
  1. Users are in general expected to call work.wait() in async mode, esp when they are CUDA Graphing. Failure to do so can results in holding of the tensors. That's the side effect of making avoid-record the default. We are just using watchdog to mitigate that side effect.
  2. The change to use current stream has some other reasons than tensor lifetime management, such as reducing stream sync overhead.

…stream` entirely"


This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately have to be atomic):
1. When async_op=False, we directly launch the collective on "current" stream instead of a trampoline stream and join back. 
- Resolves #147729
- Resolves #146881
- Also saves an event sync and one pybind during the unnecessary `work.wait()` called by distributed_c10d.py.
2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling.
- Resolves #147168
3. Remove tensor life management when async_op=False; only use it when async_op=True.
4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](#147168 (comment)).
5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels.

Joint work with cenzhaometa who wants to remove the event sync overhead.

Cc: ngimel awgu Aidyn-A skyw wconstab leonardo0lyj

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Mar 7, 2025
…147820)

Summary:

PTD current workflow:
- PTD creates its own dedicated `ncclStream` for comm operation
- it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective

such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us).

This diff:
- async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead
- async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready
- pass down async from c10d down to NCCL-PG

this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%**

Differential Revision: D70135605

[PGNCCL] Make avoid-record-stream default

[c10d] Add asyncOp argument to Ops

Change python side wait

Pass asyncOp at ProcessGroup level

Watchdog unstashing tensors as a safety net

lint

ghstack-source-id: c31ca32
Pull Request resolved: #148590
@taozhiwei
Copy link
Contributor
  1. Users are in general expected to call work.wait() in async mode, esp when they are CUDA Graphing. Failure to do so can results in holding of the tensors. That's the side effect of making avoid-record the default. We are just using watchdog to mitigate that side effect.
  2. The change to use current stream has some other reasons than tensor lifetime management, such as reducing stream sync overhead.

should be able to add a check in A syncStream If it's the same stream, don't call record and block anymore?

…stream` entirely"


This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately have to be atomic):
1. When async_op=False, we directly launch the collective on "current" stream instead of a trampoline stream and join back. 
- Resolves #147729
- Resolves #146881
- Also saves an event sync and one pybind during the unnecessary `work.wait()` called by distributed_c10d.py.
2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling.
- Resolves #147168
3. Remove tensor life management when async_op=False; only use it when async_op=True.
4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](#147168 (comment)).
5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels.

Joint work with cenzhaometa who wants to remove the event sync overhead.

Cc: ngimel awgu Aidyn-A skyw wconstab leonardo0lyj

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Mar 7, 2025
…147820)

Summary:

PTD current workflow:
- PTD creates its own dedicated `ncclStream` for comm operation
- it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective

such stream synchronization become expensive in 
8000
Inference world (cpu overhead: 70us vs GPU kernel time: 160us).

This diff:
- async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead
- async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready
- pass down async from c10d down to NCCL-PG

this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%**

Differential Revision: D70135605

[PGNCCL] Make avoid-record-stream default

[c10d] Add asyncOp argument to Ops

Change python side wait

Pass asyncOp at ProcessGroup level

Watchdog unstashing tensors as a safety net

lint

ghstack-source-id: dfc758a
Pull Request resolved: #148590
@kwen2501 kwen2501 requested review from wconstab, fduwjj, eqy and Aidyn-A March 7, 2025 18:19
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR!

Details for Dev Infra team Raised by workflow job

@atalman
Copy link
Contributor
atalman commented Mar 31, 2025

@pytorchbot merge -f "already landed"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR!

Details for Dev Infra team Raised by workflow job

@atalman
Copy link
Contributor
atalman commented Mar 31, 2025

@pytorchbot merge -f "already landed internally"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x 269f45f641011f4da5fc7d38e973036e04489b72 returned non-zero exit code 1

Auto-merging torch/_C/_distributed_c10d.pyi
Auto-merging torch/csrc/distributed/c10d/ProcessGroup.hpp
Auto-merging torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
CONFLICT (content): Merge conflict in torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Auto-merging torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
Auto-merging torch/csrc/distributed/c10d/init.cpp
Auto-merging torch/distributed/distributed_c10d.py
error: could not apply 269f45f6410... [ptd][nccl] use current-stream as nccl-stream under async=False mode (#147820)
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

atalman added a commit to atalman/pytorch that referenced this pull request Mar 31, 2025
malfet pushed a commit that referenced this pull request Mar 31, 2025
…eam` entirely (#148590) (#150352)

Revert "[PGNCCL] Launch kernel on current stream & remove `record_stream` entirely (#148590)"

This reverts commit ef6296e.
kwen2501 added a commit that referenced this pull request Apr 1, 2025
…irely

Relanding #148590 due to merge conflict.

This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately are related):
1. When async_op=False, we directly launch the collective on "current" stream, instead of a trampoline stream and join back.
- Resolves #147729
- Resolves #146881
- Also saves two event syncs (which have overhead in case of HIP) and one pybind when we call `work.wait()` in distributed_c10d.py on behalf of user.
2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling.
- Resolves #147168
3. Remove tensor life management when async_op=False; only use it when async_op=True.
4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](#147168 (comment)).
5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels.

Joint work with @cenzhaometa who wants to remove the event sync overhead.

Squashed contents:

* [ptd][nccl] use current-stream as nccl-stream under async=False mode (#147820)
PTD current workflow:
- PTD creates its own dedicated `ncclStream` for comm operation
- it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective
such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us).
This diff:
- async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead
- async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready
- pass down async from c10d down to NCCL-PG
this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%**
Differential Revision: D70135605

* [PGNCCL] Make avoid-record-stream default

* [c10d] Add asyncOp argument to Ops

* Change python side wait

* Pass asyncOp at ProcessGroup level

* Watchdog unstashing tensors as a safety net

* Stash tensors for reduce_scatter_v and all_gather_v
Pull Request approved: #149753

* [c10d] Move unstashing from watchdog to main thread
Pull Request approved: #150079

* [PGNCCL][BE] Merge mutex into TensorShelf for encapsulation
Pull Request approved: #150130

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Apr 1, 2025
…irely

Relanding #148590 due to merge conflict.

This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately are related):
1. When async_op=False, we directly launch the collective on "current" stream, instead of a trampoline stream and join back.
- Resolves #147729
- Resolves #146881
- Also saves two event syncs (which have overhead in case of HIP) and one pybind when we call `work.wait()` in distributed_c10d.py on behalf of user.
2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling.
- Resolves #147168
3. Remove tensor life management when async_op=False; only use it when async_op=True.
4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](#147168 (comment)).
5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels.

Joint work with cenzhaometa who wants to remove the event sync overhead.

Squashed contents:

* [ptd][nccl] use current-stream as nccl-stream under async=False mode (#147820)
PTD current workflow:
- PTD creates its own dedicated `ncclStream` for comm operation
- it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective
such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us).
This diff:
- async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead
- async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready
- pass down async from c10d down to NCCL-PG
this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%**
Differential Revision: D70135605

* [PGNCCL] Make avoid-record-stream default

* [c10d] Add asyncOp argument to Ops

* Change python side wait

* Pass asyncOp at ProcessGroup level

* Watchdog unstashing tensors as a safety net

* Stash tensors for reduce_scatter_v and all_gather_v
Pull Request approved: #149753

* [c10d] Move unstashing from watchdog to main thread
Pull Request approved: #150079

* [PGNCCL][BE] Merge mutex into TensorShelf for encapsulation
Pull Request approved: #150130

ghstack-source-id: ce103fc
Pull Request resolved: #150398
github-merge-queue bot pushed a commit to intel/torch-xpu-ops that referenced this pull request Apr 1, 2025
Reverts #1450

Original PR (pytorch/pytorch#148590) in PyTorch
got reverted:
pytorch/pytorch@afa1eda

---------

Co-authored-by: Yutao Xu <yutao.xu@intel.com>
pytorchmergebot pushed a commit that referenced this pull request Apr 1, 2025
…irely (#150398)

Relanding #148590 due to merge conflict.

This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately are related):
1. When async_op=False, we directly launch the collective on "current" stream, instead of a trampoline stream and join back.
- Resolves #147729
- Resolves #146881
- Also saves two event syncs (which have overhead in case of HIP) and one pybind when we call `work.wait()` in distributed_c10d.py on behalf of user.
2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling.
- Resolves #147168
3. Remove tensor life management when async_op=False; only use it when async_op=True.
4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](#147168 (comment)).
5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels.

Joint work with @cenzhaometa who wants to remove the event sync overhead.

Squashed contents:

* [ptd][nccl] use current-stream as nccl-stream under async=False mode (#147820)
PTD current workflow:
- PTD creates its own dedicated `ncclStream` for comm operation
- it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective
such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us).
This diff:
- async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead
- async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready
- pass down async from c10d down to NCCL-PG
this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%**

* [PGNCCL] Make avoid-record-stream default

* [c10d] Add asyncOp argument to Ops

* Change python side wait

* Pass asyncOp at ProcessGroup level

* Watchdog unstashing tensors as a safety net

* Stash tensors for reduce_scatter_v and all_gather_v
Pull Request approved: #149753

* [c10d] Move unstashing from watchdog to main thread
Pull Request approved: #150079

* [PGNCCL][BE] Merge mutex into TensorShelf for encapsulation
Pull Request approved: #150130

Pull Request resolved: #150398
Approved by: https://github.com/atalman
chuanqi129 pushed a commit to intel/torch-xpu-ops that referenced this pull request Apr 2, 2025
Pytorch introduce new stream method in
pytorch/pytorch#148590, which update base
distributed interface. This pr align with latest register interface.
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
…irely (pytorch#150398)

Relanding pytorch#148590 due to merge conflict.

This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately are related):
1. When async_op=False, we directly launch the collective on "current" stream, instead of a trampoline stream and join back.
- Resolves pytorch#147729
- Resolves pytorch#146881
- Also saves two event syncs (which have overhead in case of HIP) and one pybind when we call `work.wait()` in distributed_c10d.py on behalf of user.
2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling.
- Resolves pytorch#147168
3. Remove tensor life management when async_op=False; only use it when async_op=True.
4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](pytorch#147168 (comment)).
5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels.

Joint work with @cenzhaometa who wants to remove the event sync overhead.

Squashed contents:

* [ptd][nccl] use current-stream as nccl-stream under async=False mode (pytorch#147820)
PTD current workflow:
- PTD creates its own dedicated `ncclStream` for comm operation
- it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective
such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us).
This diff:
- async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead
- async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready
- pass down async from c10d down to NCCL-PG
this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%**

* [PGNCCL] Make avoid-record-stream default

* [c10d] Add asyncOp argument to Ops

* Change python side wait

* Pass asyncOp at ProcessGroup level

* Watchdog unstashing tensors as a safety net

* Stash tensors for reduce_scatter_v and all_gather_v
Pull Request approved: pytorch#149753

* [c10d] Move unstashing from watchdog to main thread
Pull Request approved: pytorch#150079

* [PGNCCL][BE] Merge mutex into TensorShelf for encapsulation
Pull Request approved: pytorch#150130

Pull Request resolved: pytorch#150398
Approved by: https://github.com/atalman
@kwen2501
Copy link
Contributor Author
kwen2501 commented May 6, 2025

Landed

@kwen2501 kwen2501 closed this May 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category Reverted
Projects
None yet
Development

Successfully merging this pull request may close these issues.

0