8000 Update on "[reland] Add autograd hook for python rpc call" · pytorch/pytorch@60e332b · GitHub
[go: up one dir, main page]

Skip to content

Commit 60e332b

Browse files
committed
Update on "[reland] Add autograd hook for python rpc call"
1. currently if autograd context is valid, even tensors do not require grads and grads function are not attached. it still send rpc with autograd meta. This is not ideal. This diff makes some change to make sure rpc with autograd meta is sent only if autograd context is valid and tensors require grads 2. meanwhile create a utiliy to attach autograd info and functions as needed 3. add autograd send/recv functions for python rpc call 4. make changes to support nested python rpc calls 5. disallow nested dist autograd context (was landed in #27022) Differential Revision: [D18017554](https://our.internmc.facebook.com/intern/diff/D18017554/) [ghstack-poisoned]
1 parent 867661f commit 60e332b

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

test/dist_autograd_test.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# sent from prev rank respectively.
1616
# rpc_done[2] and ctx_ids[2] represents for prev of prev rank.
1717
# rpc_done[3] and ctx_ids[3] represents for prev of prev of prev rank.
18+
# rpc_done[0] and ctx_ids[0] represents for current rank, but mostly not used.
1819
rpc_done = [False, False, False, False]
1920
ctx_ids = [-1, -1, -1, -1]
2021

@@ -163,8 +164,8 @@ def _verify_graph_for_first_rpc_call(self, send_function, recv_function, t1, t2,
163164
self.assertEqual(ret.grad_fn, recv_function)
164165

165166
# For a context passed from previous nested chain calls, this rank
166-
# recevied two tensors t1 and t2, execute torch.add(t1, t2) and send result
167-
# tensor t3 back.
167+
# receives two tensors t1 and t2, executes torch.add(t1, t2) and sends
168+
# result tensor t3 back.
168169
# For this context in this rank, it expects graph like this:
169170
# send and recv functions:
170171
# rpcSendBackward
@@ -191,15 +192,15 @@ def _verify_graph_for_rpc_call_exec(self, send_function):
191192
self.assertEqual(next_funcs[0][0], next_funcs[1][0])
192193

193194
# For a context passed from previous nested chain calls, this rank
194-
# recevied two tensors t1 and t2, forwarding t1 and t2 tensors using
195+
# receives two tensors t1 and t2, forwards t1 and t2 tensors using
195196
# nested rpc call to next dst. In return route, receive result tensor t3
196197
# from next dst and forwarding t3 back to previous calls.
197198
# For this context in this rank, it expects graph like this:
198-
# send and recv functions while recevive and forward t1 and t2:
199+
# send and recv functions for receving and forwarding t1 and t2:
199200
# rpcSendBackward
200201
# / \
201202
# t1.recvRpcBackward t2.recvRpcBackward
202-
# send and recv functions while receive and forward t3:
203+
# send and recv functions for receiving and forwarding t3:
203204
# rpcSendBackward
204205
# |
205206
# t3.recvRpcBackward

torch/csrc/distributed/rpc/request_callback.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct ClearAutogradContextGuard {
3838
} // anonymous namespace
3939

4040
Message RequestCallback::operator()(Message& request) const {
41-
// For a rev thread, current context id should be invalid outside
41+
// For a recv thread, current context id should be invalid outside
4242
// processMessage().
4343
ClearAutogradContextGuard guard;
4444
try {

0 commit comments

Comments
 (0)
0