8000 Update on "sync and async torch.distributed.rpc for builtin operators" · pytorch/pytorch@507a1e5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 507a1e5

Browse files
committed
Update on "sync and async torch.distributed.rpc for builtin operators"
Features: * sync and async RPC for builtin operators * RpcAgent API * ProcessGroupAgent implementation Goal: This is the first PR for #23110, and there will be many followup ones. So let's focus on the overall API and code structure. Details like efficiency and error handling can be improved in future PRs. * have a minimum working and testable RPC implementation. * make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object. * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...). * support blocking and non-blocking RequestCallback * blocking means the callback won't return before sending out the response * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list. Differential Revision: [D15194693](https://our.internmc.facebook.com/intern/diff/D15194693/)
1 parent 3c642e8 commit 507a1e5

File tree

6 files changed

+15
-13
lines changed

6 files changed

+15
-13
lines changed

torch/csrc/distributed/rpc/Message.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Message& Message::operator=(Message const& rhs) & {
3636
Message& Message::operator=(Message&& rhs) & {
3737
Message(std::move(rhs.payload_),
3838
std::move(rhs.tensors_),
39-
std::move(rhs.type_),
39+
rhs.type_,
4040
rhs.id_).swap(*this);
4141
return *this;
4242
}

torch/csrc/distributed/rpc/ScriptCall.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ using torch::jit::Unpickler;
1111

1212
} // namespace
1313

14+
#define BUILTIN_OP_NAMESPACE "torch.ops.aten."
15+
#define ATEN_PREFIX "aten::"
16+
#define ATEN_PREFIX_LEN 6
17+
1418
ScriptCall::ScriptCall(
1519
std::shared_ptr<Operator> op, std::vector<at::IValue>&& args)
1620
: op_(std::move(op)), stack_(args) {}
@@ -42,9 +46,9 @@ Message ScriptCall::toMessage() {
4246
// insert qualified name
4347
auto opName = (*op_)->schema().name();
4448
TORCH_CHECK(opName.find("::") == opName.rfind("::")
45-
&& opName.rfind("aten::") == 0, "Unexpected operator name ", opName);
49+
&& opName.rfind(ATEN_PREFIX) == 0, "Unexpected operator name ", opName);
4650
// aten::add -> torch.ops.aten.add
47-
opName.replace(0, 6, BUILTIN_OP_NAMESPACE);
51+
opName.replace(0, ATEN_PREFIX_LEN, BUILTIN_OP_NAMESPACE);
4852
pickler.pushIValue(opName);
4953
}
5054
pickler.endTuple();

torch/csrc/distributed/rpc/ScriptCall.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ namespace rpc {
1212

1313
using torch::jit::Operator;
1414

15-
#define BUILTIN_OP_NAMESPACE "torch.ops.aten."
16-
1715
// A ScriptCall instance represents an invocation of a builtin operator for a
1816
// TorchScript function (not implemented yet). If it is a builtin operator, it
1917
// contains a shared ptr to the `Operator` and a list of arguments.

torch/csrc/distributed/rpc/functions.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace distributed {
55
namespace rpc {
66

77
void processRequestBlocking(
8-
std::string from, Message&& request, RpcAgent& agent) {
8+
const std::string& from, Message&& request, RpcAgent& agent) {
99
switch (request.type()) {
1010
case MessageType::SCRIPT_CALL: {
1111
ScriptCall op = ScriptCall::fromMessage(request);
@@ -18,7 +18,7 @@ void processRequestBlocking(
1818

1919
auto response = ScriptRet(std::move(stack.front())).toMessage();
2020
response.setId(request.id());
21-
agent.send(std::move(from), std::move(response));
21+
agent.send(from, std::move(response));
2222
break;
2323
}
2424
default: {

torch/csrc/distributed/rpc/functions.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace distributed {
1111
namespace rpc {
1212

1313
void processRequestBlocking(
14-
std::string from, Message&& message, RpcAgent& agent);
14+
const std::string& from, Message&& message, RpcAgent& agent);
1515

1616
} // rpc
1717
} // distributed

torch/csrc/distributed/rpc/init.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@ PyObject* rpc_init(PyObject* /* unused */) {
5252

5353
module.def("invoke_rpc", [](
5454
RpcAgent& agent,
55-
std::string dstName,
56-
std::string opName,
57-
py::args args,
58-
py::kwargs kwargs) {
59-
return py_rpc(agent, std::move(dstName), std::move(opName), args, kwargs);
55+
const std::string& dstName,
56+
const std::string& opName,
57+
const py::args& args,
58+
const py::kwargs& kwargs) {
59+
return py_rpc(agent, dstName, opName, args, kwargs);
6060
});
6161

6262
Py_RETURN_TRUE;

0 commit comments

Comments
 (0)
0