10000 Move XnnDatatype to xnn_fusion.h · IBMZ-Linux-OSS-Python/tensorflow@45473f6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 45473f6

Browse files
Move XnnDatatype to xnn_fusion.h
PiperOrigin-RevId: 766407853
1 parent 1171000 commit 45473f6

File tree

6 files changed

+50
-20
lines changed

6 files changed

+50
-20
lines changed

third_party/xla/xla/backends/cpu/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,13 @@ cc_library(
103103
hdrs = ["xnn_fusion.h"],
104104
deps = [
105105
"//xla:shape_util",
106+
"//xla:util",
106107
"//xla:xla_data_proto_cc",
107108
"//xla/backends/cpu/codegen:target_machine_features",
108109
"//xla/backends/cpu/runtime:dot_lib",
109110
"//xla/hlo/ir:hlo",
110111
"//xla/tsl/platform:statusor",
112+
"@XNNPACK",
111113
"@com_google_absl//absl/algorithm:container",
112114
"@com_google_absl//absl/container:flat_hash_set",
113115
"@com_google_absl//absl/status:statusor",

third_party/xla/xla/backends/cpu/transforms/xnn_graph_fusion.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ HloInstruction* XnnGraphFusion::Fuse(HloInstruction* producer,
6565
}
6666

6767
bool XnnGraphFusion::IsOpSupported(HloInstruction* instr) const {
68+
if (!XnnDatatype(instr->shape().element_type()).ok()) {
69+
return false;
70+
}
71+
6872
switch (instr->opcode()) {
6973
case HloOpcode::kAdd:
7074
case HloOpcode::kSubtract:

third_party/xla/xla/backends/cpu/transforms/xnn_graph_fusion_test.cc

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,12 @@ TEST_F(XnnGraphFusionTest, BasicFusion) {
4343
HloModule FusionDemonstration
4444
4545
ENTRY entry {
46-
%param.0 = f32[2,2]{1,0} parameter(0)
47-
%constant.0 = f32[2,2]{1,0} constant({ { 1, 2 }, { 3, 4 } })
48-
%add.0 = f32[2,2]{1,0} add(f32[2,2]{1,0} %param.0, f32[2,2]{1,0} %constant.0)
49-
%sub.0 = f32[2,2]{1,0} subtract(f32[2,2]{1,0} %param.0, f32[2,2]{1,0} %constant.0)
50-
ROOT %result = f32[2,2]{1,0} multiply(f32[2,2]{1,0} %add.0, f32[2,2]{1,0} %sub.0)
46+
%param.0 = f32[2,2] parameter(0)
47+
%constant.0 = f32[2,2] constant({ { 1, 2 }, { 3, 4 } })
48+
%add.0 = f32[2,2] add(f32[2,2] %param.0, f32[2,2]{1,0} %constant.0)
49+
%sub.0 = f32[2,2] subtract(f32[2,2] %param.0, f32[2,2] %constant.0)
50+
ROOT %result = f32[2,2] multiply(f32[2,2] %add.0, f32[2,2] %sub.0)
5151
}
52-
5352
)";
5453

5554
TF_ASSERT_OK_AND_ASSIGN(auto module,
@@ -67,5 +66,24 @@ ENTRY entry {
6766
EXPECT_EQ(backend_config.fusion_config().kind(), kXnnFusionKind);
6867
}
6968

69+
TEST_F(XnnGraphFusionTest, BasicFusionUnsupportedType) {
70+
std::string hlo_string = R"(
71+
HloModule FusionDemonstration
72+
73+
ENTRY entry {
74+
%param.0 = s2[2,2] parameter(0)
75+
%constant.0 = s2[2,2] constant({ { 0, 1 }, { 1, 0 } })
76+
%add.0 = s2[2,2] add(s2[2,2] %param.0, s2[2,2] %constant.0)
77+
%sub.0 = s2[2,2] subtract(s2[2,2] %param.0, s2[2,2] %constant.0)
78+
ROOT %result = s2[2,2] multiply(s2[2,2] %add.0, s2[2,2] %sub.0)
79+
}
80+
)";
81+
82+
TF_ASSERT_OK_AND_ASSIGN(auto module,
83+
ParseAndReturnVerifiedModule(hlo_string));
84+
TF_ASSERT_OK_AND_ASSIGN(bool changed, XnnGraphFusion().Run(module.get()));
85+
ASSERT_FALSE(changed);
86+
}
87+
7088
} // namespace
7189
} // namespace xla::cpu

third_party/xla/xla/backends/cpu/xnn_emitter.cc

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,6 @@ using TensorIdMap = absl::flat_hash_map<const HloInstruction*, uint32_t>;
4848
// XLA <-> XNNPACK type conversion library.
4949
//===----------------------------------------------------------------------===//
5050

51-
static absl::StatusOr<xnn_datatype> XnnDatatype(const PrimitiveType& type) {
52-
switch (type) {
53-
case BF16:
54-
return xnn_datatype_bf16;
55-
case F16:
56-
return xnn_datatype_fp16;
57-
case F32:
58-
return xnn_datatype_fp32;
59-
default:
60-
return InvalidArgument("Unsupported XNNPACK data type: %s",
61-
primitive_util::LowercasePrimitiveTypeName(type));
62-
}
63-
}
64-
6551
static absl::StatusOr<xnn_unary_operator> XnnUnaryOperator(
6652
const HloOpcode& opcode) {
6753
switch (opcode) {

third_party/xla/xla/backends/cpu/xnn_fusion.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include <cstdint>
2020
#include <utility>
2121

22+
#include "xnnpack.h"
2223
#include "absl/algorithm/container.h"
2324
#include "absl/container/flat_hash_set.h"
2425
#include "absl/status/statusor.h"
@@ -27,9 +28,11 @@ limitations under the License.
2728
#include "xla/hlo/ir/hlo_computation.h"
2829
#include "xla/hlo/ir/hlo_instruction.h"
2930
#include "xla/hlo/ir/hlo_opcode.h"
31+
#include "xla/primitive_util.h"
3032
#include "xla/shape.h"
3133
#include "xla/shape_util.h"
3234
#include "xla/tsl/platform/statusor.h"
35+
#include "xla/util.h"
3336
#include "xla/xla_data.pb.h"
3437

3538
namespace xla::cpu {
@@ -132,4 +135,18 @@ absl::StatusOr<bool> IsXnnDotSupported(
132135
!dot_canonical_dims.rhs_column_major;
133136
}
134137

138+
absl::StatusOr<xnn_datatype> XnnDatatype(const PrimitiveType& type) {
139+
switch (type) {
140+
case BF16:
141+
return xnn_datatype_bf16;
142+
case F16:
143+
return xnn_datatype_fp16;
144+
case F32:
145+
return xnn_datatype_fp32;
146+
default:
147+
return InvalidArgument("Unsupported XNNPACK data type: %s",
148+
primitive_util::LowercasePrimitiveTypeName(type));
149+
}
150+
}
151+
135152
} // namespace xla::cpu

third_party/xla/xla/backends/cpu/xnn_fusion.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#ifndef XLA_BACKENDS_CPU_XNN_FUSION_H_
1717
#define XLA_BACKENDS_CPU_XNN_FUSION_H_
1818

19+
#include "xnnpack.h"
1920
#include "absl/status/statusor.h"
2021
#include "absl/strings/string_view.h"
2122
#include "xla/backends/cpu/codegen/target_machine_features.h"
@@ -41,6 +42,8 @@ absl::StatusOr<bool> IsXnnDotSupported(
4142
const Shape& rhs_shape, const Shape& out_shape,
4243
const TargetMachineFeatures* cpu_features = nullptr);
4344

45+
absl::StatusOr<xnn_datatype> XnnDatatype(const PrimitiveType& type);
46+
4447
} // namespace xla::cpu
4548

4649
#endif // XLA_BACKENDS_CPU_XNN_FUSION_H_

0 commit comments

Comments
 (0)
0