10000 Remove capture_pre_autograd_graph references in quantization · pytorch/pytorch@1277be9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1277be9

Browse files
yushangdifacebook-github-bot
authored andcommitted
Remove capture_pre_autograd_graph references in quantization
Summary: As title We remove the deprecated API references in code, docs, and tests. We also removed two tests that specific to capture_pre_autograd_graph API. Test Plan: CI Differential Revision: D65351887
1 parent 6fc63b4 commit 1277be9

File tree

4 files changed

+4
-122
lines changed

4 files changed

+4
-122
lines changed

docs/source/quantization.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ API Example::
508508

509509
import torch
510510
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
511-
from torch._export import capture_pre_autograd_graph
511+
from torch.export import export_for_training
512512
from torch.ao.quantization.quantizer import (
513513
XNNPACKQuantizer,
514514
get_symmetric_quantization_config,
@@ -535,7 +535,7 @@ API Example::
535535
# Step 1. program capture
536536
# NOTE: this API will be updated to torch.export API in the future, but the captured
537537
# result should mostly stay the same
538-
m = capture_pre_autograd_graph(m, *example_inputs)
538+
m = export_for_training(m, *example_inputs).module()
539539
# we get a model with aten ops
540540

541541
# Step 2. quantization

test/quantization/pt2e/test_quantize_pt2e_qat.py

Lines changed: 0 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -568,84 +568,6 @@ def forward(self, x):
568568
m = M(self.conv_class)
569569
self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
570570

571-
def test_prepare_qat_conv_bn_fusion_getitem_placeholder(self):
572-
"""
573-
Test the case where the placeholder node for the [conv - bn - getitem] pattern
574-
is also a getitem node:
575-
576-
some_op -> unrelated_getitem -> conv -> bn -> conv_bn_getitem
577-
578-
We want the metadata to be copied from the `conv_bn_getitem` node, not from
579-
the `unrelated_getitem` node, which is not part of the conv-bn pattern but
580-
is returned as part of the match anyway (as a placeholder).
581-
"""
582-
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
583-
584-
# T199018392
585-
# remove this test after we kill capture_pre_autograd_graph()
586-
if capture_pre_autograd_graph_using_training_ir():
587-
self.skipTest("Not applicable to training IR")
588-
589-
class M(torch.nn.Module):
590-
def __init__(self, conv_class, bn_class):
591-
super().__init__()
592-
self.bn1 = bn_class(3)
593-
self.conv = conv_class(3, 3, 3)
594-
self.bn2 = bn_class(3)
595-
596-
def forward(self, x):
597-
x = self.bn1(x)
598-
x = self.conv(x)
599-
x = self.bn2(x)
600-
return x
601-
602-
def _get_getitem_nodes(m: torch.fx.GraphModule):
603-
"""
604-
Return a 2-tuple of (unrelated_getitem_node, conv_bn_getitem_node) from the graph.
605-
"""
606-
unrelated_getitem_node, conv_bn_getitem_node = None, None
607-
for node in m.graph.nodes:
608-
if (
609-
node.target != operator.getitem
610-
or node.args[0].target
611-
!= torch.ops.aten._native_batch_norm_legit.default
612-
):
613-
continue
614-
if node.args[0].args[0].op == "placeholder":
615-
unrelated_getitem_node = node
616-
else:
617-
conv_bn_getitem_node = node
618-
assert (
619-
unrelated_getitem_node is not None
620-
), "did not find unrelated getitem node, bad test setup"
621-
assert (
622-
conv_bn_getitem_node is not None
623-
), "did not find conv bn getitem node, bad test setup"
624-
return (unrelated_getitem_node, conv_bn_getitem_node)
625-
626-
# Program capture
627-
m = M(self.conv_class, self.bn_class)
628-
m = torch._export.capture_pre_autograd_graph(m, self.example_inputs)
629-
m.graph.eliminate_dead_code()
630-
m.recompile()
631-
(_, original_conv_bn_getitem_node) = _get_getitem_nodes(m)
632-
633-
# Prepare QAT
634-
quantizer = XNNPACKQuantizer()
635-
quantizer.set_global(
636-
get_symmetric_quantization_config(is_per_channel=False, is_qat=True)
637-
)
638-
m = prepare_qat_pt2e(m, quantizer)
639-
(unrelated_getitem_node, conv_bn_getitem_node) = _get_getitem_nodes(m)
640-
641-
# Verify that the metadata was copied from `conv_bn_getitem`, not `unrelated_getitem`
642-
original_conv_bn_getitem_meta = original_conv_bn_getitem_node.meta[
643-
"quantization_annotation"
644-
]
645-
conv_bn_getitem_meta = conv_bn_getitem_node.meta["quantization_annotation"]
646-
self.assertEqual(conv_bn_getitem_meta, original_conv_bn_getitem_meta)
647-
self.assertTrue("quantization_annotation" not in unrelated_getitem_node.meta)
648-
649571
def test_qat_update_shared_qspec(self):
650572
"""
651573
Test the case where nodes used in SharedQuantizationSpec were replaced
@@ -926,39 +848,6 @@ def test_fold_bn_erases_bn_node(self):
926848
self.assertTrue(conv_node is not None)
927849
self.assertTrue(bn_node is None)
928850

929-
def test_preserve_capture_pre_autograd_graph_tag(self):
930-
"""
931-
Ensure the capture_pre_autograd_graph_tag node meta is preserved.
932-
TODO: Remove this test after training IR migration.
933-
T199018392
934-
"""
935-
from torch._export import capture_pre_autograd_graph
936-
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
937-
938-
if capture_pre_autograd_graph_using_training_ir():
939-
self.skipTest(
940-
"test doesn't apply when capture_pre_autograd_graph is using training IR"
941-
)
942-
943-
m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False)
944-
m = capture_pre_autograd_graph(m, self.example_inputs)
945-
946-
for node in m.graph.nodes:
947-
self.assertTrue(node.meta.get("capture_pre_autograd_graph_tag", False))
948-
quantizer = XNNPACKQuantizer()
949-
quantizer.set_global(
950-
get_symmetric_quantization_config(is_per_channel=False, is_qat=True),
951-
)
952-
m = prepare_qat_pt2e(m, quantizer)
953-
m = convert_pt2e(m)
954-
has_tag = False
955-
for node in m.graph.nodes:
956-
if not node.meta.get("capture_pre_autograd_graph_tag", False):
957-
has_tag = True
958-
break
959-
self.assertTrue(has_tag)
960-
torch.export.export(m, self.example_inputs)
961-
962851

963852
@skipIfNoQNNPACK
964853
class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base):

torch/ao/quantization/pt2e/utils.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import torch
77
import torch.nn.functional as F
8-
from torch._export import capture_pre_autograd_graph
98

109
# Makes sure that quantized_decomposed ops are registered
1110
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
@@ -381,11 +380,7 @@ def _get_aten_graph_module_for_pattern(
381380
kwargs,
382381
).module()
383382
else:
384-
aten_pattern = capture_pre_autograd_graph(
385-
pattern, # type: ignore[arg-type]
386-
example_inputs,
387-
kwargs,
388-
)
383+
raise RuntimeError("capture_pre_autograd_graph is deprecated and will be deleted soon. Please use torch.export.export_for_training instead.")
389384
aten_pattern.graph.eliminate_dead_cod E85A e()
390385
aten_pattern.recompile()
391386

torch/ao/quantization/quantize_pt2e.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def prepare_pt2e(
3636
3737
Args:
3838
* `model` (torch.fx.GraphModule): a model captured by `torch.export` API
39-
in the short term we are using `torch._export.capture_pre_autograd_graph`,
39+
in the short term we are using `torch.export.export_for_training`,
4040
in the long term we'll migrate to some `torch.export` API
4141
* `quantizer`: A backend specific quantizer that conveys how user want the
4242
model to be quantized. Tutorial for how to write a quantizer can be found here:
@@ -49,7 +49,6 @@ def prepare_pt2e(
4949
5050
import torch
5151
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
52-
from torch._export import capture_pre_autograd_graph
5352
from torch.ao.quantization.quantizer import (
5453
XNNPACKQuantizer,
5554
get_symmetric_quantization_config,
@@ -122,7 +121,6 @@ def prepare_qat_pt2e(
122121
Example::
123122
import torch
124123
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e
125-
from torch._export import capture_pre_autograd_graph
126124
from torch.ao.quantization.quantizer import (
127125
XNNPACKQuantizer,
128126
get_symmetric_quantization_config,

0 commit comments

Comments
 (0)
0