@@ -46,12 +46,11 @@ limitations under the License.
46
46
#include " mlir/Support/LLVM.h" // from @llvm-project
47
47
#include " mlir/Support/LogicalResult.h" // from @llvm-project
48
48
#include " stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep
49
- #include " tensorflow/compiler/mlir/lite/ quantization/ir/QuantOps.h"
50
- #include " tensorflow/compiler/mlir/quantization/common/attrs_and_constraints .h"
51
- #include " tensorflow/compiler/mlir/quantization/common/lift_as_function_call .h"
52
- #include " tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils .h"
49
+ #include " tensorflow/compiler/mlir/quantization/common /ir/QuantOps.h"
50
+ #include " tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints .h"
51
+ #include " tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call .h"
52
+ #include " tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils .h"
53
53
#include " tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h"
54
- #include " tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h"
55
54
#include " tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"
56
55
#include " tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h"
57
56
#include " tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@@ -62,8 +61,6 @@ namespace mlir::quant::stablehlo {
62
61
63
62
namespace {
64
63
65
- using ::mlir::quant::FindUserOfType;
66
- using ::mlir::quant::TryCast;
67
64
using ::mlir::stablehlo::AddOp;
68
65
using ::mlir::stablehlo::BroadcastInDimOp;
69
66
using ::mlir::stablehlo::ConcatenateOp;
@@ -74,6 +71,16 @@ using ::mlir::stablehlo::GatherOp;
74
71
using ::mlir::stablehlo::GetDimensionSizeOp;
75
72
using ::mlir::stablehlo::ReshapeOp;
76
73
using ::mlir::stablehlo::UniformQuantizeOp;
74
+ using ::mlir::tf_quant::FindUserOfType;
75
+ using ::mlir::tf_quant::GetDotGeneralQuantizationDim;
76
+ using ::mlir::tf_quant::GetQuantizationMethodOrDefault;
77
+ using ::mlir::tf_quant::HasWeightOnlyPtqMethod;
78
+ using ::mlir::tf_quant::IsHybridQuantizedOp;
79
+ using ::mlir::tf_quant::kCompositeFuncPrefix ;
80
+ using ::mlir::tf_quant::kQuantizationMethodAttr ;
81
+ using ::mlir::tf_quant::kQuantizedFuncPrefix ;
82
+ using ::mlir::tf_quant::kQuantTraitAttrName ;
83
+ using ::mlir::tf_quant::TryCast;
77
84
using ::stablehlo::quantization::Method;
78
85
using ::stablehlo::quantization::QuantizedDimension;
79
86
using ::stablehlo::quantization::QuantizedType;
@@ -724,12 +731,12 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern<TF::XlaCallModuleOp> {
724
731
// Quantizes only when the nested region consists of ops whose quantization
725
732
// parameters can be propagated from outside.
726
733
class QuantizeOpWithRegionPattern
727
- : public OpRewritePattern<quantfork ::DequantizeCastOp> {
734
+ : public OpRewritePattern<mlir::quant::ir ::DequantizeCastOp> {
728
735
public:
729
736
explicit QuantizeOpWithRegionPattern (MLIRContext& ctx)
730
- : OpRewritePattern<quantfork ::DequantizeCastOp>(&ctx) {};
737
+ : OpRewritePattern<mlir::quant::ir ::DequantizeCastOp>(&ctx) {};
731
738
732
- LogicalResult matchAndRewrite (quantfork ::DequantizeCastOp op,
739
+ LogicalResult matchAndRewrite (mlir::quant::ir ::DequantizeCastOp op,
733
740
PatternRewriter& rewriter) const final {
734
741
if (match (op).failed ()) {
735
742
return failure ();
@@ -739,7 +746,7 @@ class QuantizeOpWithRegionPattern
739
746
}
740
747
741
748
private:
742
- LogicalResult match (quantfork ::DequantizeCastOp op) const {
749
+ LogicalResult match (mlir::quant::ir ::DequantizeCastOp op) const {
743
750
// Match only when there is one user of the dequantize op.
744
751
if (!op.getResult ().hasOneUse ()) {
745
752
return failure ();
@@ -767,7 +774,7 @@ class QuantizeOpWithRegionPattern
767
774
return success ();
768
775
}
769
776
770
- void rewrite (quantfork ::DequantizeCastOp op,
777
+ void rewrite (mlir::quant::ir ::DequantizeCastOp op,
771
778
PatternRewriter& rewriter) const {
772
779
// Rewrite the floating-point ops to the quantized version, by fusing
773
780
// preceding dequantize ops and succeding quantize ops.
@@ -785,7 +792,7 @@ class QuantizeOpWithRegionPattern
785
792
786
793
const Type element_type =
787
794
mlir::cast<TensorType>(operand.getType ()).getElementType ();
788
- if (auto dq_op = dyn_cast_or_null<quantfork ::DequantizeCastOp>(
795
+ if (auto dq_op = dyn_cast_or_null<mlir::quant::ir ::DequantizeCastOp>(
789
796
operand.getDefiningOp ())) {
790
797
inputs.push_back (dq_op.getOperand ());
791
798
} else if (isa<IntegerType>(element_type)) {
@@ -813,8 +820,9 @@ class QuantizeOpWithRegionPattern
813
820
mlir::cast<TensorType>(result.getType ()).getElementType ();
814
821
// If the user is the QuantizeOp, it must be the only user.
815
822
if (result.hasOneUse () &&
816
- isa<quantfork::QuantizeCastOp>(*result.user_begin ())) {
817
- auto user = cast<quantfork::QuantizeCastOp>(*result.user_begin ());
823
+ isa<mlir::quant::ir::QuantizeCastOp>(*result.user_begin ())) {
824
+ auto user =
825
+ cast<mlir::quant::ir::QuantizeCastOp>(*result.user_begin ());
818
826
outputs_replaced.push_back (user.getResult ());
819
827
output_types.push_back (user.getType ());
820
828
} else if (isa<IntegerType>(result_element_type)) {
@@ -944,8 +952,8 @@ bool IsQuantizedCompositeFunction(func::CallOp call_op) {
944
952
945
953
bool IsConnectedWithQuantizedCompsiteFunction (Operation* same_scale_op) {
946
954
for (const Value operand : same_scale_op->getOperands ()) {
947
- auto dq_op =
948
- dyn_cast_or_null<quantfork::DequantizeCastOp>( operand.getDefiningOp ());
955
+ auto dq_op = dyn_cast_or_null<mlir::quant::ir::DequantizeCastOp>(
956
+ operand.getDefiningOp ());
949
957
if (!dq_op) continue ;
950
958
951
959
Operation* preceding_op = dq_op.getArg ().getDefiningOp ();
@@ -973,11 +981,11 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op) {
973
981
for (const Value result : same_scale_op->getResults ()) {
974
982
// If the user is the Quantize op, it must be the only user.
975
983
if (!result.hasOneUse () ||
976
- !isa<quantfork ::QuantizeCastOp>(*result.user_begin ())) {
984
+ !isa<mlir::quant::ir ::QuantizeCastOp>(*result.user_begin ())) {
977
985
continue ;
978
986
}
979
987
980
- auto q_op = cast<quantfork ::QuantizeCastOp>(*result.user_begin ());
988
+ auto q_op = cast<mlir::quant::ir ::QuantizeCastOp>(*result.user_begin ());
981
989
for (Operation* following_op : q_op->getUsers ()) {
982
990
// Check whether the following op is a quantized composite function.
983
991
if (isa<func::CallOp>(following_op)) {
0 commit comments