8000 [WIP][ptd][nccl] use current-stream as nccl-stream under async=False mode by cenzhaometa · Pull Request #147820 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[WIP][ptd][nccl] use current-stream as nccl-stream under async=False mode #147820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion torch/_C/_distributed_c10d.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# mypy: disable-error-code="type-arg"
from datetime import timedelta
from enum import Enum
from typing import Any, overload
from typing import Any, overload, Optional

import torch
from torch import Tensor
Expand Down Expand Up @@ -134,6 +134,8 @@ class BroadcastOptions:
class AllreduceOptions:
reduceOp: ReduceOp
timeout: timedelta
asyncOp: bool
sparseIndices: Optional[Tensor]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this line added by mistake?


class AllreduceCoalescedOptions(AllreduceOptions): ...

Expand Down
11 changes: 8 additions & 3 deletions torch/csrc/distributed/c10d/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ TORCH_LIBRARY(c10d, m) {
m.def(
"broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
m.def(
"allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
"allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
m.def(
"allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work");
"allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool asyncOp, int timeout) -> __torch__.torch.classes.c10d.Work");
m.def(
"allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
m.def(
Expand Down Expand Up @@ -169,12 +169,13 @@ IMPL_BROADCAST(PrivateUse1)
const c10::intrusive_ptr<ProcessGroup>& process_group, \
const c10::intrusive_ptr<ReduceOp>& reduce_op, \
const std::optional<at::Tensor>& sparse_indices, \
bool asyncOp, \
int64_t timeout) { \
auto tensor_vec = tensors.vec(); \
auto work = process_group->getBackend(c10::DeviceType::DEV) -> allreduce( \
tensor_vec, \
AllreduceOptions{ \
*reduce_op.get(), std::chrono::milliseconds(ti 8000 meout)}); \
*reduce_op.get(), std::chrono::milliseconds(timeout), asyncOp}); \
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
std::move(tensor_vec), work); \
}
Expand All @@ -188,11 +189,13 @@ IMPL_ALLREDUCE(PrivateUse1)
at::TensorList tensors, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
const c10::intrusive_ptr<ReduceOp>& reduce_op, \
bool asyncOp, \
int64_t timeout) { \
auto tensor_vec = tensors.vec(); \
AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; \
opts.reduceOp = *reduce_op.get(); \
opts.timeout = std::chrono::milliseconds(timeout); \
opts.asyncOp = asyncOp; \
return process_group->getBackend(c10::DeviceType::DEV) \
->allreduce_coalesced(tensor_vec, opts); \
}
Expand Down Expand Up @@ -464,6 +467,7 @@ allreduce_sparse_cuda_(
const c10::intrusive_ptr<ProcessGroup>& process_group,
const c10::intrusive_ptr<ReduceOp>& reduce_op,
const std::optional<at::Tensor>& sparse_indices,
bool asyncOp,
int64_t timeout) {
auto tensor_vec = tensors.vec();
auto work = process_group->getBackend(c10::DeviceType::CUDA)
Expand All @@ -472,6 +476,7 @@ allreduce_sparse_cuda_(
AllreduceOptions{
*reduce_op,
std::chrono::milliseconds(timeout),
asyncOp,
sparse_indices});

return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
const std::optional<at::Tensor>& sparse_indices,
bool,
int64_t)>();

auto work = std::get<1>(op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.sparseIndices,
opts.asyncOp,
opts.timeout.count()));

if (c10d::allow_inflight_collective_as_graph_input()) {
Expand All @@ -250,12 +252,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>();

auto work = op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count());

if (c10d::allow_inflight_collective_as_graph_input()) {
Expand Down
29 changes: 19 additions & 10 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3238,7 +3238,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
OpType opType,
const char* profilingTitle,
bool avoidRecordStreams,
bool nanCheck) {
bool nanCheck,
bool asyncOp) {
// Environment setting by the user may add onto collective call's option
avoidRecordStreams |= avoidRecordStreams_;
nanCheck &= enableNanCheck_;
Expand Down Expand Up @@ -3283,11 +3284,15 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
}
}

// Used many times below, so we stash the unordered_map lookup
auto ncclStream = ncclStreams_.at(key);

// First let NCCL streams wait for input tensors allocation streams
syncStream(device, ncclEvents_[key], ncclStream);
// in asyncOp=false [default] mode, we use currentStream as ncclStream
// otherwise, we use separate ncclStream and let it sync on currentStream
auto ncclStream = at::cuda::getCurrentCUDAStream(device.index());
if (asyncOp) {
// Used many times below, so we stash the unordered_map lookup
ncclStream = ncclStreams_.at(key);
// First let NCCL streams wait for input tensors allocation streams
syncStream(device, ncclEvents_[key], ncclStream);
}

bool enqueue =
!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None;
Expand Down Expand Up @@ -3883,7 +3888,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
OpType opType,
const char* profilingTitle,
bool avoidRecordStreams,
bool nanCheck) {
bool nanCheck,
bool asyncOp) {
auto inputs = std::vector<at::Tensor>{input};
auto outputs = std::vector<at::Tensor>{output};
return collective(
Expand All @@ -3895,7 +3901,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
opType,
profilingTitle,
avoidRecordStreams,
nanCheck);
nanCheck,
asyncOp);
}

template <typename Fn>
Expand All @@ -3906,7 +3913,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
OpType opType,
const char* profilingTitle,
bool avoidRecordStreams,
bool nanCheck) {
bool nanCheck,
bool asyncOp) {
auto inputs = std::vector<at::Tensor>{input};
auto outputs = std::vector<at::Tensor>{output};
return collective(
Expand All @@ -3920,7 +3928,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
opType,
profilingTitle,
avoidRecordStreams,
nanCheck);
nanCheck,
asyncOp);
}

template <typename Fn>
Expand Down
9 changes: 6 additions & 3 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
OpType opType,
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false,
bool nanCheck = true);
bool nanCheck = true,
bool asyncOp = false);

template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<Work> collective(
Expand All @@ -888,7 +889,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
OpType opType,
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false,
bool nanCheck = true);
bool nanCheck = true,
bool asyncOp = false);

template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<Work> collective(
Expand All @@ -900,7 +902,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
OpType opType,
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false,
bool nanCheck = true);
bool nanCheck = true,
bool asyncOp = false);

template <typename Fn>
c10::intrusive_ptr<Work> collectiveCoalesced(
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/distributed/c10d/Types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ struct BroadcastOptions {
struct AllreduceOptions {
ReduceOp reduceOp = ReduceOp::SUM;
std::chrono::milliseconds timeout = kUnsetTimeout;
bool asyncOp = false;
std::optional<at::Tensor> sparseIndices = std::nullopt;
};

Expand Down
6 changes: 4 additions & 2 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -999,13 +999,15 @@ This class does not support ``__members__`` property.)");
py::class_<::c10d::AllreduceOptions>(module, "AllreduceOptions")
.def(py::init<>())
.def_readwrite("reduceOp", &::c10d::AllreduceOptions::reduceOp)
.def_readwrite("timeout", &::c10d::AllreduceOptions::timeout);
.def_readwrite("timeout", &::c10d::AllreduceOptions::timeout)
.def_readwrite("asyncOp", &::c10d::AllreduceOptions::asyncOp);

py::class_<::c10d::AllreduceCoalescedOptions>(
module, "AllreduceCoalescedOptions")
.def(py::init<>())
.def_readwrite("reduceOp", &::c10d::AllreduceCoalescedOptions::reduceOp)
.def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout);
.def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout)
.def_readwrite("asyncOp", &::c10d::AllreduceCoalescedOptions::asyncOp);

py::class_<::c10d::ReduceOptions>(module, "ReduceOptions")
.def(py::init<>())
Expand Down
2 changes: 2 additions & 0 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2806,6 +2806,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):

opts = AllreduceOptions()
opts.reduceOp = op
opts.asyncOp = async_op
if group is None:
group = _get_default_group()

Expand Down Expand Up @@ -2882,6 +2883,7 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):

opts = AllreduceCoalescedOptions()
opts.reduceOp = op
opts.asyncOp = async_op
group = group or _get_default_group()
work = group.allreduce_coalesced(tensors, opts)

Expand Down
Loading
0