8000 [ptd][nccl] provide a knob to use current-stream as nccl-stream (#147… · pytorch/pytorch@9bac479 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9bac479

Browse files
cenzhaometafacebook-github-bot
authored andcommitted
[ptd][nccl] provide a knob to use current-stream as nccl-stream (#147820)
Summary: PTD current workflow: - PTD creates its own dedicated `ncclStream` for comm operation - it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us). This diff: - introduces a new env `TORCH_NCCL_USE_CURRENT_STREAM_AS_NCCL_STREAM=1` - when it's specified, PTD uses current-stream as the nccl-stream and avoids stream sync this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%** Test Plan: - # AMD > before ``` [cenzhao@devgpu039.atn3 ~/fbsource/fbcode (2265d32f0)]$ buck2 run @//mode/opt-amd-gpu -c fbcode.split-dwarf=True //param_bench/train/comms/pt:launcher -- --launcher mpi --nnode 1 --collective all_reduce --b 20M --e 20M --data-type bfloat16 --backend nccl --n 100 --w 5 --envs "NCCL_DEBUG_FILE=/tmp/dedicated_log_rccl.%h.%p.log;NCCL_DEBUG=INFO;NCCL_DEBUG_SUBSYS=INIT,COLL;MSCCL_ALGO_DIR=/data/users/${USER}/fbsource/third-party/rccl/develop/tools/msccl-algorithms;RCCL_MSCCLPP_THRESHOLD=$((128*1024*1024));RCCL_MSCCLPP_ENABLE=1;TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK=1;" --size-start-profiler 20M ``` https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/dynocli/devgpu039.atn3.facebook.com/rank-0.Feb_24_16_19_28.354787.pt.trace.json.gz&bucket=hpc_traces {F1975408857} - c10d::allreduce_(69us) - cudaStreamSync (23us) - nccl::all_reduce(26us) > after ``` [cenzhao@devgpu039.atn3 ~/fbsource/fbcode (2265d32f0)]$ buck2 run @//mode/opt-amd-gpu -c fbcode.split-dwarf=True //param_bench/train/comms/pt:launcher -- --launcher mpi --nnode 1 --collective all_reduce --b 20M --e 20M --data-type bfloat16 --backend nccl --n 100 --w 5 --envs "NCCL_DEBUG_FILE=/tmp/dedicated_log_rccl.%h.%p.log;NCCL_DEBUG=INFO;NCCL_DEBUG_SUBSYS=INIT,COLL;MSCCL_ALGO_DIR=/data/users/${USER}/fbsource/third-party/rccl/develop/tools/msccl-algorithms;RCCL_MSCCLPP_THRESHOLD=$((128*1024*1024));RCCL_MSCCLPP_ENABLE=1;TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK=1;TORCH_NCCL_USE_CURRENT_STREAM_AS_NCCL_STREAM=1" --size-start-profiler 20M ``` https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/dynocli/devgpu039.atn3.facebook.com/rank-4.Feb_24_16_22_56.534269.pt.trace.json.gz&bucket=hpc_traces {F1975408962} - c10d:allreduce_(37us) - cudaStreamSync (gone) - nccl::all_reduce(20us) # NV > before ``` [cenzhao@devgpu019.prn3 ~/fbsource/fbcode (e3f64263c)]$ buck2 run @//mode/opt -c fbcode.split-dwarf=True //param_bench/train/comms/pt:launcher -- --launcher mpi --nnode 1 --collective all_reduce --b 20M --e 20M --data-type bfloat16 --backend nccl --n 100 --w 5 --envs "NCCL_DEBUG_FILE=/tmp/dedicated_log_rccl.%h.%p.log;NCCL_DEBUG=INFO;NCCL_DEBUG_SUBSYS=INIT,COLL;" --size-start-profiler 20M ``` https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/dynocli/devgpu019.prn3.facebook.com/rank-2.Feb_25_11_11_28.3328768.pt.trace.json.gz&bucket=hpc_traces {F1975437097} - c10d::allreduce_ (62us) - cudaStreamWait (0us) - nccl::all_reduce (47us) > after ``` [cenzhao@devgpu019.prn3 ~/fbsource/fbcode (e3f64263c)]$ buck2 run @//mode/opt -c fbcode.split-dwarf=True //param_bench/train/comms/pt:launcher -- --launcher mpi --nnode 1 --collective all_reduce --b 20M --e 20M --data-type bfloat16 --backend nccl --n 100 --w 5 --envs "NCCL_DEBUG_FILE=/tmp/dedicated_log_rccl.%h.%p.log;NCCL_DEBUG=INFO;NCCL_DEBUG_SUBSYS=INIT,COLL;TORCH_NCCL_USE_CURRENT_STREAM_AS_NCCL_STREAM=1" --size-start-profiler 20M ``` https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/dynocli/devgpu019.prn3.facebook.com/rank-4.Feb_25_11_17_05.3469865.pt.trace.json.gz&bucket=hpc_traces {F1975437192} - c10d::allreduce_ (62us) - cudaStreamWait (gone) - nccl:all_reduce (53us) Differential Revision: D70135605
1 parent 91e7c79 commit 9bac479

File tree

8 files changed

+54
-26
lines changed

8 files changed

+54
-26
lines changed

torch/_C/_distributed_c10d.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# mypy: disable-error-code="type-arg"
33
from datetime import timedelta
44
from enum import Enum
5-
from typing import Any, overload
5+
from typing import Any, overload, Optional
66

77
import torch
88
from torch import Tensor
@@ -134,6 +134,8 @@ class BroadcastOptions:
134134
class AllreduceOptions:
135135
reduceOp: ReduceOp
136136
timeout: timedelta
137+
asyncOp: bool
138+
sparseIndices: Optional[Tensor]
137139

138140
class AllreduceCoalescedOptions(AllreduceOptions): ...
139141

torch/csrc/distributed/c10d/Ops.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ TORCH_LIBRARY(c10d, m) {
1919
m.def(
2020
"broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
2121
m.def(
22-
"allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
22+
"allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, int timeout, bool asyncOp) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
2323
m.def(
24-
"allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work");
24+
"allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout, bool asyncOp) -> __torch__.torch.classes.c10d.Work");
2525
m.def(
2626
"allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
2727
m.def(
@@ -169,12 +169,13 @@ IMPL_BROADCAST(PrivateUse1)
169169
const c10::intrusive_ptr<ProcessGroup>& process_group, \
170170
const c10::intrusive_ptr<ReduceOp>& reduce_op, \
171171
const std::optional<at::Tensor>& sparse_indices, \
172-
int64_t timeout) { \
172+
int64_t timeout, \
173+
bool asyncOp) { \
173174
auto tensor_vec = tensors.vec(); \
174175
auto work = process_group->getBackend(c10::DeviceType::DEV) -> allreduce( \
175176
tensor_vec, \
176177
AllreduceOptions{ \
177-
*reduce_op.get(), std::chrono::milliseconds(timeout)}); \
178+
*reduce_op.get(), std::chrono::milliseconds(timeout), asyncOp}); \
178179
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
179180
std::move(tensor_vec), work); \
180181
}
@@ -188,11 +189,13 @@ IMPL_ALLREDUCE(PrivateUse1)
188189
at::TensorList tensors, \
189190
const c10::intrusive_ptr<ProcessGroup>& process_group, \
190191
const c10::intrusive_ptr<ReduceOp>& reduce_op, \
191-
int64_t timeout) { \
192+
int64_t timeout, \
193+
bool asyncOp) { \
192194
auto tensor_vec = tensors.vec(); \
193195
AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; \
194196
opts.reduceOp = *reduce_op.get(); \
195197
opts.timeout = std::chrono::milliseconds(timeout); \
198+
opts.asyncOp = asyncOp; \
196199
return process_group->getBackend(c10::DeviceType::DEV) \
197200
->allreduce_coalesced(tensor_vec, opts); \
198201
}
@@ -464,14 +467,16 @@ allreduce_sparse_cuda_(
464467
const c10::intrusive_ptr<ProcessGroup>& process_group,
465468
const c10::intrusive_ptr<ReduceOp>& reduce_op,
466469
const std::optional<at::Tensor>& sparse_indices,
467-
int64_t timeout) {
470+
int64_t timeout,
471+
bool asyncOp) {
468472
auto tensor_vec = tensors.vec();
469473
auto work = process_group->getBackend(c10::DeviceType::CUDA)
470474
->allreduce_sparse(
471475
tensor_vec,
472476
AllreduceOptions{
473477
*reduce_op,
474478
std::chrono::milliseconds(timeout),
479+
asyncOp,
475480
sparse_indices});
476481

477482
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(

torch/csrc/distributed/c10d/ProcessGroup.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,14 +224,16 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
224224
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
225225
const c10::intrusive_ptr<::c10d::ReduceOp>&,
226226
const std::optional<at::Tensor>& sparse_indices,
227-
int64_t)>();
227+
int64_t,
228+
bool)>();
228229

229230
auto work = std::get<1>(op.call(
230231
tensors,
231232
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
232233
c10::make_intrusive<ReduceOp>(opts.reduceOp),
233234
opts.sparseIndices,
234-
opts.timeout.count()));
235+
opts.timeout.count(),
236+
opts.asyncOp));
235237

236238
if (c10d::allow_inflight_collective_as_graph_input()) {
237239
for (const auto& tensor : tensors) {
@@ -250,13 +252,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
250252
at::TensorList,
251253
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
252254
const c10::intrusive_ptr<::c10d::ReduceOp>&,
253-
int64_t)>();
255+
int64_t,
256+
bool)>();
254257

255258
auto work = op.call(
256259
tensors,
257260
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
258261
c10::make_intrusive<ReduceOp>(opts.reduceOp),
259-
opts.timeout.count());
262+
opts.timeout.count(),
263+
opts.asyncOp);
260264

261265
if (c10d::allow_inflight_collective_as_graph_input()) {
262266
for (const auto& tensor : tensors) {

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3238,7 +3238,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
32383238
OpType opType,
32393239
const char* profilingTitle,
32403240
bool avoidRecordStreams,
3241-
bool nanCheck) {
3241+
bool nanCheck,
3242+
bool asyncOp) {
32423243
// Environment setting by the user may add onto collective call's option
32433244
avoidRecordStreams |= avoidRecordStreams_;
32443245
nanCheck &= enableNanCheck_;
@@ -3283,11 +3284,15 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
32833284
}
32843285
}
32853286

3286-
// Used many times below, so we stash the unordered_map lookup
3287-
auto ncclStream = ncclStreams_.at(key);
3288-
3289-
// First let NCCL streams wait for input tensors allocation streams
3290-
syncStream(device, ncclEvents_[key], ncclStream);
3287+
// in asyncOp=false [default] mode, we use currentStream as ncclStream
3288+
// otherwise, we use separate ncclStream and let it sync on currentStream
3289+
auto ncclStream = at::cuda::getCurrentCUDAStream(device.index());
3290+
if (asyncOp) {
3291+
// Used many times below, so we stash the unordered_map lookup
3292+
ncclStream = ncclStreams_.at(key);
3293+
// First let NCCL streams wait for input tensors allocation streams
3294+
syncStream(device, ncclEvents_[key], ncclStream);
3295+
}
32913296

32923297
bool enqueue =
32933298
!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None;
@@ -3883,7 +3888,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
38833888
OpType opType,
38843889
const char* profilingTitle,
38853890
bool avoidRecordStreams,
3886-
bool nanCheck) {
3891+
bool nanCheck,
3892+
bool asyncOp) {
38873893
auto inputs = std::vector<at::Tensor>{input};
38883894
auto outputs = std::vector<at::Tensor>{output};
38893895
return collective(
@@ -3895,7 +3901,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
38953901
opType,
38963902
profilingTitle,
38973903
avoidRecordStreams,
3898-
nanCheck);
3904+
nanCheck,
3905+
asyncOp);
38993906
}
39003907

39013908
template <typename Fn>
@@ -3906,7 +3913,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
39063913
OpType opType,
39073914
const char* profilingTitle,
39083915
bool avoidRecordStreams,
3909-
bool nanCheck) {
3916+
bool nanCheck,
3917+
bool asyncOp) {
39103918
auto inputs = std::vector<at::Tensor>{input};
39113919
auto outputs = std::vector<at::Tensor>{output};
39123920
return collective(
@@ -3920,7 +3928,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
39203928
opType,
39213929
profilingTitle,
39223930
avoidRecordStreams,
3923-
nanCheck);
3931+
nanCheck,
3932+
asyncOp);
39243933
}
39253934

39263935
template <typename Fn>

torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
876876
OpType opType,
877877
const char* profilingTitle = nullptr,
878878
bool avoidRecordStreams = false,
879-
bool nanCheck = true);
879+
bool nanCheck = true,
880+
bool asyncOp = false);
880881

881882
template <typename Fn, typename PreProcess, typename PostProcess>
882883
c10::intrusive_ptr<Work> collective(
@@ -888,7 +889,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
888889
OpType opType,
889890
const char* profilingTitle = nullptr,
890891
bool avoidRecordStreams = false,
891-
bool nanCheck = true);
892+
bool nanCheck = true,
893+
bool asyncOp = false);
892894

893895
template <typename Fn, typename PreProcess, typename PostProcess>
894896
c10::intrusive_ptr<Work> collective(
@@ -900,7 +902,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
900902
OpType opType,
901903
const char* profilingTitle = nullptr,
902904
bool avoidRecordStreams = false,
903-
bool nanCheck = true);
905+
bool nanCheck = true,
906+
bool asyncOp = false);
904907

905908
template <typename Fn>
906909
c10::intrusive_ptr<Work> collectiveCoalesced(

torch/csrc/distributed/c10d/Types.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ struct BroadcastOptions {
122122
struct AllreduceOptions {
123123
ReduceOp reduceOp = ReduceOp::SUM;
124124
std::chrono::milliseconds timeout = kUnsetTimeout;
125+
bool asyncOp = false;
125126
std::optional<at::Tensor> sparseIndices = std::nullopt;
126127
};
127128

torch/csrc/distributed/c10d/init.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -999,13 +999,15 @@ This class does not support ``__members__`` property.)");
999999
py::class_<::c10d::AllreduceOptions>(module, "AllreduceOptions")
10001000
.def(py::init<>())
10011001
.def_readwrite("reduceOp", &::c10d::AllreduceOptions::reduceOp)
1002-
.def_readwrite("timeout", &::c10d::AllreduceOptions::timeout);
1002+
.def_readwrite("timeout", &::c10d::AllreduceOptions::timeout)
1003+
.def_readwrite("asyncOp", &::c10d::AllreduceOptions::asyncOp);
10031004

10041005
py::class_<::c10d::AllreduceCoalescedOptions>(
10051006
module, "AllreduceCoalescedOptions")
10061007
.def(py::init<>())
10071008
.def_readwrite("reduceOp", &::c10d::AllreduceCoalescedOptions::reduceOp)
1008-
.def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout);
1009+
.def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout)
1010+
.def_readwrite("asyncOp", &::c10d::AllreduceCoalescedOptions::asyncOp);
10091011

10101012
py::class_<::c10d::ReduceOptions>(module, "ReduceOptions")
10111013
.def(py::init<>())

torch/distributed/distributed_c10d.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2806,6 +2806,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
28062806

28072807
opts = AllreduceOptions()
28082808
opts.reduceOp = op
2809+
opts.asyncOp = async_op
28092810
if group is None:
28102811
group = _get_default_group()
28112812

@@ -2882,6 +2883,7 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
28822883

28832884
opts = AllreduceCoalescedOptions()
28842885
opts.reduceOp = op
2886+
opts.asyncOp = async_op
28852887
group = group or _get_default_group()
28862888
work = group.allreduce_coalesced(tensors, opts)
28872889

0 commit comments

Comments
 (0)
0