8000 Allow fusion to accept more operands than parameters. This relaxation… · IBMZ-Linux-OSS-Python/tensorflow@31a9da8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 31a9da8

Browse files
Allow fusion to accept more operands than parameters. This relaxation is mainly for window prefetch. With this change, now, the MSA can just append operands, without inserting a custom op inside the fusion to consume the appended operands.
PiperOrigin-RevId: 673101738
1 parent 17e7b97 commit 31a9da8

File tree

6 files changed

+63
-42
lines changed

6 files changed

+63
-42
lines changed

third_party/xla/xla/service/hlo_verifier.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,9 +1262,9 @@ absl::Status ShapeVerifier::HandleFusion(HloInstruction* fusion) {
12621262
}
12631263

12641264
auto& fused_parameters = fusion->fused_parameters();
1265-
if (fused_parameters.size() != fusion->operand_count()) {
1265+
if (fused_parameters.size() > fusion->operand_count()) {
12661266
return Internal(
1267-
"Fused parameter count (%d) does not match the number of operands (%d)"
1267+
"Fused parameter count (%d) is greater than the number of operands (%d)"
12681268
" passed to the fusion instruction in: %s.",
12691269
fused_parameters.size(), fusion->operand_count(),
12701270
fusion->ToString().c_str());
@@ -2654,7 +2654,7 @@ absl::Status CheckFusionInstruction(HloInstruction* fusion) {
26542654

26552655
// Fused parameter instructions must be numbered contiguously and match up
26562656
// (shapes equal) with their respective operand.
2657-
CHECK_EQ(fusion->operands().size(), fused_parameters.size());
2657+
CHECK_GE(fusion->operands().size(), fused_parameters.size());
26582658
std::vector<bool> parameter_numbers(fused_parameters.size(), false);
26592659
for (auto fused_param : fused_parameters) {
26602660
int64_t param_no = fused_param->parameter_number();

third_party/xla/xla/service/hlo_verifier_test.cc

Lines changed: 41 additions & 0 deletions
+
HloModule test
Original file line numberDiff line numberDiff line change
@@ -2101,6 +2101,47 @@ TEST_F(HloVerifierTest, CollectivePermuteCrossPartitionTargetOOR) {
21012101
EXPECT_THAT(error_message, HasSubstr("must be < 3"));
21022102
}
21032103

2104+
TEST_F(HloVerifierTest, FusionMoreOperandsThanParameters) {
2105+
const char* const kModuleStr = R"(
2106+
HloModule test
2107+
2108+
fused_computation {
2109+
ROOT p0 = f32[10] parameter(0)
2110+
}
2111+
2112+
ENTRY entry {
2113+
p0 = f32[10] parameter(0)
2114+
p1 = f32[10] parameter(1)
2115+
ROOT out = f32[10] fusion(p0, p1), kind=kInput, calls=fused_computation
2116+
}
2117+
)";
2118+
TF_ASSERT_OK_AND_ASSIGN(auto module,
2119+
ParseAndReturnUnverifiedModule(kModuleStr));
2120+
auto status = verifier().Run(module.get()).status();
2121+
ASSERT_TRUE(status.ok());
2122+
}
2123+
2124+
TEST_F(HloVerifierTest, FusionLessOperandsThanParameters) {
2125+
const char* const kModuleStr = R"(
2126
2127+
2128+
fused_computation {
2129+
p0 = f32[10] parameter(0)
2130+
p1 = f32[10] parameter(1)
2131+
ROOT out = f32[10] add(p0, p1)
2132+
}
2133+
2134+
ENTRY entry {
2135+
p0 = f32[10] parameter(0)
2136+
ROOT out = f32[10] fusion(p0), kind=kInput, calls=fused_computation
2137+
}
2138+
)";
2139+
TF_ASSERT_OK_AND_ASSIGN(auto module,
2140+
ParseAndReturnUnverifiedModule(kModuleStr));
2141+
EXPECT_THAT(verifier().Run(module.get()).status().message(),
2142+
HasSubstr("greater than the number of operands"));
2143+
}
2144+
21042145
TEST_F(HloVerifierTest, FusionShapeVerifier) {
21052146
const char* const kModuleStr = R"(
21062147
HloModule test

third_party/xla/xla/service/memory_space_assignment/allocation.cc

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -894,14 +894,6 @@ absl::Status WindowPrefetchedAllocation::InsertWindowPrefetchInstruction(
894894
layout.set_memory_space(options_.alternate_memory_space);
895895
*shape.mutable_layout() = layout;
896896

897-
// Insert a new parameter in the fused computation.
898-
HloComputation* fused_computation =
899-
use_instruction->fused_instructions_computation();
900-
const int64_t num_parameters = fused_computation->num_parameters();
901-
std::string name = absl::StrCat("window-buffer.", num_parameters);
902-
HloInstruction* param = fused_computation->AddParameter(
903-
HloInstruction::CreateParameter(num_parameters, shape, name));
904-
905897
// Insert async WindowPrefetch instructions as operands to the fusion.
906898
HloInstruction* prefetch =
907899
computation->AddInstruction(HloInstruction::CreateCustomCall(
@@ -910,24 +902,6 @@ absl::Status WindowPrefetchedAllocation::InsertWindowPrefetchInstruction(
910902
computation->CreateAsyncInstructions(prefetch, {}));
911903
use_instruction->AppendOperand(prefetch_instruction_);
912904

913-
// Insert instruction to consume the added operands and forwards the original
914-
// fusion output.
915-
auto get_or_create_consumer =
916-
[](HloComputation* computation) -> HloInstruction* {
917-
HloInstruction* root = computation->root_instruction();
918-
// If the root is already a WindowPrefetchBuffer, we don't need to create
919-
// a new one.
920-
if (root->IsCustomCall("WindowPrefetchBuffer")) {
921-
return root;
922-
}
923-
HloInstruction* new_root =
924-
computation->AddInstruction(HloInstruction::CreateCustomCall(
925-
root->shape(), {root}, "WindowPrefetchBuffer"));
926-
computation->set_root_instruction(new_root);
927-
return new_root;
928-
};
929-
HloInstruction* consumer = get_or_create_consumer(fused_computation);
930-
consumer->AppendOperand(param);
931905
return absl::OkStatus();
932906
}
933907

third_party/xla/xla/service/memory_space_assignment/allocation.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -450,10 +450,8 @@ class WindowPrefetchedAllocation final : public Allocation {
450450

451451
private:
452452
// This method is called by Process() to create window prefetch instructions.
453-
// These instructions include a pair of async WindowPrefetch outside the
454-
// fusion and a WindowPrefetchBuffer inside the fusion. The
455-
// WindowPrefetchBuffer is used for consuming the appended window buffer
456-
// operands.
453+
// These instructions include a pair of async WindowPrefetch which is passed
454+
// to the fusion.
457455
absl::Status InsertWindowPrefetchInstruction(
458456
HloInstruction* producing_instruction, HloInstruction* use_instruction,
459457
HloComputation* computation);

third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8330,14 +8330,21 @@ entry {
83308330
// plus 2 window prefetch buffers.
83318331
EXPECT_EQ(fusion->operand_count(), 5);
83328332

8333-
// The root of the fusion should be a WindowPrefetchBuffer. The first operand
8334-
// should be the original root, and the second and third operands should be
8335-
// the window prefetch buffers.
8336-
HloInstruction* root = fusion->fused_expression_root();
8337-
EXPECT_TRUE(root->IsCustomCall("WindowPrefetchBuffer"));
8338-
EXPECT_EQ(root->operand_count(), 3);
8339-
EXPECT_EQ(root->operand(1), fusion->fused_parameter(3));
8340-
EXPECT_EQ(root->operand(2), fusion->fused_parameter(4));
8333+
// The 2 added operands are async calls to WindowPrefetch.
8334+
for (int i = 3; i < 5; i++) {
8335+
const HloInstruction* async_done = fusion->operand(i);
8336+
EXPECT_EQ(async_done->opcode(), HloOpcode::kAsyncDone);
8337+
EXPECT_EQ(async_done->operand_count(), 1);
8338+
EXPECT_TRUE(async_done->async_wrapped_instruction()->IsCustomCall(
8339+
"WindowPrefetch"));
8340+
8341+
const HloInstruction* async_start = async_done->operand(0);
8342+
EXPECT_EQ(async_start->opcode(), HloOpcode::kAsyncStart);
8343+
EXPECT_EQ(async_start->operand_count(), 1);
8344+
EXPECT_TRUE(async_start->async_wrapped_instruction()->IsCustomCall(
8345+
"WindowPrefetch"));
8346+
}
8347+
83418348
VLOG(2) << "module: " << module->ToString();
83428349
}
83438350

third_party/xla/xla/service/memory_space_propagation.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ absl::StatusOr<bool> MemorySpacePropagation::Run(
4040
for (HloInstruction* instruction : computation->instructions()) {
4141
if (instruction->opcode() == HloOpcode::kFusion) {
4242
// Propagate the operand subshapes.
43-
for (int operand_idx = 0; operand_idx < instruction->operand_count();
43+
for (int operand_idx = 0;
44+
operand_idx < instruction->fused_parameters().size();
4445
++operand_idx) {
4546
ShapeUtil::ForEachLeafShape(
4647
instruction->operand(operand_idx)->shape(),

0 commit comments

Comments
 (0)
0