@@ -12484,65 +12484,5 @@ ENTRY main {
12484
12484
EXPECT_TRUE (instruction->sharding ().IsReplicated ());
12485
12485
}
12486
12486
12487
- TEST_F (ShardingPropagationTest, ProcessShardingInstruction1) {
12488
- const char * const hlo_string = R"(
12489
- HloModule module
12490
-
12491
- ENTRY %main.6 {
12492
- %p0 = f32[32,96] parameter(0), sharding={devices=[4,1]<=[4]}
12493
- %sine = f32[32,96] sine(%p0)
12494
- %custom-call.3 = f32[32,96] custom-call(%sine), custom_call_target="Sharding", sharding={replicated}
12495
- %cosine = f32[32,96] cosine(%sine)
12496
- ROOT %add.2 = f32[32,96] add(%custom-call.3, %cosine)
12497
- })" ;
12498
- TF_ASSERT_OK_AND_ASSIGN (auto module ,
12499
- ParseAndReturnVerifiedModule (hlo_string));
12500
- TF_ASSERT_OK_AND_ASSIGN (
12501
- bool changed,
12502
- ShardingPropagation (/* is_spmd=*/ true , /* propagate_metadata=*/ true ,
12503
- /* allow_spmd_sharding_propagation_to_output=*/ {true })
12504
- .Run (module .get ()));
12505
- EXPECT_TRUE (changed);
12506
- XLA_VLOG_LINES (1 , module ->ToString ());
12507
-
12508
- for (const HloInstruction* instruction :
12509
- module ->entry_computation ()->instructions ()) {
12510
- if (instruction->opcode () == HloOpcode::kParameter ) {
12511
- EXPECT_THAT (instruction, op::Sharding (" {devices=[4,1]<=[4]}" ));
12512
- } else {
12513
- EXPECT_THAT (instruction, op::Sharding (" {replicated}" ));
12514
- }
12515
- }
12516
- }
12517
-
12518
- TEST_F (ShardingPropagationTest, ProcessShardingInstruction2) {
12519
- const char * const hlo_string = R"(
12520
- HloModule module
12521
-
12522
- ENTRY %main.6 {
12523
- %p0 = f32[32,96] parameter(0), sharding={replicated}
12524
- %sine = f32[32,96] sine(%p0)
12525
- %custom-call.0 = f32[32,96] custom-call(%sine), custom_call_target="Sharding", sharding={devices=[2,1,2]<=[4] last_tile_dim_replicate}
12526
- %custom-call.1 = f32[32,96] custom-call(%sine), custom_call_target="Sharding", sharding={devices=[1,2,2]<=[2,2]T(1,0) last_tile_dim_replicate}
12527
- %cosine = f32[32,96] cosine(%sine)
12528
- ROOT tuple = (f32[32,96], f32[32,96], f32[32,96]) tuple(%custom-call.0, %custom-call.1, %cosine)
12529
- })" ;
12530
- TF_ASSERT_OK_AND_ASSIGN (auto module ,
12531
- ParseAndReturnVerifiedModule (hlo_string));
12532
- TF_ASSERT_OK_AND_ASSIGN (
12533
- bool changed, ShardingPropagation (
12534
- /* is_spmd=*/ true , /* propagate_metadata=*/ true ,
12535
- /* allow_spmd_sharding_propagation_to_output=*/ {true })
12536
- .Run (module .get ()));
12537
- EXPECT_TRUE (changed);
12538
- XLA_VLOG_LINES (1 , module ->ToString ());
12539
-
12540
- for (absl::string_view name : {" sine" , " cosine" }) {
12541
- HloInstruction* instruction = FindInstruction (module .get (), name);
12542
- ASSERT_NE (instruction, nullptr );
12543
- EXPECT_THAT (instruction, op::Sharding (" {devices=[2,2]<=[4]}" ));
12544
- }
12545
- }
12546
-
12547
12487
} // namespace
12548
12488
} // namespace xla
0 commit comments