@@ -568,84 +568,6 @@ def forward(self, x):
568
568
m = M (self .conv_class )
569
569
self ._verify_symmetric_xnnpack_qat_numerics (m , example_inputs )
570
570
571
- def test_prepare_qat_conv_bn_fusion_getitem_placeholder (self ):
572
- """
573
- Test the case where the placeholder node for the [conv - bn - getitem] pattern
574
- is also a getitem node:
575
-
576
- some_op -> unrelated_getitem -> conv -> bn -> conv_bn_getitem
577
-
578
- We want the metadata to be copied from the `conv_bn_getitem` node, not from
579
- the `unrelated_getitem` node, which is not part of the conv-bn pattern but
580
- is returned as part of the match anyway (as a placeholder).
581
- """
582
- from torch ._utils_internal import capture_pre_autograd_graph_using_training_ir
583
-
584
- # T199018392
585
- # remove this test after we kill capture_pre_autograd_graph()
586
- if capture_pre_autograd_graph_using_training_ir ():
587
- self .skipTest ("Not applicable to training IR" )
588
-
589
- class M (torch .nn .Module ):
590
- def __init__ (self , conv_class , bn_class ):
591
- super ().__init__ ()
592
- self .bn1 = bn_class (3 )
593
- self .conv = conv_class (3 , 3 , 3 )
594
- self .bn2 = bn_class (3 )
595
-
596
- def forward (self , x ):
597
- x = self .bn1 (x )
598
- x = self .conv (x )
599
- x = self .bn2 (x )
600
- return x
601
-
602
- def _get_getitem_nodes (m : torch .fx .GraphModule ):
603
- """
604
- Return a 2-tuple of (unrelated_getitem_node, conv_bn_getitem_node) from the graph.
605
- """
606
- unrelated_getitem_node , conv_bn_getitem_node = None , None
607
- for node in m .graph .nodes :
608
- if (
609
- node .target != operator .getitem
610
- or node .args [0 ].target
611
- != torch .ops .aten ._native_batch_norm_legit .default
612
- ):
613
- continue
614
- if node .args [0 ].args [0 ].op == "placeholder" :
615
- unrelated_getitem_node = node
616
- else :
617
- conv_bn_getitem_node = node
618
- assert (
619
- unrelated_getitem_node is not None
620
- ), "did not find unrelated getitem node, bad test setup"
621
- assert (
622
- conv_bn_getitem_node is not None
623
- ), "did not find conv bn getitem node, bad test setup"
624
- return (unrelated_getitem_node , conv_bn_getitem_node )
625
-
626
- # Program capture
627
- m = M (self .conv_class , self .bn_class )
628
- m = torch ._export .capture_pre_autograd_graph (m , self .example_inputs )
629
- m .graph .eliminate_dead_code ()
630
- m .recompile ()
631
- (_ , original_conv_bn_getitem_node ) = _get_getitem_nodes (m )
632
-
633
- # Prepare QAT
634
- quantizer = XNNPACKQuantizer ()
635
- quantizer .set_global (
636
- get_symmetric_quantization_config (is_per_channel = False , is_qat = True )
637
- )
638
- m = prepare_qat_pt2e (m , quantizer )
639
- (unrelated_getitem_node , conv_bn_getitem_node ) = _get_getitem_nodes (m )
640
-
641
- # Verify that the metadata was copied from `conv_bn_getitem`, not `unrelated_getitem`
642
- original_conv_bn_getitem_meta = original_conv_bn_getitem_node .meta [
643
- "quantization_annotation"
644
- ]
645
- conv_bn_getitem_meta = conv_bn_getitem_node .meta ["quantization_annotation" ]
646
- self .assertEqual (conv_bn_getitem_meta , original_conv_bn_getitem_meta )
647
- self .assertTrue ("quantization_annotation" not in unrelated_getitem_node .meta )
648
-
649
571
def test_qat_update_shared_qspec (self ):
650
572
"""
651
573
Test the case where nodes used in SharedQuantizationSpec were replaced
@@ -926,39 +848,6 @@ def test_fold_bn_erases_bn_node(self):
926
848
self .assertTrue (conv_node is not None )
927
849
self .assertTrue (bn_node is None )
928
850
929
- def test_preserve_capture_pre_autograd_graph_tag (self ):
930
- """
931
- Ensure the capture_pre_autograd_graph_tag node meta is preserved.
932
- TODO: Remove this test after training IR migration.
933
- T199018392
934
- """
935
- from torch ._export import capture_pre_autograd_graph
936
- from torch ._utils_internal import capture_pre_autograd_graph_using_training_ir
937
-
938
- if capture_pre_autograd_graph_using_training_ir ():
939
- self .skipTest (
940
- "test doesn't apply when capture_pre_autograd_graph is using training IR"
941
- )
942
-
943
- m = self ._get_conv_bn_model (has_conv_bias = False , has_bn = True , has_relu = False )
944
- m = capture_pre_autograd_graph (m , self .example_inputs )
945
-
946
- for node in m .graph .nodes :
947
- self .assertTrue (node .meta .get ("capture_pre_autograd_graph_tag" , False ))
948
- quantizer = XNNPACKQuantizer ()
949
- quantizer .set_global (
950
- get_symmetric_quantization_config (is_per_channel = False , is_qat = True ),
951
- )
952
- m = prepare_qat_pt2e (m , quantizer )
953
- m = convert_pt2e (m )
954
- has_tag = False
955
- for node in m .graph .nodes :
956
- if not node .meta .get ("capture_pre_autograd_graph_tag" , False ):
957
- has_tag = True
958
- break
959
- self .assertTrue (has_tag )
960
- torch .export .export (m , self .example_inputs )
961
-
962
851
963
852
@skipIfNoQNNPACK
964
853
class TestQuantizePT2EQAT_ConvBn1d (TestQuantizePT2EQAT_ConvBn_Base ):
0 commit comments