@@ -28,9 +28,9 @@ limitations under the License.
28
28
#include " mlir/IR/OwningOpRef.h" // from @llvm-project
29
29
#include " mlir/Pass/PassManager.h" // from @llvm-project
30
30
#include " mlir/Support/LogicalResult.h" // from @llvm-project
31
- #include " tensorflow/compiler/mlir/quantization/common/test_base .h"
31
+ #include " tensorflow/compiler/mlir/quantization/common/tf_test_base .h"
32
32
#include " tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h"
33
- #include " tensorflow/compiler/mlir/quantization/stablehlo/passes/passes .h"
33
+ #include " tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes .h"
34
34
#include " tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"
35
35
#include " xla/tsl/platform/status_matchers.h"
36
36
#include " tsl/platform/protobuf.h" // IWYU pragma: keep
@@ -42,6 +42,9 @@ using ::stablehlo::quantization::QuantizationResults;
42
42
using ::stablehlo::quantization::io::ReadFileToString;
43
43
using ::testing::SizeIs;
44
44
using ::testing::StrEq;
45
+ using tf_quant::QuantizationTestBase;
46
+ using tf_quant::stablehlo::createPrepareQuantizePass;
47
+ using tf_quant::stablehlo::QuantizeCompositeFunctionsPassOptions;
45
48
using ::tsl::protobuf::TextFormat;
46
49
using ::tsl::testing::IsOk;
47
50
using ::tsl::testing::StatusIs;
@@ -52,9 +55,9 @@ TEST_F(SaveQuantizationReportInstrumentationTest, SaveReport) {
52
55
constexpr absl::string_view kModuleWithCompositeDotGeneral = R"mlir(
53
56
func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> {
54
57
%cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
55
- %0 = "quantfork .stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
58
+ %0 = "quantization .stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
56
59
%1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
57
- %2 = "quantfork .stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32>
60
+ %2 = "quantization .stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32>
58
61
return %2 : tensor<1x3xf32>
59
62
}
60
63
@@ -111,9 +114,9 @@ TEST_F(SaveQuantizationReportInstrumentationTest,
111
114
constexpr absl::string_view kModuleWithCompositeDotGeneral = R"mlir(
112
115
func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> {
113
116
%cst = "stablehlo.constant"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
114
- %0 = "quantfork .stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
117
+ %0 = "quantization .stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
115
118
%1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @c
D305
omposite_dot_general_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
116
- %2 = "quantfork .stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32>
119
+ %2 = "quantization .stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32>
117
120
return %2 : tensor<1x3xf32>
118
121
}
119
122
@@ -154,9 +157,9 @@ TEST_F(SaveQuantizationReportInstrumentationTest,
154
157
constexpr absl::string_view kModuleWithCompositeDotGeneral = R"mlir(
155
158
func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> {
156
159
%cst = "stablehlo.constant"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
157
- %0 = "quantfork .stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
160
+ %0 = "quantization .stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
158
161
%1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
159
- %2 = "quantfork .stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32>
162
+ %2 = "quantization .stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32>
160
163
return %2 : tensor<1x3xf32>
161
164
}
162
165
0 commit comments