8000 Simplify InstanceRec mutex · metacortex/tensorflow@a0b68d1 · GitHub
[go: up one dir, main page]

Skip to content

Commit a0b68d1

Browse files
crccwtensorflower-gardener
authored andcommitted
Simplify InstanceRec mutex
We no longer need to call RPCs during initializing InstanceRec, thus there's no need for the complicated mutex structure, and some async method can be changed to sync ones. PiperOrigin-RevId: 333920547 Change-Id: Ic82ffb714fb1b9e0c8fc27f6f78ed39f64783fd8
1 parent e41868c commit a0b68d1

File tree

5 files changed

+104
-299
lines changed

5 files changed

+104
-299
lines changed

tensorflow/core/common_runtime/collective_param_resolver_local.cc

Lines changed: 30 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@ limitations under the License.
3838

3939
namespace tensorflow {
4040

41-
void CollectiveParamResolverLocal::InstanceRec::WaitForOutMu(mutex_lock& lock) {
42-
while (!out_mu_available) out_cv.wait(lock);
43-
}
44-
4541
CollectiveParamResolverLocal::CollectiveParamResolverLocal(
4642
const ConfigProto& config, const DeviceMgr* dev_mgr,
4743
DeviceResolverInterface* dev_resolver, const string& task_name)
@@ -497,8 +493,7 @@ void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
497493
}
498494

499495
void CollectiveParamResolverLocal::InitInstanceSharedParams(
500-
const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
501-
const StatusCallback& done) TF_NO_THREAD_SAFETY_ANALYSIS {
496+
const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir) {
502497
std::vector<DeviceAttributes> attributes;
503498
ir->shared.instance = cp->instance;
504499
{
@@ -532,34 +527,7 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams(
532527
// GetDeviceAttributesAsync will use those fields to launch RPCs.
533528
CompleteTaskIsLocal(task_name_, &ir->shared);
534529

535-
// TODO(b/151232436): clean up the following code since we no longer need to
536-
// execute it in a callback.
537-
538-
// Because the callback may execute in a different thread, we release
539-
// ir->out_mu here. Before releasing, we mark it as unavailable for other
540-
// threads.
541-
ir->out_mu_available = false;
542-
const auto device_names = ir->shared.instance.device_names;
543-
const auto task_names = ir->shared.instance.task_names;
544-
ir->out_mu.unlock();
545-
auto complete_init = [this, gr, cp, ir, attributes, done](const Status& s)
546-
TF_EXCLUSIVE_LOCK_FUNCTION(ir->out_mu) {
547-
// Then we recover the lock in the callback thread
548-
// that will hold it through the rest of the call
549-
// chain. Signal the cv now, any waiting threads
550-
// will wake only when out_mu is released later.
551-
ir->out_mu.lock();
552-
DCHECK(!ir->out_mu_available);
553-
ir->out_mu_available = true;
554-
ir->out_cv.notify_all();
555-
if (s.ok()) {
556-
CompleteDefaultRanking(gr, cp, ir, attributes);
557-
done(Status::OK());
558-
} else {
559-
done(s);
560-
}
561-
};
562-
complete_init(Status::OK());
530+
CompleteDefaultRanking(gr, cp, ir, attributes);
563531
}
564532

565533
// NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks
@@ -597,46 +565,27 @@ void CollectiveParamResolverLocal::CompleteDefaultRanking(
597565
}
598566
}
599567

600-
void CollectiveParamResolverLocal::CallbackWithStatus(
601-
const InstanceRecCallback& done, InstanceRec* irec) {
602-
Status s;
603-
{
604-
mutex_lock l(irec->out_mu);
605-
irec->WaitForOutMu(l);
606-
s = irec->status;
607-
}
608-
done(s, irec);
609-
}
610-
611-
void CollectiveParamResolverLocal::FindInstanceRec(
612-
const GroupRec* gr, CollectiveParams* cp, const InstanceRecCallback& done) {
568+
CollectiveParamResolverLocal::InstanceRec*
569+
CollectiveParamResolverLocal::GetOrCreateInstanceRec(const GroupRec* gr,
570+
CollectiveParams* cp) {
613571
InstanceRec* irec = nullptr;
614-
bool exit_outside_locks = false;
615572
{
616-
bool found_instance = false;
617573
mutex_lock l(instance_mu_);
618574
auto group_it = instance_table_.find(gr->group.group_key);
619575
if (group_it != instance_table_.end()) {
620576
auto instance_it = group_it->second.find(cp->instance.instance_key);
621577
if (instance_it != group_it->second.end()) {
622578
irec = instance_it->second.get();
623-
{
624-
mutex_lock l(irec->in_mu);
625-
if (irec->is_init) {
626-
exit_outside_locks = true;
627-
} else {
628-
irec->init_waiters.push_back([this, done](InstanceRec* irec) {
629-
CallbackWithStatus(done, irec);
630-
});
631-
return;
632-
}
633-
}
634-
found_instance = true;
635579
}
636580
}
637-
if (!found_instance) {
581+
if (irec == nullptr) {
638582
// Create new InstanceRec.
639583
irec = new InstanceRec;
584+
{
585+
mutex_lock il(irec->mu);
586+
irec->known.resize(cp->group.group_size, false);
587+
}
588+
InitInstanceSharedParams(gr, cp, irec);
640589
instance_table_[gr->group.group_key][cp->instance.instance_key].reset(
641590
irec);
642591
}
@@ -647,65 +596,10 @@ void CollectiveParamResolverLocal::FindInstanceRec(
647596
status = status_;
648597
}
649598
if (!status.ok()) {
650-
mutex_lock il(irec->out_mu);
651-
irec->WaitForOutMu(il);
599+
mutex_lock l(irec->mu);
652600
irec->status = status;
653601
}
654-
if (exit_outside_locks) {
655-
CallbackWithStatus(done, irec);
656-
return;
657-
}
658-
659-
CallInitInstanceSharedParams(gr, cp, irec, done);
660-
}
661-
662-
void CollectiveParamResolverLocal::CallInitInstanceSharedParams(
663-
const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
664-
const InstanceRecCallback& done) TF_NO_THREAD_SAFETY_ANALYSIS {
665-
// This function serves merely to make a function call that should
666-
// be thread/mutex safe but violates the simple model applied by
667-
// static analysis, so we turn off analysis only within this
668-
// function body.
669-
//
670-
// A lock on ir->out_mu must be held* throughout the _bodies_ of the
671-
// chain of function calls initiated here, each of which calls
672-
// another as its last action, but it will be dropped within the
673-
// callback defined below, which means that the lock can be dropped
674-
// before all the function stack frames pop. The static analysis will
675-
// not allow that.
676-
//
677-
// *the lock is dropped just before calling GetDeviceAttributesAsync, because
678-
// there is no guarantee that the thread that executes the callback is the
679-
// same as the one that locked ir->out_mu. To prevent other threads from
680-
// grabbing ir->out_mu, we mark ir->out_mu_available as false. Hence, in
681-
// principle, the lock is held throughout.
682-
ir->out_mu.lock();
683-
DCHECK(ir->out_mu_available);
684-
ir->known.resize(cp->group.group_size, false);
685-
InitInstanceSharedParams(
686-
gr, cp, ir,
687-
[this, ir, done](const Status& s) TF_UNLOCK_FUNCTION(ir->out_mu) {
688-
DCHECK(ir->out_mu_available);
689-
ir->status.Update(s);
690-
ir->out_mu.unlock();
691-
// Prepare to invoke any waiters that accumulated during
692-
// initialization.
693-
std::vector<IRConsumer> init_waiters;
694-
{
695-
mutex_lock tl(instance_mu_);
696-
{
697-
mutex_lock l(ir->in_mu);
698-
ir->is_init = true;
699-
if (!ir->init_waiters.empty()) {
700-
std::swap(init_waiters, ir->init_waiters);
701-
}
702-
}
703-
}
704-
CallbackWithStatus(done, ir);
705-
for (auto& f : init_waiters) {
706-
f(ir);
707-
}
708-
});
602+
return irec;
709603
}
710604

711605
void CollectiveParamResolverLocal::CompleteParamsAsync(
@@ -766,29 +660,27 @@ void CollectiveParamResolverLocal::CompleteInstanceLocal(
766660
DCHECK_EQ(cp->group.device_type, gr->group.device_type);
767661
cp->group = gr->group;
768662

769-
// Get the shared InstanceRec for this instance.
770-
FindInstanceRec(gr, cp,
771-
[this, device, gr, cp, is_source, done](const Status& s,
772-
InstanceRec* ir) {
773-
if (s.ok()) {
774-
CompleteInstanceFromInitializedIRec(device, gr, cp, ir,
775-
is_source, done);
776-
} else {
777-
done(s);
778-
}
779-
});
663+
InstanceRec* ir = GetOrCreateInstanceRec(gr, cp);
664+
CompleteInstanceFromInitializedIRec(device, gr, cp, ir, is_source, done);
780665
}
781666

782667
void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
783668
const string& device, const GroupRec* gr, CollectiveParams* cp,
784669
InstanceRec* ir, bool is_source, const StatusCallback& done) {
785670
auto expected_shape = cp->instance.shape;
671+
Status status;
786672
// Populate the fields common across instance.
787673
{
788-
mutex_lock l(ir->out_mu);
789-
ir->WaitForOutMu(l);
790-
// custom operator= does a deep copy.
791-
cp->instance = ir->shared.instance;
674+
mutex_lock l(ir->mu);
675+
status = ir->status;
676+
if (status.ok()) {
677+
// custom operator= does a deep copy.
678+
cp->instance = ir->shared.instance;
679+
}
680+
}
681+
if (!status.ok()) {
682+
done(status);
683+
return;
792684
}
793685
if (expected_shape != cp->instance.shape) {
794686
done(errors::InvalidArgument(
@@ -806,7 +698,7 @@ void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
806698
CompleteTaskIsLocal(task_name_, cp);
807699

808700
CollectiveImplementationInterface* col_impl;
809-
Status status = CollectiveRegistry::LookupParamResolverInstance(
701+
status = CollectiveRegistry::LookupParamResolverInstance(
810702
cp->instance.impl_details.collective_name, &col_impl);
811703
if (!status.ok()) {
812704
done(status);
@@ -823,8 +715,7 @@ void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
823715
s = errors::Internal("Expected ir ", ir, " and irec ",
824716
irec, " to be equal");
825717
} else {
826-
mutex_lock l(irec->out_mu);
827-
irec->WaitForOutMu(l);
718+
mutex_lock l(irec->mu);
828719
s = irec->status;
829720
cp->source_rank = irec->source_rank;
830721
}
@@ -844,8 +735,7 @@ void CollectiveParamResolverLocal::WaitForGroup(InstanceRec* ir,
844735
const IRConsumer& f) {
845736
std::vector<IRConsumer> ready_waiters;
846737
do {
847-
mutex_lock l(ir->out_mu);
848-
ir->WaitForOutMu(l);
738+
mutex_lock l(ir->mu);
849739
if (!ir->status.ok()) {
850740
break;
851741
}
@@ -933,8 +823,7 @@ void CollectiveParamResolverLocal::StartAbortLocal(const Status& s) {
933823
for (InstanceRec* ir : instances) {
934824
std::vector<IRConsumer> known_waiters;
935825
{
936-
mutex_lock il(ir->out_mu);
937-
ir->WaitForOutMu(il);
826+
mutex_lock il(ir->mu);
938827
ir->status = s;
939828
known_waiters.swap(ir->known_waiters);
940829
}

0 commit comments

Comments
 ( 2F93 0)
0