8000 [c10d] PGNCCL refactor part 2: Simplify ProcessGroupNCCL into single-device style by kwen2501 · Pull Request #119421 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[c10d] PGNCCL refactor part 2: Simplify ProcessGroupNCCL into single-device style #119421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
Prev Previous commit
Next Next commit
Re-enable flatten path of AG and RS
  • Loading branch information
kwen2501 committed Feb 8, 2024
commit 242b2a54c74cdec7139a5df96e7f62685b98bbeb
127 changes: 26 additions & 101 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2073,60 +2073,6 @@ bool check_same_size(const std::vector<at::Tensor>& input_tensors) {
return true;
}

// Flatten each list in `tensor_lists' for a gather or scatter operation, and
// ensure compatibility with the corresponding tensor in `other'.
std::vector<at::Tensor> flatten_for_scatter_gather(
std::vector<std::vector<at::Tensor>>& tensor_lists,
std::vector<at::Tensor>& other,
size_t world_size) {
if (tensor_lists.size() != other.size()) {
C10_THROW_ERROR(
ValueError,
"Tensor list operands to scatter/gather must have the same length");
}
const auto num_devices = tensor_lists.size();

std::vector<at::Tensor> flattened;
flattened.resize(num_devices);

for (const auto i : c10::irange(size_t{}, num_devices)) {
if (tensor_lists[i].size() != world_size * num_devices) {
C10_THROW_ERROR(
ValueError,
c10::str(
"Tensor list input to scatter/gather must match number of collective participants ",
"but got ",
tensor_lists[i].size(),
" inputs",
" with world_size ",
world_size,
" and ",
num_devices,
" devices."));
}

// Only check device match for the first tensor in the list; the call to
// newLikeFlat() below will check the rest.
if (tensor_lists[i].front().get_device() != other[i].get_device()) {
C10_THROW_ERROR(
ValueError,
"Corresponding input/output tensors to scatter/gather must all reside"
" on the same device");
}

for (const auto& t : tensor_lists[i]) {
if (t.numel() != other[i].numel()) {
C10_THROW_ERROR(
ValueError,
"All tensor operands to scatter/gather must have the same number of elements");
}
}
// Flatten the tensors (from all ranks) into a single big tensor.
flattened[i] = newLikeFlat(tensor_lists, i);
}
return flattened;
}

} // namespace

c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
Expand Down Expand Up @@ -3164,17 +3110,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
globalRankStride, // globalRankStride
this->getSize()); // worldSize

// TODO(kwen2501): re-enable old path
#if 1
if (false) {
#else
bool same_size = check_same_size(outputTensors_);
if (same_size) {
auto outputFlattened =
flatten_for_scatter_gather(outputTensors, inputTensors, size_);
// Flatten a vector of tensors into a single, stacked tensor.
at::Tensor outputFlattened = newLikeFlat(outputTensors_);

return collective(
inputTensors,
inputTensor,
outputFlattened,
[&](at::Tensor& input,
at::Tensor& output,
10000 Expand All @@ -3192,7 +3134,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
comm,
stream.stream());
},
[](std::vector<at::cuda::CUDAStream>& ncclStreams,
[](at::cuda::CUDAStream& ncclStream,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
// avoidRecordStreams_ note: We actually don't need to stash anything
// here.
Expand All @@ -3205,24 +3147,21 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
// released back to their allocation streams until after work_ is
// waited on.
},
[&](std::vector<at::cuda::CUDAStream>& ncclStreams,
[&](at::cuda::CUDAStream& ncclStream,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
// Copy the flattened output tensors to the outputs.
for (const auto i : c10::irange(outputTensors.size())) {
at::cuda::CUDAStreamGuard guard(ncclStreams[i]);
for (const auto j : c10::irange(outputTensors[0].size())) {
// See [Sync Streams].
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
outputTensors[i][j].storage().data_ptr(), ncclStreams[i]);
}
outputTensors[i][j].copy_(outputFlattened[i][j], true);
at::cuda::CUDAStreamGuard guard(ncclStream);
for (const auto j : c10::irange(outputTensors_.size())) {
// See [Sync Streams].
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
outputTensors_[j].storage().data_ptr(), ncclStream);
}
outputTensors_[j].copy_(outputFlattened[j], true);
}
},
OpType::ALLGATHER,
"nccl:all_gather");
#endif
} else {
const auto num_reduces = outputTensors_.size();
startCoalescing();
Expand Down Expand Up @@ -3298,22 +3237,14 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
globalRankStride, // globalRankStride
this->getSize()); // worldSize

// TODO(kwen2501): re-enable old path
#if 1
if (false) {
#else
bool same_size = check_same_size(inputTensors_);
if (same_size) {
// @lint-ignore CLANGTIDY
auto tensor = outputTensors.back();

int dev_in_group{0};
auto inputFlattened =
flatten_for_scatter_gather(inputTensors, outputTensors, size_);
// Flatten a vector of tensors into a single, stacked tensor.
at::Tensor inputFlattened = newLikeFlat(inputTensors_);

return collective(
inputFlattened,
outputTensors,
outputTensor,
[&](at::Tensor& input,
at::Tensor& output,
ncclComm_t comm,
Expand All @@ -3324,7 +3255,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
}
const auto ncclDataType = getNcclDataType(input.scalar_type());
const auto ncclReduceOp = getNcclReduceOp(
opts.reduceOp, input, ncclDataType, comm, dev_in_group++);
opts.reduceOp, input, ncclDataType, comm);
return ncclReduceScatter(
input.data_ptr(),
output.data_ptr(),
Expand All @@ -3334,7 +3265,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
comm,
stream.stream());
},
[&](std::vector<at::cuda::CUDAStream>& ncclStreams,
[&](at::cuda::CUDAStream& ncclStream,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
if (avoidRecordStreams_) {
// We only need to stash inputTensors.
Expand All @@ -3346,30 +3277,24 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
// and should also be held by the user until after waiting on
// work_.
auto& v = work->stashed_for_allocator_safety_;
for (const auto i : c10::irange(inputTensors.size())) {
v->insert(
v->end(), inputTensors[i].begin(), inputTensors[i].end());
}
v->insert(v->end(), inputTensors_.begin(), inputTensors_.end());
}

// Copy the input tensors to the flattened inputs.
for (const auto i : c10::irange(inputTensors.size())) {
at::cuda::CUDAStreamGuard guard(ncclStreams[i]);
for (const auto j : c10::irange(inputTensors[0].size())) {
// See [Sync Streams].
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
inputTensors[i][j].storage().data_ptr(), ncclStreams[i]);
}
inputFlattened[i][j].copy_(inputTensors[i][j], true);
at::cuda::CUDAStreamGuard guard(ncclStream);
for (const auto j : c10::irange(inputTensors_.size())) {
// See [Sync Streams].
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
inputTensors_[j].storage().data_ptr(), ncclStream);
}
inputFlattened[j].copy_(inputTensors_[j], true);
}
},
[&](std::vector<at::cuda::CUDAStream>&,
[&](at::cuda::CUDAStream&,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
OpType::REDUCE_SCATTER,
"nccl:reduce_scatter");
#endif
} else {
const auto num_reduces = inputTensors_.size();
startCoalescing();
Expand Down
0