10000 Add autograd hook for python rpc call · pytorch/pytorch@9d4309d · GitHub
[go: up one dir, main page]

Skip to content

Commit 9d4309d

Browse files
committed
Add autograd hook for python rpc call
Pull Request resolved: #27576 1. currently if autograd context is valid, even tensors do not require grads and grads function are not attached. it still send rpc with autograd meta. This is not ideal. This diff makes some change to make sure rpc with autograd meta is sent only if autograd context is valid and tensors require grads 2. meanwhile create a utiliy to attach autograd info and functions as needed 3. add autograd send/recv functions for python rpc call 4. make changes to support nested python rpc calls 5. disallow nested dist autograd context (was landed in #27022) ghstack-source-id: 92154535 Differential Revision: [D17819153](https://our.internmc.facebook.com/intern/diff/D17819153/)
1 parent a5ac7f6 commit 9d4309d

File tree

11 files changed

+430
-178
lines changed

11 files changed

+430
-178
lines changed

test/dist_autograd_test.py

Lines changed: 248 additions & 65 deletions
Large diffs are not rendered by default.

torch/csrc/distributed/autograd/context/dist_autograd_container.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,17 @@ int64_t DistAutogradContainer::getMaxId() {
156156
return max_id_;
157157
}
158158

159+
void DistAutogradContainer::setCurrentContextId(int64_t contextId) {
160+
TORCH_INTERNAL_ASSERT(
161+
current_context_id_ == kInvalidContextId,
162+
"Already have an autograd context id for this thread.");
163+
current_context_id_ = contextId;
164+
}
165+
166+
void DistAutogradContainer::clearCurrentContext() {
167+
current_context_id_ = -1;
168+
}
169+
159170
} // namespace autograd
160171
} // namespace distributed
161172
} // namespace torch

torch/csrc/distributed/autograd/context/dist_autograd_container.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,15 @@ class TORCH_API DistAutogradContainer {
5959
// can be generated by this worker.
6060
int64_t getMaxId();
6161

62-
// retrieves the worker ID for this node
62+
// Retrieves the worker ID for this node
6363
rpc::worker_id_t getWorkerId() const;
6464

65+
// Can set current context id if there is no valid context yet
66+
void setCurrentContextId(int64_t contextId);
67+
68+
// Clear current context id
69+
void clearCurrentContext();
70+
6571
private:
6672
DistAutogradContainer();
6773
~DistAutogradContainer() = default;

torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,14 @@ RpcWithAutograd::RpcWithAutograd(
1717
worker_id_t fromWorkerId,
1818
MessageType messageType,
1919
const AutogradMetadata& autogradMetadata,
20-
std::unique_ptr<RpcCommandBase> wrappedRpc)
20+
rpc::Message&& wrappedMessage)
2121
: fromWorkerId_(fromWorkerId),
2222
messageType_(messageType),
23-
autogradMetadata_(autogradMetadata) {
24-
TORCH_INTERNAL_ASSERT(wrappedRpc != nullptr, "wrappedRpc cannot be null!");
23+
autogradMetadata_(autogradMetadata),
24+
wrappedMessage_(std::move(wrappedMessage)) {
2525
TORCH_INTERNAL_ASSERT(
2626
messageType_ == MessageType::FORWARD_AUTOGRAD_REQ ||
2727
messageType_ == MessageType::FORWARD_AUTOGRAD_RESP);
28-
wrappedMessage_ = std::move(*wrappedRpc).toMessage();
2928
tensors_ = wrappedMessage_.tensors();
3029
wrappedMessageType_ = wrappedMessage_.type();
3130
}

torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
1818
rpc::worker_id_t fromWorkerId,
1919
rpc::MessageType messageType,
2020
const AutogradMetadata& autogradMetadata,
21-
std::unique_ptr<rpc::RpcCommandBase> wrappedRpc);
21+
rpc::Message&& wrappedMessage);
2222

2323
// Used when receiving an RPC over the wire.
2424
RpcWithAutograd(
@@ -57,10 +57,20 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
5757
rpc::MessageType messageType_;
5858

5959
AutogradMetadata autogradMetadata_;
60+
61+
// Since wrappedMessage_ is destructively constructed from wrappedRpc_,
62+
// they are valid exclusively. They are used for different purpose.
63+
// wrappedRpc_ is used while constructing receive rpcWithAutograd;
64+
// wrappedMessage_ is used while constructing send rpcWithAutograd;
65+
66+
// When receive rpcWithAutograd is constructed fromMessage, it is valid;
67+
// When send rpcWithAutograd is constructed before toMessage, it is nullptr;
6068
std::unique_ptr<RpcCommandBase> wrappedRpc_;
6169

6270
// Serialized message representing wrappedRpc_. Used mostly as a cache to
6371
// avoid serializing the request twice.
72+
// When receive rpcWithAutograd is constructed fromMessage, it is nullptr;
73+
// When send rpcWithAutograd is constructed before toMessage, it is valid;
6474
rpc::Message wrappedMessage_;
6575

6676
// message type of the wrappedMessage, this is stored separately since

torch/csrc/distributed/autograd/utils.cpp

Lines changed: 78 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,54 +9,100 @@ namespace torch {
99
namespace distributed {
1010
namespace autograd {
1111

12+
using torch::distributed::autograd::AutogradMetadata;
13+
using torch::distributed::autograd::RpcWithAutograd;
14+
using torch::distributed::rpc::FutureMessage;
1215
using torch::distributed::rpc::Message;
16+
using torch::distributed::rpc::MessageType;
17+
using torch::distributed::rpc::RpcAgent;
18+
using torch::distributed::rpc::RpcCommandBase;
19+
using torch::distributed::rpc::WorkerInfo;
1320

1421
void addSendRpcBackward(
1522
DistAutogradContext& autogradContext,
1623
const AutogradMetadata& autogradMetadata,
1724
std::vector<torch::Tensor>& tensors,
1825
const rpc::worker_id_t dst) {
1926
// Attach the appropriate autograd edges.
20-
if (torch::autograd::compute_requires_grad(tensors)) {
21-
auto grad_fn = std::make_shared<SendRpcBackward>();
22-
grad_fn->set_next_edges(torch::autograd::collect_next_edges(tensors));
23-
24-
// Add the appropriate input metadata for the grad_fn.
25-
for (const auto& tensor : tensors) {
26-
grad_fn->add_input_metadata(tensor);
27-
}
28-
29-
// Record the send autograd function in our current context.
30-
autogradContext.addSendFunction(
31-
grad_fn, autogradMetadata.autogradMessageId);
32-
// Record the workerID
33-
autogradContext.addKnownWorkerId(dst);
27+
auto grad_fn = std::make_shared<SendRpcBackward>();
28+
grad_fn->set_next_edges(torch::autograd::collect_next_edges(tensors));
29+
30+
// Add the appropriate input metadata for the grad_fn.
31+
for (const auto& tensor : tensors) {
32+
grad_fn->add_input_metadata(tensor);
3433
}
34+
35+
// Record the send autograd function in our current context.
36+
autogradContext.addSendFunction(grad_fn, autogradMetadata.autogradMessageId);
37+
// Record the workerID
38+
autogradContext.addKnownWorkerId(dst);
3539
}
3640

3741
DistAutogradContext* addRecvRpcBackward(
3842
const AutogradMetadata& autogradMetadata,
3943
std::vector<torch::Tensor>& tensors,
4044
rpc::worker_id_t fromWorkerId) {
41-
if (torch::autograd::compute_requires_grad(tensors)) {
42-
// Initialize autograd context if necessary.
43-
auto& autogradContainer = DistAutogradContainer::getInstance();
44-
DistAutogradContext& autogradContext = autogradContainer.getOrCreateContext(
45-
autogradMetadata.autogradContextId);
46-
47-
// Attach the tensors as inputs to the autograd function.
48-
auto grad_fn = std::make_shared<RecvRpcBackward>(
49-
autogradMetadata, autogradContext, fromWorkerId);
50-
for (auto& tensor : tensors) {
51-
torch::autograd::set_history(tensor, grad_fn);
52-
}
53-
54-
// Now update the autograd context with the necessary information.
55-
autogradContext.addRecvFunction(
56-
grad_fn, autogradMetadata.autogradMessageId);
57-
return &autogradContext;
45+
TORCH_INTERNAL_ASSERT(
46+
torch::autograd::compute_requires_grad(tensors),
47+
"Received tensors do not require grad, addRecvRpcBackward should not be called");
48+
// Initialize autograd context if necessary.
49+
auto& autogradContainer = DistAutogradContainer::getInstance();
50+
DistAutogradContext& autogradContext =
51+
autogradContainer.getOrCreateContext(autogradMetadata.autogradContextId);
52+
53+
// Attach the tensors as inputs to the autograd function.
54+
auto grad_fn = std::make_shared<RecvRpcBackward>(
55+
autogradMetadata, autogradContext, fromWorkerId);
56+
for (auto& tensor : tensors) {
57+
torch::autograd::set_history(tensor, grad_fn);
58+
}
59+
60+
// Now update the autograd context with the necessary information.
61+
autogradContext.addRecvFunction(grad_fn, autogradMetadata.autogradMessageId);
62+
return &autogradContext;
63+
}
64+
65+
Message getMessageWithAutograd(
66+
const rpc::worker_id_t dstId,
67+
torch::distributed::rpc::Message&& wrappedRpcMsg,
68+
MessageType msgType) {
69+
auto& autogradContainer = DistAutogradContainer::getInstance();
70+
71+
// If there is no valid context and no tensor requires grads, send original
72+
// rpc message. otherwise, attach grad info and grad functions and send
73+
// rpcWithAutograd message.
74+
if (!autogradContainer.hasValidContext() ||
75+
!torch::autograd::compute_requires_grad(wrappedRpcMsg.tensors())) {
76+
return std::move(wrappedRpcMsg);
5877
}
59-
return nullptr;
78+
79+
// Retrieve the appropriate context to modify.
80+
auto& autogradContext = autogradContainer.currentContext();
81+
82+
// Wrap the original rpc with autograd information.
83+
AutogradMetadata autogradMetadata(
84+
autogradContext.contextId(), autogradContainer.newAutogradMessageId());
85+
auto rpcWithAutograd = c10::guts::make_unique<RpcWithAutograd>(
86+
RpcAgent::getDefaultRpcAgent()->getWorkerInfo().id_,
87+
msgType,
88+
autogradMetadata,
89+
std::move(wrappedRpcMsg));
90+
91+
// Record autograd information for 'send'.
92+
addSendRpcBackward(
93+
autogradContext, autogradMetadata, rpcWithAutograd->tensors(), dstId);
94+
95+
return std::move(*rpcWithAutograd).toMessage();
96+
}
97+
98+
std::shared_ptr<FutureMessage> sendMessageWithAutograd(
99+
RpcAgent& agent,
100+
const WorkerInfo& dst,
101+
torch::distributed::rpc::Message&& wrappedRpcMsg) {
102+
auto msg = getMessageWithAutograd(
103+
dst.id_, std::move(wrappedRpcMsg), MessageType::FORWARD_AUTOGRAD_REQ);
104+
105+
return agent.send(dst, std::move(msg));
60106
}
61107

62108
} // namespace autograd

torch/csrc/distributed/autograd/utils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,22 @@ TORCH_API DistAutogradContext* addRecvRpcBackward(
3232
std::vector<torch::Tensor>& tensors,
3333
rpc::worker_id_t fromWorkerId);
3434

35+
// This method is a wrapper utility used internally to wrap autograd info
36+
// and attach autograd function for each type of rpc call if it has valid
37+
// context and tensors require grads, in this case, return RpcWithAutograd
38+
// message; otherwise return original rpc message.
39+
TORCH_API rpc::Message getMessageWithAutograd(
40+
const rpc::worker_id_t dstId,
41+
rpc::Message&& wrappedRpcMsg,
42+
rpc::MessageType msgType);
43+
44+
// Send message after autograd checking
45+
TORCH_API std::shared_ptr<torch::distributed::rpc::FutureMessage>
46+
sendMessageWithAutograd(
47+
rpc::RpcAgent& agent,
48+
const rpc::WorkerInfo& dst,
49+
rpc::Message&& wrappedRpcMsg);
50+
3551
} // namespace autograd
3652
} // namespace distributed
3753
} // namespace torch

torch/csrc/distributed/rpc/python_functions.cpp

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -124,28 +124,8 @@ std::shared_ptr<FutureMessage> pyRpcBuiltin(
124124
Stack stack;
125125
auto op = matchBuiltinOp(opName, args, kwargs, stack);
126126
auto scriptCall = c10::guts::make_unique<ScriptCall>(op, std::move(stack));
127-
auto& autogradContainer = DistAutogradContainer::getInstance();
128-
if (autogradContainer.hasValidContext()) {
129-
// Retrieve the appropriate context to modify.
130-
auto& autogradContext = autogradContainer.currentContext();
131-
132-
// Wrap the original rpc with autograd information.
133-
AutogradMetadata autogradMetadata(
134-
autogradContext.contextId(), autogradContainer.newAutogradMessageId());
135-
RpcWithAutograd rpcWithAutograd(
136-
agent.getWorkerInfo().id_,
137-
MessageType::FORWARD_AUTOGRAD_REQ,
138-
autogradMetadata,
139-
std::move(scriptCall));
140-
141-
// Record autograd information for 'send'.
142-
addSendRpcBackward(
143-
autogradContext, autogradMetadata, rpcWithAutograd.tensors(), dst.id_);
144-
145-
return agent.send(dst, std::move(rpcWithAutograd).toMessage());
146-
} else {
147-
return agent.send(dst, std::move(*scriptCall).toMessage());
148-
}
127+
return sendMessageWithAutograd(
128+
agent, dst, std::move(*scriptCall).toMessage());
149129
}
150130

151131
PyRRef pyRemoteBuiltin(
@@ -179,12 +159,11 @@ std::shared_ptr<FutureMessage> pyRpcPythonUdf(
179159
const WorkerInfo& dst,
180160
std::string& pickledPythonUDF,
181161
std::vector<torch::Tensor>& tensors) {
182-
return agent.send(
183-
dst,
184-
PythonUDFCall(
185-
std::vector<char>(pickledPythonUDF.begin(), pickledPythonUDF.end()),
186-
tensors)
187-
.toMessage());
162+
auto pythonUDFCall = c10::guts::make_unique<PythonUDFCall>(
163+
std::vector<char>(pickledPythonUDF.begin(), pickledPythonUDF.end()),
164+
tensors);
165+
return sendMessageWithAutograd(
166+
agent, dst, std::move(*pythonUDFCall).toMessage());
188167
}
189168

190169
PyRRef pyRemotePythonUdf(

torch/csrc/distributed/rpc/request_callback.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,27 @@ Message createException(const Message& request, const std::exception& e) {
2020
request.id());
2121
}
2222

23+
// When request message has autograd info, processMessage() will set up valid
24+
// current context id properly. This struct is used to clean up current context
25+
// id after processMessage() is done.
26+
struct ClearAutogradContextGuard {
27+
ClearAutogradContextGuard() = default;
28+
~ClearAutogradContextGuard() {
29+
clear();
30+
}
31+
32+
void clear() {
33+
auto& autogradContainer = DistAutogradContainer::getInstance();
34+
autogradContainer.clearCurrentContext();
35+
}
36+
};
37+
2338
} // anonymous namespace
2439

2540
Message RequestCallback::operator()(Message& request) const {
41+
// For a rev thread, current context id should be invalid outside
42+
// processMessage().
43+
ClearAutogradContextGuard guard;
2644
try {
2745
return processMessage(request);
2846
} catch (std::exception& e) {

0 commit comments

Comments
 (0)
0