8000 [3/N] Set correct device to CUDA guards · pytorch/pytorch@2496c6e · GitHub
[go: up one dir, main page]

Skip to content

Commit 2496c6e

Browse files
committed
[3/N] Set correct device to CUDA guards
ghstack-source-id: 94c9997 Pull Request resolved: #134357
1 parent 2bb435c commit 2496c6e

File tree

1 file changed

+7
-12
lines changed

1 file changed

+7
-12
lines changed

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,11 +1103,8 @@ void ProcessGroupNCCL::abortCommsFromMap(
11031103
for (auto& it : ncclCommsMap) {
11041104
auto& devName = it.first;
11051105
auto& ncclComm = it.second;
1106-
at::cuda::OptionalCUDAGuard gpuGuard;
11071106
at::DeviceIndex deviceIndex = getIndexFromDeviceKey(devName);
1108-
if (deviceIndex >= 0) {
1109-
gpuGuard.set_index(deviceIndex);
1110-
}
1107+
at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex);
11111108
LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ "
11121109
<< ncclComm->ncclComm_ << " on CUDA device: " << devName;
11131110
ncclComm->ncclCommAbort(abortReason);
@@ -2132,7 +2129,9 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
21322129
<< timerDeltaMs << " ms";
21332130
}
21342131

2135-
at::cuda::OptionalCUDAGuard gpuGuard;
2132+
// Get the device index
2133+
auto deviceIndex = device.index();
2134+
at::cuda::OptionalCUDAGuard gpuGuard(device);
21362135

21372136
// [Group Start/End Note] This is used to ensure that nccl communicator will
21382137
// be created before communication primitives are called. Let's look at this
@@ -2172,10 +2171,6 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
21722171
rank = p2pRank;
21732172
}
21742173

2175-
// Get the device index
2176-
auto deviceIndex = device.index();
2177-
gpuGuard.set_index(deviceIndex);
2178-
21792174
#ifdef NCCL_HAS_COMM_SPLIT
21802175
if (options_->split_from) {
21812176
TORCH_CHECK(
@@ -2665,7 +2660,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
26652660
work->stashed_for_allocator_safety_->push_back(input);
26662661
}
26672662

2668-
at::cuda::OptionalCUDAGuard gpuGuard;
2663+
at::cuda::OptionalCUDAGuard gpuGuard(device);
26692664

26702665
if (nanCheck) {
26712666
checkForNan(input, ncclStream);
@@ -2830,7 +2825,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
28302825
std::make_shared<std::vector<at::Tensor>>(inputs);
28312826
}
28322827

2833-
at::cuda::OptionalCUDAGuard gpuGuard;
2828+
at::cuda::OptionalCUDAGuard gpuGuard(device);
28342829

28352830
// Start event should only be recorded before the ncclGroupStart() (which
28362831
// happens inside AutoNcclGroup guard below)
@@ -3098,7 +3093,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
30983093
}
30993094

31003095
// is gpuGuard needed for the if block below, or can i swap them
3101-
at::cuda::OptionalCUDAGuard gpuGuard;
3096+
at::cuda::OptionalCUDAGuard gpuGuard(device);
31023097

31033098
// Only check for NaN for send ops, for recv ops `tensor` can be a random
31043099
// placeholder

0 commit comments

Comments
 (0)
0