8000 Update on "Add Python RRef as args and return value" · pytorch/pytorch@3c054b4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3c054b4

Browse files
committed
Update on "Add Python RRef as args and return value"
See #23110 for model parallel design details, and #26759 for the RRef protocol. This commit add support for using RRef as Python UDF arguments and return value. RRefs can now be shared from owner to user, from user to owner, or from user to user. Limitations: 1. No implicit type conversion yet. 2. No failure handling and retry. 3. UDF is not yet blocked until all RRefs are confirmed. 4. Internal RRef control messages are not idempotent yet. Main changes: 1. Added `SCRIPT_REMOTE_CALL` and `PYTHON_REMOTE_CALL` to `Message.h` to represent `dist.remote` invocations. 2. Added `SCRIPT_RREF_FETCH_CALL`, `PYTHON_RREF_FETCH_CALL`, `RREF_USER_ACCEPT`, `RREF_USER_DELETE`, `RREF_CHILD_ACCEPT`, and `RREF_FORK_REQUEST` to `Message.h` as internal RRef control messages. 3. New message request handling code is added to `functions.cpp`, and message format is added in `script_remote_call.h`, `python_remote_call.h`, and `rref_proto.h`. 4. Added a `PyRRef` type in `py_rref.h` and `py_rref.cpp` which holds a shared pointer to C++ `RRef` type. `PyRRef` wraps the C++ API and also implements RRef pickling and unpickling. RRef fork related control messages will be sent during RRef pickling/unpickling procedure. 5. Update `RRef.h` and `RRef.cpp` accordingly to support `py::object` RRefs. 6. RRef context (reference count, etc.) are tracked in `rref_context.h` and `rref_context.cpp`. Differential Revision: [D17184146](https://our.internmc.facebook.com/intern/diff/D17184146) [ghstack-poisoned]
2 parents 497d3c4 + 8735e1f commit 3c054b4

File tree

3 files changed

+20
-40
lines changed

3 files changed

+20
-40
lines changed

torch/csrc/distributed/rpc/python_functions.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ PyRRef pyRemoteBuiltin(
116116
ScriptRemoteCall(
117117
op, std::move(stack), userRRef->rrefId(), userRRef->forkId())
118118
.toMessage());
119+
120+
ctx->addPendingUser(userRRef->forkId(), userRRef);
119121
fm->addCallback(finishAcceptUserRRef);
120122
return PyRRef(userRRef);
121123
}
@@ -150,6 +152,7 @@ PyRRef pyRemotePythonUdf(
150152
userRRef->forkId().toIValue())
151153
.toMessage());
152154

155+
ctx->addPendingUser(userRRef->forkId(), userRRef);
153156
fm->addCallback(finishAcceptUserRRef);
154157
return PyRRef(userRRef);
155158
}

torch/csrc/distributed/rpc/rref_context.cpp

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,11 @@ std::shared_ptr<UserRRef<T>> RRefContext::createUserRRef(
7676
// RRefContext does not track user RRefs, it will be destructed when there
7777
// is no shared_ptrs pointing to it. NB: cannot use make_shared here as the
7878
// constructor of UserRRef is private
79-
auto userRRef =
79+
return
8080
std::shared_ptr<UserRRef<T>>(new UserRRef<T>(ownerId, rrefId, forkId));
81-
if (forkId.createdOn_ != ownerId) {
82-
addPendingUser(forkId, userRRef);
83-
}
84-
return userRRef;
81+
//if (forkId.createdOn_ != ownerId) {
82+
// addPendingUser(forkId, userRRef);
83+
//}
8584
}
8685

8786
template std::shared_ptr<UserRRef<IValue>> RRefContext::createUserRRef<IValue>(
@@ -167,13 +166,7 @@ RRefForkData RRefContext::prepareChildFork(const std::shared_ptr<RRef>& rref) {
167166
// at the owner before this dist.rpc or dist.remote call, which could
168167
// potentially trigger the `OwnerRRef` to be deleted before running the
169168
// user code.
170-
if (rref->isPyObj()) {
171-
addPendingChild(
172-
rfd.forkId_, std::static_pointer_cast<UserRRef<py::object>>(rref));
173-
} else {
174-
addPendingChild(
175-
rfd.forkId_, std::static_pointer_cast<UserRRef<IValue>>(rref));
176-
}
169+
addPendingChild(rfd.forkId_, rref);
177170
}
178171
return rfd;
179172
}
@@ -203,6 +196,7 @@ void RRefContext::notifyOwnerAndParentOfFork(
203196
agent_->getWorkerInfo(rref->owner()),
204197
RRefForkRequest(rref->rrefId(), forkId).toMessage());
205198

199+
addPendingUser(forkId, rref);
206200
fm->addCallback([this, forkId, parent](const Message& message) {
207201
handleException(message);
208202
this->finishForkRequest(forkId, parent);
@@ -211,25 +205,21 @@ void RRefContext::notifyOwnerAndParentOfFork(
211205

212206
}
213207

214-
template <typename T>
215208
void RRefContext::addPendingChild(
216209
const ForkId& forkId,
217-
const std::shared_ptr<UserRRef<T>>& rref) {
210+
const std::shared_ptr<RRef>& rref) {
211+
// see Note [Early Fork Registration]
212+
// If the parent is the owner, it should directly adding the child UserRRef
213+
// as a fork.
214+
TORCH_INTERNAL_ASSERT(!rref->isOwner(),
215+
"OwnerRRef should not have a pending child.");
218216
std::lock_guard<std::mutex> lock(mutex_);
219217
TORCH_INTERNAL_ASSERT(
220218
pendingChildren_.find(forkId) == pendingChildren_.end(),
221219
"Inconsistent states: attempt to add the same child fork twice.");
222220
pendingChildren_[forkId] = rref;
223221
}
224222

225-
template void RRefContext::addPendingChild<IValue>(
226-
const ForkId& forkId,
227-
const std::shared_ptr<UserRRef<IValue>>& rref);
228-
229-
template void RRefContext::addPendingChild<py::object>(
230-
const ForkId& forkId,
231-
const std::shared_ptr<UserRRef<py::object>>& rref);
232-
233223
void RRefContext::delPendingChild(const ForkId& forkId) {
234224
std::lock_guard<std::mutex> lock(mutex_);
235225
auto iter = pendingChildren_.find(forkId);
@@ -239,25 +229,18 @@ void RRefContext::delPendingChild(const ForkId& forkId) {
239229
pendingChildren_.erase(iter);
240230
}
241231

242-
template <typename T>
243232
void RRefContext::addPendingUser(
244233
const ForkId& forkId,
245-
const std::shared_ptr<UserRRef<T>>& rref) {
234+
const std::shared_ptr<RRef>& rref) {
235+
TORCH_INTERNAL_ASSERT(!rref->isOwner(),
236+
"Attempt to add an OwnerRRef as a pending User.");
246237
std::lock_guard<std::mutex> lock(mutex_);
247238
TORCH_INTERNAL_ASSERT(
248239
pendingUsers_.find(forkId) == pendingUsers_.end(),
249240
"Inconsistent states: attempt to add the same UserRRef twice.");
250241
pendingUsers_[forkId] = rref;
251242
}
252243

253-
template void RRefContext::addPendingUser<IValue>(
254-
const ForkId& forkId,
255-
const std::shared_ptr<UserRRef<IValue>>& rref);
256-
257-
template void RRefContext::addPendingUser<py::object>(
258-
const ForkId& forkId,
259-
const std::shared_ptr<UserRRef<py::object>>& rref);
260-
261244
void RRefContext::delPendingUser(const ForkId& forkId) {
262245
std::lock_guard<std::mutex> lock(mutex_);
263246
auto iter = pendingUsers_.find(forkId);

torch/csrc/distributed/rpc/rref_context.h

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,12 @@ class RRefContext {
8686
// previously submitted rpc/remote calls are acked before sending out the
8787
// RREF_USER_DELETE message. Otherwise, the OwnerRRef could be deleted too
8888
// soon.
89-
template <typename T>
90-
void addPendingChild(
91-
const ForkId& forkId,
92-
const std::shared_ptr<UserRRef<T>>& rref);
89+
void addPendingChild(const ForkId& forkId, const std::shared_ptr<RRef>& rref);
9390
void delPendingChild(const ForkId& forkId);
9491

9592
// When a UserRRef is created, it is added into pendingUsers_ to be held alive
9693
// until it receives RREF_USER_ACCEPT from the owner.
97-
template <typename T>
98-
void addPendingUser(
99-
const ForkId& forkId,
100-
const std::shared_ptr<UserRRef<T>>& rref);
94+
void addPendingUser(const ForkId& forkId, const std::shared_ptr<RRef>& rref);
10195
void delPendingUser(const ForkId& forkId);
10296

10397
// If there is any leak on any RRef, this method will throw an error.

0 commit comments

Comments
 (0)
0