8000 [wip][ca][ddp] traceable C++ reducer · pytorch/pytorch@fc93022 · GitHub
[go: up one dir, main page]

10000
Skip to content

Commit fc93022

Browse files
committed
[wip][ca][ddp] traceable C++ reducer
ghstack-source-id: d9b4ff7 Pull Request resolved: #153501
1 parent 17634fb commit fc93022

File tree

8 files changed

+220
-22
lines changed

8 files changed

+220
-22
lines changed

torch/_dynamo/compiled_autograd.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,7 @@ def proxy_call_hook(self, hook, *args, **kwargs):
728728
)
729729

730730
def unpack_hook(self, hook_id, data_id):
731+
breakpoint()
731732
assert self.hooks_proxy is not None
732733
hook = self.hooks_proxy[hook_id] # type: ignore[index]
733734
data = self.packed_data_proxy[data_id] # type: ignore[index]
@@ -779,6 +780,7 @@ def pre_hook(self, inputs, hook_id):
779780
return inputs
780781

781782
def post_hook(self, outputs, inputs, hook_id):
783+
breakpoint()
782784
assert self.hooks_proxy is not None
783785
hook = self.hooks_proxy[hook_id] # type: ignore[index]
784786
proxies = self.proxy_call_hook(

torch/csrc/autograd/function_hook.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ struct TORCH_API FunctionPostHook {
4242
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
4343
typeid(*this).name());
4444
}
45+
46+
virtual void apply_with_saved(
47+
Variable& tensor,
48+
torch::dynamo::autograd::SwapSavedVariables& saved) const {
49+
throw std::runtime_error(
50+
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
51+
typeid(*this).name());
52+
}
4553
};
4654

4755
struct TORCH_API PostAccumulateGradHook {

torch/csrc/autograd/utils/lambda_post_hook.h

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,23 @@ class LambdaPostHook : public torch::autograd::FunctionPostHook {
1010
using variable_list = std::vector<torch::autograd::Variable>;
1111
using fn_type =
1212
std::function<variable_list(const variable_list&, const variable_list&)>;
13-
using compiled_fn_type = std::function<void(CompiledNodeArgs&)>;
13+
using compiled_args_fn_type = std::function<void(CompiledNodeArgs&)>;
14+
using compiled_apply_fn_type =
15+
std::function<void(Variable&, SwapSavedVariables&)>;
1416

1517
public:
1618
// The lambda function takes as arguments the outputs and inputs of the
1719
// autograd function and can modify the outputs of the autograd function by
1820
// returning a new output if needed.
1921
/* implicit */ LambdaPostHook(fn_type fn) : fn_(std::move(fn)) {}
2022

21-
LambdaPostHook(fn_type fn, compiled_fn_type compiled_fn)
22-
: fn_(std::move(fn)), compiled_fn_(std::move(compiled_fn)) {}
23+
LambdaPostHook(
24+
fn_type fn,
25+
compiled_args_fn_type compiled_args_fn,
26+
compiled_apply_fn_type compiled_apply_fn)
27+
: fn_(std::move(fn)),
28+
compiled_args_fn_(std::move(compiled_args_fn)),
29+
compiled_apply_fn_(std::move(compiled_apply_fn)) {}
2330

2431
variable_list operator()(
2532
const variable_list& outputs,
@@ -28,15 +35,24 @@ class LambdaPostHook : public torch::autograd::FunctionPostHook {
2835
}
2936

3037
void compiled_args(CompiledNodeArgs& args) const override {
31-
if (compiled_fn_ != nullptr) {
32-
return compiled_fn_(args);
38+
if (compiled_args_fn_ != nullptr) {
39+
return compiled_args_fn_(args);
3340
}
3441
return FunctionPostHook::compiled_args(args);
3542
}
3643

44+
void apply_with_saved(Variable& inputs, SwapSavedVariables& saved)
45+
const override {
46+
if (compiled_apply_fn_ != nullptr) {
47+
return compiled_apply_fn_(inputs, saved);
48+
}
49+
return FunctionPostHook::apply_with_saved(inputs, saved);
50+
}
51+
3752
protected:
3853
std::function<variable_list(const variable_list&, const variable_list&)> fn_;
39-
compiled_fn_type compiled_fn_{};
54+
compiled_args_fn_type compiled_args_fn_{};
55+
compiled_apply_fn_type compiled_apply_fn_{};
4056
};
4157

4258
} // namespace torch::autograd::utils

torch/csrc/distributed/c10d/init.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3565,7 +3565,7 @@ such as `dist.all_reduce(tensor, async_op=True)`.
35653565
[](const std::vector<at::Tensor>& tensors,
35663566
const std::vector<size_t>& bucket_size_limits,
35673567
const std::vector<bool>& expect_sparse_gradient,
3568-
const std::vector<int64_t>& tensor_indices,
3568+
const std::vector<size_t>& tensor_indices,
35693569
const std::optional<std::shared_ptr<::c10d::Logger>>& logger) {
35703570
if (logger.has_value()) {
35713571
std::weak_ptr<::c10d::Logger> logger_weakref = logger.value();

torch/csrc/distributed/c10d/reducer.cpp

Lines changed: 142 additions & 9 deletions
< 83BF div data-testid="addition diffstat" class="DiffSquares-module__diffSquare--h5kjy DiffSquares-module__addition--jeNtt">
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ std::vector<at::Tensor> extractTensors(const c10::IValue& result) {
8585
return result.toTensorVector();
8686
}
8787

88+
bool should_ddp_set_last_bucket_as_small() {
89+
return getCvarString({"DDP_SET_LAST_BUCKET_CAP"}, "N/A") == "1";
90+
}
91+
8892
} // namespace
8993

9094
Reducer::Reducer(
@@ -126,7 +130,7 @@ Reducer::Reducer(
126130
use_python_reducer_(use_python_reducer) {
127131
C10_LOG_API_USAGE_ONCE("torch.distributed.ddp.reducer");
128132
TORCH_INTERNAL_ASSERT(!params_.empty(), "Expected at least one parameter.");
129-
133+
std::cout << "hello from c++ reducer" << std::endl;
130134
if (ddp_debug_level_ != c10d::DebugLevel::Off) {
131135
LOG(INFO) << "Reducer initialized with bucket_bytes_cap: "
132136
<< bucket_bytes_cap_
@@ -174,6 +178,7 @@ Reducer::Reducer(
174178
// can be marked as ready for reduction.
175179
{
176180
const auto variable_count = params_.size();
181+
std::cout << "reducer found " << variable_count << " variables" << std::endl;
177182
grad_accumulators_.resize(variable_count);
178183
for (const auto variable_index : c10::irange(variable_count)) {
179184
auto& variable = params_[variable_index];
@@ -198,13 +203,124 @@ Reducer::Reducer(
198203
this->rpc_context_.set(
199204
ThreadLocalDistAutogradContext::getContextPtr());
200205
#endif
206+
std::cout << "marking variable " << variable_index << " as ready" << std::endl;
201207
this->autograd_hook(variable_index);
202208
return outputs;
203209
},
204-
[this](torch::autograd::CompiledNodeArgs& args) {
205-
TORCH_CHECK(
206-
this->use_python_reducer_,
207-
"Compiled autograd is not compatible with C++ DDP Reducer, please use torch._dynamo.config.optimize_ddp=\"python_reducer\".");
210+
[this, variable_index](torch::autograd::CompiledNodeArgs& args) {
211+
std::cout << "collecting the post hook on variable_index=" << variable_index << std::endl;
212+
if (use_python_reducer_) {
213+
return;
214+
}
215+
216+
// filters out unsupported DDP arguments
217+
auto str =
218+
"Compiled autograd is not compatible with C++ DDP Reducer, please use torch._dynamo.config.optimize_ddp=\"python_reducer\".";
219+
// std::cout << "mixed precision" << std::endl;
220+
TORCH_CHECK(!mixed_precision_param_dtype_.has_value(), str);
221+
// std::cout << "find unused" << std::endl;
222+
TORCH_CHECK(!find_unused_parameters_, str);
223+
// std::cout << "ddp debug level" << std::endl;
224+
TORCH_CHECK(ddp_debug_level_ == c10d::DebugLevel::Off, str);
225+
// std::cout << "rpc" << std::endl;
226+
TORCH_CHECK(rpc_context_.context_ptr.load() == nullptr, str);
227+
228+
// TODO: if not expect autograd hooks, means no sync
229+
// std::cout << "expect hooks" << std::endl;
230+
TORCH_CHECK(expect_autograd_hooks_, str);
231+
232+
// std::cout << "expect spars" << std::endl;
233+
for (bool b : ex 10000 pect_sparse_gradients_) {
234+
TORCH_CHECK(!b, str);
235+
}
236+
// std::cout << "bucket view" << std::endl;
237+
TORCH_CHECK(!gradient_as_bucket_view_, str);
238+
// std::cout << "comm hook non nullptr" << std::endl;
239+
TORCH_CHECK(comm_hook_ == nullptr, str);
240+
// std::cout << "not use static world size" << std::endl;
241+
TORCH_CHECK(forwardPassWorkHandle_.useStaticWorldSize, str);
242+
TORCH_CHECK(!should_ddp_set_last_bucket_as_small(), str);
243+
// ignore param_names_
244+
// todo: skip create_graph with ddp message
245+
if (static_graph_) {
246+
TORCH_WARN_ONCE(
247+
"static_graph ignored, compiled autograd always rebuilds buckets when param ready order changes.");
248+
}
249+
int div_factor = process_group_->getSize();
250+
args.collect(div_factor);
251+
args.collect_ddp_param_index(variable_index);
252+
// collect size limit etc.
253+
// Rewrite C++ Reducer
254+
255+
// temp validation
256+
if (args.retrieve_ddp_param_index_order().size() == params_.size()) {
257+
std::cout << std::endl;
258+
std::cout << "first_bucket_bytes_cap_=" << first_bucket_bytes_cap_ << ", bucket_bytes_cap_=" << bucket_bytes_cap_ << std::endl;
259+
std::cout << "ALL PARAMS GOT HOOKS" << std::endl;
260+
auto [buckets, bucket_size_limits] = compute_bucket_assignment_by_size(
261+
params_,
262+
{static_cast<size_t>(first_bucket_bytes_cap_), static_cast<size_t>(bucket_bytes_cap_)},
263+
/* expect_sparse_gradient */ {},
264+
/* tensor_indices*/ args.retrieve_ddp_param_index_order(),
265+
/* logger */ {}
266+
);
267+
268+
std::cout << "param order: ";
269+
for (auto index : args.retrieve_ddp_param_index_order()) {
270+
auto tensor = params_[index];
271+
size_t mb = tensor.numel() * tensor.element_size();
272+
std::cout << index << " ("<< mb <<" MiB), ";
273+
}
274+
std::cout << std::endl;
275+
276+
std::string bucket_size_limits_str = "";
277+
for (auto limit : bucket_size_limits) {
278+
bucket_size_limits_str += (std::to_string(limit) + ", ");
279+
}
280+
std::cout << "limits per bucket: " << bucket_size_limits_str << std::endl;
281+
for (size_t i = 0; i < buckets.size(); i++) {
282+
std::cout << "bucket " << i << ": " << std::endl;
283+
for (auto& index : buckets[i]) {
284+
std::cout << index << ", ";
285+
}
286+
std::cout << std::endl;
287+
}
288+
std::cout << std::endl;
289+
}
290+
291+
},
292+
[this, variable_index](
293+
torch::autograd::Variable& variable,
294+
torch::autograd::SwapSavedVariables& saved) {
295+
bool is_first_hook = true;
296+
if (is_first_hook) {
297+
auto [buckets, _] = compute_bucket_assignment_by_size(
298+
params_,
299+
{static_cast<size_t>(first_bucket_bytes_cap_), static_cast<size_t>(bucket_bytes_cap_)},
300+
/* expect_sparse_gradient */ {},
301+
/* tensor_indices*/ saved.retrieve_ddp_param_index_order(),
302+
/* logger */ {}
303+
);
304+
}
305+
// TODO: NOTHING IS CALLING THIS rn
306+
at::Tensor& param = get_param_from_index(variable_index);
307+
saved.before(param);
308+
int div_factor = process_group_->getSize();
309+
// need to swap the param to its proxy
310+
// then we can call the bucket with the proxies.
311+
// and when bucket size cap reached, launch
312+
bool should_issue = true;
313+
if (should_issue) {
314+
// should issue bucket
315+
const auto& pyinterface =
316+
torch::dynamo::autograd::getPyCompilerInterface();
317+
pyinterface->call_unpack(
318+
saved.get_py_compiler(), 0, div_factor);
319+
} else {
320+
// // should bucket
321+
// saved.state.ddp_bucket.emplace_back(param);
322+
}
323+
saved.after(param);
208324
})),
209325
grad_accumulator);
210326

@@ -537,7 +653,7 @@ void Reducer::push_rebuilt_params_for_all_indices() {
537653

538654
void Reducer::push_rebuilt_params(const size_t& index) {
539655
rebuilt_params_.push_back(params_[index]);
540-
rebuilt_param_indices_.push_back(static_cast<int64_t>(index));
656+
rebuilt_param_indices_.push_back(index);
541657
}
542658

543659
void Reducer::set_divide_factor() {
@@ -1678,6 +1794,9 @@ void Reducer::finalize_backward() {
16781794
"currently only support to skip all reduce for unused params "
16791795
"when skip_all_reduce_unused_params_ is true.");
16801796
continue;
1797+
} else {
1798+
std::cout << "skipping bucket work" << std::endl;
1799+
continue;
16811800
}
16821801

16831802
bucket.future_work->wait();
@@ -1892,8 +2011,7 @@ bool Reducer::rebuild_buckets() {
18922011
std::vector<size_t> bucket_size_limits;
18932012
bucket_size_limits.push_back(first_bucket_bytes_cap_);
18942013
bucket_size_limits.push_back(bucket_bytes_cap_);
1895-
auto ddp_set_last_bucket_as_small =
1896-
(getCvarString({"DDP_SET_LAST_BUCKET_CAP"}, "N/A") == "1");
2014+
bool ddp_set_last_bucket_as_small = should_ddp_set_last_bucket_as_small();
18972015

18982016
if (ddp_set_last_bucket_as_small) {
18992017
// Reverse so that first_bucket_bytes_cap_ (smaller bucket) becomes the last
@@ -2166,7 +2284,7 @@ compute_bucket_assignment_by_size(
21662284
const std::vector<at::Tensor>& tensors,
21672285
const std::vector<size_t>& bucket_size_limits,
21682286
const std::vector<bool>& expect_sparse_gradient,
2169-
const std::vector<int64_t>& tensor_indices,
2287+
const std::vector<size_t>& tensor_indices,
21702288
const std::optional<std::weak_ptr<c10d::Logger>>& logger) {
21712289
// Either expect_sparse_gradient is not specified or it has as many elements
21722290
// as the vector with tensors.
@@ -2284,6 +2402,20 @@ compute_bucket_assignment_by_size(
22842402
bucket_indices.emplace_back(std::get<0>(bucket_indices_with_size));
22852403
per_bucket_size_limits.emplace_back(std::get<1>(bucket_indices_with_size));
22862404
}
2405+
2406+
std::cout << std::endl;
2407+
std::cout << std::endl;
2408+
std::cout << "Finished computing bucket assignment" << std::endl;
2409+
for (size_t i=0; i<bucket_indices.size(); i++) {
2410+
std::cout << "bucket["<<i<<"]: ";
2411+
for (const auto& variable_index : bucket_indices[i]) {
2412+
std::cout << variable_index << ", ";
2413+
}
2414+
std::cout << std::endl;
2415+
}
2416+
std::cout << std::endl;
2417+
std::cout << std::endl;
2418+
22872419
return std::make_tuple(bucket_indices, per_bucket_size_limits);
22882420
}
22892421

@@ -2401,6 +2533,7 @@ void verify_params_across_processes(
24012533
}
24022534

24032535
void Reducer::remove_autograd_hooks() {
2536+
std::cout << "===========================REMOVING AUTOGRAD HOOKS======================" << std::endl;
24042537
// Remove all hooks on variables registered by this Reducer. This is necessary
24052538
// to make DDP failure recoverable. Otherwise, multiple Reducer instances
24062539
// (from recoveries) will add their hooks to the original model, and those

torch/csrc/distributed/c10d/reducer.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ class TORCH_API Reducer {
460460
// Following variables are to help build dynamic bucket order
461461
bool has_rebuilt_bucket_;
462462
std::vector<at::Tensor> rebuilt_params_;
463-
std::vector<int64_t> rebuilt_param_indices_;
463+
std::vector<size_t> rebuilt_param_indices_;
464464
const int64_t bucket_bytes_cap_;
465465

466466
#ifndef _WIN32
@@ -587,7 +587,7 @@ compute_bucket_assignment_by_size(
587587
const std::vector<at::Tensor>& tensors,
588588
const std::vector<size_t>& bucket_size,
589589
const std::vector<bool>& expect_sparse_gradient = {},
590-
const std::vector<int64_t>& tensor_indices = {},
590+
const std::vector<size_t>& tensor_indices = {},
591591
const std::optional<std::weak_ptr<c10d::Logger>>& logger = {});
592592

593593
// Verify models across all processes are the same as model on rank 0 with

0 commit comments

Comments
 (0)
0