8000 Update on "Add Python RRef as args and return value" · pytorch/pytorch@3244bed · GitHub
[go: up one dir, main page]

Skip to content

Commit 3244bed

Browse files
committed
Update on "Add Python RRef as args and return value"
See #23110 for model parallel design details, and #26759 for the RRef protocol. This commit add support for using RRef as Python UDF arguments and return value. RRefs can now be shared from owner to user, from user to owner, or from user to user. Limitations: 1. No implicit type conversion yet. (#27099) 2. No failure handling and retry. (#26116) 3. UDF is not yet blocked until all RRefs are confirmed. (#27098) 4. Internal RRef control messages are not idempotent yet. (#26116) 5. Cannot delete RRefs correctly when there are circular dependencies. (#27096) Main changes: 1. Added `SCRIPT_REMOTE_CALL` and `PYTHON_REMOTE_CALL` to `Message.h` to represent `dist.remote` invocations. 2. Added `SCRIPT_RREF_FETCH_CALL`, `PYTHON_RREF_FETCH_CALL`, `RREF_USER_ACCEPT`, `RREF_USER_DELETE`, `RREF_CHILD_ACCEPT`, and `RREF_FORK_REQUEST` to `Message.h` as internal RRef control messages. 3. New message request handling code is added to `functions.cpp`, and message format is added in `script_remote_call.h`, `python_remote_call.h`, and `rref_proto.h`. 4. Added a `PyRRef` type in `py_rref.h` and `py_rref.cpp` which holds a shared pointer to C++ `RRef` type. `PyRRef` wraps the C++ API and also implements RRef pickling and unpickling. RRef fork related control messages will be sent during RRef pickling/unpickling procedure. 5. Update `RRef.h` and `RRef.cpp` accordingly to support `py::object` RRefs. 6. RRef context (reference count, etc.) are tracked in `rref_context.h` and `rref_context.cpp`. Differential Revision: [D17184146](https://our.internmc.facebook.com/intern/diff/D17184146) [ghstack-poisoned]
2 parents c50124d + a3ebaaf commit 3244bed

File tree

7 files changed

+11
-9
lines changed

7 files changed

+11
-9
lines changed

torch/csrc/distributed/rpc/init.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ PyObject* rpc_init(PyObject* /* unused */) {
126126
"invoke_rpc_python_udf",
127127
[](RpcAgent& agent,
128128
const WorkerInfo& dst,
129-
const std::string& pickledPythonUDF,
129+
std::string& pickledPythonUDF,
130130
std::vector<torch::Tensor>& tensors) {
131131
return pyRpcPythonUdf(agent, dst, pickledPythonUDF, tensors);
132132
});
@@ -145,7 +145,7 @@ PyObject* rpc_init(PyObject* /* unused */) {
145145
"invoke_remote_python_udf",
146146
[](RpcAgent& agent,
147147
const WorkerInfo& dst,
148-
const std::string& pickledPythonUDF,
148+
std::string& pickledPythonUDF,
149149
std::vector<torch::Tensor>& tensors) {
150150
return pyRemotePythonUdf(agent, dst, pickledPythonUDF, tensors);
151151
});

torch/csrc/distributed/rpc/python_functions.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ PyRRef pyRemoteBuiltin(
130130
std::shared_ptr<FutureMessage> pyRpcPythonUdf(
131131
RpcAgent& agent,
132132
const WorkerInfo& dst,
133-
const std::string& pickledPythonUDF,
133+
std::string& pickledPythonUDF,
134134
std::vector<torch::Tensor>& tensors) {
135135
std::vector<char> data(pickledPythonUDF.begin(), pickledPythonUDF.end());
136136

@@ -142,7 +142,7 @@ std::shared_ptr<FutureMessage> pyRpcPythonUdf(
142142
PyRRef pyRemotePythonUdf(
143143
RpcAgent& agent,
144144
const WorkerInfo& dst,
145-
const std::string& pickledPythonUDF,
145+
std::string& pickledPythonUDF,
146146
std::vector<torch::Tensor>& tensors) {
147147
auto& ctx = RRefContext::getInstance();
148148
// TODO: support creaing RRefs on a local object.

torch/csrc/distributed/rpc/python_functions.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ std::shared_ptr<FutureMessage> pyRpcBuiltin(
2121
std::shared_ptr<FutureMessage> pyRpcPythonUdf(
2222
RpcAgent& agent,
2323
const WorkerInfo& dst,
24-
const std::string& pickledPythonUDF,
24+
std::string& pickledPythonUDF,
2525
std::vector<torch::Tensor>& tensors);
2626

2727
PyRRef pyRemoteBuiltin(
@@ -34,7 +34,7 @@ PyRRef pyRemoteBuiltin(
3434
PyRRef pyRemotePythonUdf(
3535
RpcAgent& agent,
3636
const WorkerInfo& dst,
37-
const std::string& pickledPythonUDF,
37+
std::string& pickledPythonUDF,
3838
std::vector<torch::Tensor>& tensors);
3939

4040
} // namespace rpc

torch/csrc/distributed/rpc/python_rpc_handler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ SerializedPyObj PythonRpcHandler::serialize(const py::object& obj) {
4848
t[1].cast<std::vector<torch::Tensor>>());
4949
}
5050

51-
py::object PythonRpcHandler::deserialize(SerializedPyObj serializedObj) {
51+
py::object PythonRpcHandler::deserialize(const SerializedPyObj& serializedObj) {
5252
AutoGIL ag;
5353
return loadResultFunction_(
5454
py::bytes(serializedObj.payload_), serializedObj.tensors_);

torch/csrc/distributed/rpc/python_rpc_handler.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class PYBIND11_EXPORT PythonRpcHandler {
3030
// Serialized a py::object into a string
3131
SerializedPyObj serialize(const py::object& obj);
3232
// Deserialize a string into a py::object
33-
py::object deserialize(SerializedPyObj serializedObj);
33+
py::object deserialize(const SerializedPyObj& serializedObj);
3434

3535
private:
3636
PythonRpcHandler();

torch/csrc/distributed/rpc/types.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ std::ostream& operator<<(std::ostream& os, GloballyUniqueId const& globalId) {
6464

6565
std::vector<at::IValue> SerializedPyObj::toIValues() const {
6666
std::vector<at::IValue> ivalues;
67+
ivalues.reserve(tensors_.size() + 1);
6768
for (auto& tensor: tensors_) {
6869
ivalues.emplace_back(tensor);
6970
}
@@ -76,6 +77,7 @@ SerializedPyObj SerializedPyObj::fromIValues(std::vector<at::IValue> values) {
7677
std::string payload = values.back().toStringRef();
7778
values.pop_back();
7879
std::vector<at::Tensor> tensors;
80+
tensors.reserve(values.size());
7981
for (auto& value: values) {
8082
tensors.emplace_back(value.toTensor());
8183
}

torch/csrc/distributed/rpc/types.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ using ForkId = GloballyUniqueId;
4444

4545
struct TORCH_API SerializedPyObj final {
4646
SerializedPyObj(
47-
const std::string&& payload,
47+
std::string&& payload,
4848
std::vector<at::Tensor>&& tensors)
4949
: payload_(std::move(payload)),
5050
tensors_(std::move(tensors)) {}

0 commit comments

Comments
 (0)
0