8000 Update on "sync and async torch.distributed.rpc for builtin operators" · pytorch/pytorch@052ba85 · GitHub
[go: up one dir, main page]

Skip to content

Commit 052ba85

Browse files
committed
Update on "sync and async torch.distributed.rpc for builtin operators"
Features: * sync and async RPC for builtin operators * RpcAgent API * ProcessGroupAgent implementation Goal: This is the first PR for #23110, and there will be many followup ones. So let's focus on the overall API and code structure. Details like efficiency and error handling can be improved in future PRs. * have a minimum working and testable RPC implementation. * make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object. * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...). * support blocking and non-blocking RequestCallback * blocking means the callback won't return before sending out the response * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list. Differential Revision: [D15194693](https://our.internmc.facebook.com/intern/diff/D15194693/)
1 parent 507a1e5 commit 052ba85

File tree

3 files changed

+8
-9
lines changed

3 files changed

+8
-9
lines changed

torch/csrc/distributed/rpc/ScriptCall.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,10 @@ using torch::jit::Unpickler;
1111

1212
} // namespace
1313

14-
#define BUILTIN_OP_NAMESPACE "torch.ops.aten."
15-
#define ATEN_PREFIX "aten::"
16-
#define ATEN_PREFIX_LEN 6
17-
1814
ScriptCall::ScriptCall(
1915
std::shared_ptr<Operator> op, std::vector<at::IValue>&& args)
2016
: op_(std::move(op)), stack_(args) {}
2117

22-
2318
std::shared_ptr<Operator> ScriptCall::op() const {
2419
return *op_;
2520
}
@@ -46,9 +41,9 @@ Message ScriptCall::toMessage() {
4641
// insert qualified name
4742
auto opName = (*op_)->schema().name();
4843
TORCH_CHECK(opName.find("::") == opName.rfind("::")
49-
&& opName.rfind(ATEN_PREFIX) == 0, "Unexpected operator name ", opName);
44+
&& opName.rfind(ATEN_PREFIX_) == 0, "Unexpected operator name ", opName);
5045
// aten::add -> torch.ops.aten.add
51-
opName.replace(0, ATEN_PREFIX_LEN, BUILTIN_OP_NAMESPACE);
46+
opName.replace(0, ATEN_PREFIX_LEN_, BUILTIN_OP_NAMESPACE_);
5247
pickler.pushIValue(opName);
5348
}
5449
pickler.endTuple();
@@ -71,7 +66,7 @@ ScriptCall ScriptCall::fromMessage(const Message& message) {
7166
"contain one IValue as the operator schema.");
7267

7368
const std::string& qualifiedName = values.back().toStringRef();
74-
if (qualifiedName.rfind(BUILTIN_OP_NAMESPACE) == 0) {
69+
if (qualifiedName.rfind(BUILTIN_OP_NAMESPACE_) == 0) {
7570
values.pop_back();
7671

7772
const std::string& str_schema = values.back().toStringRef();

torch/csrc/distributed/rpc/ScriptCall.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ class TORCH_API ScriptCall final {
3232
static std::shared_ptr<Operator> matchOperator(
3333
at::Symbol& symbol, const std::string& str_schema);
3434

35+
static constexpr char BUILTIN_OP_NAMESPACE_[] = "torch.ops.aten.";
36+
static constexpr char ATEN_PREFIX_[] = "aten::";
37+
static constexpr int ATEN_PREFIX_LEN_ = 6;
38+
3539
// This field has value if this ScriptCall represents invocation of a builtin
3640
// operator.
3741
c10::optional<std::shared_ptr<Operator>> op_;

torch/csrc/distributed/rpc/init.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ PyObject* rpc_init(PyObject* /* unused */) {
6464

6565
} // namespace
6666

67-
static PyMethodDef methods[] = {
67+
static PyMethodDef methods[] = { // NOLINT
6868
{"_rpc_init", (PyCFunction)rpc_init, METH_NOARGS, nullptr},
6969
{nullptr, nullptr, 0, nullptr}};
7070

0 commit comments

Comments
 (0)
0