8000 [Intel GPU] qlinear.pointwise with mixed dtype support by ZhiweiYan-96 · Pull Request #136753 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Intel GPU] qlinear.pointwise with mixed dtype support #136753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 58 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
b1eb10e
Update
ZhiweiYan-96 Sep 26, 2024
dae8179
Update
ZhiweiYan-96 Oct 9, 2024
9c0eda1
Update
ZhiweiYan-96 Oct 9, 2024
db98241
Update
ZhiweiYan-96 Oct 17, 2024
833de1e
Update
ZhiweiYan-96 Oct 21, 2024
21ee3c3
Update
ZhiweiYan-96 Oct 23, 2024
8f1ee44
Update
ZhiweiYan-96 Oct 23, 2024
80e557a
Update
ZhiweiYan-96 Oct 24, 2024
73179b4
Update
ZhiweiYan-96 Oct 24, 2024
2ef0fdf
Update
ZhiweiYan-96 Oct 26, 2024
5d50dbc
Update
ZhiweiYan-96 Oct 27, 2024
7b1979b
Update
ZhiweiYan-96 Oct 29, 2024
cb5ac81
Update
ZhiweiYan-96 Oct 29, 2024
72e233c
Update
ZhiweiYan-96 Oct 29, 2024
8ae02a4
Update
ZhiweiYan-96 Oct 30, 2024
6eb4147
Update
ZhiweiYan-96 Oct 31, 2024
f276e7a
Update
ZhiweiYan-96 Oct 31, 2024
4b079ec
Update
ZhiweiYan-96 Oct 31, 2024
b9693ad
Update
ZhiweiYan-96 Nov 2, 2024
4039680
Update
ZhiweiYan-96 Nov 3, 2024
c81073c
Update
ZhiweiYan-96 Nov 4, 2024
9cf8dc8
Update
ZhiweiYan-96 Nov 4, 2024
36f84ff
Update
ZhiweiYan-96 Nov 4, 2024
8c6b9a4
Update
ZhiweiYan-96 Nov 4, 2024
e597d92
Update
ZhiweiYan-96 Nov 4, 2024
0671273
Update
ZhiweiYan-96 Nov 4, 2024
62a73eb
Update
ZhiweiYan-96 Nov 4, 2024
a7d21f9
Update
ZhiweiYan-96 Nov 5, 2024
d1e60e8
Update
ZhiweiYan-96 Nov 5, 2024
ac8e729
Update
ZhiweiYan-96 Nov 21, 2024
f6c2f09
Update
ZhiweiYan-96 Nov 28, 2024
64d364c
Update
ZhiweiYan-96 Dec 30, 2024
525e0e5
Update
ZhiweiYan-96 Jan 2, 2025
2638b20
Update
ZhiweiYan-96 Jan 2, 2025
2bd304f
Update
ZhiweiYan-96 Jan 3, 2025
3277e09
Update
ZhiweiYan-96 Jan 6, 2025
0aca22a
Update
ZhiweiYan-96 Jan 6, 2025
9bc38e6
Update
ZhiweiYan-96 Jan 7, 2025
a57e04b
Update
ZhiweiYan-96 Jan 7, 2025
d50865a
Update
ZhiweiYan-96 Jan 7, 2025
3636740
Update
ZhiweiYan-96 Jan 8, 2025
628487a
Update
ZhiweiYan-96 Jan 8, 2025
3fafa15
Update
ZhiweiYan-96 Jan 9, 2025
1db9b8d
Update
ZhiweiYan-96 Jan 10, 2025
9da2d41
Update
ZhiweiYan-96 Jan 16, 2025
1f1bd79
Update
ZhiweiYan-96 Jan 17, 2025
9aaa7cc
Update
ZhiweiYan-96 Jan 17, 2025
09ba9bf
Update
ZhiweiYan-96 Jan 20, 2025
eea72ec
Update
ZhiweiYan-96 Jan 20, 2025
a2410dd
Update
ZhiweiYan-96 Jan 22, 2025
9fa7182
Update
ZhiweiYan-96 Jan 23, 2025
dc1149b
Update
ZhiweiYan-96 Feb 10, 2025
384279a
Update
guangyey Feb 10, 2025
e51af94
Update
guangyey Feb 11, 2025
cef0193
Update
ZhiweiYan-96 Feb 11, 2025
56fa0cf
Update
ZhiweiYan-96 Feb 11, 2025
7fbe418
Update
ZhiweiYan-96 Feb 12, 2025
41c7207
Update
ZhiweiYan-96 Feb 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update
[ghstack-poisoned]
  • Loading branch information
ZhiweiYan-96 committed Feb 12, 2025
commit 7fbe418f6b099f4cfa1aa32cc7a88d40b589f77f
40 changes: 16 additions & 24 deletions aten/src/ATen/native/mkldnn/xpu/qlinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@ using namespace at::native::onednn;

namespace at::native::xpu {

static inline c10::ScalarType qlinear_decide_out_dtype(
const at::Tensor& act,
const std::optional<c10::ScalarType> output_dtype) {
bool fp32_output = output_dtype.has_value() && (output_dtype == c10::kFloat);
bool bfloat16_output =
output_dtype.has_value() && (output_dtype == c10::kBFloat16);
auto dst_dtype = fp32_output
? c10::kFloat
: (bfloat16_output ? c10::kBFloat16 : act.scalar_type());
return dst_dtype;
}

Tensor q_linear_pointwise(
Tensor act,
double act_scale,
Expand Down Expand Up @@ -38,12 +50,7 @@ Tensor q_linear_pointwise(
std::vector<int64_t> src_dims = {M, K};
std::vector<int64_t> dst_dims = {M, N};

bool fp32_output = output_dtype.has_value() && (output_dtype == c10::kFloat);
bool bfloat16_output =
output_dtype.has_value() && (output_dtype == c10::kBFloat16);
auto dst_dtype = fp32_output
? c10::kFloat
: (bfloat16_output ? c10::kBFloat16 : act.scalar_type());
auto dst_dtype = qlinear_decide_out_dtype(act, output_dtype);
Tensor qout = at::empty(dst_dims, act.options().dtype(dst_dtype));

quantized_matmul(
Expand Down Expand Up @@ -102,12 +109,7 @@ Tensor q_linear_pointwise_tensor(
std::vector<int64_t> src_dims = {M, K};
std::vector<int64_t> dst_dims = {M, N};

bool fp32_output = output_dtype.has_value() && (output_dtype == c10::kFloat);
bool bfloat16_output =
output_dtype.has_value() && (output_dtype == c10::kBFloat16);
auto dst_dtype = fp32_output
? c10::kFloat
: (bfloat16_output ? c10::kBFloat16 : act.scalar_type());
auto dst_dtype = qlinear_decide_out_dtype(act, output_dtype);
Tensor qout = at::empty(dst_dims, act.options().dtype(dst_dtype));

quantized_matmul(
Expand Down Expand Up @@ -169,12 +171,7 @@ Tensor q_linear_pointwise_binary(

std::vector<int64_t> src_dims = {M, K};
std::vector<int64_t> dst_dims = {M, N};
bool fp32_output = output_dtype.has_value() && (output_dtype == c10::kFloat);
bool bfloat16_output =
output_dtype.has_value() && (output_dtype == c10::kBFloat16);
auto dst_dtype = fp32_output
? c10::kFloat
: (bfloat16_output ? c10::kBFloat16 : act.scalar_type());
auto dst_dtype = qlinear_decide_out_dtype(act, output_dtype);
Tensor qout = at::empty(dst_dims, act.options().dtype(dst_dtype));

quantized_matmul(
Expand Down Expand Up @@ -236,12 +233,7 @@ Tensor q_linear_pointwise_binary_tensor(

std::vector<int64_t> src_dims = {M, K};
std::vector<int64_t> dst_dims = {M, N};
bool fp32_output = output_dtype.has_value() && (output_dtype == c10::kFloat);
bool bfloat16_output =
output_dtype.has_value() && (output_dtype == c10::kBFloat16);
auto dst_dtype = fp32_output
? c10::kFloat
: (bfloat16_output ? c10::kBFloat16 : act.scalar_type());
auto dst_dtype = qlinear_decide_out_dtype(act, output_dtype);
Tensor qout = at::empty(dst_dims, act.options().dtype(dst_dtype));

quantized_matmul(
Expand Down
21 changes: 11 additions & 10 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2461,6 +2461,7 @@ def test_qlinear_relu_xpu(self):
(torch.randn((2, 4)).to(device="xpu"),), device="xpu"
)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
def test_qlinear_relu_int8_mixed_bf16(self):
Expand Down Expand Up @@ -2782,7 +2783,7 @@ def test_qlinear_add_int8_mixed_bf16_xpu(self, use_relu, is_qat, is_dynamic):
is_dynamic=is_dynamic,
)

def _qlinear_dequant_promotion_cpu_test_helper(
def _qlinear_dequant_promotion_test_helper(
self,
inputs,
device="cpu",
Expand Down Expand Up @@ -2848,7 +2849,7 @@ def test_qlinear_dequant_promotion_cpu(self):
|
Y
"""
self._qlinear_dequant_promotion_cpu_test_helper((torch.randn((2, 4)),))
self._qlinear_dequant_promotion_test_helper((torch.randn((2, 4)),))

@skipIfNoDynamoSupport
@skipIfNoONEDNN
Expand All @@ -2866,7 +2867,7 @@ def test_qlinear_dequant_promotion_xpu(self):
|
Y
"""
self._qlinear_dequant_promotion_cpu_test_helper(
self._qlinear_dequant_promotion_test_helper(
(torch.randn((2, 4)).to(device="xpu"),), device="xpu"
)

Expand All @@ -2887,7 +2888,7 @@ def test_qlinear_dequant_promotion_int8_mixed_bf16(self):
|
Y
"""
self._qlinear_dequant_promotion_cpu_test_helper(
self._qlinear_dequant_promotion_test_helper(
(torch.randn((2, 4)),), int8_mixed_bf16=True
)

Expand All @@ -2909,7 +2910,7 @@ def test_qlinear_dequant_promotion_int8_mixed_bf16_xpu(self):
|
Y
"""
self._qlinear_dequant_promotion_cpu_test_helper(
self._qlinear_dequant_promotion_test_helper(
(torch.randn((2, 4)).to(device="xpu"),), device="xpu", int8_mixed_bf16=True
)

Expand All @@ -2928,7 +2929,7 @@ def test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2(self):
|
Y
"""
self._qlinear_dequant_promotion_cpu_test_helper((torch.randn((2, 3, 4)),))
self._qlinear_dequant_promotion_test_helper((torch.randn((2, 3, 4)),))

@skipIfNoDynamoSupport
@skipIfNoONEDNN
Expand All @@ -2946,7 +2947,7 @@ def test_qlinear_dequant_promotion_input_dim_exceeds_2_xpu(self):
|
Y
"""
self._qlinear_dequant_promotion_cpu_test_helper(
self._qlinear_dequant_promotion_test_helper(
(torch.randn((2, 3, 4)).to(device="xpu"),), device="xpu"
)

Expand All @@ -2967,7 +2968,7 @@ def test_qlinear_dequant_promotion_int8_mixed_bf16_input_dim_exceeds_2(self):
|
Y
"""
self._qlinear_dequant_promotion_cpu_test_helper(
self._qlinear_dequant_promotion_test_helper(
(torch.randn((2, 3, 4)),), int8_mixed_bf16=True
)

Expand All @@ -2989,7 +2990,7 @@ def test_qlinear_dequant_promotion_int8_mixed_bf16_input_dim_exceeds_2_xpu(self)
|
Y
"""
self._qlinear_dequant_promotion_cpu_test_helper(
self._qlinear_dequant_promotion_test_helper(
(torch.randn((2, 3, 4)).to(device="xpu"),),
device="xpu",
int8_mixed_bf16=True,
Expand Down Expand Up @@ -3019,7 +3020,7 @@ def matcher_check_fn():
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3
)

self._qlinear_dequant_promotion_cpu_test_helper(
self._qlinear_dequant_promotion_test_helper(
(torch.randn((2, 4)),),
matcher_check_fn=matcher_check_fn,
is_dynamic=True,
Expand Down
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.
0