@@ -143,20 +143,31 @@ def _test_common(
143
143
):
144
144
counters .clear ()
145
145
torch ._dynamo .reset ()
146
+ is_xpu = False
147
+ for input in inputs :
148
+ is_xpu = is_xpu or (input .device .type == "xpu" )
146
149
assert matcher_check_fn is not None or (
147
150
matcher_count is not None and matcher_nodes is not None
148
151
)
149
152
if (
150
- check_autocast == torch .bfloat16
151
- and torch .ops .mkldnn ._is_mkldnn_bf16_supported ()
153
+ (check_autocast == torch .bfloat16
154
+ and torch .ops .mkldnn ._is_mkldnn_bf16_supported ()) or
155
+ is_xpu
152
156
):
153
- maybe_autocast = torch .cpu .amp .autocast (dtype = torch .bfloat16 )
157
+ if is_xpu :
158
+ maybe_autocast = torch .amp .autocast (device_type = "xpu" , dtype = torch .bfloat16 )
159
+ else :
160
+ maybe_autocast = torch .cpu .amp .autocast (dtype = torch .bfloat16 )
154
161
atol , rtol = 1e-2 , 1e-2
155
162
elif (
156
- check_autocast == torch .float16
157
- and torch .ops .mkldnn ._is_mkldnn_fp16_supported ()
163
+ (check_autocast == torch .float16
164
+ and torch .ops .mkldnn ._is_mkldnn_fp16_supported ()) or
165
+ is_xpu
158
166
):
159
- maybe_autocast = torch .cpu .amp .autocast (dtype = torch .float16 )
167
+ if is_xpu :
168
+ maybe_autocast = torch .amp .autocast (device_type = "xpu" , dtype = torch .float16 )
169
+ else :
170
+ maybe_autocast = torch .cpu .amp .autocast (dtype = torch .float16 )
160
171
atol , rtol = 1e-2 , 1e-2
161
172
else :
162
173
assert check_autocast == torch .float32
@@ -196,6 +207,7 @@ def _test_common(
196
207
)
197
208
if matcher_check_fn is not None :
198
209
matcher_check_fn ()
210
+ print ("===== Finish one test =====" )
199
211
200
212
def _test_code_common (
201
213
self ,
@@ -713,11 +725,11 @@ def test_qconv2d_mkldnn(self, device):
713
725
@skipIfNoDynamoSupport
714
726
@skipIfNoONEDNNBF16
715
727
@skipIfNoONEDNN
716
- def test_qconv2d_int8_mixed_bf16 (self ):
728
+ def test_qconv2d_int8_mixed_bf16 (self , device = "cpu" ):
717
729
r"""
718
730
This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization.
719
731
"""
720
- self ._qconv2d_cpu_test_helper (int8_mixed_bf16 = True )
732
+ self ._qconv2d_cpu_test_helper (device = device , int8_mixed_bf16 = True )
721
733
722
734
def _qconv2d_unary_cpu_test_helper (
723
735
self ,
@@ -776,15 +788,15 @@ def test_qconv2d_relu_mkldnn(self, device):
776
788
@skipIfNoDynamoSupport
777
789
@skipIfNoONEDNNBF16
778
790
@skipIfNoONEDNN
779
- def test_qconv2d_relu_int8_mixed_bf16 (self ):
791
+ def test_qconv2d_relu_int8_mixed_bf16 (self , device = "cpu" ):
780
792
r"""
781
793
This testcase will quantize Conv2d->ReLU pattern with int8_mixed_bf16 quantization.
782
794
"""
783
- self ._qconv2d_unary_cpu_test_helper (int8_mixed_bf16 = True )
795
+ self ._qconv2d_unary_cpu_test_helper (device = device , int8_mixed_bf16 = True )
784
796
785
797
@skipIfNoDynamoSupport
786
798
@skipIfNoONEDNN
787
- def test_qconv2d_relu6_cpu (self , device ):
799
+ def test_qconv2d_relu6_mkldnn (self , device ):
788
800
r"""
789
801
This testcase will quantize Conv2d->ReLU6 pattern.
790
802
"""
@@ -801,14 +813,15 @@ def test_qconv2d_hardtanh_mkldnn(self, device):
801
813
@skipIfNoDynamoSupport
802
814
@skipIfNoONEDNNBF16
803
815
@skipIfNoONEDNN
804
- def test_qconv2d_hardtanh_int8_mixed_bf16_cpu (self ):
816
+ def test_qconv2d_hardtanh_int8_mixed_bf16_mkldnn (self , device = "cpu" ):
805
817
r"""
806
818
This testcase will quantize Conv2d->Hardtanh pattern.
807
819
Match.nodes:
808
820
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor]
809
821
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type]
810
822
"""
811
823
self ._qconv2d_unary_cpu_test_helper (
824
+ device = device ,
812
825
unary_op = torch .nn .Hardtanh (),
813
826
int8_mixed_bf16 = True ,
814
827
qconv2d_unary_matcher_nodes = 11 ,
@@ -825,7 +838,7 @@ def test_qconv2d_hardswish_mkldnn(self, device):
825
838
@skipIfNoDynamoSupport
826
839
@skipIfNoONEDNNBF16
827
840
@skipIfNoONEDNN
828
<
F438
code> - def test_qconv2d_hardswish_int8_mixed_bf16_cpu (self ):
841
+ def test_qconv2d_hardswish_int8_mixed_bf16_mkldnn (self , device = "cpu" ):
829
842
r"""
830
843
This testcase will quantize Conv2d->Hardswish pattern.
831
844
Match.nodes:
@@ -834,6 +847,7 @@ def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self):
834
847
[qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type]
835
848
"""
836
849
self ._qconv2d_unary_cpu_test_helper (
850
+ device = device ,
837
851
unary_op = torch .nn .Hardswish (),
838
852
int8_mixed_bf16 = True ,
839
853
qconv2d_unary_matcher_nodes = 17 ,
@@ -850,7 +864,7 @@ def test_qconv2d_silu_mkldnn(self, device):
850
864
@skipIfNoDynamoSupport
851
865
@skipIfNoONEDNNBF16
852
866
@skipIfNoONEDNN
853
- def test_qconv2d_silu_int8_mixed_bf16_cpu (self ):
867
+ def test_qconv2d_silu_int8_mixed_bf16_mkldnn (self , device = "cpu" ):
854
868
r"""
855
869
This testcase will quantize Conv2d->SiLU pattern.
856
870
Match.nodes:
@@ -859,6 +873,7 @@ def test_qconv2d_silu_int8_mixed_bf16_cpu(self):
859
873
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type]
860
874
"""
861
875
self ._qconv2d_unary_cpu_test_helper (
876
+ device = device ,
862
877
unary_op = torch .nn .SiLU (),
863
878
int8_mixed_bf16 = True ,
864
879
qconv2d_unary_matcher_nodes = 11 ,
@@ -941,8 +956,8 @@ def test_qconv2d_add_mkldnn(self, device="cpu"):
941
956
@skipIfNoDynamoSupport
942
957
@skipIfNoONEDNNBF16
943
958
@skipIfNoONEDNN
944
- def test_qconv2d_add_int8_mixed_bf16 (self ):
945
- self ._qconv2d_add_cpu_test_helper (int8_mixed_bf16 = True )
959
+ def test_qconv2d_add_int8_mixed_bf16 (self , device = "cpu" ):
960
+ self ._qconv2d_add_cpu_test_helper (device = device , int8_mixed_bf16 = True )
946
961
947
962
@skipIfNoDynamoSupport
948
963
@skipIfNoONEDNN
@@ -952,8 +967,8 @@ def test_qconv2d_add_relu_mkldnn(self, device="cpu"):
952
967
@skipIfNoDynamoSupport
953
968
@skipIfNoONEDNNBF16
954
969
@skipIfNoONEDNN
955
- def test_qconv2d_add_relu_int8_mixed_bf16 (self ):
956
- self ._qconv2d_add_cpu_test_helper (use_relu = True , int8_mixed_bf16 = True )
970
+ def test_qconv2d_add_relu_int8_mixed_bf16 (self , device = "cpu" ):
971
+ self ._qconv2d_add_cpu_test_helper (device = device , use_relu = True , int8_mixed_bf16 = True )
957
972
958
973
@skipIfNoDynamoSupport
959
974
@skipIfNoONEDNN
@@ -2826,7 +2841,7 @@ def matcher_check_fn():
2826
2841
quantizer = quantizer ,
2827
2842
)
2828
2843
2829
- device_types = ("xpu" )
2844
+ device_types = ("xpu" , "cpu" )
2830
2845
instantiate_device_type_tests (TestPatternMatcher , globals (), only_for = device_types , allow_xpu = True )
2831
2846
2832
2847
if __name__ == "__main__" :
0 commit comments