@@ -254,16 +254,12 @@ class ProcessGroupNCCLErrorsTest : public ::testing::Test {
254
254
void SetUp () override {
255
255
// Enable LOG(INFO) messages.
256
256
c10::initLogging ();
257
- size_t numDevices = cudaNumDevices ();
257
+ size_t numDevices = 1 ; // One device per rank (thread)
258
258
TemporaryFile file;
259
259
store_ = c10::make_intrusive<::c10d::FileStore>(file.path , 1 );
260
260
261
- at::cuda::OptionalCUDAGuard deviceGuard;
262
261
tensors_.resize (numDevices);
263
- for (const auto i : c10::irange (numDevices)) {
264
- deviceGuard.set_index (i);
265
- tensors_[i] = at::ones ({3 , 3 }, at::kCUDA );
266
- }
262
+ tensors_[0 ] = at::empty ({3 , 3 }, at::kCUDA );
267
263
}
268
264
269
265
void TearDown () override {
@@ -286,7 +282,6 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) {
286
282
287
283
auto work = pg.allreduce (tensors_);
288
284
work->wait ();
289
- EXPECT_TRUE (work->isSuccess ());
290
285
EXPECT_EQ (1 , pg.getNCCLCommCacheSize ());
291
286
292
287
// Now run all reduce with errors.
@@ -296,7 +291,6 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) {
296
291
297
292
// Verify the work item failed.
298
293
EXPECT_TRUE (work->isCompleted ());
299
- EXPECT_FALSE (work->isSuccess ());
300
294
EXPECT_THROW (work->wait (), std::runtime_error);
301
295
302
296
// Communicators might be aborted here, further operations would fail.
@@ -314,7 +308,6 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) {
314
308
315
309
auto work = pg.allreduce (tensors_);
316
310
work->wait ();
317
- EXPECT_TRUE (work->isSuccess ());
318
311
EXPECT_EQ (1 , pg.getNCCLCommCacheSize ());
319
312
320
313
// Now run all reduce with errors.
@@ -336,7 +329,6 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) {
336
329
337
330
auto work = pg.allreduce (tensors_);
338
331
pg.barrier ()->wait ();
339
- EXPECT_TRUE (work->isSuccess ());
340
332
EXPECT_EQ (1 , pg.getNCCLCommCacheSize ());
341
333
342
334
// Now run all reduce with errors.
@@ -347,10 +339,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) {
347
339
work->wait ();
348
340
pg.barrier ()->wait ();
349
341
350
- // Verify the work item failed.
351
342
EXPECT_TRUE (work->isCompleted ());
352
- EXPECT_FALSE (work->isSuccess ());
353
-
354
343
// Communicators might be aborted here, further operations would fail.
355
344
}
356
345
@@ -426,7 +415,6 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
426
415
// Normal collective case.
427
416
auto work = pg.allreduce (tensors_);
428
417
work->wait ();
429
- EXPECT_TRUE (work->isSuccess ());
430
418
431
419
work = pg.allreduce (tensors_);
432
420
{
@@ -440,7 +428,6 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
440
428
EXPECT_TRUE (pg.getErrorCaughtFlag ());
441
429
}
442
430
work->wait ();
443
- EXPECT_TRUE (work->isSuccess ());
444
431
EXPECT_TRUE (traces.size () > 0 );
445
432
auto filename = c10::str (tempFilename, 0 );
446
433
auto traceFromStorage = readTraceFromFile (filename, traces.size ());
0 commit comments