8000 Update on "[PGNCCL] Launch kernel on current stream & remove `record_… · pytorch/pytorch@6a5be4e · GitHub
[go: up one dir, main page]

Skip to content

Commit 6a5be4e

Browse files
committed
Update on "[PGNCCL] Launch kernel on current stream & remove record_stream entirely"
This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately are related): 1. When async_op=False, we directly launch the collective on "current" stream, instead of a trampoline stream and join back. - Resolves #147729 - Resolves #146881 - Also 8000 saves two event syncs (which have overhead in case of HIP) and one pybind when we call `work.wait()` in distributed_c10d.py on behalf of user. 2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling. - Resolves #147168 3. Remove tensor life management when async_op=False; only use it when async_op=True. 4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](#147168 (comment)). 5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels. Joint work with cenzhaometa who wants to remove the event sync overhead. Cc: ngimel awgu Aidyn-A skyw wconstab leonardo0lyj cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
2 parents a424cd1 + 13fc2ac commit 6a5be4e

File tree

3 files changed

+44
-2
lines changed

3 files changed

+44
-2
lines changed

test/distributed/test_c10d_ops_nccl.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,32 @@ def reduce_scatter_base(output_t, input_t):
733733
# fails the check because the dtype is different
734734
reduce_scatter_base(output_t, tensor)
735735

736+
@requires_nccl()
737+
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
738+
def test_reduce_scatter_v(self):
739+
device = torch.device("cuda", self.rank_to_GPU[self.rank][0])
740+
# A list of tensors with different sizes
741+
input_list = [torch.ones(i, device=device) for i in range(self.world_size)]
742+
# The i-th output should have size i
743+
output = torch.zeros(self.rank, device=device)
744+
work = c10d.reduce_scatter(output, input_list, group=self.pg, async_op=True)
745+
expected = torch.ones(self.rank, device=device) * self.world_size
746+
work.wait()
747+
self.assertEqual(expected, output)
748+
749+
@requires_nccl()
750+
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
751+
def test_all_gather_v(self):
752+
device = torch.device("cuda", self.rank_to_GPU[self.rank][0])
753+
# A list of tensors with different sizes
754+
output_list = [torch.zeros(i, device=device) for i in range(self.world_size)]
755+
# The i-th input has size i, filled with value i
756+
input = torch.ones(self.rank, device=device) * self.rank
757+
work = c10d.all_gather(output_list, input, group=self.pg, async_op=True)
758+
expected = [torch.ones(i, device=device) * i for i in range(self.world_size)]
759+
work.wait()
760+
self.assertEqual(expected, output_list)
761+
736762
@requires_nccl()
737763
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
738764
def test_reduce_scatter_ops(self):

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3172,6 +3172,7 @@ void ProcessGroupNCCL::startCoalescing() {
31723172

31733173
coalescedDevice_.set_index(-1);
31743174
coalescedComm_ = nullptr;
3175+
coalescedTensors_.clear();
31753176
coalescing_state_ |= CoalActive;
31763177
groupStart();
31773178
}
@@ -3217,6 +3218,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing(OpType optype) {
32173218
work->store_ = store_;
32183219
assignTimeoutToWork(work, options_);
32193220

3221+
// Hand over references to tensors during coalescing to work's stash
3222+
work->stashTensors(coalescedTensors_);
3223+
32203224
// Record start before ncclGroupEnd
32213225
if (work->timingEnabled_) {
32223226
work->ncclStartEvent_->record(ncclStream);
@@ -3239,6 +3243,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing(OpType optype) {
32393243
// Reset coalescing state
32403244
coalescing_state_ = 0;
32413245
coalescedComm_ = nullptr;
3246+
coalescedTensors_.clear();
32423247
// If in async mode, return work; otherwise, kernel is enqueued on current
32433248
// stream, no need to return work
32443249
return coalescedAsync_ ? work : nullptr;
@@ -3325,8 +3330,15 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
33253330
// stream, we don't need to do anything for tensor lifetime management.
33263331
// Otherwise, we need to stage the tensors will `work.wait()`.
33273332
if (asyncOp) {
3328-
work->stashTensors(inputs);
3329-
work->stashTensors(outputs);
3333+
if (coalescing_state_) {
3334+
coalescedTensors_.insert(
3335+
coalescedTensors_.end(), inputs.begin(), inputs.end());
3336+
coalescedTensors_.insert(
3337+
coalescedTensors_.end(), outputs.begin(), outputs.end());
3338+
} else {
3339+
work->stashTensors(inputs);
3340+
work->stashTensors(outputs);
3341+
}
33303342
}
33313343

33323344
if (nanCheck) {

torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
12321232
// Whether the coalesced calls are sync or async.
12331233
bool coalescedAsync_;
12341234

1235+
// keeps track of input and output tensors when coalescing is in flight. Will
1236+
// hand over these tensors to WorkNCCL's stash when coalescing is ended.
1237+
std::vector<at::Tensor> coalescedTensors_;
1238+
12351239
// Whether or not wait() and synchronize() are blocking operations that wait
12361240
// for the operation to complete.
12371241
bool blockingWait_ = false;

0 commit comments

Comments
 (0)
0