@@ -38,10 +38,6 @@ limitations under the License.
38
38
39
39
namespace tensorflow {
40
40
41
- void CollectiveParamResolverLocal::InstanceRec::WaitForOutMu (mutex_lock& lock) {
42
- while (!out_mu_available) out_cv.wait (lock);
43
- }
44
-
45
41
CollectiveParamResolverLocal::CollectiveParamResolverLocal (
46
42
const ConfigProto& config, const DeviceMgr* dev_mgr,
47
43
DeviceResolverInterface* dev_resolver, const string& task_name)
@@ -497,8 +493,7 @@ void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
497
493
}
498
494
499
495
void CollectiveParamResolverLocal::InitInstanceSharedParams (
500
- const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
501
- const StatusCallback& done) TF_NO_THREAD_SAFETY_ANALYSIS {
496
+ const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir) {
502
497
std::vector<DeviceAttributes> attributes;
503
498
ir->shared .instance = cp->instance ;
504
499
{
@@ -532,34 +527,7 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams(
532
527
// GetDeviceAttributesAsync will use those fields to launch RPCs.
533
528
CompleteTaskIsLocal (task_name_, &ir->shared );
534
529
535
- // TODO(b/151232436): clean up the following code since we no longer need to
536
- // execute it in a callback.
537
-
538
- // Because the callback may execute in a different thread, we release
539
- // ir->out_mu here. Before releasing, we mark it as unavailable for other
540
- // threads.
541
- ir->out_mu_available = false ;
542
- const auto device_names = ir->shared .instance .device_names ;
543
- const auto task_names = ir->shared .instance .task_names ;
544
- ir->out_mu .unlock ();
545
- auto complete_init = [this , gr, cp, ir, attributes, done](const Status& s)
546
- TF_EXCLUSIVE_LOCK_FUNCTION (ir->out_mu ) {
547
- // Then we recover the lock in the callback thread
548
- // that will hold it through the rest of the call
549
- // chain. Signal the cv now, any waiting threads
550
- // will wake only when out_mu is released later.
551
- ir->out_mu .lock ();
552
- DCHECK (!ir->out_mu_available );
553
- ir->out_mu_available = true ;
554
- ir->out_cv .notify_all ();
555
- if (s.ok ()) {
556
- CompleteDefaultRanking (gr, cp, ir, attributes);
557
- done (Status::OK ());
558
- } else {
559
- done (s);
560
- }
561
- };
562
- complete_init (Status::OK ());
530
+ CompleteDefaultRanking (gr, cp, ir, attributes);
563
531
}
564
532
565
533
// NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks
@@ -597,46 +565,27 @@ void CollectiveParamResolverLocal::CompleteDefaultRanking(
597
565
}
598
566
}
599
567
600
- void CollectiveParamResolverLocal::CallbackWithStatus (
601
- const InstanceRecCallback& done, InstanceRec* irec) {
602
- Status s;
603
- {
604
- mutex_lock l (irec->out_mu );
605
- irec->WaitForOutMu (l);
606
- s = irec->status ;
607
- }
608
- done (s, irec);
609
- }
610
-
611
- void CollectiveParamResolverLocal::FindInstanceRec (
612
- const GroupRec* gr, CollectiveParams* cp, const InstanceRecCallback& done) {
568
+ CollectiveParamResolverLocal::InstanceRec*
569
+ CollectiveParamResolverLocal::GetOrCreateInstanceRec (const GroupRec* gr,
570
+ CollectiveParams* cp) {
613
571
InstanceRec* irec = nullptr ;
614
- bool exit_outside_locks = false ;
615
572
{
616
- bool found_instance = false ;
617
573
mutex_lock l (instance_mu_);
618
574
auto group_it = instance_table_.find (gr->group .group_key );
619
575
if (group_it != instance_table_.end ()) {
620
576
auto instance_it = group_it->second .find (cp->instance .instance_key );
621
577
if (instance_it != group_it->second .end ()) {
622
578
irec = instance_it->second .get ();
623
- {
624
- mutex_lock l (irec->in_mu );
625
- if (irec->is_init ) {
626
- exit_outside_locks = true ;
627
- } else {
628
- irec->init_waiters .push_back ([this , done](InstanceRec* irec) {
629
- CallbackWithStatus (done, irec);
630
- });
631
- return ;
632
- }
633
- }
634
- found_instance = true ;
635
579
}
636
580
}
637
- if (!found_instance ) {
581
+ if (irec == nullptr ) {
638
582
// Create new InstanceRec.
639
583
irec = new InstanceRec;
584
+ {
585
+ mutex_lock il (irec->mu );
586
+ irec->known .resize (cp->group .group_size , false );
587
+ }
588
+ InitInstanceSharedParams (gr, cp, irec);
640
589
instance_table_[gr->group .group_key ][cp->instance .instance_key ].reset (
641
590
irec);
642
591
}
@@ -647,65 +596,10 @@ void CollectiveParamResolverLocal::FindInstanceRec(
647
596
status = status_;
648
597
}
649
598
if (!status.ok ()) {
650
- mutex_lock il (irec->out_mu );
651
- irec->WaitForOutMu (il);
599
+ mutex_lock l (irec->mu );
652
600
irec->status = status;
653
601
}
654
- if (exit_outside_locks) {
655
- CallbackWithStatus (done, irec);
656
- return ;
657
- }
658
-
659
- CallInitInstanceSharedParams (gr, cp, irec, done);
660
- }
661
-
662
- void CollectiveParamResolverLocal::CallInitInstanceSharedParams (
663
- const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
664
- const InstanceRecCallback& done) TF_NO_THREAD_SAFETY_ANALYSIS {
665
- // This function serves merely to make a function call that should
666
- // be thread/mutex safe but violates the simple model applied by
667
- // static analysis, so we turn off analysis only within this
668
- // function body.
669
- //
670
- // A lock on ir->out_mu must be held* throughout the _bodies_ of the
671
- // chain of function calls initiated here, each of which calls
672
- // another as its last action, but it will be dropped within the
673
- // callback defined below, which means that the lock can be dropped
674
- // before all the function stack frames pop. The static analysis will
675
- // not allow that.
676
- //
677
- // *the lock is dropped just before calling GetDeviceAttributesAsync, because
678
- // there is no guarantee that the thread that executes the callback is the
679
- // same as the one that locked ir->out_mu. To prevent other threads from
680
- // grabbing ir->out_mu, we mark ir->out_mu_available as false. Hence, in
681
- // principle, the lock is held throughout.
682
- ir->out_mu .lock ();
683
- DCHECK (ir->out_mu_available );
684
- ir->known .resize (cp->group .group_size , false );
685
- InitInstanceSharedParams (
686
- gr, cp, ir,
687
- [this , ir, done](const Status& s) TF_UNLOCK_FUNCTION (ir->out_mu ) {
688
- DCHECK (ir->out_mu_available );
689
- ir->status .Update (s);
690
- ir->out_mu .unlock ();
691
- // Prepare to invoke any waiters that accumulated during
692
- // initialization.
693
- std::vector<IRConsumer> init_waiters;
694
- {
695
- mutex_lock tl (instance_mu_);
696
- {
697
- mutex_lock l (ir->in_mu );
698
- ir->is_init = true ;
699
- if (!ir->init_waiters .empty ()) {
700
- std::swap (init_waiters, ir->init_waiters );
701
- }
702
- }
703
- }
704
- CallbackWithStatus (done, ir);
705
- for (auto & f : init_waiters) {
706
- f (ir);
707
- }
708
- });
602
+ return irec;
709
603
}
710
604
711
605
void CollectiveParamResolverLocal::CompleteParamsAsync (
@@ -766,29 +660,27 @@ void CollectiveParamResolverLocal::CompleteInstanceLocal(
766
660
DCHECK_EQ (cp->group .device_type , gr->group .device_type );
767
661
cp->group = gr->group ;
768
662
769
- // Get the shared InstanceRec for this instance.
770
- FindInstanceRec (gr, cp,
771
- [this , device, gr, cp, is_source, done](const Status& s,
772
- InstanceRec* ir) {
773
- if (s.ok ()) {
774
- CompleteInstanceFromInitializedIRec (device, gr, cp, ir,
775
- is_source, done);
776
- } else {
777
- done (s);
778
- }
779
- });
663
+ InstanceRec* ir = GetOrCreateInstanceRec (gr, cp);
664
+ CompleteInstanceFromInitializedIRec (device, gr, cp, ir, is_source, done);
780
665
}
781
666
782
667
void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec (
783
668
const string& device, const GroupRec* gr, CollectiveParams* cp,
784
669
InstanceRec* ir, bool is_source, const StatusCallback& done) {
785
670
auto expected_shape = cp->instance .shape ;
671
+ Status status;
786
672
// Populate the fields common across instance.
787
673
{
788
- mutex_lock l (ir->out_mu );
789
- ir->WaitForOutMu (l);
790
- // custom operator= does a deep copy.
791
- cp->instance = ir->shared .instance ;
674
+ mutex_lock l (ir->mu );
675
+ status = ir->status ;
676
+ if (status.ok ()) {
677
+ // custom operator= does a deep copy.
678
+ cp->instance = ir->shared .instance ;
F42D
div>
679
+ }
680
+ }
681
+ if (!status.ok ()) {
682
+ done (status);
683
+ return ;
792
684
}
793
685
if (expected_shape != cp->instance .shape ) {
794
686
done (errors::InvalidArgument (
@@ -806,7 +698,7 @@ void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
806
698
CompleteTaskIsLocal (task_name_, cp);
807
699
808
700
CollectiveImplementationInterface* col_impl;
809
- Status status = CollectiveRegistry::LookupParamResolverInstance (
701
+ status = CollectiveRegistry::LookupParamResolverInstance (
810
702
cp->instance .impl_details .collective_name , &col_impl);
811
703
if (!status.ok ()) {
812
704
done (status);
@@ -823,8 +715,7 @@ void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
823
715
s = errors::Internal (" Expected ir " , ir, " and irec " ,
824
716
irec, " to be equal" );
825
717
} else {
826
- mutex_lock l (irec->out_mu );
827
- irec->WaitForOutMu (l);
718
+ mutex_lock l (irec->mu );
828
719
s = irec->status ;
829
720
cp->source_rank = irec->source_rank ;
830
721
}
@@ -844,8 +735,7 @@ void CollectiveParamResolverLocal::WaitForGroup(InstanceRec* ir,
844
735
const IRConsumer& f) {
845
736
std::vector<IRConsumer> ready_waiters;
846
737
do {
847
- mutex_lock l (ir->out_mu );
848
- ir->WaitForOutMu (l);
738
+ mutex_lock l (ir->mu );
849
739
if (!ir->status .ok ()) {
850
740
break ;
851
741
}
@@ -933,8 +823,7 @@ void CollectiveParamResolverLocal::StartAbortLocal(const Status& s) {
933
823
for (InstanceRec* ir : instances) {
934
824
std::vector<IRConsumer> known_waiters;
935
825
{
936
- mutex_lock il (ir->out_mu );
937
- ir->WaitForOutMu (il);
826
+ mutex_lock il (ir->mu );
938
827
ir->status = s;
939
828
known_waiters.swap (ir->known_waiters );
940
829
}
0 commit comments