10000 remove old copy of quantization_patterns and rename new one · IBMZ-Linux-OSS-Python/tensorflow@1d6e51e · GitHub
[go: up one dir, main page]

Skip to content

Commit 1d6e51e

Browse files
ecalubaquibtensorflower-gardener
authored andcommitted
remove old copy of quantization_patterns and rename new one
PiperOrigin-RevId: 766422317
1 parent 3e7742d commit 1d6e51e

File tree

6 files changed

+43
-1351
lines changed

6 files changed

+43
-1351
lines changed

tensorflow/compiler/mlir/quantization/stablehlo/BUILD

Lines changed: 4 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ cc_library(
9494
":optimize_graph_inc_gen",
9595
":quantization_config_proto_cc",
9696
":quantization_options_proto_cc",
97+
":quantization_patterns",
9798
":remove_sharding_custom_call_inc_gen",
9899
":stablehlo_type_utils",
99-
":tf_quantization_patterns",
100100
":tf_stablehlo_passes_inc_gen",
101101
"//tensorflow/compiler/mlir/quantization/common:func",
102102
"//tensorflow/compiler/mlir/quantization/common:tf_attrs_and_constraints",
@@ -260,10 +260,10 @@ cc_library(
260260
)
261261

262262
cc_library(
263-
name = "tf_quantization_patterns",
264-
srcs = ["passes/tf_quantization_patterns.cc"],
263+
name = "quantization_patterns",
264+
srcs = ["passes/quantization_patterns.cc"],
265265
hdrs = [
266-
"passes/tf_quantization_patterns.h",
266+
"passes/quantization_patterns.h",
267267
],
268268
compatible_with = get_compatible_with_portable(),
269269
deps = [
@@ -289,37 +289,6 @@ cc_library(
289289
],
290290
)
291291

292-
cc_library(
293-
name = "quantization_patterns",
294-
srcs = ["passes/quantization_patterns.cc"],
295-
hdrs = [
296-
"passes/quantization_patterns.h",
297-
],
298-
compatible_with = get_compatible_with_portable(),
299-
deps = [
300-
":quantization_config_proto_cc",
301-
"//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps",
302-
"//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints",
303-
"//tensorflow/compiler/mlir/quantization/common:lift_as_function_call",
304-
"//tensorflow/compiler/mlir/quantization/common:uniform_quantized_types",
305-
"//tensorflow/compiler/mlir/quantization/common/quantization_lib",
306-
"//tensorflow/compiler/mlir/quantization/stablehlo/ops:stablehlo_op_quant_spec",
307-
"//tensorflow/compiler/mlir/quantization/tensorflow:passes",
308-
"//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc",
309-
"//tensorflow/compiler/mlir/tensorflow",
310-
"//tensorflow/core:protos_all_cc",
311-
"//tensorflow/core/platform:path",
312-
"@com_google_absl//absl/container:flat_hash_set",
313-
"@llvm-project//llvm:Support",
314-
"@llvm-project//mlir:FuncDialect",
315-
"@llvm-project//mlir:IR",
316-
"@llvm-project//mlir:Pass",
317-
"@llvm-project//mlir:QuantOps",
318-
"@llvm-project//mlir:Support",
319-
"@stablehlo//:stablehlo_ops",
320-
],
321-
)
322-
323292
td_library(
324293
name = "quant_td_files",
325294
srcs = [

tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,11 @@ limitations under the License.
4646
#include "mlir/Support/LLVM.h" // from @llvm-project
4747
#include "mlir/Support/LogicalResult.h" // from @llvm-project
4848
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep
49-
#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
50-
#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h"
51-
#include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h"
52-
#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h"
49+
#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h"
50+
#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h"
51+
#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h"
52+
#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h"
5353
#include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h"
54-
#include "tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h"
5554
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"
5655
#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h"
5756
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@@ -62,8 +61,6 @@ namespace mlir::quant::stablehlo {
6261

6362
namespace {
6463

65-
using ::mlir::quant::FindUserOfType;
66-
using ::mlir::quant::TryCast;
6764
using ::mlir::stablehlo::AddOp;
6865
using ::mlir::stablehlo::BroadcastInDimOp;
6966
using ::mlir::stablehlo::ConcatenateOp;
@@ -74,6 +71,16 @@ using ::mlir::stablehlo::GatherOp;
7471
using ::mlir::stablehlo::GetDimensionSizeOp;
7572
using ::mlir::stablehlo::ReshapeOp;
7673
using ::mlir::stablehlo::UniformQuantizeOp;
74+
using ::mlir::tf_quant::FindUserOfType;
75+
using ::mlir::tf_quant::GetDotGeneralQuantizationDim;
76+
using ::mlir::tf_quant::GetQuantizationMethodOrDefault;
77+
using ::mlir::tf_quant::HasWeightOnlyPtqMethod;
78+
using ::mlir::tf_quant::IsHybridQuantizedOp;
79+
using ::mlir::tf_quant::kCompositeFuncPrefix;
80+
using ::mlir::tf_quant::kQuantizationMethodAttr;
81+
using ::mlir::tf_quant::kQuantizedFuncPrefix;
82+
using ::mlir::tf_quant::kQuantTraitAttrName;
83+
using ::mlir::tf_quant::TryCast;
7784
using ::stablehlo::quantization::Method;
7885
using ::stablehlo::quantization::QuantizedDimension;
7986
using ::stablehlo::quantization::QuantizedType;
@@ -724,12 +731,12 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern<TF::XlaCallModuleOp> {
724731
// Quantizes only when the nested region consists of ops whose quantization
725732
// parameters can be propagated from outside.
726733
class QuantizeOpWithRegionPattern
727-
: public OpRewritePattern<quantfork::DequantizeCastOp> {
734+
: public OpRewritePattern<mlir::quant::ir::DequantizeCastOp> {
728735
public:
729736
explicit QuantizeOpWithRegionPattern(MLIRContext& ctx)
730-
: OpRewritePattern<quantfork::DequantizeCastOp>(&ctx) {};
737+
: OpRewritePattern<mlir::quant::ir::DequantizeCastOp>(&ctx) {};
731738

732-
LogicalResult matchAndRewrite(quantfork::DequantizeCastOp op,
739+
LogicalResult matchAndRewrite(mlir::quant::ir::DequantizeCastOp op,
733740
PatternRewriter& rewriter) const final {
734741
if (match(op).failed()) {
735742
return failure();
@@ -739,7 +746,7 @@ class QuantizeOpWithRegionPattern
739746
}
740747

741748
private:
742-
LogicalResult match(quantfork::DequantizeCastOp op) const {
749+
LogicalResult match(mlir::quant::ir::DequantizeCastOp op) const {
743750
// Match only when there is one user of the dequantize op.
744751
if (!op.getResult().hasOneUse()) {
745752
return failure();
@@ -767,7 +774,7 @@ class QuantizeOpWithRegionPattern
767774
return success();
768775
}
769776

770-
void rewrite(quantfork::DequantizeCastOp op,
777+
void rewrite(mlir::quant::ir::DequantizeCastOp op,
771778
PatternRewriter& rewriter) const {
772779
// Rewrite the floating-point ops to the quantized version, by fusing
773780
// preceding dequantize ops and succeding quantize ops.
@@ -785,7 +792,7 @@ class QuantizeOpWithRegionPattern
785792

786793
const Type element_type =
787794
mlir::cast<TensorType>(operand.getType()).getElementType();
788-
if (auto dq_op = dyn_cast_or_null<quantfork::DequantizeCastOp>(
795+
if (auto dq_op = dyn_cast_or_null<mlir::quant::ir::DequantizeCastOp>(
789796
operand.getDefiningOp())) {
790797
inputs.push_back(dq_op.getOperand());
791798
} else if (isa<IntegerType>(element_type)) {
@@ -813,8 +820,9 @@ class QuantizeOpWithRegionPattern
813820
mlir::cast<TensorType>(result.getType()).getElementType();
814821
// If the user is the QuantizeOp, it must be the only user.
815822
if (result.hasOneUse() &&
816-
isa<quantfork::QuantizeCastOp>(*result.user_begin())) {
817-
auto user = cast<quantfork::QuantizeCastOp>(*result.user_begin());
823+
isa<mlir::quant::ir::QuantizeCastOp>(*result.user_begin())) {
824+
auto user =
825+
cast<mlir::quant::ir::QuantizeCastOp>(*result.user_begin());
818826
outputs_replaced.push_back(user.getResult());
819827
output_types.push_back(user.getType());
820828
} else if (isa<IntegerType>(result_element_type)) {
@@ -944,8 +952,8 @@ bool IsQuantizedCompositeFunction(func::CallOp call_op) {
944952

945953
bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op) {
946954
for (const Value operand : same_scale_op->getOperands()) {
947-
auto dq_op =
948-
dyn_cast_or_null<quantfork::DequantizeCastOp>(operand.getDefiningOp());
955+
auto dq_op = dyn_cast_or_null<mlir::quant::ir::DequantizeCastOp>(
956+
operand.getDefiningOp());
949957
if (!dq_op) continue;
950958

951959
Operation* preceding_op = dq_op.getArg().getDefiningOp();
@@ -973,11 +981,11 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op) {
973981
for (const Value result : same_scale_op->getResults()) {
974982
// If the user is the Quantize op, it must be the only user.
975983
if (!result.hasOneUse() ||
976-
!isa<quantfork::QuantizeCastOp>(*result.user_begin())) {
984+
!isa<mlir::quant::ir::QuantizeCastOp>(*result.user_begin())) {
977985
continue;
978986
}
979987

980-
auto q_op = cast<quantfork::QuantizeCastOp>(*result.user_begin());
988+
auto q_op = cast<mlir::quant::ir::QuantizeCastOp>(*result.user_begin());
981989
for (Operation* following_op : q_op->getUsers()) {
982990
// Check whether the following op is a quantized composite function.
983991
if (isa<func::CallOp>(following_op)) {

tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,17 @@ limitations under the License.
4040
#include "mlir/Support/LLVM.h" // from @llvm-project
4141
#include "mlir/Support/LogicalResult.h" // from @llvm-project
4242
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo
43-
#include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h"
44-
#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h"
45-
#include "tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h"
43+
#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h"
44+
#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h"
45+
#include "tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.h"
4646
#include "tensorflow/core/framework/types.pb.h"
4747

4848
namespace mlir::quant::stablehlo {
4949

50+
using ::mlir::tf_quant::IsWeightOnlyQuantizableOp;
51+
using ::mlir::tf_quant::stablehlo::GetStableHloQuantConstraints;
52+
using ::mlir::tf_quant::stablehlo::IsOpQuantizableStableHlo;
53+
5054
// Checks whether an op is connected with a quantized composite function. If
5155
// not, the same-scale op will not be quantized. This decision is based on the
5256
// current assumption that the performance gain of the same-scale op itself

0 commit comments

Comments
 (0)
0