8000 Add send and recv backward functions for builtin operators RPC. (#25527) · thiagocrepaldi/pytorch@73173fd · GitHub
[go: up one dir, main page]

Skip to content

Commit 73173fd

Browse files
pritamdamaniaThiago Crepaldi
authored and
Add send and recv backward functions for builtin operators RPC. (pytorch#25527)
Summary: Pull Request resolved: pytorch#25527 Master GH issue: pytorch#23110. This change builds upon pytorch#24876 and provides all the autograd hooks needed for a forward pass with distributed rpc for builtin operators. This change does not address distributed rpc for python UDFs and that will be addressed in follow up PRs. Summary of changes: 1. Attach send autograd functions when a request is sent from the client and response is sent from the server. 2. Attach receive autograd functions when a request is received on the server and a response is received on the client. 3. Generate a globally unique autograd_message_id for each send/recv autograd function pair to uniquely identify them. ghstack-source-id: 91240466 Test Plan: unit tests. Differential Revision: D17148077 fbshipit-source-id: 192d8a3f552ed7cc939f55dcca332965c9bd3233
1 parent 7d9e5b5 commit 73173fd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+1508
-496
lines changed

caffe2/CMakeLists.txt

Lines changed: 10 additions & 1 deletion
< F438 tr class="diff-line-row">
Original file line numberDiff line numberDiff line change
@@ -481,18 +481,27 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
481481
if (NOT INTERN_BUILD_MOBILE)
482482
list(APPEND TORCH_SRCS
483483
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
484+
${TORCH_SRC_DIR}/csrc/distributed/autograd/context/dist_autograd_container.cpp
485+
${TORCH_SRC_DIR}/csrc/distributed/autograd/context/dist_autograd_context.cpp
486+
${TORCH_SRC_DIR}/csrc/distributed/autograd/functions/recvrpc_backward.cpp
484487
${TORCH_SRC_DIR}/csrc/distributed/autograd/functions/sendrpc_backward.cpp
485488
${TORCH_SRC_DIR}/csrc/distributed/autograd/utils.cpp
486489
${TORCH_SRC_DIR}/csrc/distributed/rpc/future_message.cpp
487490
${TORCH_SRC_DIR}/csrc/distributed/rpc/message.cpp
491+
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_udf_call.cpp
492+
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_udf_resp.cpp
493+
${TORCH_SRC_DIR}/csrc/distributed/rpc/request_callback.cpp
494+
${TORCH_SRC_DIR}/csrc/distributed/rpc/rpc_with_autograd.cpp
488495
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_call.cpp
489496
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_remote_call.cpp
490497
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_rref_proto.cpp
491-
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_ret.cpp
498+
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_resp.cpp
499+
${TORCH_SRC_DIR}/csrc/distributed/rpc/utils.cpp
492500
${TORCH_SRC_DIR}/csrc/jit/export.cpp
493501
${TORCH_SRC_DIR}/csrc/jit/import_legacy.cpp
494502
${TORCH_SRC_DIR}/csrc/jit/netdef_converter.cpp
495503
${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/fused_kernel.cpp
504+
${TORCH_SRC_DIR}/csrc/utils/byte_order.cpp
496505
)
497506
endif()
498507

test/cpp/dist_autograd/test_dist_autograd.cpp

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,39 @@
11
#include <gtest/gtest.h>
22

33
#include <ATen/ATen.h>
4+
#include <torch/csrc/distributed/autograd/context/dist_autograd_container.h>
5+
#include <torch/csrc/distributed/autograd/context/dist_autograd_context.h>
46
#include <torch/csrc/distributed/autograd/utils.h>
7+
#include <torch/csrc/distributed/rpc/rpc_with_autograd.h>
58
#include <torch/torch.h>
69

7-
TEST(DistAutogradTest, TestSendFunction) {
10+
using namespace torch::distributed::autograd;
11+
using namespace torch::distributed::rpc;
12+
13+
class DistAutogradTest : public ::testing::Test {
14+
protected:
15+
static void SetUpTestCase() {
16+
autogradContainer_ = &DistAutogradContainer::init(0);
17+
}
18+
static DistAutogradContainer* autogradContainer_;
19+
};
20+
21+
DistAutogradContainer* DistAutogradTest::autogradContainer_ = nullptr;
22+
23+
TEST_F(DistAutogradTest, TestSendFunction) {
824
// Initialize input tensors requiring grad.
925
auto options = at::TensorOptions().requires_grad(true);
1026
auto in1 = torch::ones({3, 3}, options);
1127
auto in2 = torch::ones({3, 3}, options);
1228
ASSERT_FALSE(in1.grad().defined());
1329
ASSERT_FALSE(in2.grad().defined());
1430

31+
autogradContainer_->newContext();
32+
DistAutogradContext& autogradContext = autogradContainer_->currentContext();
1533
// Attach the send autograd function to tensors.
16-
auto send_function =
17-
torch::distributed::autograd::addSendRpcBackward({in1, in2});
34+
std::vector<torch::Tensor> tensors = {in1, in2};
35+
addSendRpcBackward(autogradContext, AutogradMetadata(1, 1), tensors);
36+
auto send_function = autogradContext.sendFunctions()[1];
1837
ASSERT_NE(send_function, nullptr);
1938

2039
// Build loss and attach it as input to send autograd function.
@@ -33,14 +52,17 @@ TEST(DistAutogradTest, TestSendFunction) {
3352
ASSERT_TRUE(in2.grad().defined());
3453
}
3554

36-
TEST(DistAutogradTest, TestSendFunctionInvalidInputs) {
55+
TEST_F(DistAutogradTest, TestSendFunctionInvalidInputs) {
3756
auto options = at::TensorOptions().requires_grad(true);
3857
auto in1 = torch::ones({3, 3}, options);
3958
auto in2 = torch::ones({3, 3}, options);
4059

60+
autogradContainer_->newContext();
61+
DistAutogradContext& autogradContext = autogradContainer_->currentContext();
4162
// Attach the send autograd function to tensors.
42-
auto send_function =
43-
torch::distributed::autograd::addSendRpcBackward({in1, in2});
63+
std::vector<torch::Tensor> tensors = {in1, in2};
64+
addSendRpcBackward(autogradContext, AutogradMetadata(1, 1), tensors);
65+
auto send_function = autogradContext.sendFunctions()[1];
4466

4567
// Build loss and attach it as input to send autograd function.
4668
auto loss = torch::autograd::Variable(torch::ones({3, 3}));

test/dist_autograd_test.py

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,20 @@
11
from __future__ import absolute_import, division, print_function, unicode_literals
22

3-
import sys
43
import torch.distributed as dist
54
import torch.distributed.autograd as dist_autograd
6-
from functools import wraps
5+
from dist_utils import dist_init
76
import six
87
import unittest
98
import torch
9+
import time
1010

11-
if not dist.is_available():
12-
print("c10d not available, skipping tests")
13-
sys.exit(0)
14-
15-
def dist_init(func):
16-
"""
17-
We use this decorator for setting up and tearing down state since
18-
MultiProcessTestCase runs each `test*` method in a separate process and
19-
each process just runs the `test*` method without actually calling
20-
'setUp' and 'tearDown' methods of unittest.
21-
"""
22-
@wraps(func)
23-
def wrapper(self):
24-
self.worker_id = self.rank
25-
store = dist.FileStore(self.file_name, self.world_size)
26-
dist.init_process_group(backend='gloo', rank=self.rank,
27-
world_size=self.world_size, store=store)
28-
dist.init_model_parallel('worker%d' % self.rank)
29-
func(self)
30-
dist.join_rpc()
31-
32-
return wrapper
11+
prev_rank_rpc_done = False
12+
prev_rank_context_id = 0
13+
def _set_rpc_done(context_id):
14+
global prev_rank_rpc_done
15+
global prev_rank_context_id
16+
prev_rank_rpc_done = True
17+
prev_rank_context_id = context_id
3318

3419
@unittest.skipIf(not six.PY3, "Pytorch distributed autograd package "
3520
"does not support python2")
@@ -41,6 +26,10 @@ def world_size(self):
4126

4227
@dist_init
4328
def test_autograd_context(self):
29+
# Verify max possible id.
30+
max_auto_increment = 281474976710655
31+
self.assertEqual(max_auto_increment + (self.worker_id << 48), dist_autograd._get_max_id())
32+
4433
context_ids = []
4534
for i in range(1000):
4635
with dist_autograd.context() as context_id:
@@ -54,12 +43,13 @@ def test_autograd_context(self):
5443
dist_autograd._retrieve_context(context_id)
5544

5645
@dist_init
57-
def test_autograd_send_function(self):
46+
def test_autograd_functions(self):
5847
dst_rank = (self.rank + 1) % self.world_size
5948
with dist_autograd.context() as context_id:
6049
t1 = torch.ones(3, 3, requires_grad=True)
6150
t2 = torch.zeros(3, 3, requires_grad=True)
6251
ret = dist.rpc_sync('worker{}'.format(dst_rank), torch.add, args=(t1, t2))
52+
dist.rpc_sync('worker{}'.format(dst_rank), _set_rpc_done, args=(context_id,))
6353

6454
# Get send function.
6555
ctx = dist_autograd._current_context()
@@ -68,7 +58,7 @@ def test_autograd_send_function(self):
6858
self.assertEqual(1, len(send_functions))
6959

7060
# Retrieve the next functions in the graph.
71-
next_funcs = send_functions[0].next_functions
61+
next_funcs = list(send_functions.values())[0].next_functions
7262
self.assertEqual(2, len(next_funcs))
7363

7464
# We should now hit t1 and t2 in the autograd graph.
@@ -79,6 +69,39 @@ def test_autograd_send_function(self):
7969
self.assertEqual(t2, next_funcs[1][0].variable)
8070
self.assertEqual(0, next_funcs[1][1])
8171

72+
# Test recv functions.
73+
recv_functions = ctx._recv_functions()
74+
self.assertEqual(1, len(recv_functions))
75+
self.assertEqual(ret.grad_fn, list(recv_functions.values())[0])
76+
77+
# We should have send/recv functions from the previous rank, get all
78+
# contexts in this node to find them.
79+
80+
# Wait for the prev rank to be done with rpc.
81+
while not prev_rank_rpc_done:
82+
time.sleep(0.1)
83+
pass
84+
85+
# Now verify the autograd graph.
86+
ctx = dist_autograd._retrieve_context(prev_rank_context_id)
87+
88+
# Get the send function.
89+
send_functions = ctx._send_functions()
90+
self.assertEqual(1, len(send_functions))
91+
92+
# Verify next function is AddBackward0
93+
next_funcs = list(send_functions.values())[0].next_functions
94+
self.assertEqual(1, len(next_funcs))
95+
add_backward_fn = next_funcs[0][0]
96+
self.assertEqual('AddBackward0', add_backward_fn.name())
97+
98+
# Verify the next two functions are the same recv backward function.
99+
next_funcs = add_backward_fn.next_functions
100+
self.assertEqual(2, len(next_funcs))
101+
self.assertEqual('torch::distributed::autograd::RecvRpcBackward', next_funcs[0][0].name())
102+
self.assertEqual('torch::distributed::autograd::RecvRpcBackward', next_funcs[1][0].name())
103+
self.assertEqual(next_funcs[0][0], next_funcs[1][0])
104+
82105
# autograd context should be cleaned up by now.
83106
with self.assertRaises(RuntimeError):
84107
ctx = dist_autograd._retrieve_context(context_id)
@@ -99,7 +122,7 @@ def test_rpc_complex_args(self):
99122
self.assertEqual(torch.stack(tensors), ret)
100123

101124
# Verify appropriate tensors have been attached the autograd graph.
102-
next_funcs = dist_autograd._current_context()._send_functions()[0].next_functions
125+
next_funcs = list(dist_autograd._current_context()._send_functions().values())[0].next_functions
103126
idx = 0
104127
for i in range(num_tensors):
105128
if i % 2 == 0:

test/dist_utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from __future__ import absolute_import, division, print_function, unicode_literals
2+
3+
from os import getenv
4+
from functools import wraps
5+
import torch.distributed as dist
6+
from torch.distributed.rpc_api import RpcBackend
7+
8+
if not dist.is_available():
9+
print("c10d not available, skipping tests")
10+
sys.exit(0)
11+
12+
13+
BACKEND = getenv('RPC_BACKEND', RpcBackend.PROCESS_GROUP)
14+
RPC_INIT_URL = getenv('RPC_INIT_URL', '')
15+
16+
def dist_init(func):
17+
"""
18+
We use this decorator for setting up and tearing down state since
19+
MultiProcessTestCase runs each `test*` method in a separate process and
20+
each process just runs the `test*` method without actually calling
21+
'setUp' and 'tearDown' methods of unittest.
22+
"""
23+
@wraps(func)
24+
def wrapper(self):
25+
self.worker_id = self.rank
26+
store = dist.FileStore(self.file_name, self.world_size)
27+
dist.init_process_group(backend='gloo', rank=self.rank,
28+
world_size=self.world_size, store=store)
29+
dist.init_model_parallel(self_name='worker%d' % self.rank,
30+
backend=BACKEND,
31+
self_rank=self.rank,
32+
init_method=RPC_INIT_URL)
33+
func(self)
34+
dist.join_rpc()
35+
36+
return wrapper

0 commit comments

Comments
 (0)
0