8000 [ddp] propagate use_python_reducer to C++ reducer (#152735) · pytorch/pytorch@d1f1ff8 · GitHub
[go: up one dir, main page]

Skip to content

Commit d1f1ff8

Browse files
xmfanpytorchmergebot
authored andcommitted
[ddp] propagate use_python_reducer to C++ reducer (#152735)
C++ Reducer is silently incorrect under CA, its implementation is no-oping the collective. I'm guessing that it was no-op'd because in DDP + python reducer, the C++ reducer is still being initialized. Pull Request resolved: #152735 Approved by: https://github.com/fegin ghstack dependencies: #153300, #152689
1 parent 1b4749f commit d1f1ff8

File tree

7 files changed

+79
-11
lines changed

7 files changed

+79
-11
lines changed

test/inductor/test_compiled_autograd.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from unittest import mock
2020

2121
import torch
22+
import torch.distributed as dist
2223
import torch.nn as nn
2324
import torch.nn.functional as F
2425
from torch import _inductor as inductor
@@ -30,6 +31,7 @@
3031
from torch._inductor import config as inductor_config
3132
from torch._inductor.test_case import run_tests, TestCase
3233
from torch.nn.attention.flex_attention import flex_attention
34+
from torch.nn.parallel import DistributedDataParallel as DDP
3335
from torch.testing._internal.common_device_type import (
3436
instantiate_device_type_tests,
3537
ops,
@@ -4161,6 +4163,54 @@ def aot_eager():
41614163
first, second, third, fourth = fn(eager(), aot_eager())
41624164
self.assertIsNone(third)
41634165

4166+
@unittest.skipIf(
4167+
not torch.distributed.is_available(),
4168+
"FakePG relies on distributed build",
4169+
)
4170+
def test_ddp_cpp_reducer_error(self):
4171+
from torch.testing._internal.distributed.fake_pg import FakeStore
4172+
4173+
store = FakeStore()
4174+
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
4175+
try:
4176+
model = torch.nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10))
4177+
model = DDP(model)
4178+
inputs = torch.randn(10, 10)
4179+
loss = model(inputs).sum()
4180+
with compiled_autograd._enable(compiler_fn), self.assertRaisesRegex(
4181+
RuntimeError,
4182+
(
4183+
r"Compiled autograd is not compatible with C\+\+ DDP Reducer, "
4184+
r'please use torch._dynamo.config.optimize_ddp="python_reducer"'
4185+
),
4186+
):
4187+
loss.backward()
4188+
4189+
finally:
4190+
dist.destroy_process_group()
4191+
4192+
@unittest.skipIf(
4193+
not torch.distributed.is_available(),
4194+
"FakePG relies on distributed build",
4195+
)
4196+
@config.patch(optimize_ddp="python_reducer")
4197+
def test_ddp_python_reducer(self):
4198+
from torch.testing._internal.distributed.fake_pg import FakeStore
4199+
4200+
store = FakeStore()
4201+
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
4202+
try:
4203+
model = torch.nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10))
4204+
model = DDP(model)
4205+
inputs = torch.randn(10, 10)
4206+
loss = model(inputs).sum()
4207+
with compiled_autograd._enable(compiler_fn):
4208+
# no error expected
4209+
loss.backward()
4210+
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
4211+
finally:
4212+
dist.destroy_process_group()
4213+
41644214

41654215
def load_test_module(name):
41664216
testdir = Path(__file__).absolute().parent.parent

torch/_C/_distributed_c10d.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class Reducer:
5151
param_to_name_mapping: dict[int, str] = ...,
5252
first_bucket_types_cap: int = ..., # kDefaultFirstBucketBytes in reducer.hpp
5353
skip_all_reduce_unused_params: bool = ...,
54+
use_python_reducer: bool = ...,
5455
) -> None: ...
5556
def prepare_for_forward(self) -> None: ...
5657
def prepare_for_backward(self, output: list[Tensor]) -> None: ...

torch/csrc/autograd/utils/lambda_post_hook.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@ class LambdaPostHook : public torch::autograd::FunctionPostHook {
2727
return fn_(outputs, inputs);
2828
}
2929

30-
void compiled_args(CompiledNodeArgs& args) const override {}
30+
void compiled_args(CompiledNodeArgs& args) const override {
31+
if (compiled_fn_ != nullptr) {
32+
return compiled_fn_(args);
33+
}
34+
return FunctionPostHook::compiled_args(args);
35+
}
3136

3237
protected:
3338
std::function<variable_list(const variable_list&, const variable_list&)> fn_;

torch/csrc/distributed/c10d/init.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,8 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
560560
bool gradient_as_bucket_view,
561561
std::unordered_map<size_t, std::string> param_to_name_mapping,
562562
int64_t first_bucket_bytes_cap,
563-
bool skip_all_reduce_unused_params) {
563+
bool skip_all_reduce_unused_params,
564+
bool use_python_reducer) {
564565
// gil_scoped_release is not safe as a call_guard in init.
565566
// https://github.com/pybind/pybind11/issues/5473
566567
py::gil_scoped_release nogil{};
@@ -575,7 +576,8 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
575576
gradient_as_bucket_view,
576577
std::move(param_to_name_mapping),
577578
first_bucket_bytes_cap,
578-
skip_all_reduce_unused_params);
579+
skip_all_reduce_unused_params,
580+
use_python_reducer);
579581
}),
580582
py::arg("params"),
581583
py::arg("bucket_indices"),
@@ -588,7 +590,8 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
588590
py::arg("param_to_name_mapping") =
589591
std::unordered_map<size_t, std::string>(),
590592
py::arg("first_bucket_bytes_cap") = ::c10d::kDefaultFirstBucketBytes,
591-
py::arg("skip_all_reduce_unused_params") = false)
593+
py::arg("skip_all_reduce_unused_params") = false,
594+
py::arg("use_python_reducer") = false)
592595
.def(
593596
"prepare_for_forward",
594597
&::c10d::Reducer::prepare_for_forward,

torch/csrc/distributed/c10d/reducer.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ Reducer::Reducer(
9797
bool gradient_as_bucket_view,
9898
std::unordered_map<size_t, std::string> param_names,
9999
int64_t first_bucket_bytes_cap,
100-
bool skip_all_reduce_unused_params)
100+
bool skip_all_reduce_unused_params,
101+
bool use_python_reducer)
101102
: params_(std::move(params)),
102103
process_group_(std::move(process_group)),
103104
expect_sparse_gradients_(std::move(expect_sparse_gradients)),
@@ -121,7 +122,8 @@ Reducer::Reducer(
121122
comm_hook_(nullptr),
122123
ddp_debug_level_(debug_level()),
123124
param_names_(std::move(param_names)),
124-
first_bucket_bytes_cap_(first_bucket_bytes_cap) {
125+
first_bucket_bytes_cap_(first_bucket_bytes_cap),
126+
use_python_reducer_(use_python_reducer) {
125127
C10_LOG_API_USAGE_ONCE("torch.distributed.ddp.reducer");
126128
TORCH_INTERNAL_ASSERT(!params_.empty(), "Expected at least one parameter.");
127129

@@ -199,8 +201,9 @@ Reducer::Reducer(
199201
this->autograd_hook(variable_index);
200202
return outputs;
201203
},
202-
[=](torch::autograd::CompiledNodeArgs& args) {
203-
TORCH_INTERNAL_ASSERT(
204+
[this](torch::autograd::CompiledNodeArgs& args) {
205+
TORCH_CHECK(
206+
this->use_python_reducer_,
204207
"Compiled autograd is not compatible with C++ DDP Reducer, please use torch._dynamo.config.optimize_ddp=\"python_reducer\".");
205208
})),
206209
grad_accumulator);

torch/csrc/distributed/c10d/reducer.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ class TORCH_API Reducer {
5858
bool gradient_as_bucket_view,
5959
std::unordered_map<size_t, std::string> param_names,
6060
int64_t first_bucket_bytes_cap,
61-
bool skip_all_reduce_unused_params);
61+
bool skip_all_reduce_unused_params,
62+
bool use_python_reducer);
6263

6364
~Reducer() noexcept(false);
6465

@@ -562,6 +563,9 @@ class TORCH_API Reducer {
562563
void checkAndRaiseMarkedTwiceError(size_t curVariableIndex);
563564
// Retrieves parameter corresponding to the given VariableIndex.
564565
at::Tensor& get_param_from_index(size_t index);
566+
// Python reducer keeps C++ reducer initialized. To remove this flag,
567+
// we need to refactor the DDP wrapper's initilization.
568+
bool use_python_reducer_;
565569

566570
// Cached bucket index to model parameter mapping. Populated after buckets
567571
// are rebuilt after which this mapping is static.

torch/nn/parallel/distributed.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,9 @@ def __init__(
657657
):
658658
super().__init__()
659659
Joinable.__init__(self)
660+
self._use_python_reducer = (
661+
torch._dynamo.utils.get_optimize_ddp_mode() == "python_reducer"
662+
)
660663
self.logger: Optional[dist.Logger] = None
661664
if bool(delay_all_reduce_named_params is not None) != bool(
662665
param_to_hook_all_reduce is not None
@@ -915,8 +918,6 @@ def __init__(
915918
# True. The hooks will be deregistered if compiled_autograd is not
916919
# enabled.
917920
self._accum_grad_hooks: list[RemovableHandle] = []
918-
optimize_ddp = torch._dynamo.utils.get_optimize_ddp_mode()
919-
self._use_python_reducer = optimize_ddp == "python_reducer"
920921
if self._use_python_reducer:
921922
torch._inductor.config._fuse_ddp_communication = True
922923
torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb
@@ -1228,6 +1229,7 @@ def _ddp_init_helper(
12281229
else self.bucket_bytes_cap
12291230
),
12301231
self.skip_all_reduce_unused_params,
1232+
self._use_python_reducer,
12311233
)
12321234

12331235
self.logger = dist.Logger(self.reducer)

0 commit comments

Comments
 (0)
0