8000 [Intel GPU] qconv.pointwise with mixed dtype XPU support by ZhiweiYan-96 · Pull Request #135465 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Intel GPU] qconv.pointwise with mixed dtype XPU support #135465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 55 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
694bfad
Update
ZhiweiYan-96 Sep 9, 2024
4a80acc
Update
ZhiweiYan-96 Oct 9, 2024
88f6856
Update
ZhiweiYan-96 Oct 17, 2024
c8c1969
Update
ZhiweiYan-96 Oct 21, 2024
1576caf
Update
ZhiweiYan-96 Oct 23, 2024
c0b47f0
Update
ZhiweiYan-96 Oct 23, 2024
0701594
Update
ZhiweiYan-96 Oct 24, 2024
1e550cf
Update
ZhiweiYan-96 Oct 24, 2024
9c1d66c
Update
ZhiweiYan-96 Oct 26, 2024
b1560e4
Update
ZhiweiYan-96 Oct 27, 2024
7fcaa7f
Update
ZhiweiYan-96 Oct 29, 2024
e97f014
Update
ZhiweiYan-96 Oct 29, 2024
97b63ae
Update
ZhiweiYan-96 Oct 29, 2024
6428ae3
Update
ZhiweiYan-96 Oct 30, 2024
11e9a58
Update
ZhiweiYan-96 Oct 31, 2024
2ddadf1
Update
ZhiweiYan-96 Oct 31, 2024
67fd6d9
Update
ZhiweiYan-96 Nov 2, 2024
b855e13
Update
ZhiweiYan-96 Nov 3, 2024
1e59171
Update
ZhiweiYan-96 Nov 4, 2024
698a5f3
Update
ZhiweiYan-96 Nov 4, 2024
0c140df
Update
ZhiweiYan-96 Nov 4, 2024
c36d756
Update
ZhiweiYan-96 Nov 4, 2024
93995e5
Update
ZhiweiYan-96 Nov 4, 2024
162cfb8
Update
ZhiweiYan-96 Nov 4, 2024
a3dfcca
Update
ZhiweiYan-96 Nov 4, 2024
2ead5fc
Update
ZhiweiYan-96 Nov 5, 2024
aaabc4e
Update
ZhiweiYan-96 Nov 5, 2024
5324f17
Update
ZhiweiYan-96 Nov 21, 2024
a886324
Update
ZhiweiYan-96 Nov 28, 2024
c014f6c
Update
ZhiweiYan-96 Dec 30, 2024
7756743
Update
ZhiweiYan-96 Jan 2, 2025
c6012a0
Update
ZhiweiYan-96 Jan 2, 2025
23e1b63
Update
ZhiweiYan-96 Jan 3, 2025
d4694ea
Update
ZhiweiYan-96 Jan 6, 2025
4fe4ec6
Update
ZhiweiYan-96 Jan 6, 2025
588090c
Update
ZhiweiYan-96 Jan 7, 2025
8478edf
Update
ZhiweiYan-96 Jan 7, 2025
61b972a
Update
ZhiweiYan-96 Jan 7, 2025
78e3aa3
Update
ZhiweiYan-96 Jan 8, 2025
b974e82
Update
ZhiweiYan-96 Jan 8, 2025
aaae7c8
Update
ZhiweiYan-96 Jan 9, 2025
5efe5b2
Update
ZhiweiYan-96 Jan 10, 2025
68e0626
Update
ZhiweiYan-96 Jan 16, 2025
917da03
Update
ZhiweiYan-96 Jan 17, 2025
51d3ede
Update
ZhiweiYan-96 Jan 17, 2025
7f3c85b
Update
ZhiweiYan-96 Jan 20, 2025
318ca14
Update
ZhiweiYan-96 Jan 20, 2025
d5ba21b
Update
ZhiweiYan-96 Jan 22, 2025
c2b2fb7
Update
ZhiweiYan-96 Jan 23, 2025
d335dec
Update
ZhiweiYan-96 Feb 10, 2025
2ac8f68
Update
guangyey Feb 10, 2025
b99d38d
Update
guangyey Feb 11, 2025
babc6d6
Update
ZhiweiYan-96 Feb 11, 2025
9f806a5
Update
ZhiweiYan-96 Feb 12, 2025
94f36e3
Update
ZhiweiYan-96 Feb 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions aten/src/ATen/native/mkldnn/xpu/qconv.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/ScalarType.h>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Place this under line 3.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified

#include <torch/library.h>

#include <iostream>

using namespace at::native::onednn;
namespace at::native::xpu {

static inline c10::ScalarType qconv_decide_out_dtype(
const at::Tensor& act,
const std::optional<c10::ScalarType> output_dtype) {
bool fp32_output = output_dtype.has_value() && (output_dtype == c10::kFloat);
bool bfloat16_output =
output_dtype.has_value() && (output_dtype == c10::kBFloat16);
auto dst_dtype = fp32_output
? c10::kFloat
: (bfloat16_output ? c10::kBFloat16 : act.scalar_type());
return dst_dtype;
}

at::Tensor qconv_prepack_xpu(
at::Tensor weight,
at::Tensor weight_scales,
Expand Down Expand Up @@ -75,8 +86,9 @@ class QConvoneDNNXPU final {
stride.vec(),
dilation.vec());

Tensor output = at::empty(
dst_tz, act.options().dtype(output_dtype).memory_format(mfmt));
auto dst_dtype = qconv_decide_out_dtype(act, output_dtype);
Tensor output =
at::empty(dst_tz, act.options().dtype(dst_dtype).memory_format(mfmt));

return quantized_convolution(
act,
Expand Down Expand Up @@ -155,11 +167,11 @@ class QConvoneDNNXPU final {
stride.vec(),
dilation.vec());

auto dst_dtype = qconv_decide_out_dtype(act, output_dtype);
bool has_accum_postop_sum = binary_attr == "sum";
Tensor output = has_accum_postop_sum
? accum
: at::empty(
dst_tz, act.options().dtype(output_dtype).memory_format(mfmt));
: at::empty(dst_tz, act.options().dtype(dst_dtype).memory_format(mfmt));

output = quantized_convolution(
act,
Expand Down
105 changes: 96 additions & 9 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,17 +156,24 @@ def _test_common(
):
counters.clear()
torch._dynamo.reset()
if (
check_autocast == torch.bfloat16
and torch.ops.mkldnn._is_mkldnn_bf16_supported()
has_xpu = any(
isinstance(input, torch.Tensor) and input.device.type == "xpu"
for input in inputs
)
device_type = "xpu" if has_xpu else "cpu"
if check_autocast == torch.bfloat16 and (
torch.ops.mkldnn._is_mkldnn_bf16_supported() or has_xpu
):
maybe_autocast = torch.amp.autocast("cpu", dtype=torch.bfloat16)
maybe_autocast = torch.amp.autocast(
device_type=device_type, dtype=torch.bfloat16
)
atol, rtol = 1e-2, 1e-2
elif (
check_autocast == torch.float16
and torch.ops.mkldnn._is_mkldnn_fp16_supported()
elif check_autocast == torch.float16 and (
torch.ops.mkldnn._is_mkldnn_fp16_supported() or has_xpu
):
maybe_autocast = torch.amp.autocast("cpu", dtype=torch.float16)
maybe_autocast = torch.amp.autocast(
device_type=device_type, dtype=torch.float16
)
atol, rtol = 1e-2, 1e-2
else:
assert check_autocast == torch.float32
Expand Down Expand Up @@ -1044,6 +1051,16 @@ def test_qconv2d_int8_mixed_bf16(self):
"""
self._qconv2d_test_helper(int8_mixed_bf16=True)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
@skipIfNoXPU
def test_qconv2d_int8_mixed_bf16_xpu(self):
r"""
This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization.
"""
self._qconv2d_test_helper(device="xpu", int8_mixed_bf16=True)

def _qconv2d_unary_test_helper(
self,
device="cpu",
Expand Down Expand Up @@ -1122,7 +1139,7 @@ def test_qconv2d_relu_xpu(self):
@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
def test_qconv2d_relu_int8_mixed_bf16(self):
def test_qconv2d_relu_int8_mixed_bf16_xpu(self):
r"""
This testcase will quantize Conv2d->ReLU pattern with int8_mixed_bf16 quantization.
"""
Expand Down Expand Up @@ -1178,6 +1195,24 @@ def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self):
qconv2d_unary_matcher_nodes=11,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
@skipIfNoXPU
def test_qconv2d_hardtanh_int8_mixed_bf16_xpu(self):
r"""
This testcase will quantize Conv2d->Hardtanh pattern.
Match.nodes:
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor]
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type]
"""
self._qconv2d_unary_test_helper(
device="xpu",
unary_op=torch.nn.Hardtanh(),
int8_mixed_bf16=True,
qconv2d_unary_matcher_nodes=11,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
def test_qconv2d_hardswish_cpu(self):
Expand Down Expand Up @@ -1212,6 +1247,25 @@ def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self):
qconv2d_unary_matcher_nodes=17,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
@skipIfNoXPU
def test_qconv2d_hardswish_int8_mixed_bf16_xpu(self):
r"""
This testcase will quantize Conv2d->Hardswish pattern.
Match.nodes:
[qconv2d_pointwise_default, convert_element_type, add, clamp_min,
clamp_max, mul, div, convert_element_type, quantize_per_tensor]
[qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type]
"""
self._qconv2d_unary_test_helper(
device="xpu",
unary_op=torch.nn.Hardswish(),
int8_mixed_bf16=True,
qconv2d_unary_matcher_nodes=17,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
def test_qconv2d_silu_cpu(self):
Expand Down Expand Up @@ -1246,6 +1300,25 @@ def test_qconv2d_silu_int8_mixed_bf16_cpu(self):
qconv2d_unary_matcher_nodes=11,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
@skipIfNoXPU
def test_qconv2d_silu_int8_mixed_bf16_xpu(self):
r"""
This testcase will quantize Conv2d->SiLU pattern.
Match.nodes:
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul,
convert_element_type, quantize_per_tensor]
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type]
"""
self._qconv2d_unary_test_helper(
device="xpu",
unary_op=torch.nn.SiLU(),
int8_mixed_bf16=True,
qconv2d_unary_matcher_nodes=11,
)

def _qconv2d_add_test_helper(
self, device="cpu", use_relu=False, int8_mixed_bf16=False
):
Expand Down Expand Up @@ -1441,6 +1514,13 @@ def test_qconv2d_add_int8_mixed_bf16(self):
self._qconv2d_add_test_helper(int8_mixed_bf16=True)
self._qconv2d_add_test_helper2(int8_mixed_bf16=True)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
@skipIfNoXPU
def test_qconv2d_add_int8_mixed_bf16_xpu(self):
self._qconv2d_add_test_helper(device="xpu", int8_mixed_bf16=True)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
def test_qconv2d_add_relu_cpu(self):
Expand All @@ -1461,6 +1541,13 @@ def test_qconv2d_add_relu_int8_mixed_bf16(self):
self._qconv2d_add_test_helper(use_relu=True, int8_mixed_bf16=True)
self._qconv2d_add_test_helper2(use_relu=True, int8_mixed_bf16=True)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
@skipIfNoXPU
def test_qconv2d_add_relu_int8_mixed_bf16_xpu(self):
self._qconv2d_add_test_helper(device="xpu", use_relu=True, int8_mixed_bf16=True)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
def test_qconv2d_add_broadcast_shapes_cpu(self):
Expand Down
Loading
0