8000 Fix support for nccl < 2.17 by oraluben · Pull Request #145719 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Fix support for nccl < 2.17 #145719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions torch/csrc/cuda/nccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,12 @@
#include <c10/util/hash.h>
#include <c10/util/irange.h>

#include <nccl.h>

#include <sched.h>
#include <limits>
#include <sstream>
#include <type_traits>
#include <unordered_map>

#if (NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 13))
#define NCCL_HAS_REMOTE_ERROR 1
#endif

#if (NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 14))
#define NCCL_HAS_COMM_NONBLOCKING 1
#endif

ncclComm_t* to_nccl_comm(torch::cuda::nccl::ncclComm_t* var) {
return reinterpret_cast<ncclComm_t*>(var);
}
Expand Down Expand Up @@ -126,7 +116,7 @@ ncclDataType_t to_nccl_data_type(c10::ScalarType type) {
return ncclDataType_t::ncclUint8;
#endif

#if HAS_NCCL_BF16_DATATYPE
#ifdef NCCL_HAS_BF16_DATATYPE
case at::kBFloat16:
return ncclDataType_t::ncclBfloat16;
#endif
Expand Down
58 changes: 51 additions & 7 deletions torch/csrc/cuda/nccl.h
Original file line number Diff line number Diff line change
@@ -1,21 +1,65 @@
#pragma once

// NOTE: [pytorch nccl defines]

// All NCCL interactions should route through this header.
// Direct inclusion of <nccl.h> should be avoided.
// Version checks/compatibility macros centralized here.

#include <nccl.h>

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <cstddef>
#include <optional>
#include <vector>

static_assert(
NCCL_VERSION_CODE >= NCCL_VERSION(2, 7, 0),
"NCCL version must be 2.7 or later");

// NCCL BFloat16 is enabled only for CUDA 11+ and NCCL versions 2.10+, or for
// HIP 3.1+
#if defined(__CUDA_BF16_TYPES_EXIST__)
#define HAS_NCCL_BF16_DATATYPE \
((NCCL_MAJOR > 2) || (NCCL_MAJOR == 2) && (NCCL_MINOR >= 10))
#elif defined(USE_ROCM) && (TORCH_HIP_VERSION >= 301)
#define HAS_NCCL_BF16_DATATYPE 1
#else
#define HAS_NCCL_BF16_DATATYPE 0
#if defined(__CUDA_BF16_TYPES_EXIST__) && \
(NCCL_VERSION_CODE >= NCCL_VERSION(2, 10, 0))
#define NCCL_HAS_BF16_DATATYPE
#elif defined(RCCL_BFLOAT16)
#define NCCL_HAS_BF16_DATATYPE
#endif

#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 10, 0)
#define NCCL_HAS_AVG
#endif

#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 11, 0)
#define ENABLE_NCCL_PREMUL_SUM_SUPPORT
#endif

// ncclGetLastError() is enabled only for NCCL versions 2.13+
// ncclRemoteError only exists in NCCL versions 2.13+
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 13, 0)
#define NCCL_HAS_REMOTE_ERROR
#endif

#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 14, 0)
#define NCCL_HAS_COMM_NONBLOCKING
#endif

// Note: the first version that supports ncclConfig_t is 2.14. Here we
// fast-forward the version requirement to 2.17 where ncclConfig_t has CTA and
// CGA fields because they have already been pybinded out.
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 17, 0)
#define NCCL_HAS_CONFIG
#endif

#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 18, 0)
#define NCCL_HAS_COMM_SPLIT
#endif

#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 19, 0)
#define NCCL_HAS_COMM_REGISTER
#define NCCL_HAS_MEM_ALLOC
#endif

namespace torch::cuda::nccl {
Expand Down
42 changes: 22 additions & 20 deletions torch/csrc/cuda/python_nccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,43 +15,43 @@

using namespace at;
using namespace torch;
using namespace torch::cuda::nccl;
using namespace torch::cuda::nccl::detail;
using namespace torch::cuda;

static const char* COMM_CAPSULE_NAME = "torch.cuda.nccl.Communicator";

PyObject* THCPModule_nccl_version(PyObject* self, PyObject* args) {
return PyLong_FromUnsignedLongLong(version());
return PyLong_FromUnsignedLongLong(torch::cuda::nccl::version());
}

PyObject* THCPModule_nccl_version_suffix(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
return PyBytes_FromString(version_suffix());
return PyBytes_FromString(torch::cuda::nccl::version_suffix());
END_HANDLE_TH_ERRORS
}

PyObject* THCPModule_nccl_unique_id(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
ncclUniqueId id;
torch::cuda::nccl::ncclUniqueId id;
get_unique_id(id);
return PyBytes_FromStringAndSize((char*)&id, NCCL_UNIQUE_ID_BYTES);
END_HANDLE_TH_ERRORS
}

static ncclComm_t unpack_nccl_comm(PyObject* capsule) {
ncclComm_t comm =
(ncclComm_t)PyCapsule_GetPointer(capsule, COMM_CAPSULE_NAME);
static torch::cuda::nccl::ncclComm_t unpack_nccl_comm(PyObject* capsule) {
torch::cuda::nccl::ncclComm_t comm =
(torch::cuda::nccl::ncclComm_t)PyCapsule_GetPointer(
capsule, COMM_CAPSULE_NAME);
if (!comm)
throw python_error();
return comm;
}

static void destroy_nccl_comm(PyObject* capsule) {
HANDLE_TH_ERRORS
ncclComm_t comm = unpack_nccl_comm(capsule);
torch::cuda::nccl::ncclComm_t comm = unpack_nccl_comm(capsule);
{
pybind11::gil_scoped_release no_gil;
comm_destroy(comm);
torch::cuda::nccl::comm_destroy(comm);
}
END_HANDLE_TH_ERRORS_RET()
}
Expand All @@ -73,19 +73,21 @@ static std::vector<std::optional<at::cuda::CUDAStream>> unpack_streams(
static at::Tensor extract_tensor(PyObject* obj);
static std::vector<at::Tensor> extract_tensors(PyObject* obj);

static std::vector<ncclComm_t> unpack_comms(PyObject* obj, size_t size) {
static std::vector<torch::cuda::nccl::ncclComm_t&g 67E6 t; unpack_comms(
PyObject* obj,
size_t size) {
if (obj == Py_None) {
return std::vector<ncclComm_t>();
return std::vector<torch::cuda::nccl::ncclComm_t>();
}
std::vector<ncclComm_t> comms;
std::vector<torch::cuda::nccl::ncclComm_t> comms;
if (PyCapsule_CheckExact(obj)) {
comms = {unpack_nccl_comm(obj)};
} else {
auto seq = THPObjectPtr(PySequence_Fast(obj, "comm is not a sequence"));
if (!seq)
throw python_error();
auto size = PySequence_Fast_GET_SIZE(seq.get());
comms = std::vector<ncclComm_t>(size);
comms = std::vector<torch::cuda::nccl::ncclComm_t>(size);
for (const auto i : c10::irange(size)) {
comms[i] = unpack_nccl_comm(PySequence_Fast_GET_ITEM(seq.get(), i));
}
Expand Down Expand Up @@ -116,12 +118,12 @@ PyObject* THCPModule_nccl_init_rank(PyObject* self, PyObject* args) {
id_len,
")");

ncclUniqueId commId;
torch::cuda::nccl::ncclUniqueId commId;
memcpy(&commId, id, NCCL_UNIQUE_ID_BYTES);
ncclComm_t comm = nullptr;
torch::cuda::nccl::ncclComm_t comm = nullptr;
{
pybind11::gil_scoped_release no_gil;
comm = comm_init_rank(nranks, commId, rank);
comm = torch::cuda::nccl::comm_init_rank(nranks, commId, rank);
}
return PyCapsule_New(comm, COMM_CAPSULE_NAME, &destroy_nccl_comm);
END_HANDLE_TH_ERRORS
Expand Down Expand Up @@ -186,7 +188,7 @@ PyObject* THCPModule_nccl_all_reduce(PyObject* self, PyObject* args) {

{
pybind11::gil_scoped_release no_gil;
all_reduce(inputs, outputs, op, streams, user_comms);
torch::cuda::nccl::all_reduce(inputs, outputs, op, streams, user_comms);
}

Py_RETURN_NONE;
Expand Down Expand Up @@ -249,7 +251,7 @@ PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) {

{
pybind11::gil_scoped_release no_gil;
all_gather(inputs, outputs, streams, user_comms);
torch::cuda::nccl::all_gather(inputs, outputs, streams, user_comms);
}

Py_RETURN_NONE;
Expand Down Expand Up @@ -282,7 +284,7 @@ PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args) {

{
pybind11::gil_scoped_release no_gil;
reduce_scatter(inputs, outputs, op, streams, user_comms);
torch::cuda::nccl::reduce_scatter(inputs, outputs, op, streams, user_comms);
}

Py_RETURN_NONE;
Expand Down
24 changes: 11 additions & 13 deletions torch/csrc/distributed/c10d/NCCLUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,15 @@ ncclComm_t NCCLComm::getNcclComm() {
". ",
commFailureMsg));
}
#ifdef NCCL_HAS_COMM_NONBLOCKING
// In non-blocking mode, ensure comm is ready.
if (nonBlocking_) {
// Wait with long interval if communicator is being initialized.
bool longInterval = !initialized_;
waitReady(longInterval);
// ncclComm_ should be initialized by now
}
#endif
if (!initialized_) {
// TODO: see if we can consolidate other `initialized_` flipping here.
// Maintaining it elsewhere is some work.
Expand All @@ -150,6 +152,7 @@ ncclComm_t NCCLComm::getNcclComm() {
return ncclComm_;
}

#if defined(NCCL_HAS_COMM_SPLIT) && !defined(FBCODE_CAFFE2)
// Wait for the communicator to be ready. This is a blocking function.
// Arguments:
// longInterval: if true, wait with sleep of an interval; otherwise, wait
Expand All @@ -165,6 +168,7 @@ void NCCLComm::waitReady(bool longInterval) {
C10D_NCCL_CHECK_TIMEOUT(ncclInProgress, ncclComm_, std::nullopt);
}
}
#endif

std::optional<std::string> NCCLComm::getNcclCommFailureReason() const {
LockType lock(mutex_);
Expand Down Expand Up @@ -245,7 +249,11 @@ void NCCLComm::finalize() {
}
at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex_);
auto comm = getNcclComm();
#ifdef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK_NONBLOCKING(ncclCommFinalize(comm), std::nullopt);
#else
C10D_NCCL_CHECK(ncclCommDestroy(comm), std::nullopt);
#endif
}

void NCCLComm::destroy() {
Expand All @@ -265,7 +273,6 @@ void NCCLComm::destroy() {
void NCCLComm::abort(std::optional<std::string> commFailureReason) {
LockType lock(mutex_);
at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex_);
#ifdef ENABLE_NCCL_ERROR_CHECKING
if (aborted_ && !initialized_) {
// Should not abort twice.
return;
Expand Down Expand Up @@ -305,10 +312,6 @@ void NCCLComm::abort(std::optional<std::string> commFailureReason) {
if (ncclAsyncErr_ == ncclSuccess) {
ncclAsyncErr_ = ncclSystemError;
}
#else
// This is a NOOP, if error checks are disabled.
return;
#endif
}

bool NCCLComm::isInitialized() const {
Expand All @@ -327,17 +330,12 @@ uint64_t NCCLComm::getCommSplitCounter() const {

ncclResult_t NCCLComm::checkForNcclError() {
LockType lock(mutex_);
#ifdef ENABLE_NCCL_ERROR_CHECKING
if (ncclAsyncErr_ != ncclSuccess) {
return ncclAsyncErr_;
}
C10D_NCCL_CHECK(
ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_);
return ncclAsyncErr_;
#else
// Always return success, if error checks are disabled.
return ncclSuccess;
#endif
}

ncclResult_t NCCLComm::registerSegment(
Expand Down Expand Up @@ -511,7 +509,7 @@ std::string getNcclErrorDetailStr(
}
std::string interpret;
std::string err;
#ifdef ENABLE_NCCL_GET_LAST_ERROR
#ifdef NCCL_HAS_REMOTE_ERROR
auto ret = ncclGetLastError(nullptr);
if (ret) {
err = "\nLast error:\n" + std::string(ret);
Expand All @@ -526,7 +524,7 @@ std::string getNcclErrorDetailStr(
case ncclSystemError:
interpret =
"ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. ";
#ifndef NCCL_REMOTE_ERROR
#ifndef NCCL_HAS_REMOTE_ERROR
// Before ncclRemoteError was created, unexpected remote disconnect was
// categorized as ncclSystemError
interpret += "It can be also caused by unexpected exit of a remote peer.";
Expand All @@ -542,7 +540,7 @@ std::string getNcclErrorDetailStr(
interpret =
"ncclInvalidUsage: This usually reflects invalid usage of NCCL library.";
break;
#ifdef NCCL_REMOTE_ERROR
#ifdef NCCL_HAS_REMOTE_ERROR
case ncclRemoteError:
interpret =
"ncclRemoteError: A call failed possibly due to a network error or a remote process exiting prematurely.";
Expand Down
Loading
Loading
0