@@ -1103,11 +1103,8 @@ void ProcessGroupNCCL::abortCommsFromMap(
1103
1103
for (auto & it : ncclCommsMap) {
1104
1104
auto & devName = it.first ;
1105
1105
auto & ncclComm = it.second ;
1106
- at::cuda::OptionalCUDAGuard gpuGuard;
1107
1106
at::DeviceIndex deviceIndex = getIndexFromDeviceKey (devName);
1108
- if (deviceIndex >= 0 ) {
1109
- gpuGuard.set_index (deviceIndex);
1110
- }
1107
+ at::cuda::OptionalCUDAGuard gpuGuard (deviceIndex);
1111
1108
LOG (INFO) << logPrefix () << " ProcessGroupNCCL destroying ncclComm_ "
1112
1109
<< ncclComm->ncclComm_ << " on CUDA device: " << devName;
1113
1110
ncclComm->ncclCommAbort (abortReason);
@@ -2132,7 +2129,9 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
2132
2129
<< timerDeltaMs << " ms" ;
2133
2130
}
2134
2131
2135
- at::cuda::OptionalCUDAGuard gpuGuard;
2132
+ // Get the device index
2133
+ auto deviceIndex = device.index ();
2134
+ at::cuda::OptionalCUDAGuard gpuGuard (device);
2136
2135
2137
2136
// [Group Start/End Note] This is used to ensure that nccl communicator will
2138
2137
// be created before communication primitives are called. Let's look at this
@@ -2172,10 +2171,6 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
2172
2171
rank = p2pRank;
2173
2172
}
2174
2173
2175
- // Get the device index
2176
- auto deviceIndex = device.index ();
2177
- gpuGuard.set_index (deviceIndex);
2178
-
2179
2174
#ifdef NCCL_HAS_COMM_SPLIT
2180
2175
if (options_->split_from ) {
2181
2176
TORCH_CHECK (
@@ -2665,7 +2660,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
2665
2660
work->stashed_for_allocator_safety_ ->push_back (input);
2666
2661
}
2667
2662
2668
- at::cuda::OptionalCUDAGuard gpuGuard;
2663
+ at::cuda::OptionalCUDAGuard gpuGuard (device) ;
2669
2664
2670
2665
if (nanCheck) {
2671
2666
checkForNan (input, ncclStream);
@@ -2830,7 +2825,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
2830
2825
std::make_shared<std::vector<at::Tensor>>(inputs);
2831
2826
}
2832
2827
2833
- at::cuda::OptionalCUDAGuard gpuGuard;
2828
+ at::cuda::OptionalCUDAGuard gpuGuard (device) ;
2834
2829
2835
2830
// Start event should only be recorded before the ncclGroupStart() (which
2836
2831
// happens inside AutoNcclGroup guard below)
@@ -3098,7 +3093,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
3098
3093
}
3099
3094
3100
3095
// is gpuGuard needed for the if block below, or can i swap them
3101
- at::cuda::OptionalCUDAGuard gpuGuard;
3096
+ at::cuda::OptionalCUDAGuard gpuGuard (device) ;
3102
3097
3103
3098
// Only check for NaN for send ops, for recv ops `tensor` can be a random
3104
3099
// placeholder
0 commit comments