8000 [3/N] Set correct device to CUDA guards (#134357) · tolleybot/pytorch@13114da · GitHub
[go: up one dir, main page]

Skip to content

Commit 13114da

Browse files
kwen2501pytorchmergebot
authored andcommitted
[3/N] Set correct device to CUDA guards (pytorch#134357)
In `collective()`, `pointToPoint()` and `collectiveCoalesced()`, CUDA guards were created with an unset (default) CUDA device. This is the reason for the IMA facing the NaN checker in issue pytorch#134062. With this fix, `torch.cuda.set_device(device)` is not needed to work around the IMA. Also refactored a couple places where the guard is created -- preferably we create the guard with a known device, rather than setting the device later. Pull Request resolved: pytorch#134357 Approved by: https://github.com/wconstab, https://github.com/shuqiangzhang ghstack dependencies: pytorch#134300, pytorch#134345
1 parent be7752e commit 13114da

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def test_close_pg(self):
350350
@parametrize("type", [torch.float16, torch.float32, torch.float64, torch.bfloat16])
351351
@skip_if_rocm
352352
def test_nan_assert(self, type):
353+
# Expecting a device-side error when NaN is detected
353354
os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
354355
store = c10d.FileStore(self.file_name, self.world_size)
355356
pg = self._create_process_group_nccl(store, self.opts())
@@ -388,6 +389,24 @@ def test_nan_p2p(self):
388389
# reset env
389390
os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
390391

392+
@requires_nccl()
393+
@skip_if_lt_x_gpu(2)
394+
def test_nan_check(self):
395+
# Not expecting an error, NaN check should not make legit code fail
396+
os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
397+
store = c10d.FileStore(self.file_name, self.world_size)
398+
device = torch.device("cuda:%d" % self.rank)
399+
c10d.init_process_group(
400+
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
401+
)
402+
x = torch.ones((10,), dtype=torch.bfloat16, device=device) * self.rank
403+
t = torch.ones(3, 4, dtype=torch.bfloat16, device=device)
404+
c10d.broadcast(x, src=0)
405+
c10d.all_reduce(t)
406+
c10d.destroy_process_group()
407+
# reset env
408+
os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
409+
391410
@requires_nccl()
392411
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
393412
def test_destruct_before_terminate_pg(self):

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,6 +1162,10 @@ void ProcessGroupNCCL::abortCommsFromMap(
11621162
at::cuda::OptionalCUDAGuard gpuGuard;
11631163
at::DeviceIndex deviceIndex = getIndexFromDeviceKey(devName);
11641164
if (deviceIndex >= 0) {
1165+
// For P2P comms, the deviceIndex could be -1 (invalid), as the keys in
1166+
// the map could be non deviceIndex, but rank to rank numbers. So we
1167+
// indeed need to check if deviceIndex >= 0
1168+
// TODO: fix `getIndexFromDeviceKey` or fix `DeviceKey`
11651169
gpuGuard.set_index(deviceIndex);
11661170
}
11671171
LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ "
@@ -2162,7 +2166,9 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
21622166
bool batchP2P = ncclActiveGroupCounter_ > 0;
21632167
bool singleP2POp = isP2POp(opType, batchP2P);
21642168

2165-
at::cuda::OptionalCUDAGuard gpuGuard;
2169+
// Get the device index
2170+
auto deviceIndex = device.index();
2171+
at::cuda::OptionalCUDAGuard gpuGuard(device);
21662172

21672173
// [Group Start/End Note] This is used to ensure that nccl communicator will
21682174
// be created before communication primitives are called. Let's look at this
@@ -2202,10 +2208,6 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
22022208
rank = p2pRank;
22032209
}
22042210

2205-
// Get the device index
2206-
auto deviceIndex = device.index();
2207-
gpuGuard.set_index(deviceIndex);
2208-
22092211
#ifdef NCCL_HAS_COMM_SPLIT
22102212
if (options_->split_from) {
22112213
TORCH_CHECK(
@@ -2715,7 +2717,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
27152717
work->stashed_for_allocator_safety_->push_back(input);
27162718
}
27172719

2718-
at::cuda::OptionalCUDAGuard gpuGuard;
2720+
at::cuda::OptionalCUDAGuard gpuGuard(device);
27192721

27202722
if (nanCheck) {
27212723
checkForNan(input, ncclStream);
@@ -2880,7 +2882,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
28802882
std::make_shared<std::vector<at::Tensor>>(inputs);
28812883
}
28822884

2883-
at::cuda::OptionalCUDAGuard gpuGuard;
2885+
at::cuda::OptionalCUDAGuard gpuGuard(device);
28842886

28852887
// Start event should only be recorded before the ncclGroupStart() (which
28862888
// happens inside AutoNcclGroup guard below)
@@ -3148,7 +3150,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
31483150
}
31493151

31503152
// is gpuGuard needed for the if block below, or can i swap them
3151-
at::cuda::OptionalCUDAGuard gpuGuard;
3153+
at::cuda::OptionalCUDAGuard gpuGuard(device);
31523154

31533155
// Only check for NaN for send ops, for recv ops `tensor` can be a random
31543156
// placeholder

0 commit comments

Comments
 (0)
0