8000 [xla:cpu:onednn] Support elementwise Add and Mul in oneDNN fusion thunk · linux-on-ibm-z/tensorflow@ca77b1a · GitHub
[go: up one dir, main page]

Skip to content

Commit ca77b1a

Browse files
penpornktensorflower-gardener
authored andcommitted
[xla:cpu:onednn] Support elementwise Add and Mul in oneDNN fusion thunk
PiperOrigin-RevId: 730899327
1 parent c42688e commit ca77b1a

File tree

2 files changed

+107
-6
lines changed

2 files changed

+107
-6
lines changed

third_party/xla/xla/backends/cpu/onednn_emitter.cc

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,9 @@ limitations under the License.
1717

1818
#include <cstddef>
1919
#include <cstdint>
20-
#include <utility>
21-
#include <vector>
2220

2321
#include "oneapi/dnnl/dnnl_common.hpp"
2422
#include "oneapi/dnnl/dnnl_graph.hpp"
25-
#include "absl/container/flat_hash_map.h"
2623
#include "absl/functional/any_invocable.h"
2724
#include "absl/log/log.h"
2825
#include "absl/status/statusor.h"
@@ -69,6 +66,19 @@ static absl::StatusOr<dnnl::graph::op::kind> OneDnnUnaryOperator(
6966
}
7067
}
7168

69+
static absl::StatusOr<dnnl::graph::op::kind> OneDnnBinaryOperator(
70+
const HloOpcode& opcode) {
71+
switch (opcode) {
72+
case HloOpcode::kAdd:
73+
return dnnl::graph::op::kind::Add;
74+
case HloOpcode::kMultiply:
75+
return dnnl::graph::op::kind::Multiply;
76+
default:
77+
return InvalidArgument("Unsupported oneDNN unary operator: %s",
78+
HloOpcodeString(opcode));
79+
}
80+
}
81+
7282
static dnnl::graph::logical_tensor::dims OneDnnDimensions(const Shape& shape) {
7383
dnnl::graph::logical_tensor::dims dims;
7484
for (auto& dim : shape.dimensions()) {
@@ -142,6 +152,32 @@ static absl::StatusOr<dnnl::graph::logical_tensor> DefineUnaryOp(
142152
return output;
143153
}
144154

155+
static absl::StatusOr<dnnl::graph::logical_tensor> DefineBinaryOp(
156+
dnnl::graph::graph& graph, size_t op_id, LogicalTensorMap& logical_tensors,
157+
const HloInstruction* instr) {
158+
VLOG(3) << absl::StreamFormat("Define logical tensor value for matmul: %s",
159+
instr->ToString());
160+
161+
TF_ASSIGN_OR_RETURN(auto binary_op, OneDnnBinaryOperator(instr->opcode()));
162+
163+
TF_ASSIGN_OR_RETURN(auto lhs,
164+
FindLogicalTensor(logical_tensors, instr->operand(0)));
165+
TF_ASSIGN_OR_RETURN(auto rhs,
166+
FindLogicalTensor(logical_tensors, instr->operand(1)));
167+
168+
size_t output_id = logical_tensors.size();
169+
TF_ASSIGN_OR_RETURN(auto output,
170+
CreateLogicalTensor(output_id, instr->shape()));
171+
172+
VLOG(3) << absl::StreamFormat(" tensors: lhs=%d, rhs=%d, output=%d",
173+
lhs.get_id(), rhs.get_id(), output.get_id());
174+
175+
dnnl::graph::op op(op_id, binary_op, {lhs, rhs}, {output});
176+
ONEDNN_RETURN_IF_ERROR(graph.add_op(op));
177+
178+
return output;
179+
}
180+
145181
//===----------------------------------------------------------------------===//
146182
// Emit oneDNN graph for the given HLO computation.
147183
//===----------------------------------------------------------------------===//
@@ -165,12 +201,21 @@ static absl::StatusOr<OneDnnFusion> EmitOneDnnFusion(
165201
TF_ASSIGN_OR_RETURN(logical_tensors[instr], DefineParameter(instr));
166202
} break;
167203

204+
// Unary elementwise ops.
168205
case HloOpcode::kExp: {
169206
TF_ASSIGN_OR_RETURN(
170207
logical_tensors[instr],
171208
DefineUnaryOp(graph, op_id++, logical_tensors, instr));
172209
} break;
173210

211+
// Binary elementwise ops.
212+
case HloOpcode::kAdd:
213+
case HloOpcode::kMultiply: {
214+
TF_ASSIGN_OR_RETURN(
215+
logical_tensors[instr],
216+
DefineBinaryOp(graph, op_id++, logical_tensors, instr));
217+
} break;
218+
174219
default: {
175220
return InvalidArgument("Unsupported oneDNN fusion instruction: %s",
176221
instr->ToString());

third_party/xla/xla/service/cpu/tests/onednn_fusion_test.cc

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ namespace {
2424

2525
using OneDnnFusionTest = HloTestBase;
2626

27+
inline constexpr bool IsOneDnnGraphEnabled() {
28+
#if defined(XLA_ONEDNN_USE_GRAPH_API)
29+
return true;
30+
#endif // XLA_ONEDNN_USE_GRAPH_API
31+
return false;
32+
}
33+
2734
TEST_F(OneDnnFusionTest, Exponential) {
2835
constexpr absl::string_view kModuleStr = R"(
2936
HloModule exp
@@ -39,9 +46,58 @@ TEST_F(OneDnnFusionTest, Exponential) {
3946
backend_config={"fusion_config": {kind: "__onednn_fusion"}}
4047
})";
4148

42-
#if !defined(XLA_ONEDNN_USE_GRAPH_API)
43-
GTEST_SKIP() << "oneDNN fusion is not supported";
44-
#endif // XLA_ONEDNN_USE_GRAPH_API
49+
if (!IsOneDnnGraphEnabled()) {
50+
GTEST_SKIP() << "oneDNN fusion is not supported";
51+
}
52+
53+
EXPECT_TRUE(RunAndCompare(kModuleStr, ErrorSpec{1e-5}));
54+
}
55+
56+
// TODO(penporn): Make a parameterized BinaryEltwiseOp test instead.
57+
TEST_F(OneDnnFusionTest, Add) {
58+
constexpr absl::string_view kModuleStr = R"(
59+
HloModule add
60+
61+
onednn_fusion {
62+
%p0 = f32[10] parameter(0)
63+
%p1 = f32[10] parameter(1)
64+
ROOT %add = f32[10] add(%p0, %p1)
65+
}
66+
67+
ENTRY entry {
68+
%p0 = f32[10] parameter(0)
69+
%p1 = f32[10] parameter(1)
70+
ROOT %fusion = f32[10] fusion(%p0, %p1), kind=kCustom, calls=onednn_fusion,
71+
backend_config={"fusion_config": {kind: "__onednn_fusion"}}
72+
})";
73+
74+
if (!IsOneDnnGraphEnabled()) {
75+
GTEST_SKIP() << "oneDNN fusion is not supported";
76+
}
77+
78+
EXPECT_TRUE(RunAndCompare(kModuleStr, ErrorSpec{1e-5}));
79+
}
80+
81+
TEST_F(OneDnnFusionTest, Mul) {
82+
constexpr absl::string_view kModuleStr = R"(
83+
HloModule mul
84+
85+
onednn_fusion {
86+
%p0 = f32[10] parameter(0)
87+
%p1 = f32[10] parameter(1)
88+
ROOT %mul = f32[10] multiply(%p0, %p1)
89+
}
90+
91+
ENTRY entry {
92+
%p0 = f32[10] parameter(0)
93+
%p1 = f32[10] parameter(1)
94+
ROOT %fusion = f32[10] fusion(%p0, %p1), kind=kCustom, calls=onednn_fusion,
95+
backend_config={"fusion_config": {kind: "__onednn_fusion"}}
96+
})";
97+
98+
if (!IsOneDnnGraphEnabled()) {
99+
GTEST_SKIP() << "oneDNN fusion is not supported";
100+
}
45101

46102
EXPECT_TRUE(RunAndCompare(kModuleStr, ErrorSpec{1e-5}));
47103
}

0 commit comments

Comments
 (0)
0