8000 Update on "sync and async torch.distributed.rpc for builtin operators" · pytorch/pytorch@c0d4b14 · GitHub
[go: up one dir, main page]

Skip to content
< 8000 script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/sessions-1e75b15ae60a.js">

Commit c0d4b14

Browse files
committed
Update on "sync and async torch.distributed.rpc for builtin operators"
Features: * sync and async RPC for builtin operators * RpcAgent API * ProcessGroupAgent implementation Goal: This is the first PR for #23110, and there will be many followup ones. So let's focus on the overall API and code structure. Details like efficiency and error handling can be improved in future PRs. * have a minimum working and testable RPC implementation. * make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object. * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...). * support blocking and non-blocking RequestCallback * blocking means the callback won't return before sending out the response * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list. Differential Revision: [D15194693](https://our.internmc.facebook.com/intern/diff/D15194693/)
1 parent 97da154 commit c0d4b14

File tree

6 files changed

+43
-4
lines changed

6 files changed

+43
-4
lines changed

test/test_rpc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def wrapper(self):
3030
world_size=self.world_size, store=store)
3131
dist.init_rpc('worker%d' % self.rank)
3232
func(self)
33+
dist.destroy_rpc()
34+
dist.destroy_process_group(dist.group.WORLD)
3335

3436
return wrapper
3537

torch/csrc/distributed/rpc/ProcessGroupAgent.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,20 @@ ProcessGroupAgent::ProcessGroupAgent(
7878
}
7979

8080
ProcessGroupAgent::~ProcessGroupAgent() {
81-
//TORCH_CHECK(stop_, "Cannot destroy ProcessGroupAgent before shutdown.");
81+
if (!stop_) {
82+
AT_ERROR(stop_, "Must call ProcessGroupAgent::shutdown before destructor");
83+
}
84+
}
85+
8286

87+
void ProcessGroupAgent::shutdown() {
8388
// Every process i sends a SHUTDOWN message to process i + 1. This is
8489
// necessary for now because:
8590
// 1. There is no abort API for ProcessGroup::recvAnysource yet. We have to
8691
// feed it a message or kill the thread.
8792
// 2. A GLOO process cannot send message to itself. (there is an ongoing
8893
// effort to fix this problem).
94+
pg_->barrier()->wait();
8995
int dst = (pg_->getRank() + 1) % pg_->getSize();
9096 10000
enqueue(SendWork(dst, Message({}, {}, MessageType::SHUTDOWN)));
9197
std::unique_lock<std::mutex> lock(sendQueueMutex_);
@@ -96,8 +102,10 @@ ProcessGroupAgent::~ProcessGroupAgent() {
96102
workProduceCV_.notify_all();
97103
sendThread_.join();
98104
listenerThread_.join();
105+
pg_->barrier()->wait();
99106
}
100107

108+
101109
std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
102110
const std::string& to, Message&& message) {
103111

@@ -170,7 +178,7 @@ void ProcessGroupAgent::sendLoop() {
170178
}
171179

172180
void ProcessGroupAgent::listenLoop() {
173-
while (!stop_) {
181+
while (true) {
174182
// rank, tensor size
175183
std::vector<torch::Tensor> preamble = {torch::empty({2}, {torch::kInt64})};
176184
pg_->recvAnysource(preamble, pg_->getRank())->wait();

torch/csrc/distributed/rpc/ProcessGroupAgent.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class ProcessGroupAgent : public RpcAgent {
3838
std::shared_ptr<FutureMessage> send(
3939
const std::string& to, Message&& message) override;
4040

41+
void shutdown() override;
42+
4143
private:
4244
// put SendWork into a queue and notify the sendLoop thread
4345
void enqueue(SendWork work);

torch/csrc/distributed/rpc/RpcAgent.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@ class RpcAgent {
4747
virtual std::shared_ptr<FutureMessage> send(
4848
const std::string& to, Message&& message) = 0;
4949

50+
// This is a temporary solution to gracefully stop the listening loop.
51+
// ProcessGroupAgent does this by sending a SHUTDOWN message to the
52+
// (rank + 1) % world_size peer, which means we cannot create
53+
// ProcessGroupAgent with world_size == 1. We can drop this in the future when
54+
// we find a way to gracefully exit the blocking recvAnysource call.
55+
//
56+
// FIXME: putting its implementation in destructor sometimes causes
57+
// "Connection reset by peer" error. It seems somehow ProcessGroup object get
58+
// destructed before RpcAgent object?
59+
virtual void shutdown() = 0;
60+
5061
protected:
5162
const std::string workerName_;
5263
const RequestCallback cb_;

torch/csrc/distributed/rpc/init.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ PyObject* rpc_init(PyObject* /* unused */) {
2929

3030
auto module = py::handle(dist_module).cast<py::module>();
3131

32-
auto rpcAgent = shared_ptr_class_<RpcAgent>(module, "RpcAgent");
32+
auto rpcAgent = shared_ptr_class_<RpcAgent>(module, "RpcAgent")
33+
.def("shutdown",
34+
&RpcAgent::shutdown,
35+
py::call_guard<py::gil_scoped_release>());
3336

3437
auto futureMessage = shared_ptr_class_<FutureMessage>(module, "FutureMessage")
3538
.def("wait",
@@ -43,7 +46,10 @@ PyObject* rpc_init(PyObject* /* unused */) {
4346
module, "ProcessGroupAgent", rpcAgent)
4447
.def(py::init<std::string,
4548
std::unordered_map<std::string, int>,
46-
std::shared_ptr<::c10d::ProcessGroup>>());
49+
std::shared_ptr<::c10d::ProcessGroup>>())
50+
.def("shutdown",
51+
&ProcessGroupAgent::shutdown,
52+
py::call_guard<py::gil_scoped_release>());
4753

4854
module.def("invoke_rpc", [](
4955
RpcAgent& agent,

torch/distributed/rpc.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@ def _collect_worker_names(name, group):
3232
return names
3333

3434

35+
def destroy_rpc():
36+
r"""
37+
Destroy the local RPC agent. This is blocking until globally all RPC agents
38+
are destroyed.
39+
"""
40+
global _agent
41+
_agent.shutdown()
42+
_agent = None
43+
44+
3545
def init_rpc(name, backend='pg'):
3646
r"""
3747
Initialize the local RPC agent which immediately becomes ready to make and

0 commit comments

Comments
 (0)
0