@@ -2906,8 +2906,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::broadcast(
2906
2906
std::vector<at::Tensor>& tensors,
2907
2907
const BroadcastOptions& opts) {
2908
2908
TORCH_CHECK (tensors.size () == 1 , MULTI_DEVICE_ERROR_MSG);
2909
- check_gpu_tensors_different_devices (tensors);
2910
2909
auto tensor = tensors.back ();
2910
+ check_gpu_single_tensor (tensor);
2911
2911
2912
2912
// @lint-ignore CLANGTIDY
2913
2913
RECORD_PARAM_COMMS_DATA (
@@ -2993,9 +2993,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce(
2993
2993
std::vector<at::Tensor>& tensors,
2994
2994
const ReduceOptions& opts) {
2995
2995
TORCH_CHECK (tensors.size () == 1 , MULTI_DEVICE_ERROR_MSG);
2996
- check_gpu_tensors_different_devices (tensors);
2997
2996
// @lint-ignore CLANGTIDY
2998
2997
auto tensor = tensors.back ();
2998
+ check_gpu_single_tensor (tensor);
2999
2999
RECORD_PARAM_COMMS_DATA (
3000
3000
static_cast <int >(
3001
3001
this ->getSequenceNumberForGroup () + 1 ), // seq + 1 to match collective
@@ -3086,9 +3086,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
3086
3086
std::vector<at::Tensor>& inputTensors,
3087
3087
const AllgatherOptions& opts) {
3088
3088
TORCH_CHECK (inputTensors.size () == 1 , MULTI_DEVICE_ERROR_MSG);
3089
- check_gpu_tensors_different_devices (inputTensors);
3090
3089
// @lint-ignore CLANGTIDY
3091
3090
auto inputTensor = inputTensors.back ();
3091
+ check_gpu_single_tensor (inputTensor);
3092
3092
// @lint-ignore CLANGTIDY
3093
3093
auto outputTensors_ = outputTensors.back ();
3094
3094
@@ -3214,9 +3214,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
3214
3214
std::vector<std::vector<at::Tensor>>& inputTensors,
3215
3215
const ReduceScatterOptions& opts) {
3216
3216
TORCH_CHECK (outputTensors.size () == 1 , MULTI_DEVICE_ERROR_MSG);
3217
- check_gpu_tensors_different_devices (outputTensors);
3218
3217
// @lint-ignore CLANGTIDY
3219
3218
auto outputTensor = outputTensors.back ();
3219
+ check_gpu_single_tensor (outputTensor);
3220
3220
// @lint-ignore CLANGTIDY
3221
3221
auto inputTensors_ = inputTensors.back ();
3222
3222
@@ -3655,9 +3655,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::send(
3655
3655
int dstRank,
3656
3656
int /* unused */ ) {
3657
3657
TORCH_CHECK (tensors.size () == 1 , MULTI_DEVICE_ERROR_MSG);
3658
- check_gpu_tensors_different_devices (tensors, true );
3659
3658
// @lint-ignore CLANGTIDY
3660
3659
auto tensor = tensors.back ();
3660
+ check_gpu_single_tensor (tensor, true );
3661
3661
3662
3662
RECORD_PARAM_COMMS_DATA (
3663
3663
static_cast <int >(
@@ -3696,9 +3696,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::recv(
3696
3696
int srcRank,
3697
3697
int /* unused */ ) {
3698
3698
TORCH_CHECK (tensors.size () == 1 , MULTI_DEVICE_ERROR_MSG);
3699
- check_gpu_tensors_different_devices (tensors, true );
3700
3699
// @lint-ignore CLANGTIDY
3701
3700
auto tensor = tensors.back ();
3701
+ check_gpu_single_tensor (tensor, true );
3702
3702
3703
3703
RECORD_PARAM_COMMS_DATA (
3704
3704
static_cast <int >(
0 commit comments