10000 [Reland] Launch kernel on current stream & remove `record_stream` entirely by kwen2501 · Pull Request #150398 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Reland] Launch kernel on current stream & remove record_stream entirely #150398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,9 @@ class TestDebugInfoWriter : public c10d::DebugInfoWriter {
};

TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
// Note (kwen2501) 03/07/2025
// TODO: re-enable
GTEST_SKIP() << "Skipping test as the trace write seems unstable.";
int heartBeatIntervalInSec = 2;
std::string timeInterval = std::to_string(heartBeatIntervalInSec);
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0);
Expand Down
26 changes: 26 additions & 0 deletions test/distributed/test_c10d_ops_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,32 @@ def reduce_scatter_base(output_t, input_t):
# fails the check because the dtype is different
reduce_scatter_base(output_t, tensor)

@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_reduce_scatter_v(self):
device = torch.device("cuda", self.rank_to_GPU[self.rank][0])
# A list of tensors with different sizes
input_list = [torch.ones(i, device=device) for i in range(self.world_size)]
# The i-th output should have size i
output = torch.zeros(self.rank, device=device)
work = c10d.reduce_scatter(output, input_list, group=self.pg, async_op=True)
expected = torch.ones(self.rank, device=device) * self.world_size
work.wait()
self.assertEqual(expected, output)

@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_all_gather_v(self):
device = torch.device("cuda", self.rank_to_GPU[self.rank][0])
# A list of tensors with different sizes
output_list = [torch.zeros(i, device=device) for i in range(self.world_size)]
# The i-th input has size i, filled with value i
input = torch.ones(self.rank, device=device) * self.rank
work = c10d.all_gather(output_list, input, group=self.pg, async_op=True)
expected = [torch.ones(i, device=device) * i for i in range(self.world_size)]
work.wait()
self.assertEqual(expected, output_list)

@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_reduce_scatter_ops(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@
("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)),
("aten::all_gather_into_tensor", datetime.date(9999, 1, 30)),
("aten::all_reduce", datetime.date(9999, 1, 30)),
# These ops are defined in torch/csrc/distributed/c10d/Ops.cpp
# TODO: add back restriction when c10d ops can be exported
("c10d::.*", datetime.date(9999, 1, 1)),
]

ALLOW_LIST_COMPILED = [
Expand Down
8 changes: 7 additions & 1 deletion torch/_C/_distributed_c10d.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# mypy: disable-error-code="type-arg"
from datetime import timedelta
from enum import Enum
from typing import Any, overload
from typing import Any, Optional, overload

import torch
from torch import Tensor
Expand Down Expand Up @@ -139,6 +139,8 @@ class BroadcastOptions:
class AllreduceOptions:
reduceOp: ReduceOp
timeout: timedelta
asyncOp: bool
sparseIndices: Optional[Tensor]

class AllreduceCoalescedOptions(AllreduceOptions): ...

Expand All @@ -147,6 +149,7 @@ class ReduceOptions:
rootRank: int
rootTensor: int
timeout: timedelta
asyncOp: bool

class AllgatherOptions:
timeout: timedelta
Expand All @@ -155,6 +158,7 @@ class AllgatherOptions:
class GatherOptions:
rootRank: int
timeout: timedelta
asyncOp: bool

class ScatterOptions:
rootRank: int
Expand All @@ -170,9 +174,11 @@ class BarrierOptions:
device_ids: list[int]
device: torch.device
timeout: timedelta
asyncOp: bool

class AllToAllOptions:
timeout: timedelta
asyncOp: bool

class Store:
def set(self, key: str, value: str): ...
Expand Down
130 changes: 79 additions & 51 deletions torch/csrc/distributed/c10d/Ops.cpp

Large diffs are not rendered by default.

43 changes: 33 additions & 10 deletions torch/csrc/distributed/c10d/ProcessGroup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
const std::optional<at::Tensor>& sparse_indices,
bool,
int64_t)>();

auto work = std::get<1>(op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.sparseIndices,
opts.asyncOp,
opts.timeout.count()));

if (c10d::allow_inflight_collective_as_graph_input()) {
Expand All @@ -250,12 +252,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>();

auto work = op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count());

if (c10d::allow_inflight_collective_as_graph_input()) {
Expand All @@ -277,13 +281,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t,
int64_t,
bool,
int64_t)>();
auto work = op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.rootRank,
opts.rootTensor,
opts.asyncOp,
opts.timeout.count());

if (c10d::allow_inflight_collective_as_graph_input()) {
Expand All @@ -306,12 +312,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::vector<std::vector<at::Tensor>>&,
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool,
int64_t)>();

auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp,
opts.timeout.count()));

if (c10d::allow_inflight_collective_as_graph_input()) {
Expand Down Expand Up @@ -363,18 +371,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
std::vector<std::vector<at::Tensor>>& outputTensorLists,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::allgather_coalesced_", "")
.typed<c10::intrusive_ptr<Work>(
const std::vector<std::vector<at::Tensor>>&,
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::allgather_coalesced_", "")
.typed<c10::intrusive_ptr<Work>(
const std::vector<std::vector<at::Tensor>>&,
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool)>();

auto work = op.call(
outputTensorLists,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp);

if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor_list : outputTensorLists) {
Expand All @@ -399,12 +408,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
.typed<c10::intrusive_ptr<Work>(
const at::TensorList,
const at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool)>();

auto work = op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp);

if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : outputTensors) {
Expand All @@ -425,12 +436,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
bool,
int64_t)>();
auto work = op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank,
opts.asyncOp,
opts.timeout.count());

if (c10d::allow_inflight_collective_as_graph_input()) {
Expand Down Expand Up @@ -487,12 +500,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::vector<std::vector<at::Tensor>>&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count()));

if (c10d::allow_inflight_collective_as_graph_input()) {
Expand Down Expand Up @@ -546,13 +561,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>();

auto work = op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count());

if (c10d::allow_inflight_collective_as_graph_input()) {
Expand All @@ -577,13 +594,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
std::vector<int64_t>,
std::vector<int64_t>,
bool,
int64_t)>();
auto work = op.call(
outputBuffer,
inputBuffer,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
outputSplitSizes,
inputSplitSizes,
opts.asyncOp,
opts.timeout.count());

if (c10d::allow_inflight_collective_as_graph_input()) {
Expand All @@ -604,11 +623,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList&,
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp,
opts.timeout.count()));

if (c10d::allow_inflight_collective_as_graph_input()) {
Expand Down Expand Up @@ -778,12 +799,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
at::Tensor,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const std::vector<int64_t>&,
bool,
int64_t)>();

auto work = op.call(
tensor,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.device_ids,
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::register_work(tensor, work);
Expand Down
Loading
Loading
0