8000 Swap check_gpu_tensors_different_devices with check_gpu_single_tensor · pytorch/pytorch@45f0b15 · GitHub
[go: up one dir, main page]

Skip to content

Commit 45f0b15

Browse files
committed
Swap check_gpu_tensors_different_devices with check_gpu_single_tensor
1 parent 4de6892 commit 45f0b15

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2906,8 +2906,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::broadcast(
29062906
std::vector<at::Tensor>& tensors,
29072907
const BroadcastOptions& opts) {
29082908
TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
2909-
check_gpu_tensors_different_devices(tensors);
29102909
auto tensor = tensors.back();
2910+
check_gpu_single_tensor(tensor);
29112911

29122912
// @lint-ignore CLANGTIDY
29132913
RECORD_PARAM_COMMS_DATA(
@@ -2993,9 +2993,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce(
29932993
std::vector<at::Tensor>& tensors,
29942994
const ReduceOptions& opts) {
29952995
TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
2996-
check_gpu_tensors_different_devices(tensors);
29972996
// @lint-ignore CLANGTIDY
29982997
auto tensor = tensors.back();
2998+
check_gpu_single_tensor(tensor);
29992999
RECORD_PARAM_COMMS_DATA(
30003000
static_cast<int>(
30013001
this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
@@ -3086,9 +3086,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
30863086
std::vector<at::Tensor>& inputTensors,
30873087
const AllgatherOptions& opts) {
30883088
TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
3089-
check_gpu_tensors_different_devices(inputTensors);
30903089
// @lint-ignore CLANGTIDY
30913090
auto inputTensor = inputTensors.back();
3091+
check_gpu_single_tensor(inputTensor);
30923092
// @lint-ignore CLANGTIDY
30933093
auto outputTensors_ = outputTensors.back();
30943094

@@ -3214,9 +3214,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
32143214
std::vector<std::vector<at::Tensor>>& inputTensors,
32153215
const ReduceScatterOptions& opts) {
32163216
TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
3217-
check_gpu_tensors_different_devices(outputTensors);
32183217
// @lint-ignore CLANGTIDY
32193218
auto outputTensor = outputTensors.back();
3219+
check_gpu_single_tensor(outputTensor);
32203220
// @lint-ignore CLANGTIDY
32213221
auto inputTensors_ = inputTensors.back();
32223222

@@ -3655,9 +3655,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::send(
36553655
int dstRank,
36563656
int /* unused */) {
36573657
TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
3658-
check_gpu_tensors_different_devices(tensors, true);
36593658
// @lint-ignore CLANGTIDY
36603659
auto tensor = tensors.back();
3660+
check_gpu_single_tensor(tensor, true);
36613661

36623662
RECORD_PARAM_COMMS_DATA(
36633663
static_cast<int>(
@@ -3696,9 +3696,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::recv(
36963696
int srcRank,
36973697
int /* unused */) {
36983698
TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
3699-
check_gpu_tensors_different_devices(tensors, true);
37003699
// @lint-ignore CLANGTIDY
37013700
auto tensor = tensors.back();
3701+
check_gpu_single_tensor(tensor, true);
37023702

37033703
RECORD_PARAM_COMMS_DATA(
37043704
static_cast<int>(

0 commit comments

Comments
 (0)
0