10000 [Intel GPU] qconv.pointwise with mixed dtype XPU support (#135465) · pytorch/pytorch@075b91b · GitHub
[go: up one dir, main page]

Skip to content

Commit 075b91b

Browse files
ZhiweiYan-96guangyey
authored andcommitted
[Intel GPU] qconv.pointwise with mixed dtype XPU support (#135465)
# Motivation This PR is aimed to add mixed data type(AMP) support for `qconv_pointwise` op. With current PR, we allow `qconv` kernels output Tensor that is BF16, rather than FP32/INT8. # UT verification ```bash DNNL_VERBOSE=1 python test/inductor/test_mkldnn_pattern_matcher.py -v \ -k test_qconv2d_int8_mixed_bf16_xpu \ -k test_qconv2d_relu_int8_mixed_bf16_xpu \ -k test_qconv2d_hardtanh_int8_mixed_bf16_xpu \ -k test_qconv2d_hardswish_int8_mixed_bf16_xpu \ -k test_qconv2d_silu_int8_mixed_bf16_xpu \ -k test_qconv2d_add_int8_mixed_bf16_xpu \ -k test_qconv2d_add_relu_int8_mixed_bf16_xpu ``` # Runtime verification ```bash #qconv + bf16 onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,forward_training,src_s8::blocked:acdb::f0 wei_s8::blocked:abcd::f0 bia_f32::blocked:a::f0 dst_bf16::blocked:acdb::f0,attr-scratchpad:user attr-scales:src0:0:f32+dst:0:f32+wei:1:f32 attr-zero-points:src0:0:s32,alg:convolution_direct,mb1_ic128oc128_ih6oh4kh3sh1dh0ph0_iw6ow4kw3sw1dw0pw0,0.0539551 # qconv_silu + bf16 onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,forward_training,src_s8::blocked:acdb::f0 wei_s8::blocked:abcd::f0 bia_undef::undef::: dst_bf16::blocked:acdb::f0,attr-scratchpad:user attr-scales:src0:0:f32+dst:0:f32+wei:1:f32 attr-zero-points:src0:0:s32 attr-post-ops:eltwise_swish:1,alg:convolution_direct,mb1_ic128oc128_ih6oh4kh3sh1dh0ph0_iw6ow4kw3sw1dw0pw0,0.0588379 # qconv_hardswish + bf16 onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,forward_training,src_s8::blocked:acdb::f0 wei_s8::blocked:abcd::f0 bia_undef::undef::: dst_bf16::blocked:acdb::f0,attr-scratchpad:user attr-scales:src0:0:f32+dst:0:f32+wei:1:f32 attr-zero-points:src0:0:s32 attr-post-ops:eltwise_hardswish:0.166667:0.5,alg:convolution_direct,mb1_ic128oc128_ih6oh4kh3sh1dh0ph0_iw6ow4kw3sw1dw0pw0,0.0568848 ``` The `dst_bf16::blocked:acdb::f0` attribute in oneDNN verbose demonstrate the output tensor is computed as bf16 successfully. Pull Request resolved: #135465 Approved by: https://github.com/liangan1, https://github.com/EikanWang, https://github.com/guangyey, https://github.com/desertfire, https://github.com/jerryzh168 ghstack dependencies: #133307, #135189, #135337 Co-authored-by: guangyey <guangye.yu@intel.com>
1 parent ffa19b9 commit 075b91b

File tree

2 files changed

+114
-15
lines changed

2 files changed

+114
-15
lines changed

aten/src/ATen/native/mkldnn/xpu/qconv.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
11
#include <ATen/core/op_registration/op_registration.h>
22
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
33
#include <c10/core/MemoryFormat.h>
4+
#include <c10/core/ScalarType.h>
45
#include <torch/library.h>
56

6-
#include <iostream>
7-
87
using namespace at::native::onednn;
98
namespace at::native::xpu {
109

10+
static inline c10::ScalarType qconv_decide_out_dtype(
11+
const at::Tensor& act,
12+
const std::optional<c10::ScalarType> output_dtype) {
13+
bool fp32_output = output_dtype.has_value() && (output_dtype == c10::kFloat);
14+
bool bfloat16_output =
15+
output_dtype.has_value() && (output_dtype == c10::kBFloat16);
16+
auto dst_dtype = fp32_output
17+
? c10::kFloat
18+
: (bfloat16_output ? c10::kBFloat16 : act.scalar_type());
19+
return dst_dtype;
20+
}
21+
1122
at::Tensor qconv_prepack_xpu(
1223
at::Tensor weight,
1324
at::Tensor weight_scales,
@@ -75,8 +86,9 @@ class QConvoneDNNXPU final {
7586
stride.vec(),
7687
dilation.vec());
7788

78-
Tensor output = at::empty(
79-
dst_tz, act.options().dtype(output_dtype).memory_format(mfmt));
89+
auto dst_dtype = qconv_decide_out_dtype(act, output_dtype);
90+
Tensor output =
91+
at::empty(dst_tz, act.options().dtype(dst_dtype).memory_format(mfmt));
8092

8193
return quantized_convolution(
8294
act,
@@ -155,11 +167,11 @@ class QConvoneDNNXPU final {
155167
stride.vec(),
156168
dilation.vec());
157169

170+
auto dst_dtype = qconv_decide_out_dtype(act, output_dtype);
158171
bool has_accum_postop_sum = binary_attr == "sum";
159172
Tensor output = has_accum_postop_sum
160173
? accum
161-
: at::empty(
162-
dst_tz, act.options().dtype(output_dtype).memory_format(mfmt));
174+
: at::empty(dst_tz, act.options().dtype(dst_dtype).memory_format(mfmt));
163175

164176
output = quantized_convolution(
165177
act,

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 96 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,17 +156,24 @@ def _test_common(
156156
):
157157
counters.clear()
158158
torch._dynamo.reset()
159-
if (
160-
check_autocast == torch.bfloat16
161-
and torch.ops.mkldnn._is_mkldnn_bf16_supported()
159+
has_xpu = any(
160+
isinstance(input, torch.Tensor) and input.device.type == "xpu"
161+
for input in inputs
162+
)
163+
device_type = "xpu" if has_xpu else "cpu"
164+
if check_autocast == torch.bfloat16 and (
165+
torch.ops.mkldnn._is_mkldnn_bf16_supported() or has_xpu
162166
):
163-
maybe_autocast = torch.amp.autocast("cpu", dtype=torch.bfloat16)
167+
maybe_autocast = torch.amp.autocast(
168+
device_type=device_type, dtype=torch.bfloat16
169+
)
164170
atol, rtol = 1e-2, 1e-2
165-
elif (
166-
check_autocast == torch.float16
167-
and torch.ops.mkldnn._is_mkldnn_fp16_supported()
171+
elif check_autocast == torch.float16 and (
172+
torch.ops.mkldnn._is_mkldnn_fp16_supported() or has_xpu
168173
):
169-
maybe_autocast = torch.amp.autocast("cpu", dtype=torch.float16)
174+
maybe_autocast = torch.amp.autocast(
175+
device_type=device_type, dtype=torch.float16
176+
)
170177
atol, rtol = 1e-2, 1e-2
171178
else:
172179
assert check_autocast == torch.float32
@@ -1044,6 +1051,16 @@ def test_qconv2d_int8_mixed_bf16(self):
10441051
"""
10451052
self._qconv2d_test_helper(int8_mixed_bf16=True)
10461053

1054+
@skipIfNoDynamoSupport
1055+
@skipIfNoONEDNNBF16
1056+
@skipIfNoONEDNN
1057+
@skipIfNoXPU
1058+
def test_qconv2d_int8_mixed_bf16_xpu(self):
1059+
r"""
1060+
This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization.
1061+
"""
1062+
self._qconv2d_test_helper(device="xpu", int8_mixed_bf16=True)
1063+
10471064
def _qconv2d_unary_test_helper(
10481065
self,
10491066
device="cpu",
@@ -1122,7 +1139,7 @@ def test_qconv2d_relu_xpu(self):
11221139
@skipIfNoDynamoSupport
11231140
@skipIfNoONEDNNBF16
11241141
@skipIfNoONEDNN
1125-
def test_qconv2d_relu_int8_mixed_bf16(self):
1142+
def test_qconv2d_relu_int8_mixed_bf16_xpu(self):
11261143
r"""
11271144
This testcase will quantize Conv2d->ReLU pattern with int8_mixed_bf16 quantization.
11281145
"""
@@ -1178,6 +1195,24 @@ def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self):
11781195
qconv2d_unary_matcher_nodes=11,
11791196
)
11801197

1198+
@skipIfNoDynamoSupport
1199+
@skipIfNoONEDNNBF16
1200+
@skipIfNoONEDNN
1201+
@skipIfNoXPU
1202+
def test_qconv2d_hardtanh_int8_mixed_bf16_xpu(self):
1203+
r"""
1204+
This testcase will quantize Conv2d->Hardtanh pattern.
1205+
Match.nodes:
1206+
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor]
1207+
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type]
1208+
"""
1209+
self._qconv2d_unary_test_helper(
1210+
device="xpu",
1211+
unary_op=torch.nn.Hardtanh(),
1212+
int8_mixed_bf16=True,
1213+
qconv2d_unary_matcher_nodes=11,
1214+
)
1215+
11811216
@skipIfNoDynamoSupport
11821217
@skipIfNoONEDNN
11831218
def test_qconv2d_hardswish_cpu(self):
@@ -1212,6 +1247,25 @@ def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self):
12121247
qconv2d_unary_matcher_nodes=17,
12131248
)
12141249

1250+
@skipIfNoDynamoSupport
1251+
@skipIfNoONEDNNBF16
1252+
@skipIfNoONEDNN
1253+
@skipIfNoXPU
1254+
def test_qconv2d_hardswish_int8_mixed_bf16_xpu(self):
1255+
r"""
1256+
This testcase will quantize Conv2d->Hardswish pattern.
1257+
Match.nodes:
1258+
[qconv2d_pointwise_default, convert_element_type, add, clamp_min,
1259+
clamp_max, mul, div, convert_element_type, quantize_per_tensor]
1260+
[qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type]
1261+
"""
1262+
self._qconv2d_unary_test_helper(
1263+
device="xpu",
1264+
unary_op=torch.nn.Hardswish(),
1265+
int8_mixed_bf16=True,
1266+
qconv2d_unary_matcher_nodes=17,
1267+
)
1268+
12151269
@skipIfNoDynamoSupport
12161270
@skipIfNoONEDNN
12171271
def test_qconv2d_silu_cpu(self):
@@ -1246,6 +1300,25 @@ def test_qconv2d_silu_int8_mixed_bf16_cpu(self):
12461300
qconv2d_unary_matcher_nodes=11,
12471301
)
12481302

1303+
@skipIfNoDynamoSupport
1304+
@skipIfNoONEDNNBF16
1305+
@skipIfNoONEDNN
1306+
@skipIfNoXPU
1307+
def test_qconv2d_silu_int8_mixed_bf16_xpu(self):
1308+
r"""
1309+
This testcase will quantize Conv2d->SiLU pattern.
1310+
Match.nodes:
1311+
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul,
1312+
convert_element_type, quantize_per_tensor]
1313+
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type]
1314+
"""
1315+
self._qconv2d_unary_test_helper(
1316+
device="xpu",
1317+
unary_op=torch.nn.SiLU(),
1318+
int8_mixed_bf16=True,
1319+
qconv2d_unary_matcher_nodes=11,
1320+
)
1321+
12491322
def _qconv2d_add_test_helper(
12501323
self, device="cpu", use_relu=False, int8_mixed_bf16=False
12511324
):
@@ -1441,6 +1514,13 @@ def test_qconv2d_add_int8_mixed_bf16(self):
14411514
self._qconv2d_add_test_helper(int8_mixed_bf16=True)
14421515
self._qconv2d_add_test_helper2(int8_mixed_bf16=True)
14431516

1517+
@skipIfNoDynamoSupport
1518+
@skipIfNoONEDNNBF16
1519+
@skipIfNoONEDNN
1520+
@skipIfNoXPU
1521+
def test_qconv2d_add_int8_mixed_bf16_xpu(self):
1522+
self._qconv2d_add_test_helper(device="xpu", int8_mixed_bf16=True)
1523+
14441524
@skipIfNoDynamoSupport
14451525
@skipIfNoONEDNN
14461526
def test_qconv2d_add_relu_cpu(self):
@@ -1461,6 +1541,13 @@ def test_qconv2d_add_relu_int8_mixed_bf16(self):
14611541
self._qconv2d_add_test_helper(use_relu=True, int8_mixed_bf16=True)
14621542
self._qconv2d_add_test_helper2(use_relu=True, int8_mixed_bf16=True)
14631543

1544+
@skipIfNoDynamoSupport
1545+
@skipIfNoONEDNNBF16
1546+
@skipIfNoONEDNN
1547+
@skipIfNoXPU
1548+
def test_qconv2d_add_relu_int8_mixed_bf16_xpu(self):
1549+
self._qconv2d_add_test_helper(device="xpu", use_relu=True, int8_mixed_bf16=True)
1550+
14641551
@skipIfNoDynamoSupport
14651552
@skipIfNoONEDNN
14661553
def test_qconv2d_add_broadcast_shapes_cpu(self):

0 commit comments

Comments
 (0)
0