8000 [XLA] Emit FCMP_UNE (true if unordered) instead of FCMP_ONE (false if… · staticfloat/tensorflow@183e2c9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 183e2c9

Browse files
[XLA] Emit FCMP_UNE (true if unordered) instead of FCMP_ONE (false if unordered)
for not_equal operation to be more compliant with TF and other languages. PiperOrigin-RevId: 155835642
1 parent 688f5a6 commit 183e2c9

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

tensorflow/compiler/xla/service/elemental_ir_emitter.cc

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -240,14 +240,18 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
240240
return ir_builder_->CreateFDiv(lhs_value, rhs_value);
241241
case HloOpcode::kRemainder:
242242
return ir_builder_->CreateFRem(lhs_value, rhs_value);
243-
244-
// The 'O' prefix on the LLVM ops means "ordered" compare where comparisons
245-
// with NAN always return false.
243+
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
244+
// comparisons always return false when one of the operands is NaN, whereas
245+
// unordered comparisons return true.
246+
//
247+
// We use ordered comparisons for everything except kNe, where we use an
248+
// unordered comparison. This makes x != y equivalent to !(x == y), and
249+
// matches C++'s semantics.
246250
case HloOpcode::kEq:
247251
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
248252
rhs_value, ir_builder_);
249253
case HloOpcode::kNe:
250-
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_ONE, lhs_value,
254+
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
251255
rhs_value, ir_builder_);
252256
case HloOpcode::kLt:
253257
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
@@ -739,11 +743,11 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
739743
const HloInstruction* operand = hlo->operand(operand_idx);
740744
auto true_block = llvm_ir::CreateBasicBlock(
741745
exit_block, tensorflow::strings::StrCat(
742-
"concat_index_from_operand", operand_idx),
746+
"concat_index_from_operand", operand_idx),
743747
ir_builder_);
744748
auto false_block = llvm_ir::CreateBasicBlock(
745749
exit_block, tensorflow::strings::StrCat(
746-
"concat_index_not_from_operand", operand_idx),
750+
"concat_index_not_from_operand", operand_idx),
747751
ir_builder_);
748752
auto concat_dim_size =
749753
llvm::ConstantInt::get(source_index[concat_dim]->getType(),

tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) {
486486
ComputeAndCompareR1<bool>(&builder, {}, {});
487487
}
488488

489+
TEST_F(ArrayElementwiseOpTest, CompareNeF32s) {
490+
// Disable fast-math because we're operating on NaNs.
491+
SetFastMathDisabled(true);
492+
493+
ComputationBuilder builder(client_, TestName());
494+
auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
495+
auto rhs = builder.ConstantR1<float>({10.0f, 25.5f, 1.0f, 10.0f, NAN});
496+
auto compare = builder.Ne(lhs, rhs);
497+
498+
ComputeAndCompareR1<bool>(&builder, {true, false, true, true, true}, {});
499+
}
500+
489501
TEST_F(ArrayElementwiseOpTest, CompareNeS32s) {
490502
const int32 min = std::numeric_limits<int32>::min();
491503
const int32 max = std::numeric_limits<int32>::max();

0 commit comments

Comments
 (0)
0