@@ -4444,4 +4444,60 @@ TEST_F(LatencyHidingSchedulerTest, ValidScheduleWithRandomPreferences) {
4444
4444
// schedule.
<
6D40
/td>4445
4445
TF_EXPECT_OK (hlo_module->schedule ().Verify ());
4446
4446
}
4447
+ // Check that "keep_original_sequence_order_in_group" frontend attribute takes
4448
+ // effect.
4449
+ TEST_F (LatencyHidingSchedulerTest, FlexibleSchedulingAnnotationScheduling) {
4450
+ absl::string_view hlo_string = R"(
4451
+ HloModule module, is_scheduled=true
4452
+
4453
+ ENTRY entry {
4454
+ p0 = f32[16,64,256]{2,1,0} parameter(0)
4455
+ p1 = f32[128,2048,2048]{2,1,0} parameter(1)
4456
+ p2 = f32[512,2048,2048]{2,1,0} parameter(2)
4457
+ p3 = f32[16,256,256]{2,1,0} parameter(3)
4458
+ cp1s = (f32[512,2048,2048]{2,1,0}, f32[512,2048,2048]{2,1,0}, u32[], u32[]) collective-permute-start(p2), source_target_pairs={{1,0},{0,3},{3,2}}, frontend_attributes={_scheduling_group_id="0", keep_original_sequence_order_in_group="true"}
4459
+ cp1d = f32[512,2048,2048]{2,1,0} collective-permute-done(cp1s), frontend_attributes={_scheduling_group_id="0", keep_original_sequence_order_in_group="true"}
4460
+ cp2s = (f32[128,2048,2048]{2,1,0}, f32[128,2048,2048]{2,1,0}, u32[], u32[]) collective-permute-start(p1), source_target_pairs={{1,0},{0,3},{3,2}}, frontend_attributes={_scheduling_group_id="0", keep_original_sequence_order_in_group="true"}
4461
+ c0 = f32[16,256,256]{2,1,0} convolution(p0, p0),
4462
+ window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb, frontend_attributes={_scheduling_group_id="0", keep_original_sequence_order_in_group="true"}
4463
+ c1 = f32[16,256,256]{2,1,0} convolution(p3, p3),
4464
+ window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb
4465
+ cp2d = f32[128,2048,2048]{2,1,0} collective-permute-done(cp2s), frontend_attributes={_scheduling_group_id="0", keep_original_sequence_order_in_group="true"}
4466
+ ROOT tuple.2 = (f32[16,256,256]{2,1,0}, f32[512,2048,2048]{2,1,0}, f32[16,256,256]{2,1,0}) tuple(c0, cp1d, c1)
4467
+ }
4468
+ )" ;
4469
+
4470
+ TF_ASSERT_OK_AND_ASSIGN (auto hlo_module, ParseHloText (hlo_string));
4471
+ HloSchedule& module_schedule = hlo_module->schedule ();
4472
+ EXPECT_TRUE (hlo_module->has_entry_computation ());
4473
+ auto sched_config = GetDefaultSchedConfig ();
4474
+ sched_config.flexible_scheduling_annotation_scheduling = true ;
4475
+ sched_config.aggressive_scheduling_policies = true ;
4476
+ TF_EXPECT_OK (RunScheduler (hlo_module.get (), sched_config,
4477
+ std::make_unique<TestLatencyEstimator>()));
4478
+ EXPECT_TRUE (hlo_module->has_entry_computation ());
4479
+
4480
+ std::vector<HloInstruction*> new_instruction_sequence =
4481
+ module_schedule.sequence (hlo_module->entry_computation ()).instructions ();
4482
+ if (VLOG_IS_ON (1 )) {
4483
+ for (auto * new_i : new_instruction_sequence) {
4484
+ VLOG (1 ) << new_i->ToString ();
4485
+ }
4486
+ }
4487
+
4488
+ // Check that the original sequence order is kept in the annotation group.
4489
+ EXPECT_LT (GetIndex (new_instruction_sequence, " cp1s" ),
4490
+ GetIndex (new_instruction_sequence, " c1" ));
4491
+ EXPECT_LT (GetIndex (new_instruction_sequence, " c1" ),
4492
+ GetIndex (new_instruction_sequence, " c0" ));
4493
+ EXPECT_LT (GetIndex (new_instruction_sequence, " cp1s" ),
4494
+ GetIndex (new_instruction_sequence, " cp1d" ));
4495
+ EXPECT_LT (GetIndex (new_instruction_sequence, " cp1d" ),
4496
+ GetIndex (new_instruction_sequence, " cp2s" ));
4497
+ EXPECT_LT (GetIndex (new_instruction_sequence, " cp2s" ),
4498
+ GetIndex (new_instruction_sequence, " c0" ));
4499
+ EXPECT_LT (GetIndex (new_instruction_sequence, " c0" ),
4500
+ GetIndex (new_instruction_sequence, " cp2d" ));
4501
+ }
4502
+
4447
4503
} // namespace xla
0 commit comments