10000 [ptd][nccl] use current-stream as nccl-stream under async=False mode … · pytorch/pytorch@573b7e2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 573b7e2

Browse files
committed
[ptd][nccl] use current-stream as nccl-stream under async=False mode (#147820)
Summary: PTD current workflow: - PTD creates its own dedicated `ncclStream` for comm operation - it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us). This diff: - async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead - async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready - pass down async from c10d down to NCCL-PG this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%** Differential Revision: D70135605 [PGNCCL] Make avoid-record-stream default [c10d] Add asyncOp argument to Ops Change python side wait Pass asyncOp at ProcessGroup level Watchdog unstashing tensors as a safety net lint ghstack-source-id: 27b228c Pull Request resolved: #148590
1 parent 666508e commit 573b7e2

File tree

11 files changed

+411
-362
lines changed

11 files changed

+411
-362
lines changed

test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,9 @@ class TestDebugInfoWriter : public c10d::DebugInfoWriter {
363363
};
364364

365365
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
366+
// Note (kwen2501) 03/07/2025
367+
// TODO: re-enable
368+
GTEST_SKIP() << "Skipping test as the trace write seems unstable.";
366369
int heartBeatIntervalInSec = 2;
367370
std::string timeInterval = std::to_string(heartBeatIntervalInSec);
368371
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0);

test/forward_backward_compatibility/check_forward_backward_compatibility.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@
126126
("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)),
127127
("aten::all_gather_into_tensor", datetime.date(9999, 1, 30)),
128128
("aten::all_reduce", datetime.date(9999, 1, 30)),
129+
# These ops are defined in torch/csrc/distributed/c10d/Ops.cpp
130+
# TODO: add back restriction when c10d ops can be exported
131+
("c10d::.*", datetime.date(9999, 1, 1)),
129132
]
130133

131134
ALLOW_LIST_COMPILED = [

torch/_C/_distributed_c10d.pyi

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# mypy: disable-error-code="type-arg"
33
from datetime import timedelta
44
from enum import Enum
5-
from typing import Any, overload
5+
from typing import Any, Optional, overload
66

77
import torch
88
from torch import Tensor
@@ -139,6 +139,8 @@ class BroadcastOptions:
139139
class AllreduceOptions:
140140
reduceOp: ReduceOp
141141
timeout: timedelta
142+
asyncOp: bool
143+
sparseIndices: Optional[Tensor]
142144

143145
class AllreduceCoalescedOptions(AllreduceOptions): ...
144146

@@ -147,6 +149,7 @@ class ReduceOptions:
147149
rootRank: int
148150
rootTensor: int
149151
timeout: timedelta
152+
asyncOp: bool
150153

151154
class AllgatherOptions:
152155
timeout: timedelta
@@ -155,6 +158,7 @@ class AllgatherOptions:
155158
class GatherOptions:
156159
rootRank: int
157160
timeout: timedelta
161+
asyncOp: bool
158162

159163
class ScatterOptions:
160164
rootRank: int
@@ -170,9 +174,11 @@ class BarrierOptions:
170174
device_ids: list[int]
171175
device: torch.device
172176
timeout: timedelta
177+
asyncOp: bool
173178

174179
class AllToAllOptions:
175180
timeout: timedelta
181+
asyncOp: bool
176182

177183
class Store:
178184
def set(self, key: str, value: str): ...

torch/csrc/distributed/c10d/Ops.cpp

Lines changed: 79 additions & 51 deletions
Large diffs are not rendered by default.

torch/csrc/distributed/c10d/ProcessGroup.hpp

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
224224
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
225225
const c10::intrusive_ptr<::c10d::ReduceOp>&,
226226
const std::optional<at::Tensor>& sparse_indices,
227+
bool,
227228
int64_t)>();
228229

229230
auto work = std::get<1>(op.call(
230231
tensors,
231232
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
232233
c10::make_intrusive<ReduceOp>(opts.reduceOp),
233234
opts.sparseIndices,
235+
opts.asyncOp,
234236
opts.timeout.count()));
235237

236238
if (c10d::allow_inflight_collective_as_graph_input()) {
@@ -250,12 +252,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
250252
at::TensorList,
251253
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
252254
const c10::intrusive_ptr<::c10d::ReduceOp>&,
255+
bool,
253256
int64_t)>();
254257

255258
auto work = op.call(
256259
tensors,
257260
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
258261
c10::make_intrusive<ReduceOp>(opts.reduceOp),
262+
opts.asyncOp,
259263
opts.timeout.count());
260264

261265
if (c10d::allow_inflight_collective_as_graph_input()) {
@@ -277,13 +281,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
277281
const c10::intrusive_ptr<::c10d::ReduceOp>&,
278282
int64_t,
279283
int64_t,
284+
bool,
280285
int64_t)>();
281286
auto work = op.call(
282287
tensors,
283288
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
284289
c10::make_intrusive<ReduceOp>(opts.reduceOp),
285290
opts.rootRank,
286291
opts.rootTensor,
292+
opts.asyncOp,
287293
opts.timeout.count());
288294

289295
if (c10d::allow_inflight_collective_as_graph_input()) {
@@ -306,12 +312,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
306312
const std::vector<std::vector<at::Tensor>>&,
307313
at::TensorList,
308314
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
315+
bool,
309316
int64_t)>();
310317

311318
auto work = std::get<1>(op.call(
312319
outputTensors,
313320
inputTensors,
314321
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
322+
opts.asyncOp,
315323
opts.timeout.count()));
316324

317325
if (c10d::allow_inflight_collective_as_graph_input()) {
@@ -363,18 +371,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
363371
std::vector<std::vector<at::Tensor>>& outputTensorLists,
364372
std::vector<at::Tensor>& inputTensors,
365373
const AllgatherOptions& opts = AllgatherOptions()) {
366-
static auto op =
367-
c10::Dispatcher::singleton()
368-
.findSchemaOrThrow("c10d::allgather_coalesced_", "")
369-
.typed<c10::intrusive_ptr<Work>(
370-
const std::vector<std::vector<at::Tensor>>&,
371-
const at::TensorList&,
372-
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
374+
static auto op = c10::Dispatcher::singleton()
375+
.findSchemaOrThrow("c10d::allgather_coalesced_", "")
376+
.typed<c10::intrusive_ptr<Work>(
377+
const std::vector<std::vector<at::Tensor>>&,
378+
const at::TensorList&,
379+
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
380+
bool)>();
373381

374382
auto work = op.call(
375383
outputTensorLists,
376384
inputTensors,
377-
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
385+
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
386+
opts.asyncOp);
378387

379388
if (c10d::allow_inflight_collective_as_graph_input()) {
380389
for (const auto& tensor_list : outputTensorLists) {
@@ -399,12 +408,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
399408
.typed<c10::intrusive_ptr<Work>(
400409
const at::TensorList,
401410
const at::TensorList,
402-
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
411+
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
412+
bool)>();
403413

404414
auto work = op.call(
405415
outputTensors,
406416
inputTensors,
407-
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
417+
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
418+
opts.asyncOp);
408419

409420
if (c10d::allow_inflight_collective_as_graph_input()) {
410421
for (const auto& tensor : outputTensors) {
@@ -425,12 +436,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
425436
const at::TensorList&,
426437
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
427438
int64_t,
439+
bool,
428440
int64_t)>();
429441
auto work = op.call(
430442
outputTensors,
431443
inputTensors,
432444
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
433445
opts.rootRank,
446+
opts.asyncOp,
434447
opts.timeout.count());
435448

436449
if (c10d::allow_inflight_collective_as_graph_input()) {
@@ -487,12 +500,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
487500
const std::vector<std::vector<at::Tensor>>&,
488501
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
489502
const c10::intrusive_ptr<::c10d::ReduceOp>&,
503+
bool,
490504
int64_t)>();
491505
auto work = std::get<1>(op.call(
492506
outputTensors,
493507
inputTensors,
494508
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
495509
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
510+
opts.asyncOp,
496511
opts.timeout.count()));
497512

498513
if (c10d::allow_inflight_collective_as_graph_input()) {
@@ -546,13 +561,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
546561
const at::TensorList,
547562
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
548563
const c10::intrusive_ptr<::c10d::ReduceOp>&,
564+
bool,
549565
int64_t)>();
550566

551567
auto work = op.call(
552568
outputTensors,
553569
inputTensors,
554570
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
555571
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
572+
opts.asyncOp,
556573
opts.timeout.count());
557574

558575
if (c10d::allow_inflight_collective_as_graph_input()) {
@@ -577,13 +594,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
577594
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
578595
std::vector<int64_t>,
579596
std::vector<int64_t>,
597+
bool,
580598
int64_t)>();
581599
auto work = op.call(
582600
outputBuffer,
583601
inputBuffer,
584602
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
585603
outputSplitSizes,
586604
inputSplitSizes,
605+
opts.asyncOp,
587606
opts.timeout.count());
588607

589608
if (c10d::allow_inflight_collective_as_graph_input()) {
@@ -604,11 +623,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
604623
const at::TensorList&,
605624
const at::TensorList&,
606625
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
626+
bool,
607627
int64_t)>();
608628
auto work = std::get<1>(op.call(
609629
outputTensors,
610630
inputTensors,
611631
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
632+
opts.asyncOp,
612633
opts.timeout.count()));
613634

614635
if (c10d::allow_inflight_collective_as_graph_input()) {
@@ -778,12 +799,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
778799
at::Tensor,
779800
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
780801
const std::vector<int64_t>&,
802+
bool,
781803
int64_t)>();
782804

783805
auto work = op.call(
784806
tensor,
785807
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
786808
opts.device_ids,
809+
opts.asyncOp,
787810
opts.timeout.count());
788811
if (c10d::allow_inflight_collective_as_graph_input()) {
789812
c10d::register_work(tensor, work);

0 commit comments

Comments
 (0)
0