10000 [PGNCCL] Launch kernel on current stream & remove `record_stream` ent… · pytorch/pytorch@ef6296e · GitHub
[go: up one dir, main page]

Skip to content

Commit ef6296e

Browse files
kwen2501pytorchmergebot
authored andcommitted
[PGNCCL] Launch kernel on current stream & remove record_stream entirely (#148590)
This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately are related): 1. When async_op=False, we directly launch the collective on "current" stream, instead of a trampoline stream and join back. - Resolves #147729 - Resolves #146881 - Also saves two event syncs (which have overhead in case of HIP) and one pybind when we call `work.wait()` in distributed_c10d.py on behalf of user. 2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling. - Resolves #147168 3. Remove tensor life management when async_op=False; only use it when async_op=True. 4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](#147168 (comment)). 5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels. Joint work with @cenzhaometa who wants to remove the event sync overhead. Cc: @ngimel @awgu @Aidyn-A @skyw @wconstab @leonardo0lyj Differential Revision: [D70937982](https://our.internmc.facebook.com/intern/diff/D70937982) Pull Request resolved: #148590 Approved by: https://github.com/eqy, https://github.com/Aidyn-A, https://github.com/fduwjj
1 parent b366f33 commit ef6296e

File tree

11 files changed

+411
-362
lines changed
  • distributed
  • testing/_internal/distributed
  • 11 files changed

    +411
    -362
    lines changed

    test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp

    Lines changed: 3 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -363,6 +363,9 @@ class TestDebugInfoWriter : public c10d::DebugInfoWriter {
    363363
    };
    364364

    365365
    TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
    366+
    // Note (kwen2501) 03/07/2025
    367+
    // TODO: re-enable
    368+
    GTEST_SKIP() << "Skipping test as the trace write seems unstable.";
    366369
    int heartBeatIntervalInSec = 2;
    367370
    std::string timeInterval = std::to_string(heartBeatIntervalInSec);
    368371
    ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0);

    test/forward_backward_compatibility/check_forward_backward_compatibility.py

    Lines changed: 3 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -126,6 +126,9 @@
    126126
    ("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)),
    127127
    ("aten::all_gather_into_tensor", datetime.date(9999, 1, 30)),
    128128
    ("aten::all_reduce", datetime.date(9999, 1, 30)),
    129+
    # These ops are defined in torch/csrc/distributed/c10d/Ops.cpp
    130+
    # TODO: add back restriction when c10d ops can be exported
    131+
    ("c10d::.*", datetime.date(9999, 1, 1)),
    129132
    ]
    130133

    131134
    ALLOW_LIST_COMPILED = [

    torch/_C/_distributed_c10d.pyi

    Lines changed: 7 additions & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -2,7 +2,7 @@
    22
    # mypy: disable-error-code="type-arg"
    33
    from datetime import timedelta
    44
    from enum import Enum
    5-
    from typing import Any, overload
    5+
    from typing import Any, Optional, overload
    66

    77
    import torch
    88
    from torch import Tensor
    @@ -139,6 +139,8 @@ class BroadcastOptions:
    139139
    class AllreduceOptions:
    140140
    reduceOp: ReduceOp
    141141
    timeout: timedelta
    142+
    asyncOp: bool
    143+
    sparseIndices: Optional[Tensor]
    142144

    143145
    class AllreduceCoalescedOptions(AllreduceOptions): ...
    144146

    @@ -147,6 +149,7 @@ class ReduceOptions:
    147149
    rootRank: int
    148150
    rootTensor: int
    149151
    timeout: timedelta
    152+
    asyncOp: bool
    150153

    151154
    class AllgatherOptions:
    152155
    timeout: timedelta
    @@ -155,6 +158,7 @@ class AllgatherOptions:
    155158
    class GatherOptions:
    156159
    rootRank: int
    157160
    timeout: timedelta
    161+
    asyncOp: bool
    158162

    159163
    class ScatterOptions:
    160164
    rootRank: int
    @@ -170,9 +174,11 @@ class BarrierOptions:
    170174
    device_ids: list[int]
    171175
    device: torch.device
    172176
    timeout: timedelta
    177+
    asyncOp: bool
    173178

    174179
    class AllToAllOptions:
    175180
    timeout: timedelta
    181+
    asyncOp: bool
    176182

    177183
    class Store:
    178184
    def set(self, key: str, value: str): ...

    torch/csrc/distributed/c10d/Ops.cpp

    Lines changed: 79 additions & 51 deletions
    Large diffs are not rendered by default.

    torch/csrc/distributed/c10d/ProcessGroup.hpp

    Lines changed: 33 additions & 10 deletions
    10000
    Original file line numberDiff line numberDiff line change
    @@ -224,13 +224,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
    224224
    const c10::intrusive_ptr<::c10d::ProcessGroup>&,
    225225
    const c10::intrusive_ptr<::c10d::ReduceOp>&,
    226226
    const std::optional<at::Tensor>& sparse_indices,
    227+
    bool,
    227228
    int64_t)>();
    228229

    229230
    auto work = std::get<1>(op.call(
    230231
    tensors,
    231232
    c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
    232233
    c10::make_intrusive<ReduceOp>(opts.reduceOp),
    233234
    opts.sparseIndices,
    235+
    opts.asyncOp,
    234236
    opts.timeout.count()));
    235237

    236238
    if (c10d::allow_inflight_collective_as_graph_input()) {
    @@ -250,12 +252,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
    250252
    at::TensorList,
    251253
    const c10::intrusive_ptr<::c10d::ProcessGroup>&,
    252254
    const c10::intrusive_ptr<::c10d::ReduceOp>&,
    255+
    bool,
    253256
    int64_t)>();
    254257

    255258
    auto work = op.call(
    256259
    tensors,
    257260
    c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
    258261
    c10::make_intrusive<ReduceOp>(opts.reduceOp),
    262+
    opts.asyncOp,
    259263
    opts.timeout.count());
    260264

    261265
    if (c10d::allow_inflight_collective_as_graph_input()) {
    @@ -277,13 +281,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
    277281
    const c10::intrusive_ptr<::c10d::ReduceOp>&,
    278282
    int64_t,
    279283
    int64_t,
    284+
    bool,
    280285
    int64_t)>();
    281286
    auto work = op.call(
    282287
    tensors,
    283288
    c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
    284289
    c10::make_intrusive<ReduceOp>(opts.reduceOp),
    285290
    opts.rootRank,
    286291
    opts.rootTensor,
    292+
    opts.asyncOp,
    287293
    opts.timeout.count());
    288294

    289295
    if (c10d::allow_inflight_collective_as_graph_input()) {
    @@ -306,12 +312,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
    306312
    const std::vector<std::vector<at::Tensor>>&,
    307313
    at::TensorList,
    308314
    const c10::intrusive_ptr<::c10d::ProcessGroup>&,
    315+
    bool,
    309316
    int64_t)>();
    310317

    311318
    auto work = std::get<1>(op.call(
    312319
    outputTensors,
    313320
    inputTensors,
    314321
    c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
    322+
    opts.asyncOp,
    315323
    opts.timeout.count()));
    316324

    317325
    if (c10d::allow_inflight_collective_as_graph_input()) {
    @@ -363,18 +371,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
    363371
    std::vector<std::vector<at::Tensor>>& outputTensorLists,
    364372
    std::vector<at::Tensor>& inputTensors,
    365373
    const AllgatherOptions& opts = AllgatherOptions()) {
    366-
    static auto op =
    367-
    c10::Dispatcher::singleton()
    368-
    .findSchemaOrThrow("c10d::allgather_coalesced_", "")
    369-
    .typed<c10::intrusive_ptr<Work>(
    370-
    const std::vector<std::vector<at::Tensor>>&,
    371-
    const at::TensorList&,
    372-
    const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
    374+
    static auto op = c10::Dispatcher::singleton()
    375+
    .findSchemaOrThrow("c10d::allgather_coalesced_", "")
    376+
    .typed<c10::intrusive_ptr<Work>(
    377+
    const std::vector<std::vector<at::Tensor>>&,
    378+
    const at::TensorList&,
    379+
    const c10::intrusive_ptr<::c10d::ProcessGroup>&,
    380+
    bool)>();
    373381

    374382
    auto work = op.call(
    375383
    outputTensorLists,
    376384
    inputTensors,
    377-
    c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
    385+
    c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
    386+
    opts.asyncOp);
    378387

    379388
    if (c10d::allow_inflight_collective_as_graph_input()) {
    380389
    for (const auto& tensor_list : outputTensorLists) {
    @@ -399,12 +408,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
    399408
    .typed<c10::intrusive_ptr<Work>(
    400409
    const at::TensorList,
    401410
    const at::TensorList,
    402-
    const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
    411+
    const c10::intrusive_ptr<::c10d::ProcessGroup>&,
    412+
    bool)>();
    403413

    404414
    auto work = op.call(
    405415
    outputTensors,
    406416
    inputTensors,
    407-
    c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
    417+
    c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
    418+
    opts.asyncOp);
    408419

    409420
    if (c10d::allow_inflight_collective_as_graph_input()) {
    410421
    for (const auto& tensor : outputTensors) {
    @@ -425,12 +436,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
    425436
    const at::TensorList&,
    426437
    const c10::intrusive_ptr<::c10d::ProcessGroup>&,
    427438
    int64_t,
    439+
    bool,
    428440
    int64_t)>();
    429441
    auto work = op.call(
    430442
    outputTensors,
    431443
    inputTensors,
    432444
    c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
    433445
    opts.rootRank,
    446+
    opts.asyncOp,
    434447
    opts.timeout.count());
    435448

    436449
    if (c10d::allow_inflight_collective_as_graph_input()) {
    @@ -487,12 +500,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
    487500
    const std::vector<std::vector<at::Tensor>>&,
    488501
    const c10::intrusive_ptr<::c10d::ProcessGroup>&,
    489502
    const c10::intrusive_ptr<::c10d::ReduceOp>&,
    503+
    bool,
    490504
    int64_t)>();
    491505
    auto work = std::get<1>(op.call(
    492506
    outputTensors,
    493507
    inputTensors,
    494508
    c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
    495509
    c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
    510+
    opts.asyncOp,
    496511
    opts.timeout.count()));
    497512

    498513
    if (c10d::allow_inflight_collective_as_graph_input()) {
    @@ -546,13 +561,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
    546561
    const at::TensorList,
    547562
    const c10::intrusive_ptr<::c10d::ProcessGroup>&,
    548563
    const c10::intrusive_ptr<::c10d::ReduceOp>&,
    564+
    bool,
    549565
    int64_t)>();
    550566

    551567
    auto work = op.call(
    552568
    outputTensors,
    553569
    inputTensors,
    554570
    c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
    555571
    c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
    572+
    opts.asyncOp,
    556573
    opts.timeout.count());
    557574

    558575
    if (c10d::allow_inflight_collective_as_graph_input()) {
    @@ -577,13 +594,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
    577594
    const c10::intrusive_ptr<::c10d::ProcessGroup>&,
    578595
    std::vector<int64_t>,
    579596
    std::vector<int64_t>,
    597+
    bool,
    580598
    int64_t)>();
    581599
    auto work = op.call(
    582600
    outputBuffer,
    583601
    inputBuffer,
    584602
    c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
    585603
    outputSplitSizes,
    586604
    inputSplitSizes,
    605+
    opts.asyncOp,
    587606
    opts.timeout.count());
    588607

    589608
    if (c10d::allow_inflight_collective_as_graph_input()) {
    @@ -604,11 +623,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
    604623
    const at::TensorList&,
    605624
    const at::TensorList&,
    606625
    const c10::intrusive_ptr<::c10d::ProcessGroup>&,
    626+
    bool,
    607627
    int64_t)>();
    608628
    auto work = std::get<1>(op.call(
    609629
    outputTensors,
    610630
    inputTensors,
    611631
    c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
    632+
    opts.asyncOp,
    612633
    opts.timeout.count()));
    613634

    614635
    if (c10d::allow_inflight_collective_as_graph_input()) {
    @@ -778,12 +799,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
    778799
    at::Tensor,
    779800
    const c10::intrusive_ptr<::c10d::ProcessGroup>&,
    780801
    const std::vector<int64_t>&,
    802+
    bool,
    781803
    int64_t)>();
    782804

    783805
    auto work = op.call(
    784806
    tensor,
    785807
    c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
    786808
    opts.device_ids,
    809+
    opts.asyncOp,
    787810
    opts.timeout.count());
    788811
    if (c10d::allow_inflight_collective_as_graph_input()) {
    789812
    c10d::register_work(tensor, work);

    0 commit comments

    Comments
     (0)
    0