8000 [Intel GPU] qconv.pointwise with mixed dtype XPU support · pytorch/pytorch@6682aa0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6682aa0

Browse files
committed
[Intel GPU] qconv.pointwise with mixed dtype XPU support
ghstack-source-id: 26a8e4a Pull Request resolved: #135465
1 parent c11bd4f commit 6682aa0

File tree

3 files changed

+40
-24
lines changed

3 files changed

+40
-24
lines changed

aten/src/ATen/native/mkldnn/xpu/qconv_pt2e.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <torch/library.h>
55

66
#include <iostream>
7+
#include "c10/core/ScalarType.h"
78

89
using namespace at::native::onednn;
910
namespace at {
@@ -164,11 +165,11 @@ class QConvoneDNNXPU final {
164165
stride.vec(),
165166
dilation.vec());
166167

167-
// TODO: handle difference of this dtype with argument dtype
168-
// auto dtype =
169-
// (act.scalar_type() == c10::ScalarType::Byte) ? c10::kByte : c10::kChar;
168+
bool fp32_output = output_dtype.has_value() && (output_dtype == c10::kFloat);
169+
bool bfloat16_output = output_dtype.has_value() && (output_dtype == c10::kBFloat16);
170+
auto dst_dtype = fp32_output ? c10::kFloat : (bfloat16_output ? c10::kBFloat16 : c10::kByte);
170171
Tensor output = at::empty(
171-
dst_tz, device(c10::kXPU).dtype(output_dtype).memory_format(mfmt));
172+
dst_tz, device(c10::kXPU).dtype(dst_dtype).memory_format(mfmt));
172173

173174
return quantized_convolution_pt2(
174175
act,

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -143,20 +143,31 @@ def _test_common(
143143
):
144144
counters.clear()
145145
torch._dynamo.reset()
146+
is_xpu = False
147+
for input in inputs:
148+
is_xpu = is_xpu or (input.device.type == "xpu")
146149
assert matcher_check_fn is not None or (
147150
matcher_count is not None and matcher_nodes is not None
148151
)
149152
if (
150-
check_autocast == torch.bfloat16
151-
and torch.ops.mkldnn._is_mkldnn_bf16_supported()
153+
(check_autocast == torch.bfloat16
154+
and torch.ops.mkldnn._is_mkldnn_bf16_supported()) or
155+
is_xpu
152156
):
153-
maybe_autocast = torch.cpu.amp.autocast(dtype=torch.bfloat16)
157+
if is_xpu:
158+
maybe_autocast = torch.amp.autocast(device_type="xpu", dtype=torch.bfloat16)
159+
else:
160+
maybe_autocast = torch.cpu.amp.autocast(dtype=torch.bfloat16)
154161
atol, rtol = 1e-2, 1e-2
155162
elif (
156-
check_autocast == torch.float16
157-
and torch.ops.mkldnn._is_mkldnn_fp16_supported()
163+
(check_autocast == torch.float16
164+
and torch.ops.mkldnn._is_mkldnn_fp16_supported()) or
165+
is_xpu
158166
):
159-
maybe_autocast = torch.cpu.amp.autocast(dtype=torch.float16)
167+
if is_xpu:
168+
maybe_autocast = torch.amp.autocast(device_type="xpu", dtype=torch.float16)
169+
else:
170+
maybe_autocast = torch.cpu.amp.autocast(dtype=torch.float16)
160171
atol, rtol = 1e-2, 1e-2
161172
else:
162173
assert check_autocast == torch.float32
@@ -196,6 +207,7 @@ def _test_common(
196207
)
197208
if matcher_check_fn is not None:
198209
matcher_check_fn()
210+
print("===== Finish one test =====")
199211

200212
def _test_code_common(
201213
self,
@@ -713,11 +725,11 @@ def test_qconv2d_mkldnn(self, device):
713725
@skipIfNoDynamoSupport
714726
@skipIfNoONEDNNBF16
715727
@skipIfNoONEDNN
716-
def test_qconv2d_int8_mixed_bf16(self):
728+
def test_qconv2d_int8_mixed_bf16(self, device="cpu"):
717729
r"""
718730
This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization.
719731
"""
720-
self._qconv2d_cpu_test_helper(int8_mixed_bf16=True)
732+
self._qconv2d_cpu_test_helper(device=device, int8_mixed_bf16=True)
721733

722734
def _qconv2d_unary_cpu_test_helper(
723735
self,
@@ -776,15 +788,15 @@ def test_qconv2d_relu_mkldnn(self, device):
776788
@skipIfNoDynamoSupport
777789
@skipIfNoONEDNNBF16
778790
@skipIfNoONEDNN
779-
def test_qconv2d_relu_int8_mixed_bf16(self):
791+
def test_qconv2d_relu_int8_mixed_bf16(self, device="cpu"):
780792
r"""
781793
This testcase will quantize Conv2d->ReLU pattern with int8_mixed_bf16 quantization.
782794
"""
783-
self._qconv2d_unary_cpu_test_helper(int8_mixed_bf16=True)
795+
self._qconv2d_unary_cpu_test_helper(device=device, int8_mixed_bf16=True)
784796

785797
@skipIfNoDynamoSupport
786798
@skipIfNoONEDNN
787-
def test_qconv2d_relu6_cpu(self, device):
799+
def test_qconv2d_relu6_mkldnn(self, device):
788800
r"""
789801
This testcase will quantize Conv2d->ReLU6 pattern.
790802
"""
@@ -801,14 +813,15 @@ def test_qconv2d_hardtanh_mkldnn(self, device):
801813
@skipIfNoDynamoSupport
802814
@skipIfNoONEDNNBF16
803815
@skipIfNoONEDNN
804-
def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self):
816+
def test_qconv2d_hardtanh_int8_mixed_bf16_mkldnn(self, device="cpu"):
805817
r"""
806818
This testcase will quantize Conv2d->Hardtanh pattern.
807819
Match.nodes:
808820
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor]
809821
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type]
810822
"""
811823
self._qconv2d_unary_cpu_test_helper(
824+
device=device,
812825
unary_op=torch.nn.Hardtanh(),
813826
int8_mixed_bf16=True,
814827
qconv2d_unary_matcher_nodes=11,
@@ -825,7 +838,7 @@ def test_qconv2d_hardswish_mkldnn(self, device):
825838
@skipIfNoDynamoSupport
826839
@skipIfNoONEDNNBF16
827840
@skipIfNoONEDNN
828< F438 code>-
def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self):
841+
def test_qconv2d_hardswish_int8_mixed_bf16_mkldnn(self, device="cpu"):
829842
r"""
830843
This testcase will quantize Conv2d->Hardswish pattern.
831844
Match.nodes:
@@ -834,6 +847,7 @@ def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self):
834847
[qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type]
835848
"""
836849
self._qconv2d_unary_cpu_test_helper(
850+
device=device,
837851
unary_op=torch.nn.Hardswish(),
838852
int8_mixed_bf16=True,
839853
qconv2d_unary_matcher_nodes=17,
@@ -850,7 +864,7 @@ def test_qconv2d_silu_mkldnn(self, device):
850864
@skipIfNoDynamoSupport
851865
@skipIfNoONEDNNBF16
852866
@skipIfNoONEDNN
853-
def test_qconv2d_silu_int8_mixed_bf16_cpu(self):
867+
def test_qconv2d_silu_int8_mixed_bf16_mkldnn(self, device="cpu"):
854868
r"""
855869
This testcase will quantize Conv2d->SiLU pattern.
856870
Match.nodes:
@@ -859,6 +873,7 @@ def test_qconv2d_silu_int8_mixed_bf16_cpu(self):
859873
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type]
860874
"""
861875
self._qconv2d_unary_cpu_test_helper(
876+
device=device,
862877
unary_op=torch.nn.SiLU(),
863878
int8_mixed_bf16=True,
864879
qconv2d_unary_matcher_nodes=11,
@@ -941,8 +956,8 @@ def test_qconv2d_add_mkldnn(self, device="cpu"):
941956
@skipIfNoDynamoSupport
942957
@skipIfNoONEDNNBF16
943958
@skipIfNoONEDNN
944-
def test_qconv2d_add_int8_mixed_bf16(self):
945-
self._qconv2d_add_cpu_test_helper(int8_mixed_bf16=True)
959+
def test_qconv2d_add_int8_mixed_bf16(self, device="cpu"):
960+
self._qconv2d_add_cpu_test_helper(device=device, int8_mixed_bf16=True)
946961

947962
@skipIfNoDynamoSupport
948963
@skipIfNoONEDNN
@@ -952,8 +967,8 @@ def test_qconv2d_add_relu_mkldnn(self, device="cpu"):
952967
@skipIfNoDynamoSupport
953968
@skipIfNoONEDNNBF16
954969
@skipIfNoONEDNN
955-
def test_qconv2d_add_relu_int8_mixed_bf16(self):
956-
self._qconv2d_add_cpu_test_helper(use_relu=True, int8_mixed_bf16=True)
970+
def test_qconv2d_add_relu_int8_mixed_bf16(self, device="cpu"):
971+
self._qconv2d_add_cpu_test_helper(device=device, use_relu=True, int8_mixed_bf16=True)
957972

958973
@skipIfNoDynamoSupport
959974
@skipIfNoONEDNN
@@ -2826,7 +2841,7 @@ def matcher_check_fn():
28262841
quantizer=quantizer,
28272842
)
28282843

2829-
device_types = ("xpu")
2844+
device_types = ("xpu", "cpu")
28302845
instantiate_device_type_tests(TestPatternMatcher, globals(), only_for=device_types, allow_xpu=True)
28312846

28322847
if __name__ == "__main__":

torch/_inductor/mkldnn_ir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def _original_deconv_weight_size(
161161
else:
162162
output_stride = make_channels_last_strides_for(output_size)
163163

164-
assert x.get_device().type in ["xpu", "xpu"] and weight.get_device().type in ["cpu", "xpu"]
164+
assert x.get_device().type in ["xpu", "cpu"] and weight.get_device().type in ["cpu", "xpu"]
165165
inputs = [x, weight]
166166

167167
kernel_layout = FixedLayout(

0 commit comments

Comments
 (0)
0