8000 [3/N] Set correct device to CUDA guards · pytorch/pytorch@bd9e7b0 · GitHub
[go: up one dir, main page]

Skip to content

Commit bd9e7b0

Browse files
committed
[3/N] Set correct device to CUDA guards
ghstack-source-id: f1a5b94 Pull Request resolved: #134357
1 parent 2e36eb5 commit bd9e7b0

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def test_close_pg(self):
350350
@parametrize("type", [torch.float16, torch.float32, torch.float64, torch.bfloat16])
351351
@skip_if_rocm
352352
def test_nan_assert(self, type):
353+
# Expecting a device-side error when NaN is detected
353354
os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
354355
store = c10d.FileStore(self.file_name, self.world_size)
355356
pg = self._create_process_group_nccl(store, self.opts())
@@ -366,6 +367,24 @@ def test_nan_assert(self, type):
366367
# reset env
367368
os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
368369

370+
@requires_nccl()
371+
@skip_if_lt_x_gpu(2)
372+
def test_nan_check(self):
373+
# Not expecting an error, NaN check should not make legit code fail
374+
os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
375+
store = c10d.FileStore(self.file_name, self.world_size)
376+
device = torch.device("cuda:%d" % self.rank)
377+
c10d.init_process_group(
378+
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
379+
)
380+
x = torch.ones((10,), dtype=torch.bfloat16, device=device) * self.rank
381+
t = torch.ones(3, 4, dtype=torch.bfloat16, device=device)
382+
c10d.broadcast(x, src=0)
383+
c10d.all_reduce(t)
384+
c10d.destroy_process_group()
385+
# reset env
386+
os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
387+
369388
@requires_nccl()
370389
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
371390
def test_destruct_before_terminate_pg(self):

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,6 +1162,10 @@ void ProcessGroupNCCL::abortCommsFromMap(
11621162
at::cuda::OptionalCUDAGuard gpuGuard;
11631163
at::DeviceIndex deviceIndex = getIndexFromDeviceKey(devName);
11641164
if (deviceIndex >= 0) {
1165+
// For P2P comms, the deviceIndex could be -1 (invalid), as the keys in
1166+
// the map could be non deviceIndex, but rank to rank numbers. So we
1167+
// indeed need to check if deviceIndex >= 0
1168+
// TODO: fix `getIndexFromDeviceKey` or fix `DeviceKey`
11651169
gpuGuard.set_index(deviceIndex);
11661170
}
11671171
LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ "
@@ -2162,7 +2166,9 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
21622166
bool batchP2P = ncclActiveGroupCounter_ > 0;
21632167
bool singleP2POp = isP2POp(opType, batchP2P);
21642168

2165-
at::cuda::OptionalCUDAGuard gpuGuard;
2169+
// Get the device index
2170+
auto deviceIndex = device.index();
2171+
at::cuda::OptionalCUDAGuard gpuGuard(device);
21662172

21672173
// [Group Start/End Note] This is used to ensure that nccl communicator will
21682174
// be created before communication primitives are called. Let's look at this
@@ -2202,10 +2208,6 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
22022208
rank = p2pRank;
22032209
}
22042210

2205-
// Get the device index
2206-
auto deviceIndex = device.index();
2207-
gpuGuard.set_index(deviceIndex);
2208-
22092211
#ifdef NCCL_HAS_COMM_SPLIT
22102212
if (options_->split_from) {
22112213
TORCH_CHECK(
@@ -2715,7 +2717,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
27152717
work->stashed_for_allocator_safety_->push_back(input);
27162718
}
27172719

2718-
at::cuda::OptionalCUDAGuard gpuGuard;
2720+
at::cuda::OptionalCUDAGuard gpuGuard(device);
27192721

27202722
if (nanCheck) {
27212723
checkForNan(input, ncclStream);
@@ -2880,7 +2882,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
28802882
std::make_shared<std::vector<at::Tensor>>(inputs);
28812883
}
28822884

2883-
at::cuda::OptionalCUDAGuard gpuGuard;
2885+
at::cuda::OptionalCUDAGuard gpuGuard(device);
28842886

28852887
// Start event should only be recorded before the ncclGroupStart() (which
28862888
// happens inside AutoNcclGroup guard below)
@@ -3148,7 +3150,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
31483150
}
31493151

31503152
// is gpuGuard needed for the if block below, or can i swap them
3151-
at::cuda::OptionalCUDAGuard gpuGuard;
3153+
at::cuda::OptionalCUDAGuard gpuGuard(device);
31523154

31533155
// Only check for NaN for send ops, for recv ops `tensor` can be a random
31543156
// placeholder

0 commit comments

Comments
 (0)
0