8000 Disable bfloat16 propagation for dynamic slice op with operand in hos… · IBMZ-Linux-OSS-Python/tensorflow@0a2dff1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0a2dff1

Browse files
fhoushmandtensorflower-gardener
authored andcommitted
Disable bfloat16 propagation for dynamic slice op with operand in host memory
PiperOrigin-RevId: 766384513
1 parent 5c31bcf commit 0a2dff1

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

third_party/xla/xla/hlo/transforms/bfloat16_propagation.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ limitations under the License.
3434
#include "xla/hlo/ir/hlo_opcode.h"
3535
#include "xla/hlo/transforms/simplifiers/hlo_dce.h"
3636
#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h"
37+
#include "xla/layout.h"
3738
#include "xla/literal.h"
3839
#include "xla/map_util.h"
3940
#include "xla/service/float_support.h"
@@ -568,6 +569,12 @@ bool BFloat16Propagation::InstructionIsCandidateForBF16Output(
568569
}
569570
}
570571
}
572+
if (hlo->opcode() == HloOpcode::kDynamicSlice &&
573+
hlo->operand(0)->shape().has_layout() &&
574+
hlo->operand(0)->shape().layout().memory_space() ==
575+
Layout::kHostMemorySpace) {
576+
return false;
577+
}
571578
return true;
572579
}
573580

third_party/xla/xla/hlo/transforms/bfloat16_propagation_test.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,6 +1252,29 @@ ENTRY main {
12521252
EXPECT_FALSE(OutputsBF16(dus));
12531253
}
12541254

1255+
TEST_F(BFloat16PropagationTest, DynamicSliceWithHostMemory) {
1256+
// In the case of dynamic-slice from host memory, we should not propagate
1257+
// bf16.
1258+
const std::string module_str = R"(
1259+
HloModule Module
1260+
1261+
ENTRY main {
1262+
param = f32[128,128]{1,0:S(5)} parameter(0)
1263+
constant.3 = s32[] constant(0)
1264+
dynamic-slice = f32[128,8] dynamic-slice(param, constant.3, constant.3), dynamic_slice_sizes={128,8}
1265+
ROOT dot = f32[128,128] dot(dynamic-slice, dynamic-slice), lhs_contracting_dims={1}, rhs_contracting_dims={1}
1266+
}
1267+
)";
1268+
1269+
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1270+
ParseAndReturnVerifiedModule(module_str));
1271+
EXPECT_FALSE(PropagatePrecision(module.get()));
1272+
1273+
HloInstruction* dus =
1274+
module->entry_computation()->GetInstructionWithName("dynamic-slice");
1275+
EXPECT_FALSE(OutputsBF16(dus));
1276+
}
1277+
12551278
// This test demonstrates the need for invoking the ResolveAliasingBuffer
12561279
// multiple times via a fixed-point algorithm. The key was the aliasing of the
12571280
// two output buffers of the conditional, at subshape 0 (first element). This

0 commit comments

Comments
 (0)
0