8000 PR #32521: [XLA:CPU][oneDNN] Update dot plus elementwise op fusion for oneDNN by copybara-service[bot] · Pull Request #112137 · tensorflow/tensorflow · GitHub
[go: up one dir, main page]

Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions third_party/xla/xla/backends/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ onednn_graph_cc_library(
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/codegen:target_machine_features",
"//xla/backends/cpu/runtime:dot_dims",
"//xla/hlo/ir:hlo",
"//xla/service/cpu:onednn_util",
"//xla/tsl/mkl:onednn",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
Expand Down
40 changes: 11 additions & 29 deletions third_party/xla/xla/backends/cpu/onednn_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,32 +67,6 @@ static absl::StatusOr<dnnl::graph::logical_tensor::data_type> OneDnnDatatype(
}
}

static absl::StatusOr<dnnl::graph::op::kind> OneDnnUnaryOperator(
const HloOpcode& opcode) {
switch (opcode) {
case HloOpcode::kExp:
return dnnl::graph::op::kind::Exp;
default:
return InvalidArgument("Unsupported oneDNN unary operator: %s",
HloOpcodeString(opcode));
}
}

static absl::StatusOr<dnnl::graph::op::kind> OneDnnBinaryOperator(
const HloOpcode& opcode) {
switch (opcode) {
case HloOpcode::kAdd:
return dnnl::graph::op::kind::Add;
case HloOpcode::kMultiply:
return dnnl::graph::op::kind::Multiply;
case HloOpcode::kDot:
return dnnl::graph::op::kind::MatMul;
default:
return InvalidArgument("Unsupported oneDNN unary operator: %s",
HloOpcodeString(opcode));
}
}

static dnnl::graph::logical_tensor::dims OneDnnDimensions(const Shape& shape) {
dnnl::graph::logical_tensor::dims dims;
for (auto& dim : shape.dimensions()) {
Expand Down Expand Up @@ -201,7 +175,7 @@ static absl::StatusOr<dnnl::graph::logical_tensor> DefineMatMul(
const Shape& rhs_shape = instr->operand(1)->shape();
TF_ASSIGN_OR_RETURN(
bool is_supported,
IsOneDnnDotSupported(dnums, lhs_shape, rhs_shape, instr->shape()));
IsDotSupportedByOneDnn(dnums, lhs_shape, rhs_shape, instr->shape()));

if (!is_supported) {
return InvalidArgument("Unsupported oneDNN Dot op variation: %s",
Expand Down Expand Up @@ -268,15 +242,23 @@ static absl::StatusOr<OneDnnFusion> EmitOneDnnFusion(
} break;

// Unary elementwise ops.
case HloOpcode::kExp: {
case HloOpcode::kAbs:
case HloOpcode::kExp:
case HloOpcode::kLog:
case HloOpcode::kSqrt:
case HloOpcode::kTanh: {
TF_ASSIGN_OR_RETURN(
logical_tensors[instr],
DefineUnaryOp(graph, op_id++, logical_tensors, instr));
} break;

// Binary elementwise ops.
case HloOpcode::kAdd:
case HloOpcode::kMultiply: {
case HloOpcode::kDivide:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kSubtract: {
TF_ASSIGN_OR_RETURN(
logical_tensors[instr],
DefineBinaryOp(graph, op_id++, logical_tensors, instr));
Expand Down
182 changes: 150 additions & 32 deletions third_party/xla/xla/backends/cpu/onednn_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,57 +15,57 @@ limitations under the License.

#include "xla/backends/cpu/onednn_support.h"

#include "absl/base/no_destructor.h"
#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "dnnl.hpp" // NOLINT: for DNNL_MAX_NDIMS
#include "oneapi/dnnl/dnnl.hpp" // NOLINT: for DNNL_MAX_NDIMS
#include "oneapi/dnnl/dnnl_graph.hpp"
#include "xla/backends/cpu/codegen/target_machine_features.h"
#include "xla/backends/cpu/runtime/dot_dims.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/layout_util.h"
#include "xla/service/cpu/onednn_util.h"
#include "xla/shape.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/cpu_info.h"

namespace xla::cpu {

bool IsOneDnnSupportedDType(PrimitiveType dtype) {
using tsl::port::CPUFeature;
switch (dtype) {
case F32:
return true;
case BF16:
return TestCPUFeature(CPUFeature::AVX512F) ||
TestCPUFeature(CPUFeature::AVX_NE_CONVERT) ||
TestCPUFeature(CPUFeature::AMX_BF16);
case F16:
return (TestCPUFeature(CPUFeature::AVX512BW) &&
(TestCPUFeature(CPUFeature::AVX512_FP16) ||
TestCPUFeature(CPUFeature::AMX_FP16))) ||
TestCPUFeature(CPUFeature::AVX_NE_CONVERT);
default:
return false;
}
}

bool IsOneDnnSupportedDType(PrimitiveType dtype,
const TargetMachineFeatures* cpu_features) {
if (dtype == F32) {
if (dtype == F32 || IsSupportedType(dtype)) {
return true;
}
// Check for data type support if target machine features are provided.
// Unit tests may provide target machine features to simulate different CPU
// capabilities.
return (cpu_features != nullptr &&
((dtype == BF16 && cpu_features->has_avx512bf16()) ||
(dtype == F16 && cpu_features->has_avx512fp16())));
}

if (cpu_features == nullptr) {
return IsOneDnnSupportedDType(dtype);
}
bool IsOneDnnSupportedLayout(const Shape& shape) {
return !shape.has_layout() || LayoutUtil::HasDescendingLayout(shape.layout());
}

if (dtype == BF16) {
return cpu_features->has_avx512bf16();
}
if (dtype == F16) {
return cpu_features->has_avx512fp16();
}
bool IsOneDnnSupportedTypeAndLayout(const HloInstruction* hlo,
const TargetMachineFeatures* cpu_features) {
auto is_supported = [cpu_features](const HloInstruction* hlo) {
return IsOneDnnSupportedDType(hlo->shape().element_type(), cpu_features) &&
IsOneDnnSupportedLayout(hlo->shape());
};

return false;
if (!is_supported(hlo)) {
return false;
}
return (std::all_of(hlo->operands().begin(), hlo->operands().end(),
is_supported));
}

absl::StatusOr<bool> IsOneDnnDotSupported(
absl::StatusOr<bool> IsDotSupportedByOneDnn(
const DotDimensionNum 4D24 bers& dot_dimensions, const Shape& lhs_shape,
const Shape& rhs_shape, const Shape& out_shape,
const TargetMachineFeatures* cpu_features) {
Expand Down Expand Up @@ -110,4 +110,122 @@ absl::StatusOr<bool> IsOneDnnDotSupported(
!dot_canonical_dims.rhs_column_major;
}

const absl::flat_hash_map<HloOpcode, op::kind>& GetOneDnnUnaryOpMap() {
static absl::NoDestructor<absl::flat_hash_map<HloOpcode, op::kind>>
unary_op_map({
{HloOpcode::kAbs, op::kind::Abs},
{HloOpcode::kExp, op::kind::Exp},
{HloOpcode::kLog, op::kind::Log},
{HloOpcode::kSqrt, op::kind::Sqrt},
{HloOpcode::kTanh, op::kind::Tanh},
});
return *unary_op_map;
}

absl::StatusOr<op::kind> OneDnnUnaryOperator(const HloOpcode& opcode) {
const auto& unary_op_map = GetOneDnnUnaryOpMap();
auto result = unary_op_map.find(opcode);
if (result == unary_op_map.end()) {
return InvalidArgument("Unsupported OneDNN unary operator: %s",
HloOpcodeString(opcode));
}
return result->second;
}

std::vector<absl::string_view> GetOneDnnSupportedUnaryOpsStrings() {
auto& unary_op_map = GetOneDnnUnaryOpMap();
std::vector<absl::string_view> op_names;
op_names.reserve(unary_op_map.size());
for (auto& pair : unary_op_map) {
op_names.push_back(HloOpcodeString(pair.first));
}
return op_names;
}

const absl::flat_hash_map<HloOpcode, op::kind>& GetOneDnnBinaryOpMap() {
static absl::NoDestructor<absl::flat_hash_map<HloOpcode, op::kind>>
binary_op_map({
{HloOpcode::kAdd, op::kind::Add},
{HloOpcode::kDivide, op::kind::Divide},
{HloOpcode::kDot, op::kind::MatMul},
{HloOpcode::kMaximum, op::kind::Maximum},
{HloOpcode::kMinimum, op::kind::Minimum},
{HloOpcode::kMultiply, op::kind::Multiply},
{HloOpcode::kSubtract, op::kind::Subtract},
});
return *binary_op_map;
}

absl::StatusOr<op::kind> OneDnnBinaryOperator(const HloOpcode& opcode) {
const auto& binary_op_map = GetOneDnnBinaryOpMap();
auto result = binary_op_map.find(opcode);
if (result == binary_op_map.end()) {
return InvalidArgument("Unsupported OneDNN binary operator: %s",
HloOpcodeString(opcode));
}
return result->second;
}

std::vector<absl::string_view> GetOneDnnSupportedBinaryOpsStrings() {
auto& binary_op_map = GetOneDnnBinaryOpMap();
std::vector<absl::string_view> op_names;
op_names.reserve(binary_op_map.size());
for (auto& pair : binary_op_map) {
op_names.push_back(HloOpcodeString(pair.first));
}
return op_names;
}

bool IsOpSupportedByOneDnn(const HloInstruction* hlo,
const TargetMachineFeatures* cpu_features) {
if (!OneDnnBinaryOperator(hlo->opcode()).ok() &&
!OneDnnUnaryOperator(hlo->opcode()).ok()) {
return false;
}
if (hlo->opcode() == HloOpcode::kDot) {
return IsDotSupportedByOneDnn(
hlo->dot_dimension_numbers(), hlo->operand(0)->shape(),
hlo->operand(1)->shape(), hlo->shape(), cpu_features)
.value_or(false);
}
if (hlo->opcode() == HloOpcode::kBitcast) {
return IsBitcastOpSupportedByOneDnn(hlo, cpu_features);
}

return IsOneDnnSupportedTypeAndLayout(hlo, cpu_features);
}

bool IsConstantSupportedByOneDnn(const HloInstruction* hlo,
const TargetMachineFeatures* cpu_features) {
CHECK(hlo->IsConstant());
return IsOneDnnSupportedDType(hlo->shape().element_type(), cpu_features) &&
IsOneDnnSupportedLayout(hlo->shape());
}

bool IsBitcastOpSupportedByOneDnn(const HloInstruction* hlo,
const TargetMachineFeatures* cpu_features) {
CHECK_EQ(hlo->opcode(), HloOpcode::kBitcast);
if (!IsOneDnnSupportedTypeAndLayout(hlo, cpu_features)) {
return false;
}
const HloInstruction* input = hlo->operand(0);
return hlo->shape().element_type() == input->shape().element_type();
}

bool IsElementwiseOpSupportedByOneDnn(
const HloInstruction* hlo, const TargetMachineFeatures* cpu_features) {
CHECK(hlo->IsElementwise());
if (hlo->IsConstant()) {
return IsConstantSupportedByOneDnn(hlo, cpu_features);
}
if (hlo->opcode() == HloOpcode::kBitcast) {
return IsBitcastOpSupportedByOneDnn(hlo, cpu_features);
}
if (!OneDnnBinaryOperator(hlo->opcode()).ok() &&
!OneDnnUnaryOperator(hlo->opcode()).ok()) {
return false;
}
return IsOneDnnSupportedTypeAndLayout(hlo, cpu_features);
}

} // namespace xla::cpu
56 changes: 54 additions & 2 deletions third_party/xla/xla/backends/cpu/onednn_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,77 @@ limitations under the License.

#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "oneapi/dnnl/dnnl.hpp"
#include "oneapi/dnnl/dnnl_graph.hpp"
#include "xla/backends/cpu/codegen/target_machine_features.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/shape.h"
#include "xla/xla_data.pb.h"

namespace xla::cpu {

inline constexpr absl::string_view kOneDnnFusionKind = "__onednn_fusion";
using dnnl::graph::op;

bool IsOneDnnSupportedDType(PrimitiveType dtype);
bool IsOneDnnSupportedDType(PrimitiveType dtype,
const TargetMachineFeatures* cpu_features);

// Returns true if the shape doesn't have a layout or the layout is descending.
bool IsOneDnnSupportedLayout(const Shape& shape);

// Returns true if the HLO instruction and all its operands have supported data
// types and layouts.
bool IsOneDnnSupportedTypeAndLayout(
const HloInstruction* hlo,
const TargetMachineFeatures* cpu_features = nullptr);

// Returns true if the dot operation is supported by oneDNN. Returns an error
// if the dot operation shape is invalid.
absl::StatusOr<bool> IsOneDnnDotSupported(
absl::StatusOr<bool> IsDotSupportedByOneDnn(
const DotDimensionNumbers& dot_dimensions, const Shape& lhs_shape,
const Shape& rhs_shape, const Shape& out_shape,
const TargetMachineFeatures* cpu_features = nullptr);

// Returns the mappings from HLO opcodes to OneDNN unary operators.
const absl::flat_hash_map<HloOpcode, op::kind>& GetOneDnnUnaryOpMap();

// Returns the OneDNN unary operator corresponding to the given HLO opcode.
// Returns `InvalidArgument` if the opcode is not supported.
absl::StatusOr<op::kind> OneDnnUnaryOperator(const HloOpcode& opcode);

// Returns the names of the OneDNN supported HLO unary ops.
std::vector<absl::string_view> GetOneDnnSupportedUnaryOpsStrings();

// Returns the mappings from HLO opcodes to OneDNN binary operators.
const absl::flat_hash_map<HloOpcode, op::kind>& GetOneDnnBinaryOpMap();

// Returns the OneDNN binary operator corresponding to the given HLO opcode.
// Returns `InvalidArgument` if the opcode is not supported.
absl::StatusOr<op::kind> OneDnnBinaryOperator(const HloOpcode& opcode);

// Returns the names of the OneDNN supported HLO binary ops.
std::vector<absl::string_view> GetOneDnnSupportedBinaryOpsStrings();

// Returns true if the HLO op is supported by OneDNN.
bool IsOpSupportedByOneDnn(const HloInstruction* hlo,
const TargetMachineFeatures* cpu_features = nullptr);

// Returns true if the constant is supported by OneDNN.
bool IsConstantSupportedByOneDnn(
const HloInstruction* hlo,
const TargetMachineFeatures* cpu_features = nullptr);

// Returns true if the bitcast op is supported by OneDNN.
bool IsBitcastOpSupportedByOneDnn(
const HloInstruction* hlo,
const TargetMachineFeatures* cpu_features = nullptr);

// Returns true if the elementwise op is supported by OneDNN.
bool IsElementwiseOpSupportedByOneDnn(
const HloInstruction* hlo,
const TargetMachineFeatures* cpu_features = nullptr);
} // namespace xla::cpu

#endif // XLA_BACKENDS_CPU_ONEDNN_SUPPORT_H_
7 changes: 7 additions & 0 deletions third_party/xla/xla/backends/cpu/transforms/library_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ class LibraryMatcher {
return instr->shape().element_type();
}

// Returns true if there is a limit on the number of ops in the fusion and
// the maximum fusion size is already reached.
virtual bool ReachedMaxFusionSize(int fused_op_count) { return false; }

// Return true if the library supports merging fusions.
virtual bool ShouldMergeFusions() { return true; }

// Returns a prefix string for the fusion op's name.
virtual std::string fusion_prefix() const { return ""; }

Expand Down
Loading
Loading
0