@@ -156,17 +156,24 @@ def _test_common(
156
156
):
157
157
counters .clear ()
158
158
torch ._dynamo .reset ()
159
- if (
160
- check_autocast == torch .bfloat16
161
- and torch .ops .mkldnn ._is_mkldnn_bf16_supported ()
159
+ has_xpu = any (
160
+ isinstance (input , torch .Tensor ) and input .device .type == "xpu"
161
+ for input in inputs
162
+ )
163
+ device_type = "xpu" if has_xpu else "cpu"
164
+ if check_autocast == torch .bfloat16 and (
165
+ torch .ops .mkldnn ._is_mkldnn_bf16_supported () or has_xpu
162
166
):
163
- maybe_autocast = torch .amp .autocast ("cpu" , dtype = torch .bfloat16 )
167
+ maybe_autocast = torch .amp .autocast (
168
+ device_type = device_type , dtype = torch .bfloat16
169
+ )
164
170
atol , rtol = 1e-2 , 1e-2
165
- elif (
166
- check_autocast == torch .float16
167
- and torch .ops .mkldnn ._is_mkldnn_fp16_supported ()
171
+ elif check_autocast == torch .float16 and (
172
+ torch .ops .mkldnn ._is_mkldnn_fp16_supported () or has_xpu
168
173
):
169
- maybe_autocast = torch .amp .autocast ("cpu" , dtype = torch .float16 )
174
+ maybe_autocast = torch .amp .autocast (
175
+ device_type = device_type , dtype = torch .float16
176
+ )
170
177
atol , rtol = 1e-2 , 1e-2
171
178
else :
172
179
assert check_autocast == torch .float32
@@ -1044,6 +1051,16 @@ def test_qconv2d_int8_mixed_bf16(self):
1044
1051
"""
1045
1052
self ._qconv2d_test_helper (int8_mixed_bf16 = True )
1046
1053
1054
+ @skipIfNoDynamoSupport
1055
+ @skipIfNoONEDNNBF16
1056
+ @skipIfNoONEDNN
1057
+ @skipIfNoXPU
1058
+ def test_qconv2d_int8_mixed_bf16_xpu (self ):
1059
+ r"""
1060
+ This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization.
1061
+ """
1062
+ self ._qconv2d_test_helper (device = "xpu" , int8_mixed_bf16 = True )
1063
+
1047
1064
def _qconv2d_unary_test_helper (
1048
1065
self ,
1049
1066
device = "cpu" ,
@@ -1122,7 +1139,7 @@ def test_qconv2d_relu_xpu(self):
1122
1139
@skipIfNoDynamoSupport
1123
1140
@skipIfNoONEDNNBF16
1124
1141
@skipIfNoONEDNN
1125
- def test_qconv2d_relu_int8_mixed_bf16 (self ):
1142
+ def test_qconv2d_relu_int8_mixed_bf16_xpu (self ):
1126
1143
r"""
1127
1144
This testcase will quantize Conv2d->ReLU pattern with int8_mixed_bf16 quantization.
1128
1145
"""
@@ -1178,6 +1195,24 @@ def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self):
1178
1195
qconv2d_unary_matcher_nodes = 11 ,
1179
1196
)
1180
1197
1198
+ @skipIfNoDynamoSupport
1199
+ @skipIfNoONEDNNBF16
1200
+ @skipIfNoONEDNN
1201
+ @skipIfNoXPU
1202
+ def test_qconv2d_hardtanh_int8_mixed_bf16_xpu (self ):
1203
+ r"""
1204
+ This testcase will quantize Conv2d->Hardtanh pattern.
1205
+ Match.nodes:
1206
+ [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor]
1207
+ [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type]
1208
+ """
1209
+ self ._qconv2d_unary_test_helper (
1210
+ device = "xpu" ,
1211
+ unary_op = torch .nn .Hardtanh (),
1212
+ int8_mixed_bf16 = True ,
1213
+ qconv2d_unary_matcher_nodes = 11 ,
1214
+ )
1215
+
1181
1216
@skipIfNoDynamoSupport
1182
1217
@skipIfNoONEDNN
1183
1218
def test_qconv2d_hardswish_cpu (self ):
@@ -1212,6 +1247,25 @@ def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self):
1212
1247
qconv2d_unary_matcher_nodes = 17 ,
1213
1248
)
1214
1249
1250
+ @skipIfNoDynamoSupport
1251
+ @skipIfNoONEDNNBF16
1252
+ @skipIfNoONEDNN
1253
+ @skipIfNoXPU
1254
+ def test_qconv2d_hardswish_int8_mixed_bf16_xpu (self ):
1255
+ r"""
1256
+ This testcase will quantize Conv2d->Hardswish pattern.
1257
+ Match.nodes:
1258
+ [qconv2d_pointwise_default, convert_element_type, add, clamp_min,
1259
+ clamp_max, mul, div, convert_element_type, quantize_per_tensor]
1260
+ [qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type]
1261
+ """
1262
+ self ._qconv2d_unary_test_helper (
1263
+ device = "xpu" ,
1264
+ unary_op = torch .nn .Hardswish (),
1265
+ int8_mixed_bf16 = True ,
1266
+ qconv2d_unary_matcher_nodes = 17 ,
1267
+ )
1268
+
1215
1269
@skipIfNoDynamoSupport
1216
1270
@skipIfNoONEDNN
1217
1271
def test_qconv2d_silu_cpu (self ):
@@ -1246,6 +1300,25 @@ def test_qconv2d_silu_int8_mixed_bf16_cpu(self):
1246
1300
qconv2d_unary_matcher_nodes = 11 ,
1247
1301
)
1248
1302
1303
+ @skipIfNoDynamoSupport
1304
+ @skipIfNoONEDNNBF16
1305
+ @skipIfNoONEDNN
1306
+ @skipIfNoXPU
1307
+ def test_qconv2d_silu_int8_mixed_bf16_xpu (self ):
1308
+ r"""
1309
+ This testcase will quantize Conv2d->SiLU pattern.
1310
+ Match.nodes:
1311
+ [qconv2d_pointwise_default, convert_element_type, sigmoid, mul,
1312
+ convert_element_type, quantize_per_tensor]
1313
+ [qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type]
1314
+ """
1315
+ self ._qconv2d_unary_test_helper (
1316
+ device = "xpu" ,
1317
+ unary_op = torch .nn .SiLU (),
1318
+ int8_mixed_bf16 = True ,
1319
+ qconv2d_unary_matcher_nodes = 11 ,
1320
+ )
1321
+
1249
1322
def _qconv2d_add_test_helper (
1250
1323
self , device = "cpu" , use_relu = False , int8_mixed_bf16 = False
1251
1324
):
@@ -1441,6 +1514,13 @@ def test_qconv2d_add_int8_mixed_bf16(self):
1441
1514
self ._qconv2d_add_test_helper (int8_mixed_bf16 = True )
1442
1515
self ._qconv2d_add_test_helper2 (int8_mixed_bf16 = True )
1443
1516
1517
+ @skipIfNoDynamoSupport
1518
+ @skipIfNoONEDNNBF16
1519
+ @skipIfNoONEDNN
1520
+ @skipIfNoXPU
1521
+ def test_qconv2d_add_int8_mixed_bf16_xpu (self ):
1522
+ self ._qconv2d_add_test_helper (device = "xpu" , int8_mixed_bf16 = True )
1523
+
1444
1524
@skipIfNoDynamoSupport
1445
1525
@skipIfNoONEDNN
1446
1526
def test_qconv2d_add_relu_cpu (self ):
@@ -1461,6 +1541,13 @@ def test_qconv2d_add_relu_int8_mixed_bf16(self):
1461
1541
self ._qconv2d_add_test_helper (use_relu = True , int8_mixed_bf16 = True )
1462
1542
self ._qconv2d_add_test_helper2 (use_relu = True , int8_mixed_bf16 = True )
1463
1543
1544
+ @skipIfNoDynamoSupport
1545
+ @skipIfNoONEDNNBF16
1546
+ @skipIfNoONEDNN
1547
+ @skipIfNoXPU
1548
+ def test_qconv2d_add_relu_int8_mixed_bf16_xpu (self ):
1549
+ self ._qconv2d_add_test_helper (device = "xpu" , use_relu = True , int8_mixed_bf16 = True )
1550
+
1464
1551
@skipIfNoDynamoSupport
1465
1552
@skipIfNoONEDNN
1466
1553
def test_qconv2d_add_broadcast_shapes_cpu (self ):
0 commit comments