10000 [XLA] Enable option to schedule better annotations using all availabl… · IBMZ-Linux-OSS-Python/tensorflow@ea0f180 · GitHub
[go: up one dir, main page]

Skip to content

Commit ea0f180

Browse files
Marcello Maggionitensorflower-gardener
authored andcommitted
[XLA] Enable option to schedule better annotations using all available room
PiperOrigin-RevId: 766244600
1 parent eb57f13 commit ea0f180

File tree

3 files changed

+80
-5
lines changed

3 files changed

+80
-5
lines changed

third_party/xla/xla/service/latency_hiding_scheduler.cc

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,6 +1843,7 @@ absl::Status DefaultSchedulerCore::ScheduleAnnotation(
18431843
}
18441844
}
18451845
int64_t num_scheduled = 0;
1846+
int64_t non_ready_instr = 0;
18461847
int64_t annotation_size =
18471848
annotation_tracker_->GetNumInstructions(computation, annotation);
18481849
while (!sched_state->annotation_ready.empty()) {
@@ -1866,6 +1867,17 @@ absl::Status DefaultSchedulerCore::ScheduleAnnotation(
18661867

18671868
TF_RET_CHECK(node != nullptr)
18681869
<< "Couldn't find an annotated node to schedule.";
1870+
// Delay last instruction of annotation maybe.
1871+
if (config_.flexible_scheduling_annotation_scheduling &&
1872+
num_scheduled == annotation_size - 1 &&
1873+
async_tracker_->IsSupportedAsyncStart(node->GetInstr())) {
1874+
// Give instruction back to the scheduler to schedule.
1875+
VLOG(2) << "Non ready instr: " << node->GetInstr().name();
1876+
++non_ready_instr;
1877+
node->ClearAnnotation();
1878+
sched_state->nodes_holding_annotations.insert(node);
1879+
continue;
1880+
}
18691881
// Delete the node from the ready set.
18701882
auto node_it = std::find(sched_state->ready_set.begin(),
18711883
sched_state->ready_set.end(), node);
@@ -1883,7 +1895,7 @@ absl::Status DefaultSchedulerCore::ScheduleAnnotation(
18831895
<< annotation_size << "): " << node->GetInstr().name();
18841896
}
18851897
// Check that we scheduled all the nodes in the annotation.
1886-
TF_RET_CHECK(num_scheduled == annotation_size)
1898+
TF_RET_CHECK(num_scheduled == annotation_size - non_ready_instr)
18871899
<< "Couldn't schedule all annotated nodes in one go.";
18881900
return absl::OkStatus();
18891901
}
@@ -1910,6 +1922,7 @@ absl::StatusOr<HloGraphNode::TimeCost> DefaultSchedulerCore::ScheduleNode(
19101922
sched_state->new_sequence_reversed.push_back(
19111923
const_cast<HloInstruction*>(&n->GetInstr()));
19121924
n->SetScheduled();
1925+
sched_state->nodes_holding_annotations.erase(n);
19131926

19141927
// If this node was a successor to one or more scheduling groups, update the
19151928
// number of scheduled successors for each of those groups and add the group
@@ -2691,7 +2704,8 @@ DefaultSchedulerCore::ScheduleComputation(const HloComputation* computation) {
26912704
};
26922705
return absl::StrJoin(sched_state.ready_set, "\n", LogFormatter());
26932706
}());
2694-
if (!sched_state.ready_annotations.empty()) {
2707+
if (!sched_state.ready_annotations.empty() &&
2708+
sched_state.nodes_holding_annotations.empty()) {
26952709
// Pick the first ready annotation whose scheduling will not cross the
26962710
// overlap limit. If there is no such annotation, continue with scheduling
26972711
// non-annotated ops.

third_party/xla/xla/service/latency_hiding_scheduler.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ struct SchedulerConfig {
137137
int64_t send_recv_host_overlap_limit = 1;
138138
int64_t copy_overlap_limit = 1;
139139
uint64_t memory_limit = UINT64_MAX;
140+
int64_t max_hops_to_closest_selective_overlap = 0;
141+
int64_t rerun = 0;
142+
int64_t parallel_collective_overlap_limit = 1;
140143
bool schedule_send_recvs = false;
141144
// Consider send recv as the same resource. Some platforms do not take well
142145
// overlapping the send/recv ops between themselves.
@@ -149,9 +152,7 @@ struct SchedulerConfig {
149152
bool resource_serializing = false;
150153
bool depth_based_memory_pressure_reduction = false;
151154
bool enable_selective_resources = false;
152-
int64_t max_hops_to_closest_selective_overlap = 0;
153-
int64_t rerun = 0;
154-
int64_t parallel_collective_overlap_limit = 1;
155+
bool flexible_scheduling_annotation_scheduling = false;
155156
};
156157

157158
// Class used estimate latency between instructions and cost of HLOs.
@@ -682,6 +683,7 @@ class HloGraphNode {
682683
annotation_ = annotation;
683684
return absl::OkStatus();
684685
}
686+
void ClearAnnotation() { annotation_ = -1; }
685687
std::string ToString(const AsyncTracker* async_tracker = nullptr) const {
686688
std::string result;
687689
absl::StrAppend(&result, "Instr: ", instr_->ToShortString(), "\n");
@@ -1244,6 +1246,9 @@ class DefaultSchedulerCore : public SchedulerCore {
12441246
ReadyQueueSet annotation_ready;
12451247
// Annotation that is currently being scheduled.
12461248
int64_t ongoing_annotation = kInvalidAnnotation;
1249+
// If this set is not empty it means that we shouldn't schedule any more
1250+
// annotated nodes until empty.
1251+
absl::flat_hash_set<HloGraphNode*> nodes_holding_annotations;
12471252
// Reference to this scheduler run configuration.
12481253
const SchedulerConfig& config;
12491254
SchedulingState(const HloInstructionSequence* instr_sequence,

third_party/xla/xla/service/latency_hiding_scheduler_test.cc

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4444,4 +4444,60 @@ TEST_F(LatencyHidingSchedulerTest, ValidScheduleWithRandomPreferences) {
44444444
// schedule.
< 6D40 /td>
44454445
TF_EXPECT_OK(hlo_module->schedule().Verify());
44464446
}
4447+
// Check that "keep_original_sequence_order_in_group" frontend attribute takes
4448+
// effect.
4449+
TEST_F(LatencyHidingSchedulerTest, FlexibleSchedulingAnnotationScheduling) {
4450+
absl::string_view hlo_string = R"(
4451+
HloModule module, is_scheduled=true
4452+
4453+
ENTRY entry {
4454+
p0 = f32[16,64,256]{2,1,0} parameter(0)
4455+
p1 = f32[128,2048,2048]{2,1,0} parameter(1)
4456+
p2 = f32[512,2048,2048]{2,1,0} parameter(2)
4457+
p3 = f32[16,256,256]{2,1,0} parameter(3)
4458+
cp1s = (f32[512,2048,2048]{2,1,0}, f32[512,2048,2048]{2,1,0}, u32[], u32[]) collective-permute-start(p2), source_target_pairs={{1,0},{0,3},{3,2}}, frontend_attributes={_scheduling_group_id="0", keep_original_sequence_order_in_group="true"}
4459+
cp1d = f32[512,2048,2048]{2,1,0} collective-permute-done(cp1s), frontend_attributes={_scheduling_group_id="0", keep_original_sequence_order_in_group="true"}
4460+
cp2s = (f32[128,2048,2048]{2,1,0}, f32[128,2048,2048]{2,1,0}, u32[], u32[]) collective-permute-start(p1), source_target_pairs={{1,0},{0,3},{3,2}}, frontend_attributes={_scheduling_group_id="0", keep_original_sequence_order_in_group="true"}
4461+
c0 = f32[16,256,256]{2,1,0} convolution(p0, p0),
4462+
window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb, frontend_attributes={_scheduling_group_id="0", keep_original_sequence_order_in_group="true"}
4463+
c1 = f32[16,256,256]{2,1,0} convolution(p3, p3),
4464+
window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb
4465+
cp2d = f32[128,2048,2048]{2,1,0} collective-permute-done(cp2s), frontend_attributes={_scheduling_group_id="0", keep_original_sequence_order_in_group="true"}
4466+
ROOT tuple.2 = (f32[16,256,256]{2,1,0}, f32[512,2048,2048]{2,1,0}, f32[16,256,256]{2,1,0}) tuple(c0, cp1d, c1)
4467+
}
4468+
)";
4469+
4470+
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string));
4471+
HloSchedule& module_schedule = hlo_module->schedule();
4472+
EXPECT_TRUE(hlo_module->has_entry_computation());
4473+
auto sched_config = GetDefaultSchedConfig();
4474+
sched_config.flexible_scheduling_annotation_scheduling = true;
4475+
sched_config.aggressive_scheduling_policies = true;
4476+
TF_EXPECT_OK(RunScheduler(hlo_module.get(), sched_config,
4477+
std::make_unique<TestLatencyEstimator>()));
4478+
EXPECT_TRUE(hlo_module->has_entry_computation());
4479+
4480+
std::vector<HloInstruction*> new_instruction_sequence =
4481+
module_schedule.sequence(hlo_module->entry_computation()).instructions();
4482+
if (VLOG_IS_ON(1)) {
4483+
for (auto* new_i : new_instruction_sequence) {
4484+
VLOG(1) << new_i->ToString();
4485+
}
4486+
}
4487+
4488+
// Check that the original sequence order is kept in the annotation group.
4489+
EXPECT_LT(GetIndex(new_instruction_sequence, "cp1s"),
4490+
GetIndex(new_instruction_sequence, "c1"));
4491+
EXPECT_LT(GetIndex(new_instruction_sequence, "c1"),
4492+
GetIndex(new_instruction_sequence, "c0"));
4493+
EXPECT_LT(GetIndex(new_instruction_sequence, "cp1s"),
4494+
GetIndex(new_instruction_sequence, "cp1d"));
4495+
EXPECT_LT(GetIndex(new_instruction_sequence, "cp1d"),
4496+
GetIndex(new_instruction_sequence, "cp2s"));
4497+
EXPECT_LT(GetIndex(new_instruction_sequence, "cp2s"),
4498+
GetIndex(new_instruction_sequence, "c0"));
4499+
EXPECT_LT(GetIndex(new_instruction_sequence, "c0"),
4500+
GetIndex(new_instruction_sequence, "cp2d"));
4501+
}
4502+
44474503
} // namespace xla

0 commit comments

Comments
 (0)
0