8000 Revert "[c10d] PGNCCL refactor part 2: Simplify ProcessGroupNCCL into… · pytorch/pytorch@8f8a72a · GitHub
  • [go: up one dir, main page]

    Skip to content

    Commit 8f8a72a

    Browse files
    pytorchmergebotclee2000
    authored andcommitted
    Revert "[c10d] PGNCCL refactor part 2: Simplify ProcessGroupNCCL into single-device style (#119421)"
    This reverts commit f3e7d80. Reverted #119421 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](#119421 (comment)))
    1 parent 2dbb455 commit 8f8a72a

    File tree

    8 files changed

    +1126
    -828
    lines changed

    8 files changed

    +1126
    -828
    lines changed

    test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp

    Lines changed: 14 additions & 12 deletions
    Original file line numberDiff line numberDiff line change
    @@ -20,18 +20,20 @@ constexpr int kNcclErrorHandlingVersion = 2400;
    2020
    class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
    2121
    public:
    2222
    WorkNCCLSimulateErrors(
    23-
    at::Device& device,
    23+
    const std::vector<at::Device>& devices,
    2424
    bool simulate_error,
    2525
    int rank,
    2626
    c10d::OpType opType,
    2727
    uint64_t seq)
    28-
    : WorkNCCL(device, rank, opType, seq), simulateError_(simulate_error) {}
    28+
    : WorkNCCL(devices, rank, opType, seq), simulateError_(simulate_error) {}
    2929

    30-
    std::exception_ptr checkForNCCLErrors() override {
    30+
    std::exception_ptr checkForNCCLErrors(
    31+
    const std::vector<std::shared_ptr<c10d::NCCLComm>>& ncclComms)
    32+
    const override {
    3133
    if (simulateError_) {
    3234
    return std::make_exception_ptr(std::runtime_error("Error"));
    3335
    }
    34-
    return c10d::ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors();
    36+
    return c10d::ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors(ncclComms);
    3537
    }
    3638

    3739
    private:
    @@ -48,11 +50,11 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
    4850
    : ProcessGroupNCCL(store, rank, size, opts), simulateError_(false) {}
    4951

    5052
    std::exception_ptr checkForNCCLErrors(
    51-
    std::shared_ptr<c10d::NCCLComm>& ncclComm) override {
    53+
    const std::vector<std::shared_ptr<c10d::NCCLComm>>& ncclComms) override {
    5254
    if (simulateError_) {
    5355
    return std::make_exception_ptr(std::runtime_error("Error"));
    5456
    }
    55-
    return c10d::ProcessGroupNCCL::checkForNCCLErrors(ncclComm);
    57+
    return c10d::ProcessGroupNCCL::checkForNCCLErrors(ncclComms);
    5658
    }
    5759

    5860
    std::chrono::duration<int64_t, std::milli> getWatchdogSleepInterval() {
    @@ -61,14 +63,14 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
    6163
    }
    6264

    6365
    c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
    64-
    at::Device& device,
    66+
    std::vector<at::Device> devices,
    6567
    int rank,
    6668
    c10d::OpType opType,
    6769
    const char* profilingTitle,
    6870
    const std::vector<at::Tensor>& inputs = {},
    6971
    const std::vector<at::Tensor>& outputs = {}) override {
    7072
    return c10::make_intrusive<WorkNCCLSimulateErrors>(
    71-
    device, simulateError_, rank, opType, seq_);
    73+
    devices, simulateError_, rank, opType, seq_);
    7274
    }
    7375

    7476
    size_t getNCCLCommCacheSize() {
    @@ -90,12 +92,12 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
    9092
    class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
    9193
    public:
    9294
    WorkNCCLTimedoutErrors(
    93-
    at::Device& device,
    95+
    const std::vector<at::Device>& devices,
    9496
    bool set_timedout_error,
    9597
    int rank,
    9698
    c10d::OpType opType,
    9799
    uint64_t seq)
    98-
    : WorkNCCL(device, rank, opType, seq),
    100+
    : WorkNCCL(devices, rank, opType, seq),
    99101
    setTimedoutError_(set_timedout_error) {}
    100102

    101103
    private:
    @@ -122,14 +124,14 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors {
    122124
    setTimedoutError_(false) {}
    123125

    124126
    c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
    125-
    at::Device& device,
    127+
    std::vector<at::Device> devices,
    126128
    int rank,
    127129
    c10d::OpType opType,
    128130
    const char* profilingTitle,
    129131
    const std::vector<at::Tensor>& inputs = {},
    130132
    const std::vector<at::Tensor>& outputs = {}) override {
    131133
    return c10::make_intrusive<WorkNCCLTimedoutErrors>(
    132-
    device, setTimedoutError_, rank, opType, seq_);
    134+
    devices, setTimedoutError_, rank, opType, seq_);
    133135
    }
    134136

    135137
    void setTimedoutError() {

    test/distributed/test_c10d_nccl.py

    Lines changed: 4 additions & 10 deletions
    Original file line numberDiff line numberDiff line change
    @@ -2947,10 +2947,6 @@ def world_size(self):
    29472947
    def blocking_wait_error_msg(self):
    29482948
    return "timeout"
    29492949

    2950-
    @property
    2951-
    def remote_error_msg(self):
    2952-
    return "remote process exit"
    2953-
    29542950
    def _run_all_reduce(self, pg):
    29552951
    pg.allreduce(torch.rand(10).cuda(self.rank))
    29562952

    @@ -2999,9 +2995,8 @@ def _test_nccl_errors_blocking(self, func):
    29992995
    process_group.allreduce(torch.rand(10).cuda(self.rank))
    30002996
    if self.rank == 0:
    30012997
    work = process_group.allreduce(torch.rand(10).cuda(self.rank))
    3002-
    with self.assertRaisesRegex(dist.DistBackendError, self.remote_error_msg):
    3003-
    # Previously this should timeout; but with newer NCCL version,
    3004-
    # it seems NCCL would detect that the peer rank has exited
    2998+
    with self.assertRaisesRegex(dist.DistBackendError, self.blocking_wait_error_msg):
    2999+
    # Operation would time out in blocking mode.
    30053000
    work.wait(timeout=timedelta(seconds=self.op_timeout_sec))
    30063001
    # Run some GPU operations to make sure cuda has not gotten stuck.
    30073002
    # It was observed cuda could get stuck if NCCL communicators were
    @@ -3069,9 +3064,8 @@ def test_nccl_blocking_wait_with_barrier(self):
    30693064
    )
    30703065
    process_group.barrier().wait()
    30713066
    if self.rank == 0:
    3072-
    with self.assertRaisesRegex(dist.DistBackendError, self.remote_error_msg):
    3073-
    # Previously this should timeout; but with newer NCCL version,
    3074-
    # it seems NCCL would detect that the peer rank has exited
    3067+
    with self.assertRaisesRegex(dist.DistBackendError, self.blocking_wait_error_msg):
    3068+
    # This should timeout
    30753069
    process_group.barrier().wait(timeout=timedelta(seconds=self.op_timeout_sec))
    30763070

    30773071
    def _run_invalid_nccl_blocking_wait_env(self, val):

    torch/csrc/cuda/nccl.cpp

    Lines changed: 8 additions & 6 deletions
    Original file line numberDiff line numberDiff line change
    @@ -415,18 +415,20 @@ AutoNcclGroup::AutoNcclGroup() {
    415415
    (c10::cuda::getFreeMutex())->lock();
    416416
    #endif
    417417
    comm_nonblocking_ = false;
    418-
    comm_ = nullptr;
    419418
    #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
    420419
    detail::NCCL_CHECK(ncclGroupStart());
    421420
    #endif
    422421
    }
    423422

    424-
    AutoNcclGroup::AutoNcclGroup(ncclComm_t comm, bool comm_nonblocking) {
    423+
    AutoNcclGroup::AutoNcclGroup(
    424+
    std::vector<ncclComm_t>& comms,
    425+
    bool comm_nonblocking) {
    425426
    #if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)
    426427
    // nccl < 2.0 cannot be called concurrently with cudaFree
    427428
    (c10::cuda::getFreeMutex())->lock();
    428429
    #endif
    429-
    comm_ = comm;
    430+
    // TODO(eqy): can we make comms_ reference?
    431+
    comms_ = comms;
    430432
    comm_nonblocking_ = comm_nonblocking;
    431433
    #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
    432434
    detail::NCCL_CHECK(ncclGroupStart());
    @@ -435,10 +437,10 @@ AutoNcclGroup::AutoNcclGroup(ncclComm_t comm, bool comm_nonblocking) {
    435437

    436438
    AutoNcclGroup::~AutoNcclGroup() noexcept(false) {
    437439
    #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
    438-
    if (comm_nonblocking_ && comm_ != nullptr) {
    439-
    detail::NCCL_CHECK_TIMEOUT(ncclGroupEnd(), comm_);
    440-
    } else {
    440+
    if (!comm_nonblocking_) {
    441441
    detail::NCCL_CHECK(ncclGroupEnd());
    442+
    } else {
    443+
    detail::NCCL_CHECK_TIMEOUT(ncclGroupEnd(), comms_);
    442444
    }
    443445
    #endif
    444446
    #if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)

    torch/csrc/cuda/nccl.h

    Lines changed: 2 additions & 2 deletions
    Original file line numberDiff line numberDiff line change
    @@ -76,9 +76,9 @@ enum class ncclDataType {
    7676
    // manages group and lock lifetimes.
    7777
    struct AutoNcclGroup {
    7878
    AutoNcclGroup();
    79-
    AutoNcclGroup(ncclComm_t comm, bool comm_nonblocking);
    79+
    AutoNcclGroup(std::vector<ncclComm_t>& comms, bool comm_nonblocking);
    8080
    ~AutoNcclGroup() noexcept(false);
    81-
    ncclComm_t comm_;
    81+
    std::vector<ncclComm_t> comms_;
    8282
    bool comm_nonblocking_;
    8383
    };
    8484

    torch/csrc/distributed/c10d/NCCLUtils.hpp

    Lines changed: 31 additions & 26 deletions
    Original file line numberDiff line numberDiff line change
    @@ -126,32 +126,37 @@
    126126
    TORCH_CHECK_WITH(DistBackendError, false, err); \
    127127
    }
    128128

    129-
    #define C10D_NCCL_CHECK_TIMEOUT_GROUPEND(cmd, comm, failureReason) \
    130-
    ncclResult_t state = cmd; \
    131-
    auto startTimepoint = std::chrono::steady_clock::now(); \
    132-
    if (state == ncclInProgress) { \
    133-
    do { \
    134-
    if (nccl_nonblocking_timeout() > 0) { \
    135-
    auto currentTimepoint = std::chrono::steady_clock::now(); \
    136-
    auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>( \
    137-
    currentTimepoint - startTimepoint) \
    138-
    .count(); \
    139-
    if (timeElapsed > nccl_nonblocking_timeout()) { \
    140-
    std::string err = "NCCL timeout in: " + std::string(__FILE__) + \
    141-
    ":" + std::to_string(__LINE__) + ", " + \
    142-
    ncclGetErrorWithVersion(state) + "\n" + \
    143-
    getNcclErrorDetailStr(state, failureReason); \
    144-
    TORCH_CHECK_WITH(DistBackendError, false, err); \
    145-
    } \
    146-
    } \
    147-
    ncclCommGetAsyncError(comm->getNcclComm(), &state); \
    148-
    } while (state == ncclInProgress); \
    149-
    } \
    150-
    if (state != ncclSuccess) { \
    151-
    std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
    152-
    std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) + \
    153-
    "\n" + getNcclErrorDetailStr(state, failureReason); \
    154-
    TORCH_CHECK_WITH(DistBackendError, false, err); \
    129+
    #define C10D_NCCL_CHECK_TIMEOUT_GROUPEND(cmd, comms_, failureReason) \
    130+
    ncclResult_t state = cmd; \
    131+
    auto startTimepoint = std::chrono::steady_clock::now(); \
    132+
    if (state == ncclInProgress) { \
    133+
    for (const auto i : c10::irange(comms_.size())) { \
    134+
    do { \
    135+
    if (nccl_nonblocking_timeout() > 0) { \
    136+
    auto currentTimepoint = std::chrono::steady_clock::now(); \
    137+
    auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>( \
    138+
    currentTimepoint - startTimepoint) \
    139+
    .count(); \
    140+
    if (timeElapsed > nccl_nonblocking_timeout()) { \
    141+
    std::string err = "NCCL timeout in: " + std::string(__FILE__) + \
    142+
    ":" + std::to_string(__LINE__) + ", " + \
    143+
    ncclGetErrorWithVersion(state) + "\n" + \
    144+
    getNcclErrorDetailStr(state, failureReason); \
    145+
    TORCH_CHECK_WITH(DistBackendError, false, err); \
    146+
    } \
    147+
    } \
    148+
    ncclCommGetAsyncError(comms_[i]->getNcclComm(), &state); \
    149+
    } while (state == ncclInProgress); \
    150+
    if (state != ncclSuccess) { \
    151+
    break; /* fall through to failed case */ \
    152+
    } \
    153+
    } \
    154+
    } \
    155+
    if (state != ncclSuccess) { \
    156+
    std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
    157+
    std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) + \
    158+
    "\n" + getNcclErrorDetailStr(state, failureReason); \
    159+
    TORCH_CHECK_WITH(DistBackendError, false, err); \
    155160
    }
    156161

    157162
    // Macro to print and abort on a non-successful NCCL return value.

    0 commit comments

    Comments
     (0)
    0