8000 [wip][ca][ddp] traceable C++ reducer by xmfan · Pull Request #153501 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[wip][ca][ddp] traceable C++ reducer #153501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: gh/xmfan/237/base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8000
2 changes: 2 additions & 0 deletions torch/_dynamo/compiled_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,7 @@ def proxy_call_hook(self, hook, *args, **kwargs):
)

def unpack_hook(self, hook_id, data_id):
breakpoint()
assert self.hooks_proxy is not None
hook = self.hooks_proxy[hook_id] # type: ignore[index]
data = self.packed_data_proxy[data_id] # type: ignore[index]
Expand Down Expand Up @@ -779,6 +780,7 @@ def pre_hook(self, inputs, hook_id):
return inputs

def post_hook(self, outputs, inputs, hook_id):
breakpoint()
assert self.hooks_proxy is not None
hook = self.hooks_proxy[hook_id] # type: ignore[index]
proxies = self.proxy_call_hook(
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/autograd/function_hook.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ struct TORCH_API FunctionPostHook {
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
typeid(*this).name());
}

virtual void apply_with_saved(
Variable& tensor,
torch::dynamo::autograd::SwapSavedVariables& saved) const {
throw std::runtime_error(
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
typeid(*this).name());
}
};

struct TORCH_API PostAccumulateGradHook {
Expand Down
28 changes: 22 additions & 6 deletions torch/csrc/autograd/utils/lambda_post_hook.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,23 @@ class LambdaPostHook : public torch::autograd::FunctionPostHook {
using variable_list = std::vector<torch::autograd::Variable>;
using fn_type =
std::function<variable_list(const variable_list&, const variable_list&)>;
using compiled_fn_type = std::function<void(CompiledNodeArgs&)>;
using compiled_args_fn_type = std::function<void(CompiledNodeArgs&)>;
using compiled_apply_fn_type =
std::function<void(Variable&, SwapSavedVariables&)>;

public:
// The lambda function takes as arguments the outputs and inputs of the
// autograd function and can modify the outputs of the autograd function by
// returning a new output if needed.
/* implicit */ LambdaPostHook(fn_type fn) : fn_(std::move(fn)) {}

LambdaPostHook(fn_type fn, compiled_fn_type compiled_fn)
: fn_(std::move(fn)), compiled_fn_(std::move(compiled_fn)) {}
LambdaPostHook(
fn_type fn,
compiled_args_fn_type compiled_args_fn,
compiled_apply_fn_type compiled_apply_fn)
: fn_(std::move(fn)),
compiled_args_fn_(std::move(compiled_args_fn)),
compiled_apply_fn_(std::move(compiled_apply_fn)) {}

variable_list operator()(
const variable_list& outputs,
Expand All @@ -28,15 +35,24 @@ class LambdaPostHook : public torch::autograd::FunctionPostHook {
}

void compiled_args(CompiledNodeArgs& args) const override {
if (compiled_fn_ != nullptr) {
return compiled_fn_(args);
if (compiled_args_fn_ != nullptr) {
return compiled_args_fn_(args);
}
return FunctionPostHook::compiled_args(args);
}

void apply_with_saved(Variable& inputs, SwapSavedVariables& saved)
const override {
if (compiled_apply_fn_ != nullptr) {
return compiled_apply_fn_(inputs, saved);
}
return FunctionPostHook::apply_with_saved(inputs, saved);
}

protected:
std::function<variable_list(const variable_list&, const variable_list&)> fn_;
compiled_fn_type compiled_fn_{};
compiled_args_fn_type compiled_args_fn_{};
compiled_apply_fn_type compiled_apply_fn_{};
};

} // namespace torch::autograd::utils
2 changes: 1 addition & 1 deletion torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3565,7 +3565,7 @@ such as `dist.all_reduce(tensor, async_op=True)`.
[](const std::vector<at::Tensor>& tensors,
const std::vector<size_t>& bucket_size_limits,
const std::vector<bool>& expect_sparse_gradient,
const std::vector<int64_t>& tensor_indices,
const std::vector<size_t>& tensor_indices,
const std::optional<std::shared_ptr<::c10d::Logger>>& logger) {
if (logger.has_value()) {
std::weak_ptr<::c10d::Logger> logger_weakref = logger.value();
Expand Down
151 changes: 142 additions & 9 deletions torch/csrc/distributed/c10d/reducer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ std::vector<at::Tensor> extractTensors(const c10::IValue& result) {
return result.toTensorVector();
}

bool should_ddp_set_last_bucket_as_small() {
return getCvarString({"DDP_SET_LAST_BUCKET_CAP"}, "N/A") == "1";
}

} // namespace

Reducer::Reducer(
Expand Down Expand Up @@ -126,7 +130,7 @@ Reducer::Reducer(
use_python_reducer_(use_python_reducer) {
C10_LOG_API_USAGE_ONCE("torch.distributed.ddp.reducer");
TORCH_INTERNAL_ASSERT(!params_.empty(), "Expected at least one parameter.");

std::cout << "hello from c++ reducer" << std::endl;
if (ddp_debug_level_ != c10d::DebugLevel::Off) {
LOG(INFO) << "Reducer initialized with bucket_bytes_cap: "
<< bucket_bytes_cap_
Expand Down Expand Up @@ -174,6 +178,7 @@ Reducer::Reducer(
// can be marked as ready for reduction.
{
const auto variable_count = params_.size();
std::cout << "reducer found " << variable_count << " variables" << std::endl;
grad_accumulators_.resize(variable_count);
for (const auto variable_index : c10::irange(variable_count)) {
auto& variable = params_[variable_index];
Expand All @@ -198,13 +203,124 @@ Reducer::Reducer(
this->rpc_context_.set(
ThreadLocalDistAutogradContext::getContextPtr());
#endif
std::cout << "marking variable " << variable_index << " as ready" << std::endl;
this->autograd_hook(variable_index);
return outputs;
},
[this](torch::autograd::CompiledNodeArgs& args) {
TORCH_CHECK(
this->use_python_reducer_,
"Compiled autograd is not compatible with C++ DDP Reducer, please use torch._dynamo.config.optimize_ddp=\"python_reducer\".");
[this, variable_index](torch::autograd::CompiledNodeArgs& args) {
std::cout << "collecting the post hook on variable_index=" << variable_index << std::endl;
if (use_python_reducer_) {
return;
}

// filters out unsupported DDP arguments
auto str =
"Compiled autograd is not compatible with C++ DDP Reducer, please use torch._dynamo.config.optimize_ddp=\"python_reducer\".";
// std::cout << "mixed precision" << std::endl;
TORCH_CHECK(!mixed_precision_param_dtype_.has_value(), str);
// std::cout << "find unused" << std::endl;
TORCH_CHECK(!find_unused_parameters_, str);
// std::cout << "ddp debug level" << std::endl;
TORCH_CHECK(ddp_debug_level_ == c10d::De 9E7A bugLevel::Off, str);
// std::cout << "rpc" << std::endl;
TORCH_CHECK(rpc_context_.context_ptr.load() == nullptr, str);

// TODO: if not expect autograd hooks, means no sync
// std::cout << "expect hooks" << std::endl;
TORCH_CHECK(expect_autograd_hooks_, str);

// std::cout << "expect spars" << std::endl;
for (bool b : expect_sparse_gradients_) {
TORCH_CHECK(!b, str);
}
// std::cout << "bucket view" << std::endl;
TORCH_CHECK(!gradient_as_bucket_view_, str);
// std::cout << "comm hook non nullptr" << std::endl;
TORCH_CHECK(comm_hook_ == nullptr, str);
// std::cout << "not use static world size" << std::endl;
TORCH_CHECK(forwardPassWorkHandle_.useStaticWorldSize, str);
TORCH_CHECK(!should_ddp_set_last_bucket_as_small(), str);
// ignore param_names_
// todo: skip create_graph with ddp message
if (static_graph_) {
TORCH_WARN_ONCE(
"static_graph ignored, compiled autograd always rebuilds buckets when param ready order changes.");
}
int div_factor = process_group_->getSize();
args.collect(div_factor);
args.collect_ddp_param_index(variable_index);
// collect size limit etc.
// Rewrite C++ Reducer

// temp validation
if (args.retrieve_ddp_param_index_order().size() == params_.size()) {
std::cout << std::endl;
std::cout << "first_bucket_bytes_cap_=" << first_bucket_bytes_cap_ << ", bucket_bytes_cap_=" << bucket_bytes_cap_ << std::endl;
std::cout << "ALL PARAMS GOT HOOKS" << std::endl;
auto [buckets, bucket_size_limits] = compute_bucket_assignment_by_size(
params_,
{static_cast<size_t>(first_bucket_bytes_cap_), static_cast<size_t>(bucket_bytes_cap_)},
/* expect_sparse_gradient */ {},
/* tensor_indices*/ args.retrieve_ddp_param_index_order(),
/* logger */ {}
);

std::cout << "param order: ";
for (auto index : args.retrieve_ddp_param_index_order()) {
auto tensor = params_[index];
size_t mb = tensor.numel() * tensor.element_size();
std::cout << index << " ("<< mb <<" MiB), ";
}
std::cout << std::endl;

std::string bucket_size_limits_str = "";
for (auto limit : bucket_size_limits) {
bucket_size_limits_str += (std::to_string(limit) + ", ");
}
std::cout << "limits per bucket: " << bucket_size_limits_str << std::endl;
for (size_t i = 0; i < buckets.size(); i++) {
std::cout << "bucket " << i << ": " << std::endl;
for (auto& index : buckets[i]) {
std::cout << index << ", ";
}
std::cout << std::endl;
}
std::cout << std::endl;
}

},
[this, variable_index](
torch::autograd::Variable& variable,
torch::autograd::SwapSavedVariables& saved) {
bool is_first_hook = true;
if (is_first_hook) {
auto [buckets, _] = compute_bucket_assignment_by_size(
params_,
{static_cast<size_t>(first_bucket_bytes_cap_), static_cast<size_t>(bucket_bytes_cap_)},
/* expect_sparse_gradient */ {},
/* tensor_indices*/ saved.retrieve_ddp_param_index_order(),
/* logger */ {}
);
}
// TODO: NOTHING IS CALLING THIS rn
at::Tensor& param = get_param_from_index(variable_index);
saved.before(param);
int div_factor = process_group_->getSize();
// need to swap the param to its proxy
// then we can call the bucket with the proxies.
// and when bucket size cap reached, launch
bool should_issue = true;
if (should_issue) {
// should issue bucket
const auto& pyinterface =
torch::dynamo::autograd::getPyCompilerInterface();
pyinterface->call_unpack(
saved.get_py_compiler(), 0, div_factor);
} else {
// // should bucket
// saved.state.ddp_bucket.emplace_back(param);
}
saved.after(param);
})),
grad_accumulator);

Expand Down Expand Up @@ -537,7 +653,7 @@ void Reducer::push_rebuilt_params_for_all_indices() {

void Reducer::push_rebuilt_params(const size_t& index) {
rebuilt_params_.push_back(params_[index]);
rebuilt_param_indices_.push_back(static_cast<int64_t>(index));
rebuilt_param_indices_.push_back(index);
}

void Reducer::set_divide_factor() {
Expand Down Expand Up @@ -1678,6 +1794,9 @@ void Reducer::finalize_backward() {
"currently only support to skip all reduce for unused params "
"when skip_all_reduce_unused_params_ is true.");
continue;
} else {
std::cout << "skipping bucket work" << std::endl;
continue;
}

bucket.future_work->wait();
Expand Down Expand Up @@ -1892,8 +2011,7 @@ bool Reducer::rebuild_buckets() {
std::vector<size_t> bucket_size_limits;
bucket_size_limits.push_back(first_bucket_bytes_cap_);
bucket_size_limits.push_back(bucket_bytes_cap_);
auto ddp_set_last_bucket_as_small =
(getCvarString({"DDP_SET_LAST_BUCKET_CAP"}, "N/A") == "1");
bool ddp_set_last_bucket_as_small = should_ddp_set_last_bucket_as_small();

if (ddp_set_last_bucket_as_small) {
// Reverse so that first_bucket_bytes_cap_ (smaller bucket) becomes the last
Expand Down Expand Up @@ -2166,7 +2284,7 @@ compute_bucket_assignment_by_size(
const std::vector<at::Tensor>& tensors,
const std::vector<size_t>& bucket_size_limits,
const std::vector<bool>& expect_sparse_gradient,
const std::vector<int64_t>& tensor_indices,
const std::vector<size_t>& tensor_indices,
const std::optional<std::weak_ptr<c10d::Logger>>& logger) {
// Either expect_sparse_gradient is not specified or it has as many elements
// as the vector with tensors.
Expand Down Expand Up @@ -2284,6 +2402,20 @@ compute_bucket_assignment_by_size(
bucket_indices.emplace_back(std::get<0>(bucket_indices_with_size));
per_bucket_size_limits.emplace_back(std::get<1>(bucket_indices_with_size));
}

std::cout << std::endl;
std::cout << std::endl;
std::cout << "Finished computing bucket assignment" << std::endl;
for (size_t i=0; i<bucket_indices.size(); i++) {
std::cout << "bucket["<<i<<"]: ";
for (const auto& variable_index : bucket_indices[i]) {
std::cout << variable_index << ", ";
}
std::cout << std::endl;
}
std::cout << std::endl;
std::cout << std::endl;

return std::make_tuple(bucket_indices, per_bucket_size_limits);
}

Expand Down Expand Up @@ -2401,6 +2533,7 @@ void verify_params_across_processes(
}

void Reducer::remove_autograd_hooks() {
std::cout << "===========================REMOVING AUTOGRAD HOOKS======================" << std::endl;
// Remove all hooks on variables registered by this Reducer. This is necessary
// to make DDP failure recoverable. Otherwise, multiple Reducer instances
// (from recoveries) will add their hooks to the original model, and those
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/distributed/c10d/reducer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ class TORCH_API Reducer {
// Following variables are to help build dynamic bucket order
bool has_rebuilt_bucket_;
std::vector<at::Tensor> rebuilt_params_;
std::vector<int64_t> rebuilt_param_indices_;
std::vector<size_t> rebuilt_param_indices_;
const int64_t bucket_bytes_cap_;

#ifndef _WIN32
Expand Down Expand Up @@ -587,7 +587,7 @@ compute_bucket_assignment_by_size(
const std::vector<at::Tensor>& tensors,
const std::vector<size_t>& bucket_size,
const std::vector<bool>& expect_sparse_gradient = {},
const std::vector<int64_t>& tensor_indices = {},
const std::vector<size_t>& tensor_indices = {},
const std::optional<std::weak_ptr<c10d::Logger>>& logger = {});

// Verify models across all processes are the same as model on rank 0 with
Expand Down
Loading
Loading
0