8000 Re-enable flatten path of AG and RS · pytorch/pytorch@db92422 · GitHub
[go: up one dir, main page]

Skip to content

Commit db92422

Browse files
committed
Re-enable flatten path of AG and RS
1 parent c144c1b commit db92422

File tree

1 file changed

+26
-101
lines changed

1 file changed

+26
-101
lines changed

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 26 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -2071,60 +2071,6 @@ bool check_same_size(const std::vector<at::Tensor>& input_tensors) {
20712071
return true;
20722072
}
20732073

2074-
// Flatten each list in `tensor_lists' for a gather or scatter operation, and
2075-
// ensure compatibility with the corresponding tensor in `other'.
2076-
std::vector<at::Tensor> flatten_for_scatter_gather(
2077-
std::vector<std::vector<at::Tensor>>& tensor_lists,
2078-
std::vector<at::Tensor>& other,
2079-
size_t world_size) {
2080-
if (tensor_lists.size() != other.size()) {
2081-
C10_THROW_ERROR(
2082-
ValueError,
2083-
"Tensor list operands to scatter/gather must have the same length");
2084-
}
2085-
const auto num_devices = tensor_lists.size();
2086-
2087-
std::vector<at::Tensor> flattened;
2088-
flattened.resize(num_devices);
2089-
2090-
for (const auto i : c10::irange(size_t{}, num_devices)) {
2091-
if (tensor_lists[i].size() != world_size * num_devices) {
2092-
C10_THROW_ERROR(
2093-
ValueError,
2094-
c10::str(
2095-
"Tensor list input to scatter/gather must match number of collective participants ",
2096-
"but got ",
2097-
tensor_lists[i].size(),
2098-
" inputs",
2099-
" with world_size ",
2100-
world_size,
2101-
" and ",
2102-
num_devices,
2103-
" devices."));
2104-
}
2105-
2106-
// Only check device match for the first tensor in the list; the call to
2107-
// newLikeFlat() below will check the rest.
2108-
if (tensor_lists[i].front().get_device() != other[i].get_device()) {
2109-
C10_THROW_ERROR(
2110-
ValueError,
2111-
"Corresponding input/output tensors to scatter/gather must all reside"
2112-
" on the same device");
2113-
}
2114-
2115-
for (const auto& t : tensor_lists[i]) {
2116-
if (t.numel() != other[i].numel()) {
2117-
C10_THROW_ERROR(
2118-
ValueError,
2119-
"All tensor operands to scatter/gather must have the same number of elements");
2120-
}
2121-
}
2122-
// Flatten the tensors (from all ranks) into a single big tensor.< 10000 /span>
2123-
flattened[i] = newLikeFlat(tensor_lists, i);
2124-
}
2125-
return flattened;
2126-
}
2127-
21282074
} // namespace
21292075

21302076
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
@@ -3159,17 +3105,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
31593105
globalRankStride, // globalRankStride
31603106
this->getSize()); // worldSize
31613107

3162-
// TODO(kwen2501): re-enable old path
3163-
#if 1
3164-
if (false) {
3165-
#else
31663108
bool same_size = check_same_size(outputTensors_);
31673109
if (same_size) {
3168-
auto outputFlattened =
3169-
flatten_for_scatter_gather(outputTensors, inputTensors, size_);
3110+
// Flatten a vector of tensors into a single, stacked tensor.
3111+
at::Tensor outputFlattened = newLikeFlat(outputTensors_);
31703112

31713113
return collective(
3172-
inputTensors,
3114+
inputTensor,
31733115
outputFlattened,
31743116
[&](at::Tensor& input,
31753117
at::Tensor& output,
@@ -3187,7 +3129,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
31873129
comm,
31883130
stream.stream());
31893131
},
3190-
[](std::vector<at::cuda::CUDAStream>& ncclStreams,
3132+
[](at::cuda::CUDAStream& ncclStream,
31913133
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
31923134
// avoidRecordStreams_ note: We actually don't need to stash anything
31933135
// here.
@@ -3200,24 +3142,21 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
32003142
// released back to their allocation streams until after work_ is
32013143
// waited on.
32023144
},
3203-
[&](std::vector<at::cuda::CUDAStream>& ncclStreams,
3145+
[&](at::cuda::CUDAStream& ncclStream,
32043146
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
32053147
// Copy the flattened output tensors to the outputs.
3206-
for (const auto i : c10::irange(outputTensors.size())) {
3207-
at::cuda::CUDAStreamGuard guard(ncclStreams[i]);
3208-
for (const auto j : c10::irange(outputTensors[0].size())) {
3209-
// See [Sync Streams].
3210-
if (!avoidRecordStreams_) {
3211-
c10::cuda::CUDACachingAllocator::recordStream(
3212-
outputTensors[i][j].storage().data_ptr(), ncclStreams[i]);
3213-
}
3214-
outputTensors[i][j].copy_(outputFlattened[i][j], true);
3148+
at::cuda::CUDAStreamGuard guard(ncclStream);
3149+
for (const auto j : c10::irange(outputTensors_.size())) {
3150+
// See [Sync Streams].
3151+
if (!avoidRecordStreams_) {
3152+
c10::cuda::CUDACachingAllocator::recordStream(
3153+
outputTensors_[j].storage().data_ptr(), ncclStream);
32153154
}
3155+
outputTensors_[j].copy_(outputFlattened[j], true);
32163156
}
32173157
},
32183158
OpType::ALLGATHER,
32193159
"nccl:all_gather");
3220-
#endif
32213160
} else {
32223161
const auto num_reduces = outputTensors_.size();
32233162
startCoalescing();
@@ -3292,22 +3231,14 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
32923231
globalRankStride, // globalRankStride
32933232
this->getSize()); // worldSize
32943233

3295-
// TODO(kwen2501): re-enable old path
3296-
#if 1
3297-
if (false) {
3298-
#else
32993234
bool same_size = check_same_size(inputTensors_);
33003235
if (same_size) {
3301-
// @lint-ignore CLANGTIDY
3302-
auto tensor = outputTensors.back();
3303-
3304-
int dev_in_group{0};
3305-
auto inputFlattened =
3306-
flatten_for_scatter_gather(inputTensors, outputTensors, size_);
3236+
// Flatten a vector of tensors into a single, stacked tensor.
3237+
at::Tensor inputFlattened = newLikeFlat(inputTensors_);
33073238

33083239
return collective(
33093240
inputFlattened,
3310-
F438 outputTensors,
3241+
outputTensor,
33113242
[&](at::Tensor& input,
33123243
at::Tensor& output,
33133244
ncclComm_t comm,
@@ -3318,7 +3249,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
33183249
}
33193250
const auto ncclDataType = getNcclDataType(input.scalar_type());
33203251
const auto ncclReduceOp = getNcclReduceOp(
3321-
opts.reduceOp, input, ncclDataType, comm, dev_in_group++);
3252+
opts.reduceOp, input, ncclDataType, comm);
33223253
return ncclReduceScatter(
33233254
input.data_ptr(),
33243255
output.data_ptr(),
@@ -3328,7 +3259,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
33283259
comm,
33293260
stream.stream());
33303261
},
3331-
[&](std::vector<at::cuda::CUDAStream>& ncclStreams,
3262+
[&](at::cuda::CUDAStream& ncclStream,
33323263
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
33333264
if (avoidRecordStreams_) {
33343265
// We only need to stash inputTensors.
@@ -3340,30 +3271,24 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
33403271
// and should also be held by the user until after waiting on
33413272
// work_.
33423273
auto& v = work->stashed_for_allocator_safety_;
3343-
for (const auto i : c10::irange(inputTensors.size())) {
3344-
v->insert(
3345-
v->end(), inputTensors[i].begin(), inputTensors[i].end());
3346-
}
3274+
v->insert(v->end(), inputTensors_.begin(), inputTensors_.end());
33473275
}
33483276

33493277
// Copy the input tensors to the flattened inputs.
3350-
for (const auto i : c10::irange(inputTensors.size())) {
3351-
at::cuda::CUDAStreamGuard guard(ncclStreams[i]);
3352-
for (const auto j : c10::irange(inputTensors[0].size())) {
3353-
// See [Sync Streams].
3354-
if (!avoidRecordStreams_) {
3355-
c10::cuda::CUDACachingAllocator::recordStream(
3356-
inputTensors[i][j].storage().data_ptr(), ncclStreams[i]);
3357-
}
3358-
inputFlattened[i][j].copy_(inputTensors[i][j], true);
3278+
at::cuda::CUDAStreamGuard guard(ncclStream);
3279+
for (const auto j : c10::irange(inputTensors_.size())) {
3280+
// See [Sync Streams].
3281+
if (!avoidRecordStreams_) {
3282+
c10::cuda::CUDACachingAllocator::recordStream(
3283+
inputTensors_[j].storage().data_ptr(), ncclStream);
33593284
}
3285+
inputFlattened[j].copy_(inputTensors_[j], true);
33603286
}
33613287
},
3362-
[&](std::vector<at::cuda::CUDAStream>&,
3288+
[&](at::cuda::CUDAStream&,
33633289
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
33643290
OpType::REDUCE_SCATTER,
33653291
"nccl:reduce_scatter");
3366-
#endif
33673292
} else {
33683293
const auto num_reduces = inputTensors_.size();
33693294
startCoalescing();

0 commit comments

Comments
 (0)
0