10000 [c10d] PGNCCL refactor part 2: Simplify ProcessGroupNCCL into single-… · pytorch/pytorch@f3e7d80 · GitHub
[go: up one dir, main page]

Skip to content

Commit f3e7d80

Browse files
kwen2501pytorchmergebot
authored andcommitted
[c10d] PGNCCL refactor part 2: Simplify ProcessGroupNCCL into single-device style (#119421)
Part 2 and last part of #118674: Introduce actual "single-device" code change to ProcessGroupNCCL. assert size == 1 and test refactor have been done in #119099. Pull Request resolved: #119421 Approved by: https://github.com/shuqiangzhang
1 parent 0597dab commit f3e7d80

File tree

8 files changed

+828
-1126
lines changed

8 files changed

+828
-1126
lines changed

test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,18 @@ constexpr int kNcclErrorHandlingVersion = 2400;
2020
class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
2121
public:
2222
WorkNCCLSimulateErrors(
23-
const std::vector<at::Device>& devices,
23+
at::Device& device,
2424
bool simulate_error,
2525
int rank,
2626
c10d::OpType opType,
2727
uint64_t seq)
28-
: WorkNCCL(devices, rank, opType, seq), simulateError_(simulate_error) {}
28+
: WorkNCCL(device, rank, opType, seq), simulateError_(simulate_error) {}
2929

30-
std::exception_ptr checkForNCCLErrors(
31-
const std::vector<std::shared_ptr<c10d::NCCLComm>>& ncclComms)
32-
const override {
30+
std::exception_ptr checkForNCCLErrors() override {
3331
if (simulateError_) {
3432
return std::make_exception_ptr(std::runtime_error("Error"));
3533
}
36-
return c10d::ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors(ncclComms);
34+
return c10d::ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors();
3735
}
3836

3937
private:
@@ -50,11 +48,11 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
5048
: ProcessGroupNCCL(store, rank, size, opts), simulateError_(false) {}
5149

5250
std::exception_ptr checkForNCCLErrors(
53-
const std::vector<std::shared_ptr<c10d::NCCLComm>>& ncclComms) override {
51+
std::shared_ptr<c10d::NCCLComm>& ncclComm) override {
5452
if (simulateError_) {
5553
return std::make_exception_ptr(std::runtime_error("Error"));
5654
}
57-
return c10d::ProcessGroupNCCL::checkForNCCLErrors(ncclComms);
55+
return c10d::ProcessGroupNCCL::checkForNCCLErrors(ncclComm);
5856
}
5957

6058
std::chrono::duration<int64_t, std::milli> getWatchdogSleepInterval() {
@@ -63,14 +61,14 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
6361
}
6462

6563
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
66-
std::vector<at::Device> devices,
64+
at::Device& device,
6765
int rank,
6866
c10d::OpType opType,
6967
const char* profilingTitle,
7068
const std::vector<at::Tensor>& inputs = {},
7169
const std::vector<at::Tensor>& outputs = {}) override {
7270
return c10::make_intrusive<WorkNCCLSimulateErrors>(
73-
devices, simulateError_, rank, opType, seq_);
71+
device, simulateError_, rank, opType, seq_);
7472
}
7573

7674
size_t getNCCLCommCacheSize() {
@@ -92,12 +90,12 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
9290
class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
9391
public:
9492
WorkNCCLTimedoutErrors(
95-
const std::vector<at::Device>& devices,
93+
at::Device& device,
9694
bool set_timedout_error,
9795
int rank,
9896
c10d::OpType opType,
9997
uint64_t seq)
100-
: WorkNCCL(devices, rank, opType, seq),
98+
: WorkNCCL(device, rank, opType, seq),
10199
setTimedoutError_(set_timedout_error) {}
102100

103101
private:
@@ -124,14 +122,14 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors {
124122
setTimedoutError_(false) {}
125123

126124
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
127-
std::vector<at::Device> devices,
125+
at::Device& device,
128126
int rank,
129127
c10d::OpType opType,
130128
const char* profilingTitle,
131129
const std::vector<at::Tensor>& inputs = {},
132130
const std::vector<at::Tensor>& outputs = {}) override {
133131
return c10::make_intrusive<WorkNCCLTimedoutErrors>(
134-
devices, setTimedoutError_, rank, opType, seq_);
132+
device, setTimedoutError_, rank, opType, seq_);
135133
}
136134

137135
void setTimedoutError() {

test/distributed/test_c10d_nccl.py

Lines changed: 10 additions & 4 deletions
F438
Original file line numberDiff line numberDiff line change
@@ -2947,6 +2947,10 @@ def world_size(self):
29472947
def blocking_wait_error_msg(self):
29482948
return "timeout"
29492949

2950+
@property
2951+
def remote_error_msg(self):
2952+
return "remote process exit"
2953+
29502954
def _run_all_reduce(self, pg):
29512955
pg.allreduce(torch.rand(10).cuda(self.rank))
29522956

@@ -2995,8 +2999,9 @@ def _test_nccl_errors_blocking(self, func):
29952999
process_group.allreduce(torch.rand(10).cuda(self.rank))
29963000
if self.rank == 0:
29973001
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
2998-
with self.assertRaisesRegex(dist.DistBackendError, self.blocking_wait_error_msg):
2999-
# Operation would time out in blocking mode.
3002+
with self.assertRaisesRegex(dist.DistBackendError, self.remote_error_msg):
3003+
# Previously this should timeout; but with newer NCCL version,
3004+
# it seems NCCL would detect that the peer rank has exited
30003005
work.wait(timeout=timedelta(seconds=self.op_timeout_sec))
30013006
# Run some GPU operations to make sure cuda has not gotten stuck.
30023007
# It was observed cuda could get stuck if NCCL communicators were
@@ -3064,8 +3069,9 @@ def test_nccl_blocking_wait_with_barrier(self):
30643069
)
30653070
process_group.barrier().wait()
30663071
if self.rank == 0:
3067-
with self.assertRaisesRegex(dist.DistBackendError, self.blocking_wait_error_msg):
3068-
# This should timeout
3072+
with self.assertRaisesRegex(dist.DistBackendError, self.remote_error_msg):
3073+
# Previously this should timeout; but with newer NCCL version,
3074+
# it seems NCCL would detect that the peer rank has exited
30693075
process_group.barrier().wait(timeout=timedelta(seconds=self.op_timeout_sec))
30703076

30713077
def _run_invalid_nccl_blocking_wait_env(self, val):

torch/csrc/cuda/nccl.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -415,20 +415,18 @@ AutoNcclGroup::AutoNcclGroup() {
415415
(c10::cuda::getFreeMutex())->lock();
416416
#endif
417417
comm_nonblocking_ = false;
418+
comm_ = nullptr;
418419
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
419420
detail::NCCL_CHECK(ncclGroupStart());
420421
#endif
421422
}
422423

423-
AutoNcclGroup::AutoNcclGroup(
424-
std::vector<ncclComm_t>& comms,
425-
bool comm_nonblocking) {
424+
AutoNcclGroup::AutoNcclGroup(ncclComm_t comm, bool comm_nonblocking) {
426425
#if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)
427426
// nccl < 2.0 cannot be called concurrently with cudaFree
428427
(c10::cuda::getFreeMutex())->lock();
429428
#endif
430-
// TODO(eqy): can we make comms_ reference?
431-
comms_ = comms;
429+
comm_ = comm;
432430
comm_nonblocking_ = comm_nonblocking;
433431
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
434432
detail::NCCL_CHECK(ncclGroupStart());
@@ -437,10 +435,10 @@ AutoNcclGroup::AutoNcclGroup(
437435

438436
AutoNcclGroup::~AutoNcclGroup() noexcept(false) {
439437
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
440-
if (!comm_nonblocking_) {
441-
detail::NCCL_CHECK(ncclGroupEnd());
438+
if (comm_nonblocking_ && comm_ != nullptr) {
439+
detail::NCCL_CHECK_TIMEOUT(ncclGroupEnd(), comm_);
442440
} else {
443-
detail::NCCL_CHECK_TIMEOUT(ncclGroupEnd(), comms_);
441+
detail::NCCL_CHECK(ncclGroupEnd());
444442
}
445443
#endif
446444
#if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)

torch/csrc/cuda/nccl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ enum class ncclDataType {
7676
// manages group and lock lifetimes.
7777
struct AutoNcclGroup {
7878
AutoNcclGroup();
79-
AutoNcclGroup(std::vector<ncclComm_t>& comms, bool comm_nonblocking);
79+
AutoNcclGroup(ncclComm_t comm, bool comm_nonblocking);
8080
~AutoNcclGroup() noexcept(false);
81-
std::vector<ncclComm_t> comms_;
81+
ncclComm_t comm_;
8282
bool comm_nonblocking_;
8383
};
8484

torch/csrc/distributed/c10d/NCCLUtils.hpp

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -126,37 +126,32 @@
126126
TORCH_CHECK_WITH(DistBackendError, false, err); \
127127
}
128128

129-
#define C10D_NCCL_CHECK_TIMEOUT_GROUPEND(cmd, comms_, failureReason) \
130-
ncclResult_t state = cmd; \
131-
auto startTimepoint = std::chrono::steady_clock::now(); \
132-
if (state == ncclInProgress) { \
133-
for (const auto i : c10::irange(comms_.size())) { \
134-
do { \
135-
if (nccl_nonblocking_timeout() > 0) { \
136-
auto currentTimepoint = std::chrono::steady_clock::now(); \
137-
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>( \
138-
currentTimepoint - startTimepoint) \
139-
.count(); \
140-
if (timeElapsed > nccl_nonblocking_timeout()) { \
141-
std::string err = "NCCL timeout in: " + std::string(__FILE__) + \
142-
":" + std::to_string(__LINE__) + ", " + \
143-
ncclGetErrorWithVersion(state) + "\n" + \
144-
getNcclErrorDetailStr(state, failureReason); \
145-
TORCH_CHECK_WITH(DistBackendError, false, err); \
146-
} \
147-
} \
148-
ncclCommGetAsyncError(comms_[i]->getNcclComm(), &state); \
149-
} while (state == ncclInProgress); \
150-
if (state != ncclSuccess) { \
151-
break; /* fall through to failed case */ \
152-
} \
153-
} \
154-
} \
155-
if (state != ncclSuccess) { \
156-
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
157-
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) + \
158-
"\n" + getNcclErrorDetailStr(state, failureReason); \
159-
TORCH_CHECK_WITH(DistBackendError, false, err); \
129+
#define C10D_NCCL_CHECK_TIMEOUT_GROUPEND(cmd, comm, failureReason) \
130+
ncclResult_t state = cmd; \
131+
auto startTimepoint = std::chrono::steady_clock::now(); \
132+
if (state == ncclInProgress) { \
133+
do { \
134+
if (nccl_nonblocking_timeout() > 0) { \
135+
auto currentTimepoint = std::chrono::steady_clock::now(); \
136+
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>( \
137+
currentTimepoint - startTimepoint) \
138+
.count(); \
139+
if (timeElapsed > nccl_nonblocking_timeout()) { \
140+
std::string err = "NCCL timeout in: " + std::string(__FILE__) + \
141+
":" + std::to_string(__LINE__) + ", " + \
142+
ncclGetErrorWithVersion(state) + "\n" + \
143+
getNcclErrorDetailStr(state, failureReason); \
144+
TORCH_CHECK_WITH(DistBackendError, false, err); \
145+
} \
146+
} \
147+
ncclCommGetAsyncError(comm->getNcclComm(), &state); \
148+
} while (state == ncclInProgress); \
149+
} \
150+
< 6006 span class="pl-k">if (state != ncclSuccess) { \
151+
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
152+
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) + \
153+
"\n" + getNcclErrorDetailStr(state, failureReason); \
154+
TORCH_CHECK_WITH(DistBackendError, false, err); \
160155
}
161156

162157
// Macro to print and abort on a non-successful NCCL return value.

0 commit comments

Comments
 (0)
0