8000 [wip][ca][ddp] traceable C++ reducer · pytorch/pytorch@5b669f2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5b669f2

Browse files
committed
[wip][ca][ddp] traceable C++ reducer
ghstack-source-id: c4086d2 Pull Request resolved: #153501
1 parent dd30f61 commit 5b669f2

File tree

6 files changed

+74
-11
lines changed

6 files changed

+74
-11
lines changed

torch/_dynamo/compiled_autograd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,7 @@ def pre_hook(self, inputs, hook_id):
779779
return inputs
780780

781781
def post_hook(self, outputs, inputs, hook_id):
782+
breakpoint()
782783
assert self.hooks_proxy is not None
783784
hook = self.hooks_proxy[hook_id] # type: ignore[index]
784785
proxies = self.proxy_call_hook(

torch/csrc/autograd/function_hook.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ struct TORCH_API FunctionPostHook {
4242
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
4343
typeid(*this).name());
4444
}
45+
46+
virtual void apply_with_saved(
47+
Variable& tensor,
48+
torch::dynamo::autograd::SwapSavedVariables& saved) const {
49+
throw std::runtime_error(
50+
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
51+
typeid(*this).name());
52+
}
4553
};
4654

4755
struct TORCH_API PostAccumulateGradHook {

torch/csrc/autograd/utils/lambda_post_hook.h

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,23 @@ class LambdaPostHook : public torch::autograd::FunctionPostHook {
1010
using variable_list = std::vector<torch::autograd::Variable>;
1111
using fn_type =
1212
std::function<variable_list(const variable_list&, const variable_list&)>;
13-
using compiled_fn_type = std::function<void(CompiledNodeArgs&)>;
13+
using compiled_args_fn_type = std::function<void(CompiledNodeArgs&)>;
14+
using compiled_apply_fn_type =
15+
std::function<void(Variable&, SwapSavedVariables&)>;
1416

1517
public:
1618
// The lambda function takes as arguments the outputs and inputs of the
1719
// autograd function and can modify the outputs of the autograd function by
1820
// returning a new output if needed.
1921
/* implicit */ LambdaPostHook(fn_type fn) : fn_(std::move(fn)) {}
2022

21-
LambdaPostHook(fn_type fn, compiled_fn_type compiled_fn)
22-
: fn_(std::move(fn)), compiled_fn_(std::move(compiled_fn)) {}
23+
LambdaPostHook(
24+
fn_type fn,
25+
compiled_args_fn_type compiled_args_fn,
26+
compiled_apply_fn_type compiled_apply_fn)
27+
: fn_(std::move(fn)),
28+
compiled_args_fn_(std::move(compiled_args_fn)),
29+
compiled_apply_fn_(std::move(compiled_apply_fn)) {}
2330

2431
variable_list operator()(
2532
const variable_list& outputs,
@@ -28,15 +35,24 @@ class LambdaPostHook : public torch::autograd::FunctionPostHook {
2835
}
2936

3037
void compiled_args(CompiledNodeArgs& args) const override {
31-
if (compiled_fn_ != nullptr) {
32-
return compiled_fn_(args);
38+
if (compiled_args_fn_ != nullptr) {
39+
return compiled_args_fn_(args);
3340
}
3441
return FunctionPostHook::compiled_args(args);
3542
}
3643

44+
void apply_with_saved(Variable& inputs, SwapSavedVariables& saved)
45+
const override {
46+
if (compiled_apply_fn_ != nullptr) {
47+
return compiled_apply_fn_(inputs, saved);
48+
}
49+
return FunctionPostHook::apply_with_saved(inputs, saved);
50+
}
51+
3752
protected:
3853
std::function<variable_list(const variable_list&, const variable_list&)> fn_;
39-
compiled_fn_type compiled_fn_{};
54+
compiled_args_fn_type compiled_args_fn_{};
55+
compiled_apply_fn_type compiled_apply_fn_{};
4056
};
4157

4258
} // namespace torch::autograd::utils

torch/csrc/distributed/c10d/reducer.cpp

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ Reducer::Reducer(
126126
use_python_reducer_(use_python_reducer) {
127127
C10_LOG_API_USAGE_ONCE("torch.distributed.ddp.reducer");
128128
TORCH_INTERNAL_ASSERT(!params_.empty(), "Expected at least one parameter.");
129-
129+
std::cout << "hello from c++ reducer" << std::endl;
130130
if (ddp_debug_level_ != c10d::DebugLevel::Off) {
131131
LOG(INFO) << "Reducer initialized with bucket_bytes_cap: "
132132
<< bucket_bytes_cap_
@@ -174,6 +174,7 @@ Reducer::Reducer(
174174
// can be marked as ready for reduction.
175175
{
176176
const auto variable_count = params_.size();
177+
std::cout << "reducer found " << variable_count << " variables" << std::endl;
177178
grad_accumulators_.resize(variable_count);
178179
for (const auto variable_index : c10::irange(variable_count)) {
179180
auto& variable = params_[variable_index];
@@ -187,6 +188,7 @@ Reducer::Reducer(
187188
using torch::distributed::autograd::ThreadLocalDistAutogradContext;
188189
#endif
189190
// Hook to execute after the gradient accumulator has executed.
191+
std::cout << "registering the post hook" << std::endl;
190192
hooks_.emplace_back(
191193
grad_accumulator->add_post_hook(std::make_unique<
192194
torch::autograd::utils::
@@ -201,12 +203,41 @@ Reducer::Reducer(
201203
this->autograd_hook(variable_index);
202204
return outputs;
203205
},
204-
[this](torch::autograd::CompiledNodeArgs& args) {
205-
TORCH_CHECK(
206-
this->use_python_reducer_,
207-
"Compiled autograd is not compatible with C++ DDP Reducer, please use torch._dynamo.config.optimize_ddp=\"python_reducer\".");
206+
[this, variable_index](torch::autograd::CompiledNodeArgs& args) {
207+
std::cout << "collecting the post hook on variable_index=" << variable_index << std::endl;
208+
if (this->use_python_reducer_) {
209+
return;
210+
}
211+
212+
// filters out unsupported DDP arguments
213+
auto str =
214+
"Compiled autograd is not compatible with C++ DDP Reducer, please use torch._dynamo.config.optimize_ddp=\"python_reducer\".";
215+
TORCH_CHECK(!mixed_precision_param_dtype_.has_value(), str);
216+
TORCH_CHECK(!find_unused_parameters_, str);
217+
TORCH_CHECK(ddp_debug_level_ == c10d::DebugLevel::Off, str);
218+
TORCH_CHECK(rpc_context_.context_ptr.load() == nullptr, str);
219+
if (static_graph_) {
220+
TORCH_WARN_ONCE(
221+
"static_graph ignored, compiled autograd always rebuilds buckets when param ready order changes.");
222+
}
223+
224+
// Attempt to trace C++ Reducer
225+
args.collect(variable_index);
226+
// args.cpp_post_hook();
227+
// at::Tensor& param = get_param_from_index(variable_index);
228+
},
229+
[this, variable_index](
230+
torch::autograd::Variable& variable,
231+
torch::autograd::SwapSavedVariables& saved) {
232+
// update bucketing state in tracker
233+
// saved.compiler_call.update_reducer_state
234+
// issue bucketing op with the correct tensors
235+
// pycompiler.call_ddp_autograd_hook(bucket: List[Tensor])
236+
// then bucket and issue collective
237+
return;
208238
})),
209239
grad_accumulator);
240+
std::cout << "registered post hook on " << &(*grad_accumulator) << std::endl;
210241

211242
// Map raw function pointer to parameter index.
212243
// This is used later on when the autograd graph is traversed
@@ -2401,6 +2432,7 @@ void verify_params_across_processes(
24012432
}
24022433

24032434
void Reducer::remove_autograd_hooks() {
2435+
std::cout << "===========================REMOVING AUTOGRAD HOOKS======================" << std::endl;
24042436
// Remove all hooks on variables registered by this Reducer. This is necessary
24052437
// to make DDP failure recoverable. Otherwise, multiple Reducer instances
24062438
// (from recoveries) will add their hooks to the original model, and those

torch/csrc/dynamo/compiled_autograd.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,8 @@ struct AutogradCompilerCall {
367367
// pynode -> backward and backward state idx
368368
std::unordered_map<const Node*, std::pair<size_t, std::optional<size_t>>>
369369
pynode_objs;
370+
// C++ reducer state
371+
370372
};
371373

372374
class CompiledNodeArgs {
@@ -611,6 +613,7 @@ class CompiledNodeArgs {
611613
#undef COLLECT_AS_BYTES
612614

613615
void collect_hooks_from(Node* fn) {
616+
std::cout << "collecting hooks from " << fn->name() << "(" << fn << ")" << std::endl;
614617
for (auto& i : fn->tensor_pre_hooks()) {
615618
i->compiled_args(*this);
616619
}
@@ -621,6 +624,7 @@ class CompiledNodeArgs {
621624
i->compiled_args(*this);
622625
}
623626
for (auto& i : fn->post_hooks()) {
627+
std::cout << "found post hook" << std::endl;
624628
i->compiled_args(*this);
625629
}
626630
collect_size(_node_call.tensor_pre_hooks.size());

torch/csrc/dynamo/python_compiled_autograd.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,6 +1053,7 @@ static CacheNode* _compiled_autograd_impl(
10531053
}
10541054

10551055
SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call);
1056+
std::cout << call.node->name() << " has post hooks before apply_with_saved? " << (!call.post_hooks.empty()) << std::endl;
10561057
variable_list outputs = call.node->apply_with_saved(inputs, saved);
10571058
saved.debug_asserts();
10581059
saved.before(call.node->next_edges());
@@ -1104,6 +1105,7 @@ static CacheNode* _compiled_autograd_impl(
11041105
saved.after(call.node->next_edges());
11051106
saved.debug_asserts();
11061107

1108+
std::cout << call.node->name() << " has post hooks after apply_with_saved? " << (!call.post_hooks.empty()) << std::endl;
11071109
if (!call.post_hooks.empty()) {
11081110
THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
11091111
THPObjectPtr pyoutputs(THPVariable_WrapList(outputs));

0 commit comments

Comments
 (0)
0