@@ -17,12 +17,9 @@ limitations under the License.
17
17
18
18
#include < cstddef>
19
19
#include < cstdint>
20
- #include < utility>
21
- #include < vector>
22
20
23
21
#include " oneapi/dnnl/dnnl_common.hpp"
24
22
#include " oneapi/dnnl/dnnl_graph.hpp"
25
- #include " absl/container/flat_hash_map.h"
26
23
#include " absl/functional/any_invocable.h"
27
24
#include " absl/log/log.h"
28
25
#include " absl/status/statusor.h"
@@ -69,6 +66,19 @@ static absl::StatusOr<dnnl::graph::op::kind> OneDnnUnaryOperator(
69
66
}
70
67
}
71
68
69
+ static absl::StatusOr<dnnl::graph::op::kind> OneDnnBinaryOperator (
70
+ const HloOpcode& opcode) {
71
+ switch (opcode) {
72
+ case HloOpcode::kAdd :
73
+ return dnnl::graph::op::kind::Add;
74
+ case HloOpcode::kMultiply :
75
+ return dnnl::graph::op::kind::Multiply;
76
+ default :
77
+ return InvalidArgument (" Unsupported oneDNN unary operator: %s" ,
78
+ HloOpcodeString (opcode));
79
+ }
80
+ }
81
+
72
82
static dnnl::graph::logical_tensor::dims OneDnnDimensions (const Shape& shape) {
73
83
dnnl::graph::logical_tensor::dims dims;
74
84
for (auto & dim : shape.dimensions ()) {
@@ -142,6 +152,32 @@ static absl::StatusOr<dnnl::graph::logical_tensor> DefineUnaryOp(
142
152
return output;
143
153
}
144
154
155
+ static absl::StatusOr<dnnl::graph::logical_tensor> DefineBinaryOp (
156
+ dnnl::graph::graph& graph, size_t op_id, LogicalTensorMap& logical_tensors,
157
+ const HloInstruction* instr) {
158
+ VLOG (3 ) << absl::StreamFormat (" Define logical tensor value for matmul: %s" ,
159
+ instr->ToString ());
160
+
161
+ TF_ASSIGN_OR_RETURN (auto binary_op, OneDnnBinaryOperator (instr->opcode ()));
162
+
163
+ TF_ASSIGN_OR_RETURN (auto lhs,
164
+ FindLogicalTensor (logical_tensors, instr->operand (0 )));
165
+ TF_ASSIGN_OR_RETURN (auto rhs,
166
+ FindLogicalTensor (logical_tensors, instr->operand (1 )));
167
+
168
+ size_t output_id = logical_tensors.size ();
169
+ TF_ASSIGN_OR_RETURN (auto output,
170
+ CreateLogicalTensor (output_id, instr->shape ()));
171
+
172
+ VLOG (3 ) << absl::StreamFormat (" tensors: lhs=%d, rhs=%d, output=%d" ,
173
+ lhs.get_id (), rhs.get_id (), output.get_id ());
174
+
175
+ dnnl::graph::op op (op_id, binary_op, {lhs, rhs}, {output});
176
+ ONEDNN_RETURN_IF_ERROR (graph.add_op (op));
177
+
178
+ return output;
179
+ }
180
+
145
181
// ===----------------------------------------------------------------------===//
146
182
// Emit oneDNN graph for the given HLO computation.
147
183
// ===----------------------------------------------------------------------===//
@@ -165,12 +201,21 @@ static absl::StatusOr<OneDnnFusion> EmitOneDnnFusion(
165
201
TF_ASSIGN_OR_RETURN (logical_tensors[instr], DefineParameter (instr));
166
202
} break ;
167
203
204
+ // Unary elementwise ops.
168
205
case HloOpcode::kExp : {
169
206
TF_ASSIGN_OR_RETURN (
170
207
logical_tensors[instr],
171
208
DefineUnaryOp (graph, op_id++, logical_tensors, instr));
172
209
} break ;
173
210
211
+ // Binary elementwise ops.
212
+ case HloOpcode::kAdd :
213
+ case HloOpcode::kMultiply : {
214
+ TF_ASSIGN_OR_RETURN (
215
+ logical_tensors[instr],
216
+ DefineBinaryOp (graph, op_id++, logical_tensors, instr));
217
+ } break ;
218
+
174
219
default : {
175
220
return InvalidArgument (" Unsupported oneDNN fusion instruction: %s" ,
176
221
instr->ToString ());
0 commit comments