8000 [c10d] PGNCCL refactor part 1: adds assert size==1 (#119099) · pytorch/pytorch@e1a0ad4 · GitHub
[go: up one dir, main page]

Skip to content

Commit e1a0ad4

Browse files
kwen2501clee2000
authored andcommitted
[c10d] PGNCCL refactor part 1: adds assert size==1 (#119099)
Breaking #118674 into multiple smaller PRs. This is the first one. It adds `assert size==1` to PGNCCL, and refactors some old tests written in multi-device style (which would otherwise fail at the assert). Pull Request resolved: #119099 Approved by: https://github.com/wconstab
1 parent b7e5a7b commit e1a0ad4

File tree

4 files changed

+161
-209
lines changed

4 files changed

+161
-209
lines changed

test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -254,16 +254,12 @@ class ProcessGroupNCCLErrorsTest : public ::testing::Test {
254254
void SetUp() override {
255255
// Enable LOG(INFO) messages.
256256
c10::initLogging();
257-
size_t numDevices = cudaNumDevices();
257+
size_t numDevices = 1; // One device per rank (thread)
258258
TemporaryFile file;
259259
store_ = c10::make_intrusive<::c10d::FileStore>(file.path, 1);
260260

261-
at::cuda::OptionalCUDAGuard deviceGuard;
262261
tensors_.resize(numDevices);
263-
for (const auto i : c10::irange(numDevices)) {
264-
deviceGuard.set_index(i);
265-
tensors_[i] = at::ones({3, 3}, at::kCUDA);
266-
}
262+
tensors_[0] = at::empty({3, 3}, at::kCUDA);
267263
}
268264

269265
void TearDown() override {
@@ -286,7 +282,6 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) {
286282

287283
auto work = pg.allreduce(tensors_);
288284
work->wait();
289-
EXPECT_TRUE(work->isSuccess());
290285
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
291286

292287
// Now run all reduce with errors.
@@ -296,7 +291,6 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) {
296291

297292
// Verify the work item failed.
298293
EXPECT_TRUE(work->isCompleted());
299-
EXPECT_FALSE(work->isSuccess());
300294
EXPECT_THROW(work->wait(), std::runtime_error);
301295

302296
// Communicators might be aborted here, further operations would fail.
@@ -314,7 +308,6 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) {
314308

315309
auto work = pg.allreduce(tensors_);
316310
work->wait();
317-
EXPECT_TRUE(work->isSuccess());
318311
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
319312

320313
// Now run all reduce with errors.
@@ -336,7 +329,6 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) {
336329

337330
auto work = pg.allreduce(tensors_);
338331
pg.barrier()->wait();
339-
EXPECT_TRUE(work->isSuccess());
340332
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
341333

342334
// Now run all reduce with errors.
@@ -347,10 +339,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) {
347339
work->wait();
348340
pg.barrier()->wait();
349341

350-
// Verify the work item failed.
351342
EXPECT_TRUE(work->isCompleted());
352-
EXPECT_FALSE(work->isSuccess());
353-
354343
// Communicators might be aborted here, further operations would fail.
355344
}
356345

@@ -426,7 +415,6 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
426415
// Normal collective case.
427416
auto work = pg.allreduce(tensors_);
428417
work->wait();
429-
EXPECT_TRUE(work->isSuccess());
430418

431419
work = pg.allreduce(tensors_);
432420
{
@@ -440,7 +428,6 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
440428
EXPECT_TRUE(pg.getErrorCaughtFlag());
441429
}
442430
work->wait();
443-
EXPECT_TRUE(work->isSuccess());
444431
EXPECT_TRUE(traces.size() > 0);
445432
auto filename = c10::str(tempFilename, 0);
446433
auto traceFromStorage = readTraceFromFile(filename, traces.size());

0 commit comments

Comments
 (0)
0