1
1
from __future__ import absolute_import , division , print_function , unicode_literals
2
2
3
- import sys
4
3
import torch .distributed as dist
5
4
import torch .distributed .autograd as dist_autograd
6
- from functools import wraps
5
+ from dist_utils import dist_init
7
6
import six
8
7
import unittest
9
8
import torch
9
+ import time
10
10
11
- if not dist .is_available ():
12
- print ("c10d not available, skipping tests" )
13
- sys .exit (0 )
14
-
15
- def dist_init (func ):
16
- """
17
- We use this decorator for setting up and tearing down state since
18
- MultiProcessTestCase runs each `test*` method in a separate process and
19
- each process just runs the `test*` method without actually calling
20
- 'setUp' and 'tearDown' methods of unittest.
21
- """
22
- @wraps (func )
23
- def wrapper (self ):
24
- self .worker_id = self .rank
25
- store = dist .FileStore (self .file_name , self .world_size )
26
- dist .init_process_group (backend = 'gloo' , rank = self .rank ,
27
- world_size = self .world_size , store = store )
28
- dist .init_model_parallel ('worker%d' % self .rank )
29
- func (self )
30
- dist .join_rpc ()
31
-
32
- return wrapper
11
+ prev_rank_rpc_done = False
12
+ prev_rank_context_id = 0
13
+ def _set_rpc_done (context_id ):
14
+ global prev_rank_rpc_done
15
+ global prev_rank_context_id
16
+ prev_rank_rpc_done = True
17
+ prev_rank_context_id = context_id
33
18
34
19
@unittest .skipIf (not six .PY3 , "Pytorch distributed autograd package "
35
20
"does not support python2" )
@@ -41,6 +26,10 @@ def world_size(self):
41
26
42
27
@dist_init
43
28
def test_autograd_context (self ):
29
+ # Verify max possible id.
30
+ max_auto_increment = 281474976710655
31
+ self .assertEqual (max_auto_increment + (self .worker_id << 48 ), dist_autograd ._get_max_id ())
32
+
44
33
context_ids = []
45
34
for i in range (1000 ):
46
35
with dist_autograd .context () as context_id :
@@ -54,12 +43,13 @@ def test_autograd_context(self):
54
43
dist_autograd ._retrieve_context (context_id )
55
44
56
45
@dist_init
57
- def test_autograd_send_function (self ):
46
+ def test_autograd_functions (self ):
58
47
dst_rank = (self .rank + 1 ) % self .world_size
59
48
with dist_autograd .context () as context_id :
60
49
t1 = torch .ones (3 , 3 , requires_grad = True )
61
50
t2 = torch .zeros (3 , 3 , requires_grad = True )
62
51
F438
ret = dist .rpc_sync ('worker{}' .format (dst_rank ), torch .add , args = (t1 , t2 ))
52
+ dist .rpc_sync ('worker{}' .format (dst_rank ), _set_rpc_done , args = (context_id ,))
63
53
64
54
# Get send function.
65
55
ctx = dist_autograd ._current_context ()
@@ -68,7 +58,7 @@ def test_autograd_send_function(self):
68
58
self .assertEqual (1 , len (send_functions ))
69
59
70
60
# Retrieve the next functions in the graph.
71
- next_funcs = send_functions [0 ].next_functions
61
+ next_funcs = list ( send_functions . values ()) [0 ].next_functions
72
62
self .assertEqual (2 , len (next_funcs ))
73
63
74
64
# We should now hit t1 and t2 in the autograd graph.
@@ -79,6 +69,39 @@ def test_autograd_send_function(self):
79
69
self .assertEqual (t2 , next_funcs [1 ][0 ].variable )
80
70
self .assertEqual (0 , next_funcs [1 ][1 ])
81
71
72
+ # Test recv functions.
73
+ recv_functions = ctx ._recv_functions ()
74
+ self .assertEqual (1 , len (recv_functions ))
75
+ self .assertEqual (ret .grad_fn , list (recv_functions .values ())[0 ])
76
+
77
+ # We should have send/recv functions from the previous rank, get all
78
+ # contexts in this node to find them.
79
+
80
+ # Wait for the prev rank to be done with rpc.
81
+ while not prev_rank_rpc_done :
82
+ time .sleep (0.1 )
83
+ pass
84
+
85
+ # Now verify the autograd graph.
86
+ ctx = dist_autograd ._retrieve_context (prev_rank_context_id )
87
+
88
+ # Get the send function.
89
+ send_functions = ctx ._send_functions ()
90
+ self .assertEqual (1 , len (send_functions ))
91
+
92
+ # Verify next function is AddBackward0
93
+ next_funcs = list (send_functions .values ())[0 ].next_functions
94
+ self .assertEqual (1 , len (next_funcs ))
95
+ add_backward_fn = next_funcs [0 ][0 ]
96
+ self .assertEqual ('AddBackward0' , add_backward_fn .name ())
97
+
98
+ # Verify the next two functions are the same recv backward function.
99
+ next_funcs = add_backward_fn .next_functions
100
+ self .assertEqual (2 , len (next_funcs ))
101
+ self .assertEqual ('torch::distributed::autograd::RecvRpcBackward' , next_funcs [0 ][0 ].name ())
102
+ self .assertEqual ('torch::distributed::autograd::RecvRpcBackward' , next_funcs [1 ][0 ].name ())
103
+ self .assertEqual (next_funcs [0 ][0 ], next_funcs [1 ][0 ])
104
+
82
105
# autograd context should be cleaned up by now.
83
106
with self .assertRaises (RuntimeError ):
84
107
ctx = dist_autograd ._retrieve_context (context_id )
@@ -99,7 +122,7 @@ def test_rpc_complex_args(self):
99
122
self .assertEqual (torch .stack (tensors ), ret )
100
123
101
124
# Verify appropriate tensors have been attached the autograd graph.
102
- next_funcs = dist_autograd ._current_context ()._send_functions ()[0 ].next_functions
125
+ next_funcs = list ( dist_autograd ._current_context ()._send_functions (). values () )[0 ].next_functions
103
126
idx = 0
104
127
for i in range (num_tensors ):
105
128
if i % 2 == 0 :
0 commit comments