8000 [quant][pt2e][bc-breaking] Set `fold_quantize` to True in `convert_pt… · pytorch/pytorch@d92de5f · GitHub
[go: up one dir, main page]

Skip to content

Commit d92de5f

Browse files
jerryzh168facebook-github-bot
authored andcommitted
[quant][pt2e][bc-breaking] Set fold_quantize to True in convert_pt2e (#119425)
Summary: X-link: pytorch/executorch#1882 This is a follow up to #118605 to set `fold_quantize` flag to True in `convert_pt2e` Test Plan: CI Reviewed By: digantdesai Differential Revision: D53550237
1 parent 91f0381 commit d92de5f

11 files changed

+29
-36
lines changed

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def _generate_qdq_quantized_model(self, mod, inputs, is_qat=False):
106106
else prepare_pt2e(export_model, quantizer)
107107
)
108108
prepare_model(*inputs)
109-
convert_model = convert_pt2e(prepare_model, fold_quantize=True)
109+
convert_model = convert_pt2e(prepare_model)
110110
torch.ao.quantization.move_exported_model_to_eval(convert_model)
111111
return convert_model
112112

test/quantization/pt2e/test_duplicate_dq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _test_duplicate_dq(
110110
m = prepare_pt2e(m, quantizer)
111111
# Calibrate
112112
m(*example_inputs)
113-
m = convert_pt2e(m, fold_quantize=True)
113+
m = convert_pt2e(m)
114114

115115
pt2_quant_output = m(*example_inputs)
116116
for n in m.graph.nodes:

test/quantization/pt2e/test_generate_numeric_debug_handle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,6 @@ def test_quantize_pt2e_preserve_handle(self):
9494
debug_handle_map = _extract_conv2d_pattern_debug_handle_map(m)
9595
self.assertEqual(debug_handle_map, debug_handle_map_ref)
9696
m(*example_inputs)
97-
m = convert_pt2e(m, fold_quantize=True)
97+
m = convert_pt2e(m)
9898
debug_handle_map = _extract_conv2d_pattern_debug_handle_map(m)
9999
self.assertEqual(debug_handle_map, debug_handle_map_ref)

test/quantization/pt2e/test_metadata_porting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _test_metadata_porting(
110110
m = prepare_pt2e(m, quantizer)
111111
# Calibrate
112112
m(*example_inputs)
113-
m = convert_pt2e(m, fold_quantize=True)
113+
m = convert_pt2e(m)
114114

115115
pt2_quant_output = m(*example_inputs)
116116
recorded_node_tags = {}

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,7 @@ def validate(self, model: torch.fx.GraphModule) -> None:
676676
assert conv_output_obs[0] == conv_output_obs[1]
677677

678678
m(*example_inputs)
679-
m = convert_pt2e(m, fold_quantize=True)
679+
m = convert_pt2e(m)
680680

681681
node_occurrence = {
682682
# two for input of the first conv, one for output for the first conv
@@ -739,7 +739,7 @@ def _test_transitive_sharing_with_cat_helper(self, quantizer):
739739
assert conv_output_obs[0] == conv_output_obs[1]
740740

741741
m(*example_inputs)
742-
m = convert_pt2e(m, fold_quantize=True)
742+
m = convert_pt2e(m)
743743

744744
node_occurrence = {
745745
# two for input of the first conv, one for output for the first conv
@@ -1202,7 +1202,7 @@ def forward(self, x):
12021202

12031203
m = prepare_pt2e(m, quantizer)
12041204
m(*example_inputs)
1205-
m = convert_pt2e(m, fold_quantize=True)
1205+
m = convert_pt2e(m)
12061206

12071207
for n in m.graph.nodes:
12081208
if n.op == "get_attr" and "frozen_param" in n.target:
@@ -1619,7 +1619,7 @@ def test_disallow_eval_train(self):
16191619
m.train()
16201620

16211621
# After convert: still not OK
1622-
m = convert_pt2e(m, fold_quantize=True)
1622+
m = convert_pt2e(m)
16231623
with self.assertRaises(NotImplementedError):
16241624
m.eval()
16251625
with self.assertRaises(NotImplementedError):
@@ -1706,12 +1706,12 @@ def test_reentrant(self):
17061706
m.conv_bn_relu = capture_pre_autograd_graph(m.conv_bn_relu, example_inputs)
17071707
m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer)
17081708
m(*example_inputs)
1709-
m.conv_bn_relu = convert_pt2e(m.conv_bn_relu, fold_quantize=True)
1709+
m.conv_bn_relu = convert_pt2e(m.conv_bn_relu)
17101710

17111711
quantizer = XNNPACKQuantizer().set_module_type(torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False))
17121712
m = capture_pre_autograd_graph(m, example_inputs)
17131713
m = prepare_pt2e(m, quantizer)
1714-
m = convert_pt2e(m, fold_quantize=True)
1714+
m = convert_pt2e(m)
17151715

17161716
node_occurrence = {
17171717
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 4,

test/quantization/pt2e/test_quantize_pt2e_qat.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def _verify_symmetric_xnnpack_qat_numerics_helper(
161161
if verify_convert:
162162
# We don't want to impose any ordering requirements between move_exported_model_to_eval and convert_pt2e
163163
torch.ao.quantization.move_exported_model_to_eval(model_pt2e)
164-
model_pt2e = convert_pt2e(model_pt2e, fold_quantize=True)
164+
model_pt2e = convert_pt2e(model_pt2e)
165165
quant_result_pt2e = model_pt2e(*example_inputs)
166166
model_fx.eval()
167167
model_fx = _convert_to_reference_decomposed_fx(
@@ -631,7 +631,7 @@ def forward(self, x):
631631
m = capture_pre_autograd_graph(m, example_inputs)
632632
m = prepare_qat_pt2e(m, quantizer)
633633
m(*example_inputs)
634-
m = convert_pt2e(m, fold_quantize=True)
634+
m = convert_pt2e(m)
635635

636636
# Extract the conv and relu nodes (bn was folded into conv)
637637
first_conv, first_relu, second_conv, second_relu = None, None, None, None
@@ -690,7 +690,7 @@ def test_qat_conv_bn_bias_derived_qspec(self):
690690
quantizer = ConvBnDerivedBiasQuantizer()
691691
m = prepare_qat_pt2e(m, quantizer)
692692
m(*example_inputs)
693-
m = convert_pt2e(m, fold_quantize=True)
693+
m = convert_pt2e(m)
694694
m(*example_inputs)
695695

696696
# Assert that both weight and bias are quantized
@@ -737,7 +737,7 @@ def test_qat_per_channel_weight_custom_dtype(self):
737737
quantizer = ConvBnInt32WeightQuantizer()
738738
m = prepare_qat_pt2e(m, quantizer)
739739
m(*example_inputs)
740-
m = convert_pt2e(m, fold_quantize=True)
740+
m = convert_pt2e(m)
741741
m(*example_inputs)
742742

743743
# Assert that conv weight is quantized per channel
@@ -972,7 +972,7 @@ def _convert_qat_linears(self, model):
972972
for name, child in model.named_children():
973973
if isinstance(child, torch.fx.GraphModule):
974974
torch.ao.quantization.move_exported_model_to_eval(child)
975-
converted_child = convert_pt2e(child, fold_quantize=True)
975+
converted_child = convert_pt2e(child)
976976
setattr(model, name, converted_child)
977977
else:
978978
self._convert_qat_linears(child)
@@ -999,7 +999,7 @@ def test_mixing_qat_ptq(self):
999999
quantizer.set_global(quantization_config)
10001000
model_pt2e = prepare_pt2e(model_pt2e, quantizer)
10011001
after_prepare_result_pt2e = model_pt2e(*example_inputs)
1002-
model_pt2e = convert_pt2e(model_pt2e, fold_quantize=True)
1002+
model_pt2e = convert_pt2e(model_pt2e)
10031003
quant_result_pt2e = model_pt2e(*example_inputs)
10041004

10051005
exported_model = torch.export.export(model_pt2e, example_inputs)

test/quantization/pt2e/test_representation.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,7 @@ def _test_representation(
4242
model = prepare_pt2e(model, quantizer)
4343
# Calibrate
4444
model(*example_inputs)
45-
model = convert_pt2e(
46-
model, use_reference_representation=True, fold_quantize=True
47-
)
45+
model = convert_pt2e(model, use_reference_representation=True)
4846
self.checkGraphModuleNodes(model, expected_node_occurrence=ref_node_occurrence)
4947
# make sure it runs
5048
pt2e_quant_output = model(*example_inputs)
@@ -54,9 +52,7 @@ def _test_representation(
5452
model_copy = prepare_pt2e(model_copy, quantizer)
5553
# Calibrate
5654
model_copy(*example_inputs)
57-
model_copy = convert_pt2e(
58-
model_copy, use_reference_representation=False, fold_quantize=True
59-
)
55+
model_copy = convert_pt2e(model_copy, use_reference_representation=False)
6056
self.checkGraphModuleNodes(
6157
model_copy, expected_node_occurrence=non_ref_node_occurrence
6258
)

test/quantization/pt2e/test_x86inductor_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def _test_quantizer(
326326
# Calibrate
327327
m(*example_inputs)
328328
prepare_model = copy.deepcopy(m)
329-
m = convert_pt2e(m, fold_quantize=True)
329+
m = convert_pt2e(m)
330330
convert_model = copy.deepcopy(m)
331331
pt2_quant_output = m(*example_inputs)
332332
node_occurrence = {

test/quantization/pt2e/test_xnnpack_quantizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def test_propagate_annotation(self):
472472
output_act = getattr(m, next(iter(n.users)).target)
473473
self.assertIs(input_act, output_act)
474474

475-
m = convert_pt2e(m, fold_quantize=True)
475+
m = convert_pt2e(m)
476476
node_occurrence = {
477477
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
478478
ns.call_function(
@@ -723,7 +723,7 @@ def forward(self, input_tensor, hidden_tensor):
723723
quantizer.set_global(quantization_config)
724724
model_graph = prepare_pt2e(model_graph, quantizer)
725725
model_graph(*example_inputs)
726-
model_graph = convert_pt2e(model_graph, fold_quantize=True)
726+
model_graph = convert_pt2e(model_graph)
727727
self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs))
728728

729729
def test_linear_gru(self):
@@ -787,7 +787,7 @@ def forward(self, input_tensor, hidden_tensor):
787787
quantizer.set_global(quantization_config)
788788
model_graph = prepare_pt2e(model_graph, quantizer)
789789
model_graph(*example_inputs)
790-
model_graph = convert_pt2e(model_graph, fold_quantize=True)
790+
model_graph = convert_pt2e(model_graph)
791791
self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs))
792792

793793
def test_add_and_inplace_add(self):
@@ -968,7 +968,7 @@ def test_resnet18(self):
968968
id(m.activation_post_process_3), id(m.activation_post_process_2)
969969
)
970970
after_prepare_result = m(*example_inputs)
971-
m = convert_pt2e(m, fold_quantize=True)
971+
m = convert_pt2e(m)
972972

973973
after_quant_result = m(*example_inputs)
974974

torch/ao/quantization/quantize_pt2e.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,18 +201,14 @@ def _quant_node_constraint(n: Node) -> bool:
201201
def convert_pt2e(
202202
model: GraphModule,
203203
use_reference_representation: bool = False,
204-
fold_quantize: bool = False,
204+
fold_quantize: bool = True,
205205
) -> GraphModule:
206206
"""Convert a calibrated/trained model to a quantized model
207207
208208
Args:
209209
* `model` (torch.fx.GraphModule): calibrated/trained model
210210
* `use_reference_representation` (bool): boolean flag to indicate whether to produce referece representation or not
211-
* `fold_quantize` (bool): boolean flag to indicate whether fold the quantize op or not
212-
213-
Note: please set `fold_quantize` to True whenever you can, we'll deprecate this flag and
214-
make True the default option in the future, to make sure the change doesn't break BC for you, it's
215-
better to set the flag to True now.
211+
* `fold_quantize` (bool): boolean flag for whether fold the quantize op or not
216212
217213
Returns:
218214
quantized model, either in q/dq representation or reference representation
@@ -243,7 +239,8 @@ def convert_pt2e(
243239
pm = PassManager([PortNodeMetaForQDQ()])
244240
model = pm(model).graph_module
245241

246-
constant_fold(model, _quant_node_constraint)
242+
if fold_quantize:
243+
constant_fold(model, _quant_node_constraint)
247244

248245
if use_reference_representation:
249246
model = reference_representation_rewrite(model)

0 commit comments

Comments
 (0)
0