10000 [3/N] Set correct device to CUDA guards (#134357) · Chao1Han/pytorch@afc76c6 · GitHub
[go: up one dir, main page]

Skip to content

Commit afc76c6

Browse files
kwen2501pytorchmergebot
authored andcommitted
[3/N] Set correct device to CUDA guards (pytorch#134357)
In `collective()`, `pointToPoint()` and `collectiveCoalesced()`, CUDA guards were created with an unset (default) CUDA device. This is the reason for the IMA facing the NaN checker in issue pytorch#134062. With this fix, `torch.cuda.set_device(device)` is not needed to work around the IMA. Also refactored a couple places where the guard is created -- preferably we create the guard with a known device, rather than setting the device later. Pull Request resolved: pytorch#134357 Approved by: https://github.com/wconstab, https://github.com/shuqiangzhang ghstack dependencies: pytorch#134300, pytorch#134345
1 parent 5ff97e7 commit afc76c6

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def test_close_pg(self):
350350
@parametrize("type", [torch.float16, torch.float32, torch.float64, torch.bfloat16])
351351
@skip_if_rocm
352352
def test_nan_assert(self, type):
353+
# Expecting a device-side error when NaN is detected
353354
os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
354355
store = c10d.FileStore(self.file_name, self.world_size)
355356
pg = self._create_process_group_nccl(store, self.opts())
@@ -388,6 +389,24 @@ def test_nan_p2p(self):
388389
# reset env
389390
os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
390391

392+
@requires_nccl()
393+
@skip_if_lt_x_gpu(2)
394+
def test_nan_check(self):
395+
# Not expecting an error, NaN check should not make legit code fail
396+
os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
397+
store = c10d.FileStore(self.file_name, self.world_size)
398+
device = torch.device("cuda:%d" % self.rank)
399+
c10d.init_process_group(
400+
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
401+
)
402+
x = torch.ones((10,), dtype=torch.bfloat16, device=device) * self.rank
403+
t = torch.ones(3, 4, dtype=torch.bfloat16, device=device)
404+
c10d.broadcast(x, src=0)
405+
c10d.all_reduce(t)
406+
c10d.destroy_process_group()
407+
# reset env
408+
os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
409+
391410
@requires_nccl()
392411
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
393412
def test_destruct_before_terminate_pg(self):

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,6 +1148,10 @@ void ProcessGroupNCCL::abortCommsFromMap(
11481148
at::cuda::OptionalCUDAGuard gpuGuard;
11491149
at::DeviceIndex deviceIndex = getIndexFromDeviceKey(devName);
11501150
if (deviceIndex >= 0) {
1151+
// For P2P comms, the deviceIndex could be -1 (invalid), as the keys in
1152+
// the map could be non deviceIndex, but rank to rank numbers. So we
1153+
// indeed need to check if deviceIndex >= 0
1154+
// TODO: fix `getIndexFromDeviceKey` or fix `DeviceKey`
11511155
gpuGuard.set_index(deviceIndex);
11521156
}
11531157
LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ "
@@ -2141,7 +2145,9 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
21412145
bool batchP2P = ncclActiveGroupCounter_ > 0;
21422146
bool singleP2POp = isP2POp(opType, batchP2P);
21432147

2144-
at::cuda::OptionalCUDAGuard gpuGuard;
2148+
// Get the device index
2149+
auto deviceIndex = device.index();
2150+
at::cuda::OptionalCUDAGuard gpuGuard(device);
21452151

21462152
// [Group Start/End Note] This is used to ensure that nccl communicator will
21472153
// be created before communication primitives are called. Let's look at this
@@ -2181,10 +2187,6 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
21812187
rank = p2pRank;
21822188
}
21832189

2184-
// Get the device index
2185-
auto deviceIndex = device.index();
2186-
gpuGuard.set_index(deviceIndex);
2187-
21882190
#ifdef NCCL_HAS_COMM_SPLIT
21892191
if (options_->split_from) {
21902192
TORCH_CHECK(
@@ -2694,7 +2696,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
26942696
work->stashed_for_allocator_safety_->push_back(input);
26952697
}
26962698

2697-
at::cuda::OptionalCUDAGuard gpuGuard;
2699+
at::cuda::OptionalCUDAGuard gpuGuard(device);
26982700

26992701
if (nanCheck) {
27002702
checkForNan(input, ncclStream);
@@ -2859,7 +2861,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
28592861
std::make_shared<std::vector<at::Tensor>>(inputs);
28602862
}
28612863

2862-
at::cuda::OptionalCUDAGuard gpuGuard;
2864+
at::cuda::OptionalCUDAGuard gpuGuard(device);
28632865

28642866
// Start event should only be recorded before the ncclGroupStart() (which
28652867
// happens inside AutoNcclGroup guard below)
@@ -3127,7 +3129,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
31273129
}
31283130

31293131
// is gpuGuard needed for the if block below, or can i swap them
3130-
at::cuda::OptionalCUDAGuard gpuGuard;
3132+
at::cuda::OptionalCUDAGuard gpuGuard(device);
31313133

31323134
// Only check for NaN for send ops, for recv ops `tensor` can be a random
31333135
// placeholder

0 commit comments

Comments
 (0)
0