8000 Stash tensors for reduce_scatter_v and all_gather_v by kwen2501 · Pull Request #150332 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Stash tensors for reduce_scatter_v and all_gather_v #150332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions test/distributed/test_c10d_ops_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,32 @@ def reduce_scatter_base(output_t, input_t):
# fails the check because the dtype is different
reduce_scatter_base(output_t, tensor)

@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_reduce_scatter_v(self):
device = torch.device("cuda", self.rank_to_GPU[self.rank][0])
# A list of tensors with different sizes
input_list = [torch.ones(i, device=device) for i in range(self.world_size)]
# The i-th output should have size i
output = torch.zeros(self.rank, device=device)
work = c10d.reduce_scatter(output, input_list, group=self.pg, async_op=True)
expected = torch.ones(self.rank, device=device) * self.world_size
work.wait()
self.assertEqual(expected, output)

@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_all_gather_v(self):
device = torch.device("cuda", self.rank_to_GPU[self.rank][0])
# A list of tensors with different sizes
output_list = [torch.zeros(i, device=device) for i in range(self.world_size)]
# The i-th input has size i, filled with value i
input = torch.ones(self.rank, device=device) * self.rank
work = c10d.all_gather(output_list, input, group=self.pg, async_op=True)
expected = [torch.ones(i, device=device) * i for i in range(self.world_size)]
work.wait()
self.assertEqual(expected, output_list)

@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_reduce_scatter_ops(self):
Expand Down
16 changes: 14 additions & 2 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3172,6 +3172,7 @@ void ProcessGroupNCCL::startCoalescing() {

coalescedDevice_.set_index(-1);
coalescedComm_ = nullptr;
coalescedTensors_.clear();
coalescing_state_ |= CoalActive;
groupStart();
}
Expand Down Expand Up @@ -3217,6 +3218,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing(OpType optype) {
work->store_ = store_;
assignTimeoutToWork(work, options_);

// Hand over references to tensors during coalescing to work's stash
work->stashTensors(coalescedTensors_);

// Record start before ncclGroupEnd
if (work->timingEnabled_) {
work->ncclStartEvent_->record(ncclStream);
Expand All @@ -3239,6 +3243,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing(OpType optype) {
// Reset coalescing state
coalescing_state_ = 0;
coalescedComm_ = nullptr;
coalescedTensors_.clear();
// If in async mode, return work; otherwise, kernel is enqueued on current
// stream, no need to return work
return coalescedAsync_ ? work : nullptr;
Expand Down Expand Up @@ -3325,8 +3330,15 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
// stream, we don't need to do anything for tensor lifetime management.
// Otherwise, we need to stage the tensors will `work.wait()`.
if (asyncOp) {
work->stashTensors(inputs);
work->stashTensors(outputs);
if (coalescing_state_) {
coalescedTensors_.insert(
coalescedTensors_.end(), inputs.begin(), inputs.end());
coalescedTensors_.insert(
coalescedTensors_.end(), outputs.begin(), outputs.end());
} else {
work->stashTensors(inputs);
work->stashTensors(outputs);
}
}

if (nanCheck) {
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Whether the coalesced calls are sync or async.
bool coalescedAsync_;

// keeps track of input and output tensors when coalescing is in flight. Will
// hand over these tensors to WorkNCCL's stash when coalescing is ended.
std::vector<at::Tensor> coalescedTensors_;

// Whether or not wait() and synchronize() are blocking operations that wait
// for the operation to complete.
bool blockingWait_ = false;
Expand Down
Loading
0