8000 cleanup for instrumentation:save_report · IBMZ-Linux-OSS-Python/tensorflow@eb55ae0 · GitHub
[go: up one dir, main page]

Skip to content

Commit eb55ae0

Browse files
cleanup for instrumentation:save_report
PiperOrigin-RevId: 766314696
1 parent 034dd63 commit eb55ae0

File tree

9 files changed

+22
-390
lines changed

9 files changed

+22
-390
lines changed

tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ cc_library(
421421
":config",
422422
":tf_pass_pipeline",
423423
"//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc",
424-
"//tensorflow/compiler/mlir/quantization/stablehlo/instrumentations:tf_save_report",
424+
"//tensorflow/compiler/mlir/quantization/stablehlo/instrumentations:save_report",
425425
"//tensorflow/compiler/mlir/quantization/tensorflow:tf_passes",
426426
"//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes",
427427
"@com_google_absl//absl/base:nullability",
@@ -483,7 +483,7 @@ cc_library(
483483
":tf_pass_pipeline",
484484
":types",
485485
"//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc",
486-
"//tensorflow/compiler/mlir/quantization/stablehlo/instrumentations:tf_save_report",
486+
"//tensorflow/compiler/mlir/quantization/stablehlo/instrumentations:save_report",
487487
"//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc",
488488
"//tensorflow/compiler/mlir/quantization/tensorflow:tf_passes",
489489
"//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes",

tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ limitations under the License.
2323
#include "mlir/Pass/PassManager.h" // from @llvm-project
2424
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h"
2525
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pass_pipeline.h"
26-
#include "tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/tf_save_report.h"
26+
#include "tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.h"
2727
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"
2828
#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h"
2929
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
@@ -36,7 +36,6 @@ using ::stablehlo::quantization::PipelineConfig;
3636
using ::stablehlo::quantization::QuantizationConfig;
3737
using ::stablehlo::quantization::QuantizationSpecs;
3838
using ::tensorflow::quantization::RunPasses;
39-
using tf_quant::stablehlo::SaveQuantizationReportInstrumentation;
4039

4140
PostCalibrationComponent::PostCalibrationComponent(
4241
MLIRContext* absl_nonnull ctx)

tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ limitations under the License.
3434
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h"
3535
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pass_pipeline.h"
3636
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h"
37-
#include "tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/tf_save_report.h"
37+
#include "tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.h"
3838
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"
3939
#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h"
4040
#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h"
@@ -47,7 +47,6 @@ limitations under the License.
4747
namespace mlir::quant::stablehlo {
4848

4949
using ::mlir::tf_quant::stablehlo::AddWeightOnlyQuantizationPasses;
50-
using ::mlir::tf_quant::stablehlo::SaveQuantizationReportInstrumentation;
5150
using ::stablehlo::quantization::GetReportFilePath;
5251
using ::stablehlo::quantization::QuantizationConfig;
5352
using ::tensorflow::SignatureDef;

tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/BUILD

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -9,50 +9,13 @@ package(
99
licenses = ["notice"],
1010
)
1111

12-
cc_library(
13-
name = "tf_save_report",
14-
srcs = ["tf_save_report.cc"],
15-
hdrs = ["tf_save_report.h"],
16-
compatible_with = get_compatible_with_portable(),
17-
deps = [
18-
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:tf_report",
19-
"@com_google_absl//absl/base:nullability",
20-
"@com_google_absl//absl/log",
21-
"@com_google_absl//absl/strings:string_view",
22-
"@llvm-project//mlir:IR",
23-
"@llvm-project//mlir:Pass",
24-
"@llvm-project//mlir:Support",
25-
],
26-
)
27-
28-
tf_cc_test(
29-
name = "tf_save_report_test",
30-
srcs = ["tf_save_report_test.cc"],
31-
deps = [
32-
":tf_save_report",
33-
"//tensorflow/compiler/mlir/quantization/common:tf_test_base",
34-
"//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc",
35-
"//tensorflow/compiler/mlir/quantization/stablehlo:tf_passes",
36-
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:io",
37-
"@com_google_absl//absl/status",
38-
"@com_google_absl//absl/status:statusor",
39-
"@com_google_absl//absl/strings",
40-
"@com_google_googletest//:gtest_main",
41-
"@llvm-project//mlir:IR",
42-
"@llvm-project//mlir:Pass",
43-
"@llvm-project//mlir:Support",
44-
"@local_tsl//tsl/platform:protobuf",
45-
"@local_xla//xla/tsl/platform:status_matchers",
46-
],
47-
)
48-
4912
cc_library(
5013
name = "save_report",
5114
srcs = ["save_report.cc"],
5215
hdrs = ["save_report.h"],
5316
compatible_with = get_compatible_with_portable(),
5417
deps = [
55-
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:report",
18+
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:tf_report",
5619
"@com_google_absl//absl/base:nullability",
5720
"@com_google_absl//absl/log",
5821
"@com_google_absl//absl/strings:string_view",
@@ -67,9 +30,9 @@ tf_cc_test(
6730
srcs = ["save_report_test.cc"],
6831
deps = [
6932
":save_report",
70-
"//tensorflow/compiler/mlir/quantization/common:test_base",
71-
"//tensorflow/compiler/mlir/quantization/stablehlo:passes",
33+
"//tensorflow/compiler/mlir/quantization/common:tf_test_base",
7234
"//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc",
35+
"//tensorflow/compiler/mlir/quantization/stablehlo:tf_passes",
7336
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:io",
7437
"@com_google_absl//absl/status",
7538
"@com_google_absl//absl/status:statusor",

tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ limitations under the License.
2424
#include "mlir/IR/Operation.h" // from @llvm-project
2525
#include "mlir/Pass/Pass.h" // from @llvm-project
2626
#include "mlir/Support/LLVM.h" // from @llvm-project
27-
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h"
27+
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_report.h"
2828

2929
namespace mlir::quant::stablehlo {
3030
namespace {
@@ -37,14 +37,16 @@ std::optional<std::string> OptionalStringViewToOptionalString(
3737
return std::make_optional<std::string>(*view);
3838
}
3939

40+
using tf_quant::stablehlo::QuantizationReport;
41+
4042
// Whether the pass is `QuantizeCompositeFunctionPass`.
4143
bool IsQuantizeCompositeFunctionPass(Pass* absl_nullable pass,
4244
Operation* absl_nullable op) {
4345
// It is known that `op` is `ModuleOp` when `pass` is
4446
// `QuantizeCompositeFunctionPass`, but the check is still performed to be
4547
// defensive.
4648
return pass != nullptr &&
47-
pass->getArgument() == "stablehlo-quantize-composite-functions" &&
49+
pass->getArgument() == "tf-stablehlo-quantize-composite-functions" &&
4850
isa_and_nonnull<ModuleOp>(op);
4951
}
5052

tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report_test.cc

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ limitations under the License.
2828
#include "mlir/IR/OwningOpRef.h" // from @llvm-project
2929
#include "mlir/Pass/PassManager.h" // from @llvm-project
3030
#include "mlir/Support/LogicalResult.h" // from @llvm-project
31-
#include "tensorflow/compiler/mlir/quantization/common/test_base.h"
31+
#include "tensorflow/compiler/mlir/quantization/common/tf_test_base.h"
3232
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h"
33-
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h"
33+
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h"
3434
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"
3535
#include "xla/tsl/platform/status_matchers.h"
3636
#include "tsl/platform/protobuf.h" // IWYU pragma: keep
@@ -42,6 +42,9 @@ using ::stablehlo::quantization::QuantizationResults;
4242
using ::stablehlo::quantization::io::ReadFileToString;
4343
using ::testing::SizeIs;
4444
using ::testing::StrEq;
45+
using tf_quant::QuantizationTestBase;
46+
using tf_quant::stablehlo::createPrepareQuantizePass;
47+
using tf_quant::stablehlo::QuantizeCompositeFunctionsPassOptions;
4548
using ::tsl::protobuf::TextFormat;
4649
using ::tsl::testing::IsOk;
4750
using ::tsl::testing::StatusIs;
@@ -52,9 +55,9 @@ TEST_F(SaveQuantizationReportInstrumentationTest, SaveReport) {
5255
constexpr absl::string_view kModuleWithCompositeDotGeneral = R"mlir(
5356
func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> {
5457
%cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
55-
%0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
58+
%0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
5659
%1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
57-
%2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32>
60+
%2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32>
5861
return %2 : tensor<1x3xf32>
5962
}
6063
@@ -111,9 +114,9 @@ TEST_F(SaveQuantizationReportInstrumentationTest,
111114
constexpr absl::string_view kModuleWithCompositeDotGeneral = R"mlir(
112115
func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> {
113116
%cst = "stablehlo.constant"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
114-
%0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
117+
%0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
115118
%1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @c D305 omposite_dot_general_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
116-
%2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32>
119+
%2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32>
117120
return %2 : tensor<1x3xf32>
118121
}
119122
@@ -154,9 +157,9 @@ TEST_F(SaveQuantizationReportInstrumentationTest,
154157
constexpr absl::string_view kModuleWithCompositeDotGeneral = R"mlir(
155158
func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> {
156159
%cst = "stablehlo.constant"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
157-
%0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
160+
%0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
158161
%1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
159-
%2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32>
162+
%2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32>
160163
return %2 : tensor<1x3xf32>
161164
}
162165

tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/tf_save_report.cc

Lines changed: 0 additions & 95 deletions
This file was deleted.

tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/tf_save_report.h

Lines changed: 0 additions & 52 deletions
This file was deleted.

0 commit comments

Comments
 (0)
0