8000 Distributed Autograd - FAST mode backward pass implementation. · pytorch/pytorch@cf24179 · GitHub
[go: up one dir, main page]

Skip to content

Commit cf24179

Browse files
author
pritam
committed
Distributed Autograd - FAST mode backward pass implementation.
Pull Request resolved: #27022 This change implements the "FAST" mode distributed autograd backward pass as described in #23110. At a high level the backward pass works as follows: 1. We start by computing dependencies on the node that calls `torch.distributed.backward`. 2. This node computes the dependencies starting from the root nodes provided in the backward call and all the 'send' functions present in the current autograd context. The "FAST" mode assumes all 'send' functions are part of the autograd computation. 3. Once the dependency computation is done, the distributed autograd engine calls the local autograd engine to execute the autograd graph. Note that the autograd graph on a single node is not necessarily connected because of inter-node communication. As a result, we have special handling to ensure the local autograd engine ensures we execute the entire graph starting from the provided roots and all 'send' functions on the node. 4. When the local autograd engine hits a 'recv' function, it performs an async RPC to send the gradients over to the appropriate node and stores a future in the autograd context to keep track of this RPC. 5. On the destination node, the appropriate 'send' function is looked up and enqueued on the local autograd engine. If this is the first time the node is hearing about this autograd context id on the backward pass, then the node computes dependencies for the local autograd engine. 6. As part of compute dependencies, the distributed autograd engine discovers all leaf nodes and ensures those are passed as 'outputs' to the local autograd engine. This avoids running the 'AccumulateGrad' function. 7. The gradients computed for the leaf nodes are then actually accumulated in `DistAutogradContext` for the appropriate autograd context id. 8. The distributed autograd engine waits for the local autograd engine to complete and also waits for all the 'Futures' (stored in 4.) for respective RPCs to finish. We have made the following changes to the local autograd engine for this purpose: 1. Expose GraphTask and NodeTask so that the distributed autograd engine can use them. 2. Expose a `execute_with_graph_task` API which gives the distributed engine to build a GraphTask and pass it to the local autograd engine. 3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build a `NodeTask` for a 'send' function and enqueue it on the local autograd engine. In addition to this a few general improvements: 1. Added a `PropagateGradients` RPC call for the 'recv' function to pass gradients to the appropriate node during the backward pass. 2. Use IValues as much as possible in serialization for RpcWithAutograd. 3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate exception instead of just returning the message. This is inline with what most Future.wait() APIs do. 4. Added a `get_gradients(context_id)` API which allows users to retrieve a map from Tensor to respective gradient for the provided context_id on the local node. ghstack-source-id: 91794926 Differential Revision: [D17652615](https://our.internmc.facebook.com/intern/diff/D17652615/)
1 parent f35d7d4 commit cf24179

File tree

46 files changed

+1675
-455
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1675
-455
lines changed

caffe2/CMakeLists.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,16 +481,21 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
481481
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
482482
${TORCH_SRC_DIR}/csrc/distributed/autograd/context/dist_autograd_container.cpp
483483
${TORCH_SRC_DIR}/csrc/distributed/autograd/context/dist_autograd_context.cpp
484+
${TORCH_SRC_DIR}/csrc/distributed/autograd/engine/dist_engine.cpp
484485
${TORCH_SRC_DIR}/csrc/distributed/autograd/functions/recvrpc_backward.cpp
485486
${TORCH_SRC_DIR}/csrc/distributed/autograd/functions/sendrpc_backward.cpp
487+
${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/autograd_metadata.cpp
488+
${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp
489+
${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.cpp
490+
${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp
486491
${TORCH_SRC_DIR}/csrc/distributed/autograd/utils.cpp
487492
${TORCH_SRC_DIR}/csrc/distributed/rpc/future_message.cpp
488493
${TORCH_SRC_DIR}/csrc/distributed/rpc/message.cpp
489494
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_remote_call.cpp
490495
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_udf_call.cpp
491496
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_udf_resp.cpp
497+
${TORCH_SRC_DIR}/csrc/distributed/rpc/rpc_agent.cpp
492498
${TORCH_SRC_DIR}/csrc/distributed/rpc/request_callback.cpp
493-
${TORCH_SRC_DIR}/csrc/distributed/rpc/rpc_with_autograd.cpp
494499
${TORCH_SRC_DIR}/csrc/distributed/rpc/rref_proto.cpp
495500
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_call.cpp
496501
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_remote_call.cpp

test/cpp/dist_autograd/test_dist_autograd.cpp

Lines changed: 8 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
#include <ATen/ATen.h>
44
#include <torch/csrc/distributed/autograd/context/dist_autograd_container.h>
55
#include <torch/csrc/distributed/autograd/context/dist_autograd_context.h>
6+
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
67
#include <torch/csrc/distributed/autograd/utils.h>
7-
#include <torch/csrc/distributed/rpc/rpc_with_autograd.h>
88
#include <torch/torch.h>
99

1010
using namespace torch::distributed::autograd;
@@ -20,38 +20,6 @@ class DistAutogradTest : public ::testing::Test {
2020

2121
DistAutogradContainer* DistAutogradTest::autogradContainer_ = nullptr;
2222

23-
TEST_F(DistAutogradTest, TestSendFunction) {
24-
// Initialize input tensors requiring grad.
25-
auto options = at::TensorOptions().requires_grad(true);
26-
auto in1 = torch::ones({3, 3}, options);
27-
auto in2 = torch::ones({3, 3}, options);
28-
ASSERT_FALSE(in1.grad().defined());
29-
ASSERT_FALSE(in2.grad().defined());
30-
31-
autogradContainer_->newContext();
32-
DistAutogradContext& autogradContext = autogradContainer_->currentContext();
33-
// Attach the send autograd function to tensors.
34-
std::vector<torch::Tensor> tensors = {in1, in2};
35-
addSendRpcBackward(autogradContext, AutogradMetadata(1, 1), tensors);
36-
auto send_function = autogradContext.sendFunctions()[1];
37-
ASSERT_NE(send_function, nullptr);
38-
39-
// Build loss and attach it as input to send autograd function.
40-
auto o1 = torch::autograd::Variable(torch::ones({3, 3}));
41-
auto edge = torch::autograd::Edge(send_function, 0);
42-
o1.set_gradient_edge(edge);
43-
auto o2 = torch::autograd::Variable(torch::ones({3, 3}));
44-
edge = torch::autograd::Edge(send_function, 1);
45-
o2.set_gradient_edge(edge);
46-
auto loss = torch::add(o1, o2);
47-
48-
// Run backwards pass and verify gradients accumulated.
49-
auto gradient = torch::autograd::Variable(torch::rand({3, 3}));
50-
loss.backward(gradient, false, false);
51-
ASSERT_TRUE(in1.grad().defined());
52-
ASSERT_TRUE(in2.grad().defined());
53-
}
54-
5523
TEST_F(DistAutogradTest, TestSendFunctionInvalidInputs) {
5624
auto options = at::TensorOptions().requires_grad(true);
5725
auto in1 = torch::ones({3, 3}, options);
@@ -64,12 +32,12 @@ TEST_F(DistAutogradTest, TestSendFunctionInvalidInputs) {
6432
addSendRpcBackward(autogradContext, AutogradMetadata(1, 1), tensors);
6533
auto send_function = autogradContext.sendFunctions()[1];
6634

67-
// Build loss and attach it as input to send autograd function.
68-
auto loss = torch::autograd::Variable(torch::ones({3, 3}));
69-
loss.set_gradient_edge(torch::autograd::Edge(send_function, 1));
35+
// This should fail since the SendRpcBackward function shouldn't receive any
36+
// inputs grad.
37+
EXPECT_THROW(send_function->apply({in1, in2}), c10::Error);
7038

71-
// This should fail since the SendRpcBackward function is looking for two
72-
// inputs and as a result encounters an undefined grad.
73-
EXPECT_THROW(
74-
loss.backward(torch::autograd::Variable(), false, false), c10::Error);
39+
// This should fail since the SendRpcBackward function encounters an undefined
40+
// grad.
41+
send_function->setGrads({in1, torch::autograd::Variable()});
42+
EXPECT_THROW(send_function->apply({}), c10::Error);
7543
}

0 commit comments

Comments
 (0)
0