diff --git a/third_party/xla/xla/backends/cpu/BUILD b/third_party/xla/xla/backends/cpu/BUILD index f8f7633cafdc20..9d05bc86ef1d50 100644 --- a/third_party/xla/xla/backends/cpu/BUILD +++ b/third_party/xla/xla/backends/cpu/BUILD @@ -112,7 +112,10 @@ onednn_graph_cc_library( "//xla:xla_data_proto_cc", "//xla/backends/cpu/codegen:target_machine_features", "//xla/backends/cpu/runtime:dot_dims", + "//xla/hlo/ir:hlo", + "//xla/service/cpu:onednn_util", "//xla/tsl/mkl:onednn", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", diff --git a/third_party/xla/xla/backends/cpu/onednn_emitter.cc b/third_party/xla/xla/backends/cpu/onednn_emitter.cc index 0e51804ff92743..788de31a33a1b8 100644 --- a/third_party/xla/xla/backends/cpu/onednn_emitter.cc +++ b/third_party/xla/xla/backends/cpu/onednn_emitter.cc @@ -67,32 +67,6 @@ static absl::StatusOr OneDnnDatatype( } } -static absl::StatusOr OneDnnUnaryOperator( - const HloOpcode& opcode) { - switch (opcode) { - case HloOpcode::kExp: - return dnnl::graph::op::kind::Exp; - default: - return InvalidArgument("Unsupported oneDNN unary operator: %s", - HloOpcodeString(opcode)); - } -} - -static absl::StatusOr OneDnnBinaryOperator( - const HloOpcode& opcode) { - switch (opcode) { - case HloOpcode::kAdd: - return dnnl::graph::op::kind::Add; - case HloOpcode::kMultiply: - return dnnl::graph::op::kind::Multiply; - case HloOpcode::kDot: - return dnnl::graph::op::kind::MatMul; - default: - return InvalidArgument("Unsupported oneDNN unary operator: %s", - HloOpcodeString(opcode)); - } -} - static dnnl::graph::logical_tensor::dims OneDnnDimensions(const Shape& shape) { dnnl::graph::logical_tensor::dims dims; for (auto& dim : shape.dimensions()) { @@ -201,7 +175,7 @@ static absl::StatusOr DefineMatMul( const Shape& rhs_shape = instr->operand(1)->shape(); TF_ASSIGN_OR_RETURN( bool is_supported, - IsOneDnnDotSupported(dnums, lhs_shape, rhs_shape, instr->shape())); + IsDotSupportedByOneDnn(dnums, lhs_shape, rhs_shape, instr->shape())); if (!is_supported) { return InvalidArgument("Unsupported oneDNN Dot op variation: %s", @@ -268,7 +242,11 @@ static absl::StatusOr EmitOneDnnFusion( } break; // Unary elementwise ops. - case HloOpcode::kExp: { + case HloOpcode::kAbs: + case HloOpcode::kExp: + case HloOpcode::kLog: + case HloOpcode::kSqrt: + case HloOpcode::kTanh: { TF_ASSIGN_OR_RETURN( logical_tensors[instr], DefineUnaryOp(graph, op_id++, logical_tensors, instr)); @@ -276,7 +254,11 @@ static absl::StatusOr EmitOneDnnFusion( // Binary elementwise ops. case HloOpcode::kAdd: - case HloOpcode::kMultiply: { + case HloOpcode::kDivide: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kSubtract: { TF_ASSIGN_OR_RETURN( logical_tensors[instr], DefineBinaryOp(graph, op_id++, logical_tensors, instr)); diff --git a/third_party/xla/xla/backends/cpu/onednn_support.cc b/third_party/xla/xla/backends/cpu/onednn_support.cc index 332c0a1cf70212..6cd42cef6f6514 100644 --- a/third_party/xla/xla/backends/cpu/onednn_support.cc +++ b/third_party/xla/xla/backends/cpu/onednn_support.cc @@ -15,57 +15,57 @@ limitations under the License. #include "xla/backends/cpu/onednn_support.h" +#include "absl/base/no_destructor.h" #include "absl/log/log.h" #include "absl/status/statusor.h" -#include "dnnl.hpp" // NOLINT: for DNNL_MAX_NDIMS +#include "oneapi/dnnl/dnnl.hpp" // NOLINT: for DNNL_MAX_NDIMS +#include "oneapi/dnnl/dnnl_graph.hpp" #include "xla/backends/cpu/codegen/target_machine_features.h" #include "xla/backends/cpu/runtime/dot_dims.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout_util.h" +#include "xla/service/cpu/onednn_util.h" #include "xla/shape.h" #include "xla/xla_data.pb.h" #include "tsl/platform/cpu_info.h" namespace xla::cpu { -bool IsOneDnnSupportedDType(PrimitiveType dtype) { - using tsl::port::CPUFeature; - switch (dtype) { - case F32: - return true; - case BF16: - return TestCPUFeature(CPUFeature::AVX512F) || - TestCPUFeature(CPUFeature::AVX_NE_CONVERT) || - TestCPUFeature(CPUFeature::AMX_BF16); - case F16: - return (TestCPUFeature(CPUFeature::AVX512BW) && - (TestCPUFeature(CPUFeature::AVX512_FP16) || - TestCPUFeature(CPUFeature::AMX_FP16))) || - TestCPUFeature(CPUFeature::AVX_NE_CONVERT); - default: - return false; - } -} - bool IsOneDnnSupportedDType(PrimitiveType dtype, const TargetMachineFeatures* cpu_features) { - if (dtype == F32) { + if (dtype == F32 || IsSupportedType(dtype)) { return true; } + // Check for data type support if target machine features are provided. + // Unit tests may provide target machine features to simulate different CPU + // capabilities. + return (cpu_features != nullptr && + ((dtype == BF16 && cpu_features->has_avx512bf16()) || + (dtype == F16 && cpu_features->has_avx512fp16()))); +} - if (cpu_features == nullptr) { - return IsOneDnnSupportedDType(dtype); - } +bool IsOneDnnSupportedLayout(const Shape& shape) { + return !shape.has_layout() || LayoutUtil::HasDescendingLayout(shape.layout()); +} - if (dtype == BF16) { - return cpu_features->has_avx512bf16(); - } - if (dtype == F16) { - return cpu_features->has_avx512fp16(); - } +bool IsOneDnnSupportedTypeAndLayout(const HloInstruction* hlo, + const TargetMachineFeatures* cpu_features) { + auto is_supported = [cpu_features](const HloInstruction* hlo) { + return IsOneDnnSupportedDType(hlo->shape().element_type(), cpu_features) && + IsOneDnnSupportedLayout(hlo->shape()); + }; - return false; + if (!is_supported(hlo)) { + return false; + } + return (std::all_of(hlo->operands().begin(), hlo->operands().end(), + is_supported)); } -absl::StatusOr IsOneDnnDotSupported( +absl::StatusOr IsDotSupportedByOneDnn( const DotDimensionNumbers& dot_dimensions, const Shape& lhs_shape, const Shape& rhs_shape, const Shape& out_shape, const TargetMachineFeatures* cpu_features) { @@ -110,4 +110,122 @@ absl::StatusOr IsOneDnnDotSupported( !dot_canonical_dims.rhs_column_major; } +const absl::flat_hash_map& GetOneDnnUnaryOpMap() { + static absl::NoDestructor> + unary_op_map({ + {HloOpcode::kAbs, op::kind::Abs}, + {HloOpcode::kExp, op::kind::Exp}, + {HloOpcode::kLog, op::kind::Log}, + {HloOpcode::kSqrt, op::kind::Sqrt}, + {HloOpcode::kTanh, op::kind::Tanh}, + }); + return *unary_op_map; +} + +absl::StatusOr OneDnnUnaryOperator(const HloOpcode& opcode) { + const auto& unary_op_map = GetOneDnnUnaryOpMap(); + auto result = unary_op_map.find(opcode); + if (result == unary_op_map.end()) { + return InvalidArgument("Unsupported OneDNN unary operator: %s", + HloOpcodeString(opcode)); + } + return result->second; +} + +std::vector GetOneDnnSupportedUnaryOpsStrings() { + auto& unary_op_map = GetOneDnnUnaryOpMap(); + std::vector op_names; + op_names.reserve(unary_op_map.size()); + for (auto& pair : unary_op_map) { + op_names.push_back(HloOpcodeString(pair.first)); + } + return op_names; +} + +const absl::flat_hash_map& GetOneDnnBinaryOpMap() { + static absl::NoDestructor> + binary_op_map({ + {HloOpcode::kAdd, op::kind::Add}, + {HloOpcode::kDivide, op::kind::Divide}, + {HloOpcode::kDot, op::kind::MatMul}, + {HloOpcode::kMaximum, op::kind::Maximum}, + {HloOpcode::kMinimum, op::kind::Minimum}, + {HloOpcode::kMultiply, op::kind::Multiply}, + {HloOpcode::kSubtract, op::kind::Subtract}, + }); + return *binary_op_map; +} + +absl::StatusOr OneDnnBinaryOperator(const HloOpcode& opcode) { + const auto& binary_op_map = GetOneDnnBinaryOpMap(); + auto result = binary_op_map.find(opcode); + if (result == binary_op_map.end()) { + return InvalidArgument("Unsupported OneDNN binary operator: %s", + HloOpcodeString(opcode)); + } + return result->second; +} + +std::vector GetOneDnnSupportedBinaryOpsStrings() { + auto& binary_op_map = GetOneDnnBinaryOpMap(); + std::vector op_names; + op_names.reserve(binary_op_map.size()); + for (auto& pair : binary_op_map) { + op_names.push_back(HloOpcodeString(pair.first)); + } + return op_names; +} + +bool IsOpSupportedByOneDnn(const HloInstruction* hlo, + const TargetMachineFeatures* cpu_features) { + if (!OneDnnBinaryOperator(hlo->opcode()).ok() && + !OneDnnUnaryOperator(hlo->opcode()).ok()) { + return false; + } + if (hlo->opcode() == HloOpcode::kDot) { + return IsDotSupportedByOneDnn( + hlo->dot_dimension_numbers(), hlo->operand(0)->shape(), + hlo->operand(1)->shape(), hlo->shape(), cpu_features) + .value_or(false); + } + if (hlo->opcode() == HloOpcode::kBitcast) { + return IsBitcastOpSupportedByOneDnn(hlo, cpu_features); + } + + return IsOneDnnSupportedTypeAndLayout(hlo, cpu_features); +} + +bool IsConstantSupportedByOneDnn(const HloInstruction* hlo, + const TargetMachineFeatures* cpu_features) { + CHECK(hlo->IsConstant()); + return IsOneDnnSupportedDType(hlo->shape().element_type(), cpu_features) && + IsOneDnnSupportedLayout(hlo->shape()); +} + +bool IsBitcastOpSupportedByOneDnn(const HloInstruction* hlo, + const TargetMachineFeatures* cpu_features) { + CHECK_EQ(hlo->opcode(), HloOpcode::kBitcast); + if (!IsOneDnnSupportedTypeAndLayout(hlo, cpu_features)) { + return false; + } + const HloInstruction* input = hlo->operand(0); + return hlo->shape().element_type() == input->shape().element_type(); +} + +bool IsElementwiseOpSupportedByOneDnn( + const HloInstruction* hlo, const TargetMachineFeatures* cpu_features) { + CHECK(hlo->IsElementwise()); + if (hlo->IsConstant()) { + return IsConstantSupportedByOneDnn(hlo, cpu_features); + } + if (hlo->opcode() == HloOpcode::kBitcast) { + return IsBitcastOpSupportedByOneDnn(hlo, cpu_features); + } + if (!OneDnnBinaryOperator(hlo->opcode()).ok() && + !OneDnnUnaryOperator(hlo->opcode()).ok()) { + return false; + } + return IsOneDnnSupportedTypeAndLayout(hlo, cpu_features); +} + } // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/onednn_support.h b/third_party/xla/xla/backends/cpu/onednn_support.h index 39039506090855..9c05370025d8d5 100644 --- a/third_party/xla/xla/backends/cpu/onednn_support.h +++ b/third_party/xla/xla/backends/cpu/onednn_support.h @@ -21,25 +21,77 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "oneapi/dnnl/dnnl.hpp" +#include "oneapi/dnnl/dnnl_graph.hpp" #include "xla/backends/cpu/codegen/target_machine_features.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/shape.h" #include "xla/xla_data.pb.h" namespace xla::cpu { inline constexpr absl::string_view kOneDnnFusionKind = "__onednn_fusion"; +using dnnl::graph::op; -bool IsOneDnnSupportedDType(PrimitiveType dtype); bool IsOneDnnSupportedDType(PrimitiveType dtype, const TargetMachineFeatures* cpu_features); +// Returns true if the shape doesn't have a layout or the layout is descending. +bool IsOneDnnSupportedLayout(const Shape& shape); + +// Returns true if the HLO instruction and all its operands have supported data +// types and layouts. +bool IsOneDnnSupportedTypeAndLayout( + const HloInstruction* hlo, + const TargetMachineFeatures* cpu_features = nullptr); + // Returns true if the dot operation is supported by oneDNN. Returns an error // if the dot operation shape is invalid. -absl::StatusOr IsOneDnnDotSupported( +absl::StatusOr IsDotSupportedByOneDnn( const DotDimensionNumbers& dot_dimensions, const Shape& lhs_shape, const Shape& rhs_shape, const Shape& out_shape, const TargetMachineFeatures* cpu_features = nullptr); +// Returns the mappings from HLO opcodes to OneDNN unary operators. +const absl::flat_hash_map& GetOneDnnUnaryOpMap(); + +// Returns the OneDNN unary operator corresponding to the given HLO opcode. +// Returns `InvalidArgument` if the opcode is not supported. +absl::StatusOr OneDnnUnaryOperator(const HloOpcode& opcode); + +// Returns the names of the OneDNN supported HLO unary ops. +std::vector GetOneDnnSupportedUnaryOpsStrings(); + +// Returns the mappings from HLO opcodes to OneDNN binary operators. +const absl::flat_hash_map& GetOneDnnBinaryOpMap(); + +// Returns the OneDNN binary operator corresponding to the given HLO opcode. +// Returns `InvalidArgument` if the opcode is not supported. +absl::StatusOr OneDnnBinaryOperator(const HloOpcode& opcode); + +// Returns the names of the OneDNN supported HLO binary ops. +std::vector GetOneDnnSupportedBinaryOpsStrings(); + +// Returns true if the HLO op is supported by OneDNN. +bool IsOpSupportedByOneDnn(const HloInstruction* hlo, + const TargetMachineFeatures* cpu_features = nullptr); + +// Returns true if the constant is supported by OneDNN. +bool IsConstantSupportedByOneDnn( + const HloInstruction* hlo, + const TargetMachineFeatures* cpu_features = nullptr); + +// Returns true if the bitcast op is supported by OneDNN. +bool IsBitcastOpSupportedByOneDnn( + const HloInstruction* hlo, + const TargetMachineFeatures* cpu_features = nullptr); + +// Returns true if the elementwise op is supported by OneDNN. +bool IsElementwiseOpSupportedByOneDnn( + const HloInstruction* hlo, + const TargetMachineFeatures* cpu_features = nullptr); } // namespace xla::cpu #endif // XLA_BACKENDS_CPU_ONEDNN_SUPPORT_H_ diff --git a/third_party/xla/xla/backends/cpu/transforms/library_matcher.h b/third_party/xla/xla/backends/cpu/transforms/library_matcher.h index 381ea0c970b503..65e99d408442aa 100644 --- a/third_party/xla/xla/backends/cpu/transforms/library_matcher.h +++ b/third_party/xla/xla/backends/cpu/transforms/library_matcher.h @@ -75,6 +75,13 @@ class LibraryMatcher { return instr->shape().element_type(); } + // Returns true if there is a limit on the number of ops in the fusion and + // the maximum fusion size is already reached. + virtual bool ReachedMaxFusionSize(int fused_op_count) { return false; } + + // Return true if the library supports merging fusions. + virtual bool ShouldMergeFusions() { return true; } + // Returns a prefix string for the fusion op's name. virtual std::string fusion_prefix() const { return ""; } diff --git a/third_party/xla/xla/backends/cpu/transforms/library_rewriter.cc b/third_party/xla/xla/backends/cpu/transforms/library_rewriter.cc index 36f206a6d9ab7d..6e99baad862dfd 100644 --- a/third_party/xla/xla/backends/cpu/transforms/library_rewriter.cc +++ b/third_party/xla/xla/backends/cpu/transforms/library_rewriter.cc @@ -247,7 +247,21 @@ absl::Status LibraryRewriter::FuseNeighbors(HloFusionInstruction* fusion, // This queue only tracks original HLO instructions in the parent computation, // not any new instructions created during the fusion process. std::queue> frontier; - AddFusionCandidates(fusion, fusion, FusionDirection::kBoth, frontier); + + FusionDirection direction = FusionDirection::kBoth; + + // TODO(intel-tf): Restrict fusion direction for oneDNN till future + // release of oneDNN library with both fusion direction support. +#if XLA_ONEDNN_USE_GRAPH_API + if (lib->fusion_kind() == kOneDnnFusionKind) { + direction = FusionDirection::kDown; + } +#endif // XLA_ONEDNN_USE_GRAPH_API + + AddFusionCandidates(fusion, fusion, direction, frontier); + + // Track the number of operations added to the fusion. + int fused_op_count = 0; // BFS and fuse as many neighbors as possible. while (!frontier.empty()) { @@ -258,16 +272,34 @@ absl::Status LibraryRewriter::FuseNeighbors(HloFusionInstruction* fusion, FusionDirectionToString(dir)); } - // If `instr` is another fusion of the same library type, fuse it. - // We don't need to add its neighbors to the frontier because anything that - // can be fused would have already been fused into `instr`. - if (IsCustomFusionWithKind(instr, lib->fusion_kind())) { + // If `instr` is another fusion of the same library type and the library + // supports merging fusions, fuse it. We don't need to add its neighbors to + // the frontier because anything that can be fused would have already been + // fused into `instr`. + if (lib->ShouldMergeFusions() && + IsCustomFusionWithKind(instr, lib->fusion_kind())) { TF_ASSIGN_OR_RETURN(fusion, MergeFusionInstructions( fusion, Cast(instr), dir)); continue; } + // [TODO]: Make this generic with keeping track of fusion state, i.e., + // number of special ops already in fusion, and checking if library can + // still fuse additional same kind of op. +#if XLA_ONEDNN_USE_GRAPH_API + if (lib->fusion_kind() == kOneDnnFusionKind && + instr->opcode() == HloOpcode::kDot) { + VLOG(4) << " Only one dot op is allowed in oneDNN fusion"; + break; + } +#endif // XLA_ONEDNN_USE_GRAPH_API + + if (lib->ReachedMaxFusionSize(fused_op_count)) { + VLOG(4) << " Reached max fusion size: " << fused_op_count; + break; + } + // Skip this instruction if it can't be fused. TF_ASSIGN_OR_RETURN(bool op_supported, lib->IsOpSupported(instr)); if (!op_supported) { @@ -284,6 +316,7 @@ absl::Status LibraryRewriter::FuseNeighbors(HloFusionInstruction* fusion, GrowFusion(fusion, instr, dir)); TF_RETURN_IF_ERROR( InsertConvertIfNecessary(new_instr, lib->LibraryOpOutputType(instr))); + fused_op_count++; } return absl::OkStatus(); } diff --git a/third_party/xla/xla/backends/cpu/transforms/library_rewriter_test.cc b/third_party/xla/xla/backends/cpu/transforms/library_rewriter_test.cc index 537a1d1d2344b7..2dcdec8cd5593e 100644 --- a/third_party/xla/xla/backends/cpu/transforms/library_rewriter_test.cc +++ b/third_party/xla/xla/backends/cpu/transforms/library_rewriter_test.cc @@ -189,7 +189,11 @@ TEST_P(CpuLibraryFullParamTest, AddMatMul) { DotRewriteTestSpec spec = GetParam(); FusionProperties expected = {HloOpcode::kDot, 0, 0, false}; if (IsDotEnabledOnCPU()) { - expected = FusionProperties{HloOpcode::kDot, 3, 6, true}; + // {Add, Add, Dot} for XNN, {Dot} for oneDNN. + // TODO(Intel-tf): Update expected values when fusion is supported. + expected = spec.lib != "onednn" + ? FusionProperties{HloOpcode::kDot, 3, 6, true} + : FusionProperties{HloOpcode::kDot, 2, 3, true}; } else if (spec.fusion_mode == "greedy") { expected = FusionProperties{HloOpcode::kAdd, 2, 3, true}; } @@ -304,10 +308,7 @@ TEST_P(CpuLibraryFullParamTest, MatMulAddSubMulSameInputs) { DotRewriteTestSpec spec = GetParam(); FusionProperties expected = {HloOpcode::kMultiply, 0, 0, false}; if (IsDotEnabledOnCPU()) { - // {Dot, Add, Sub, Mul} for YNN, {Dot, Add} for oneDNN. - expected = spec.lib == "ynn" - ? FusionProperties{HloOpcode::kMultiply, 3, 7, true} - : FusionProperties{HloOpcode::kAdd, 3, 5, true}; + expected = {HloOpcode::kMultiply, 3, 7, true}; } else if (spec.fusion_mode == "greedy") { // Only Add, Sub, and Mul in the fusion. expected = {HloOpcode::kMultiply, 2, 5, true}; @@ -335,10 +336,7 @@ TEST_P(CpuLibraryFullParamTest, MatMulAddSubMulDifferentInputs) { DotRewriteTestSpec spec = GetParam(); FusionProperties expected = {HloOpcode::kMultiply, 0, 0, false}; if (IsDotEnabledOnCPU()) { - // {Dot, Add, Sub, Mul} for YNN, {Dot, Add} for oneDNN. - expected = spec.lib == "ynn" - ? FusionProperties{HloOpcode::kMultiply, 5, 9, true} - : FusionProperties{HloOpcode::kAdd, 3, 5, true}; + expected = {HloOpcode::kMultiply, 5, 9, true}; } else if (spec.fusion_mode == "greedy") { // Only Add, Sub, and Mul in the fusion. expected = {HloOpcode::kMultiply, 4, 7, true}; @@ -374,10 +372,7 @@ TEST_P(CpuLibraryFullParamTest, MatMulAddMinExpSort) { DotRewriteTestSpec spec = GetParam(); FusionProperties expected = {HloOpcode::kExp, 0, 0, false}; if (IsDotEnabledOnCPU()) { - // {Dot, Add, Min, Exp} for YNN, {Dot, Add} for oneDNN. - expected = spec.lib == "ynn" - ? FusionProperties{HloOpcode::kExp, 4, 8, true} - : FusionProperties{HloOpcode::kAdd, 3, 5, true}; + expected = {HloOpcode::kExp, 4, 8, true}; } else if (spec.fusion_mode == "greedy") { // Only {Add, Min, Exp} in the fusion. expected = {HloOpcode::kExp, 3, 6, true}; diff --git a/third_party/xla/xla/backends/cpu/transforms/onednn_matcher.h b/third_party/xla/xla/backends/cpu/transforms/onednn_matcher.h index 7b033f7ca946ec..a6dad43150bfcf 100644 --- a/third_party/xla/xla/backends/cpu/transforms/onednn_matcher.h +++ b/third_party/xla/xla/backends/cpu/transforms/onednn_matcher.h @@ -30,6 +30,8 @@ limitations under the License. #include "tsl/platform/protobuf.h" namespace xla::cpu { +// TODO(intel-tf): Use oneDNN defined constant +static const int kMaxOneDnnFusionSize = 4; class OneDnnMatcher : public LibraryMatcher { public: @@ -40,9 +42,17 @@ class OneDnnMatcher : public LibraryMatcher { // Returns the set of supported HLO instructions. absl::flat_hash_set SupportedOps() const override { - static const auto* kSupportedOps = new absl::flat_hash_set{ - HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kMultiply, - HloOpcode::kExp}; + static const auto* kSupportedOps = []() { + static auto* supported_ops = + new absl::flat_hash_set{HloOpcode::kDot}; + for (const auto& [op, _] : GetOneDnnUnaryOpMap()) { + supported_ops->insert(op); + } + for (const auto& [op, _] : GetOneDnnBinaryOpMap()) { + supported_ops->insert(op); + } + return supported_ops; + }(); return *kSupportedOps; } @@ -51,20 +61,7 @@ class OneDnnMatcher : public LibraryMatcher { if (!SupportedOps().contains(instr->opcode())) { return false; } - if (instr->opcode() == HloOpcode::kDot) { - return IsOneDnnDotSupported( - instr->dot_dimension_numbers(), instr->operand(0)->shape(), - instr->operand(1)->shape(), instr->shape(), target_machine_features_); - } - - return IsOneDnnSupportedDType(instr->shape().element_type(), - target_machine_features_) && - std::all_of(instr->operands().begin(), instr->operands().end(), - [this](const HloInstruction* operand) { - return IsOneDnnSupportedDType( - operand->shape().element_type(), - target_machine_features_); - }); + return IsOpSupportedByOneDnn(instr, target_machine_features_); } // Returns true if we should start a new fusion containing just the given HLO @@ -74,6 +71,17 @@ class OneDnnMatcher : public LibraryMatcher { return instr->opcode() == HloOpcode::kDot; } + // Returns true if there is a limit on the number of ops in the fusion and + // the maximum fusion size is already reached. + bool ReachedMaxFusionSize(int fused_op_count) override { + return fused_op_count >= kMaxOneDnnFusionSize; + } + + // oneDNN library does not support merging fusions. + // TODO(intel-tf): Evaluate if merging fusions has performance benefit for + // oneDNN. + bool ShouldMergeFusions() override { return false; } + // Returns a prefix string for the fusion op's name. std::string fusion_prefix() const override { return "onednn_"; } diff --git a/third_party/xla/xla/service/cpu/tests/BUILD b/third_party/xla/xla/service/cpu/tests/BUILD index e4c2f16e7be88f..d7b826c8a40d90 100644 --- a/third_party/xla/xla/service/cpu/tests/BUILD +++ b/third_party/xla/xla/service/cpu/tests/BUILD @@ -7,7 +7,7 @@ load("//xla/tests:build_defs.bzl", "xla_test") load("//xla/tsl:tsl.bzl", "tsl_copts") load("//xla/tsl:tsl.default.bzl", "filegroup") load("//xla/tsl/mkl:build_defs.bzl", "if_graph_api") -load("//xla/tsl/mkl:graph.bzl", "onednn_cc_test") +load("//xla/tsl/mkl:graph.bzl", "onednn_cc_test", "onednn_graph_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -460,16 +460,18 @@ xla_cc_test( ], ) -xla_cc_test( +onednn_graph_cc_test( name = "onednn_fusion_test", srcs = ["onednn_fusion_test.cc"], - local_defines = if_graph_api(["XLA_ONEDNN_USE_GRAPH_API=1"]), tags = ["pjrt_migration_candidate"], deps = [ "//xla:error_spec", + "//xla/backends/cpu:onednn_support", "//xla/service:cpu_plugin", + "//xla/service/cpu:onednn_util", "//xla/tests:hlo_test_base", "//xla/tsl/platform:test", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", "@tsl//tsl/platform:platform_port", diff --git a/third_party/xla/xla/service/cpu/tests/onednn_fusion_test.cc b/third_party/xla/xla/service/cpu/tests/onednn_fusion_test.cc index 9e26ce45dcba31..dd73dd601b45f0 100644 --- a/third_party/xla/xla/service/cpu/tests/onednn_fusion_test.cc +++ b/third_party/xla/xla/service/cpu/tests/onednn_fusion_test.cc @@ -14,8 +14,13 @@ limitations under the License. ==============================================================================*/ #include +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" +#include "xla/backends/cpu/onednn_support.h" #include "xla/error_spec.h" +#include "xla/service/cpu/onednn_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/platform/test.h" #include "tsl/platform/cpu_info.h" @@ -23,8 +28,6 @@ limitations under the License. namespace xla::cpu { namespace { -using OneDnnFusionTest = HloTestBase; - inline constexpr bool IsOneDnnGraphEnabled() { #if defined(XLA_ONEDNN_USE_GRAPH_API) // Some Aarch64 CPUs have failures. Only test on x86 for now. @@ -33,129 +36,280 @@ inline constexpr bool IsOneDnnGraphEnabled() { return false; } -TEST_F(OneDnnFusionTest, Exponential) { - constexpr absl::string_view kModuleStr = R"( - HloModule exp +struct OneDnnFusionTestParams { + PrimitiveType dtype; + std::string op_type; +}; - onednn_fusion { - %p0 = f32[4] parameter(0) - ROOT %exp = f32[4] exponential(%p0) +class OneDnnFusionTestBase + : public HloTestBase, + public ::testing::WithParamInterface { + protected: + void SetUp() override { + OneDnnFusionTestParams params = GetParam(); + data_type_ = params.dtype; + op_type_ = params.op_type; + atol_ = (data_type_ == F32) ? 1e-4 : 1e-2; + rtol_ = (data_type_ == F32) ? 1e-4 : 1e-2; + + if (!IsOneDnnGraphEnabled()) { + GTEST_SKIP() << "oneDNN fusion is not supported"; } - ENTRY entry { - %p0 = f32[4] parameter(0) - ROOT %fusion = f32[4] fusion(%p0), kind=kCustom, calls=onednn_fusion, - backend_config={"fusion_config": {kind: "__onednn_fusion"}} - })"; + if (!IsSupportedType(data_type_)) { + GTEST_SKIP() << "CPU does not support dtype: " + << primitive_util::LowercasePrimitiveTypeName(data_type_); + } + } + + PrimitiveType data_type_; + std::string op_type_; + float atol_; + float rtol_; +}; - if (!IsOneDnnGraphEnabled()) { - GTEST_SKIP() << "oneDNN fusion is not supported"; +class OneDnnFusionBinaryOpTest : public OneDnnFusionTestBase { + public: + static std::string Name( + const ::testing::TestParamInfo& data) { + return absl::StrCat( + data.param.op_type, "_", + primitive_util::LowercasePrimitiveTypeName(data.param.dtype)); } - EXPECT_TRUE(RunAndCompare(kModuleStr, ErrorSpec{1e-5})); -} + protected: + void RunTest() { + std::string hlo_template = + (op_type_ == "dot") ? GetMatMulHLOTemplate() : GetBinaryOpHLOTemplate(); + std::string hlo_binary_str = absl::StrReplaceAll( + hlo_template, + {{"$dtype", primitive_util::LowercasePrimitiveTypeName(data_type_)}, + {"$op_type", op_type_}}); + EXPECT_TRUE(RunAndCompare(hlo_binary_str, ErrorSpec{atol_, rtol_})); + } -// TODO(penporn): Make a parameterized BinaryEltwiseOp test instead. -TEST_F(OneDnnFusionTest, Add) { - constexpr absl::string_view kModuleStr = R"( - HloModule add + private: + const std::string GetBinaryOpHLOTemplate() { + return R"( + HloModule binary_op onednn_fusion { - %p0 = f32[10] parameter(0) - %p1 = f32[10] parameter(1) - ROOT %add = f32[10] add(%p0, %p1) + %p0 = $dtype[10, 20] parameter(0) + %p1 = $dtype[10, 20] parameter(1) + ROOT %op = $dtype[10, 20] $op_type(%p0, %p1) } ENTRY entry { - %p0 = f32[10] parameter(0) - %p1 = f32[10] parameter(1) - ROOT %fusion = f32[10] fusion(%p0, %p1), kind=kCustom, calls=onednn_fusion, + %p0 = $dtype[10, 20] parameter(0) + %p1 = $dtype[10, 20] parameter(1) + ROOT %fusion = $dtype[10, 20] fusion(%p0, %p1), kind=kCustom, + calls=onednn_fusion, backend_config={"fusion_config": {kind: "__onednn_fusion"}} })"; - - if (!IsOneDnnGraphEnabled()) { - GTEST_SKIP() << "oneDNN fusion is not supported"; } - EXPECT_TRUE(RunAndCompare(kModuleStr, ErrorSpec{1e-5})); -} - -TEST_F(OneDnnFusionTest, Mul) { - constexpr absl::string_view kModuleStr = R"( - HloModule mul + const std::string GetMatMulHLOTemplate() { + return R"( + HloModule matmul onednn_fusion { - %p0 = f32[10] parameter(0) - %p1 = f32[10] parameter(1) - ROOT %mul = f32[10] multiply(%p0, %p1) + %p0 = $dtype[1000,200] parameter(0) + %p1 = $dtype[200,300] parameter(1) + ROOT %mul = $dtype[1000,300] $op_type(%p0, %p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} } ENTRY entry { - %p0 = f32[10] parameter(0) - %p1 = f32[10] parameter(1) - ROOT %fusion = f32[10] fusion(%p0, %p1), kind=kCustom, calls=onednn_fusion, + %p0 = $dtype[1000,200] parameter(0) + %p1 = $dtype[200,300] parameter(1) + ROOT %fusion = $dtype[1000,300] fusion(%p0, %p1), kind=kCustom, + calls=onednn_fusion, backend_config={"fusion_config": {kind: "__onednn_fusion"}} })"; + } +}; + +TEST_P(OneDnnFusionBinaryOpTest, BinaryOp) { RunTest(); } + +std::vector GetOneDnnFusionBinaryOpTestSpecs() { + std::vector specs; + for (const auto& op_type : GetOneDnnSupportedBinaryOpsStrings()) { + specs.push_back({PrimitiveType::F32, std::string(op_type)}); + } + specs.push_back({PrimitiveType::BF16, "dot"}); + specs.push_back({PrimitiveType::F16, "dot"}); + return specs; +} + +INSTANTIATE_TEST_SUITE_P( + OneDnnFusionBinaryOpTestSuite, OneDnnFusionBinaryOpTest, + ::testing::ValuesIn(GetOneDnnFusionBinaryOpTestSpecs()), + OneDnnFusionBinaryOpTest::Name); + +class OneDnnFusionUnaryOpTest : public OneDnnFusionTestBase { + public: + static std::string Name( + const ::testing::TestParamInfo& data) { + return absl::StrCat( + data.param.op_type, "_", + primitive_util::LowercasePrimitiveTypeName(data.param.dtype)); + } - if (!IsOneDnnGraphEnabled()) { - GTEST_SKIP() << "oneDNN fusion is not supported"; + protected: + void RunTest() { + absl::string_view hlo_unary_template = R"( + HloModule unary_op + + onednn_fusion { + %p0 = $dtype[40] parameter(0) + ROOT %op = $dtype[40] $op_type(%p0) + } + + ENTRY entry { + %p0 = $dtype[40] parameter(0) + ROOT %fusion = $dtype[40] fusion(%p0), kind=kCustom, + calls=onednn_fusion, + backend_config={"fusion_config": {kind: "__onednn_fusion"}} + })"; + + std::string hlo_unary_str = absl::StrReplaceAll( + hlo_unary_template, + {{"$dtype", primitive_util::LowercasePrimitiveTypeName(data_type_)}, + {"$op_type", op_type_}}); + + EXPECT_TRUE(RunAndCompare(hlo_unary_str, ErrorSpec{atol_, rtol_})); } +}; - EXPECT_TRUE(RunAndCompare(kModuleStr, ErrorSpec{1e-5})); +TEST_P(OneDnnFusionUnaryOpTest, UnaryOp) { RunTest(); } + +std::vector GetOneDnnFusionUnaryOpTestSpecs() { + std::vector specs; + for (const auto& op_type : GetOneDnnSupportedUnaryOpsStrings()) { + specs.push_back({PrimitiveType::F32, std::string(op_type)}); + } + return specs; } -TEST_F(OneDnnFusionTest, MatMul) { - constexpr absl::string_view kModuleStr = R"( - HloModule mul +INSTANTIATE_TEST_SUITE_P(OneDnnFusionUnaryOpTestSuite, OneDnnFusionUnaryOpTest, + ::testing::ValuesIn(GetOneDnnFusionUnaryOpTestSpecs()), + OneDnnFusionUnaryOpTest::Name); + +class OneDnnFusionMatMulFuseBinaryTest : public OneDnnFusionTestBase { + public: + static std::string Name( + const ::testing::TestParamInfo& data) { + return absl::StrCat( + data.param.op_type, "_", + primitive_util::LowercasePrimitiveTypeName(data.param.dtype)); + } + + protected: + void RunTest() { + absl::string_view hlo_fusion_template = R"( + HloModule matmul_fusion onednn_fusion { - %p0 = f32[10,20] parameter(0) - %p1 = f32[20,30] parameter(1) - ROOT %mul = f32[10,30] dot(%p0, %p1), + %p0 = $dtype[1000,200] parameter(0) + %p1 = $dtype[200,300] parameter(1) + %p2 = $dtype[1000,300] parameter(2) + %dot = $dtype[1000,300] dot(%p0, %p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT %root = $dtype[1000,300] $op_type(%dot, %p2) } ENTRY entry { - %p0 = f32[10,20] parameter(0) - %p1 = f32[20,30] parameter(1) - ROOT %fusion = f32[10,30] fusion(%p0, %p1), kind=kCustom, + %p0 = $dtype[1000,200] parameter(0) + %p1 = $dtype[200,300] parameter(1) + %p2 = $dtype[1000,300] parameter(2) + ROOT %fusion = $dtype[1000,300] fusion(%p0, %p1, %p2), kind=kCustom, calls=onednn_fusion, backend_config={"fusion_config": {kind: "__onednn_fusion"}} - })"; + })"; - if (!IsOneDnnGraphEnabled()) { - GTEST_SKIP() << "oneDNN fusion is not supported"; + std::string hlo_fusion_str = absl::StrReplaceAll( + hlo_fusion_template, + {{"$dtype", primitive_util::LowercasePrimitiveTypeName(data_type_)}, + {"$op_type", op_type_}}); + + EXPECT_TRUE(RunAndCompare(hlo_fusion_str, ErrorSpec{atol_, rtol_})); } +}; + +TEST_P(OneDnnFusionMatMulFuseBinaryTest, MatmulFuseWith) { RunTest(); } - EXPECT_TRUE(RunAndCompare(kModuleStr, ErrorSpec{1e-5})); +std::vector GetOneDnnFusionFuseBinaryTestSpecs() { + std::vector specs; + for (const auto& dtype : {PrimitiveType::F32}) { + for (const auto& op_type : GetOneDnnSupportedBinaryOpsStrings()) { + // oneDNN does not support fusing two dot instructions + if (op_type == HloOpcodeString(HloOpcode::kDot)) continue; + specs.push_back({dtype, std::string(op_type)}); + } + } + return specs; } -TEST_F(OneDnnFusionTest, MatMulAdd) { - constexpr absl::string_view kModuleStr = R"( - HloModule mul +INSTANTIATE_TEST_SUITE_P( + OneDnnFusionMatMulFusionTestSuite, OneDnnFusionMatMulFuseBinaryTest, + ::testing::ValuesIn(GetOneDnnFusionFuseBinaryTestSpecs()), + OneDnnFusionMatMulFuseBinaryTest::Name); + +class OneDnnFusionMatMulFuseUnaryTest : public OneDnnFusionTestBase { + public: + static std::string Name( + const ::testing::TestParamInfo& data) { + return absl::StrCat( + data.param.op_type, "_", + primitive_util::LowercasePrimitiveTypeName(data.param.dtype)); + } + + protected: + void RunTest() { + absl::string_view hlo_fusion_template = R"( + HloModule matmul_fusion + onednn_fusion { - %p0 = f32[10,20] parameter(0) - %p1 = f32[20,30] parameter(1) - %dot = f32[10,30] dot(%p0, %p1), + %p0 = $dtype[1000,200] parameter(0) + %p1 = $dtype[200,300] parameter(1) + %dot = $dtype[1000,300] dot(%p0, %p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} - %p2 = f32[10,30] parameter(2) - ROOT %add = f32[10,30] add(%dot, %p2) + ROOT %root = $dtype[1000,300] $op_type(%dot) } + ENTRY entry { - %p0 = f32[10,20] parameter(0) - %p1 = f32[20,30] parameter(1) - %p2 = f32[10,30] parameter(2) - ROOT %fusion = f32[10,30] fusion(%p0, %p1, %p2), kind=kCustom, + %p0 = $dtype[1000,200] parameter(0) + %p1 = $dtype[200,300] parameter(1) + ROOT %fusion = $dtype[1000,300] fusion(%p0, %p1), kind=kCustom, calls=onednn_fusion, backend_config={"fusion_config": {kind: "__onednn_fusion"}} - })"; + })"; + + std::string hlo_fusion_str = absl::StrReplaceAll( + hlo_fusion_template, + {{"$dtype", primitive_util::LowercasePrimitiveTypeName(data_type_)}, + {"$op_type", op_type_}}); - if (!IsOneDnnGraphEnabled()) { - GTEST_SKIP() << "oneDNN fusion is not supported"; + EXPECT_TRUE(RunAndCompare(hlo_fusion_str, ErrorSpec{atol_, rtol_})); } +}; + +TEST_P(OneDnnFusionMatMulFuseUnaryTest, MatmulFuseWith) { RunTest(); } - EXPECT_TRUE(RunAndCompare(kModuleStr, ErrorSpec{1e-5})); +std::vector GetOneDnnFusionFuseUnaryTestSpecs() { + std::vector specs; + for (const auto& dtype : {PrimitiveType::F32}) { + for (const auto& op_type : GetOneDnnSupportedUnaryOpsStrings()) { + specs.push_back({dtype, std::string(op_type)}); + } + } + return specs; } +INSTANTIATE_TEST_SUITE_P( + OneDnnFusionMatMulFusionTestSuite, OneDnnFusionMatMulFuseUnaryTest, + ::testing::ValuesIn(GetOneDnnFusionFuseUnaryTestSpecs()), + OneDnnFusionMatMulFuseUnaryTest::Name); + } // namespace } // namespace xla::cpu