8000 Reverts 52fc64b538c7291b8caa0de7b0bfdcf7762376e8 · linux-on-ibm-z/tensorflow@e77316a · GitHub
[go: up one dir, main page]

Skip to content

Commit e77316a

Browse files
Reverts 52fc64b
PiperOrigin-RevId: 730554805
1 parent 7587767 commit e77316a

File tree

2 files changed

+1
-92
lines changed

2 files changed

+1
-92
lines changed

third_party/xla/xla/service/sharding_propagation.cc

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,33 +1065,6 @@ bool IsCSEPreventionSharding(const HloSharding& sharding) {
10651065
"_sharding_propagation_cse_prevention";
10661066
}
10671067

1068-
bool ShouldApplyShardingConstraint(
1069-
const HloInstruction* sharding_constraint,
1070-
absl::flat_hash_set<const HloInstruction*>&
1071-
instructions_constrained_by_different_shardings) {
1072-
const HloInstruction* operand = sharding_constraint->operand(0);
1073-
if (operand->has_sharding()) {
1074-
return false;
1075-
}
1076-
1077-
if (instructions_constrained_by_different_shardings.contains(operand)) {
1078-
// The operand is used by multiple sharding constraints with different
1079-
// shardings.
1080-
return false;
1081-
}
1082-
1083-
for (const HloInstruction* other_user : operand->users()) {
1084-
if (other_user != sharding_constraint &&
1085-
other_user->IsCustomCall("Sharding") &&
1086-
other_user->sharding() != sharding_constraint->sharding()) {
1087-
instructions_constrained_by_different_shardings.insert(operand);
1088-
return false;
1089-
}
1090-
}
1091-
1092-
return true;
1093-
}
1094-
10951068
} // namespace
10961069

10971070
bool InferDotShardingFromOperands(
@@ -1486,8 +1459,6 @@ absl::StatusOr<bool> ProcessShardingInstruction(
14861459

14871460
for (HloComputation* computation : module->computations(execution_threads)) {
14881461
auto instructions = computation->MakeInstructionPostOrder();
1489-
absl::flat_hash_set<const HloInstruction*>
1490-
instructions_constrained_by_different_shardings;
14911462
for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) {
14921463
HloInstruction* instruction = *it;
14931464
if (instruction->IsCustomCall("Sharding")) {
@@ -1527,9 +1498,7 @@ absl::StatusOr<bool> ProcessShardingInstruction(
15271498
if (!unspec_dims.empty()) {
15281499
absl::c_sort(unspec_dims);
15291500
unspecified_dims->emplace(instruction, std::move(unspec_dims));
1530-
} else if (ShouldApplyShardingConstraint(
1531-
instruction,
1532-
instructions_constrained_by_different_shardings)) {
1501+
} else if (!instruction->operand(0)->has_sharding()) {
15331502
instruction->mutable_operand(0)->set_sharding(
15341503
instruction->sharding());
15351504
}

third_party/xla/xla/service/sharding_propagation_test.cc

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -12484,65 +12484,5 @@ ENTRY main {
1248412484
EXPECT_TRUE(instruction->sharding().IsReplicated());
1248512485
}
1248612486

12487-
TEST_F(ShardingPropagationTest, ProcessShardingInstruction1) {
12488-
const char* const hlo_string = R"(
12489-
HloModule module
12490-
12491-
ENTRY %main.6 {
12492-
%p0 = f32[32,96] parameter(0), sharding={devices=[4,1]<=[4]}
12493-
%sine = f32[32,96] sine(%p0)
12494-
%custom-call.3 = f32[32,96] custom-call(%sine), custom_call_target="Sharding", sharding={replicated}
12495-
%cosine = f32[32,96] cosine(%sine)
12496-
ROOT %add.2 = f32[32,96] add(%custom-call.3, %cosine)
12497-
})";
12498-
TF_ASSERT_OK_AND_ASSIGN(auto module,
12499-
ParseAndReturnVerifiedModule(hlo_string));
12500-
TF_ASSERT_OK_AND_ASSIGN(
12501-
bool changed,
12502-
ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true,
12503-
/*allow_spmd_sharding_propagation_to_output=*/{true})
12504-
.Run(module.get()));
12505-
EXPECT_TRUE(changed);
12506-
XLA_VLOG_LINES(1, module->ToString());
12507-
12508-
for (const HloInstruction* instruction :
12509-
module->entry_computation()->instructions()) {
12510-
if (instruction->opcode() == HloOpcode::kParameter) {
12511-
EXPECT_THAT(instruction, op::Sharding("{devices=[4,1]<=[4]}"));
12512-
} else {
12513-
EXPECT_THAT(instruction, op::Sharding("{replicated}"));
12514-
}
12515-
}
12516-
}
12517-
12518-
TEST_F(ShardingPropagationTest, ProcessShardingInstruction2) {
12519-
const char* const hlo_string = R"(
12520-
HloModule module
12521-
12522-
ENTRY %main.6 {
12523-
%p0 = f32[32,96] parameter(0), sharding={replicated}
12524-
%sine = f32[32,96] sine(%p0)
12525-
%custom-call.0 = f32[32,96] custom-call(%sine), custom_call_target="Sharding", sharding={devices=[2,1,2]<=[4] last_tile_dim_replicate}
12526-
%custom-call.1 = f32[32,96] custom-call(%sine), custom_call_target="Sharding", sharding={devices=[1,2,2]<=[2,2]T(1,0) last_tile_dim_replicate}
12527-
%cosine = f32[32,96] cosine(%sine)
12528-
ROOT tuple = (f32[32,96], f32[32,96], f32[32,96]) tuple(%custom-call.0, %custom-call.1, %cosine)
12529-
})";
12530-
TF_ASSERT_OK_AND_ASSIGN(auto module,
12531-
ParseAndReturnVerifiedModule(hlo_string));
12532-
TF_ASSERT_OK_AND_ASSIGN(
12533-
bool changed, ShardingPropagation(
12534-
/*is_spmd=*/true, /*propagate_metadata=*/true,
12535-
/*allow_spmd_sharding_propagation_to_output=*/{true})
12536-
.Run(module.get()));
12537-
EXPECT_TRUE(changed);
12538-
XLA_VLOG_LINES(1, module->ToString());
12539-
12540-
for (absl::string_view name : {"sine", "cosine"}) {
12541-
HloInstruction* instruction = FindInstruction(module.get(), name);
12542-
ASSERT_NE(instruction, nullptr);
12543-
EXPECT_THAT(instruction, op::Sharding("{devices=[2,2]<=[4]}"));
12544-
}
12545-
}
12546-
1254712487
} // namespace
1254812488
} // namespace xla

0 commit comments

Comments
 (0)
0