-
Notifications
You must be signed in to change notification settings - Fork 24.8k
[Intel GPU] qlinear at XPU backend #133307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
bed8720
06bf0f2
f2136d8
a3c07aa
0dc3ab2
eacaafd
9d02aa4
b458071
8718de2
d6c7879
c75fcad
3a5307c
e479ff3
5b32861
2732057
fc42187
dfacc8e
969d5c6
9b64424
7fd8bcc
3431fd4
db2eca3
3103c58
88dd4a8
c77d447
d048068
a8bf0f0
39dcc4e
67c054a
03b6bba
a25333e
fe9039e
4b39daa
7890880
c9f37be
b7e1794
a9500f5
8573301
00d6d6f
1c5645c
4593a94
d2716e9
94a2ed9
30de529
1f5ed5a
9ebd45e
e577def
9f8ca53
263f371
f5aadd5
8d7e570
341eed1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
[ghstack-poisoned]
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
#include <torch/library.h> | ||
|
||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h> | ||
#include "c10/core/ScalarType.h" | ||
|
||
using namespace at::native::onednn; | ||
|
||
|
@@ -22,6 +23,8 @@ Tensor q_linear_pointwise( | |
c10::string_view post_op_name, | ||
torch::List<std::optional<at::Scalar>> post_op_args, | ||
c10::string_view post_op_algorithm) { | ||
|
||
|
||
Tensor b_raw = bias.has_value() ? bias.value() : at::Tensor(); | ||
|
||
const int64_t dim = act.dim(); | ||
|
@@ -34,21 +37,79 @@ Tensor q_linear_pointwise( | |
std::vector<int64_t> dst_dims = {M, N}; | ||
Tensor qout = at::empty(dst_dims, device(c10::kXPU).dtype(c10::kByte)); | ||
|
||
Attr attr = Attr(); | ||
quantized_matmul_pt2( | ||
act.contiguous(), | ||
act_scale, | ||
act_zero_point, | ||
weight.contiguous(), | ||
weight_scales, | ||
weight_zero_points, | ||
b_raw, | ||
qout, | ||
output_scale, | ||
output_zero_point, | ||
output_dtype, | ||
/*other*/ std::nullopt, | ||
/*other scale*/ 1.0, | ||
/*other zp*/0, | ||
/*binary post op*/ "none", | ||
/*binary alpha*/1.0, | ||
post_op_name, | ||
post_op_args, | ||
post_op_algorithm | ||
); | ||
|
||
return qout; | ||
} | ||
|
||
Tensor q_linear_pointwise_tensor( | ||
Tensor act, | ||
Tensor act_scale, | ||
Tensor act_zero_point, | ||
Tensor weight, | ||
Tensor weight_scales, | ||
Tensor weight_zero_points, | ||
std::optional<Tensor> bias, | ||
double output_scale, | ||
int64_t output_zero_point, | ||
std::optional<c10::ScalarType> output_dtype, | ||
c10::string_view post_op_name, | ||
torch::List<std::optional<at::Scalar>> post_op_args, | ||
c10::string_view post_op_algorithm | ||
){ | ||
Tensor b_raw = bias.has_value() ? bias.value() : at::Tensor(); | ||
|
||
const int64_t dim = act.dim(); | ||
int64_t K = act.size(dim - 1); | ||
int64_t M = act.numel() / K; | ||
// [M, K] x [K, N] | ||
int64_t N = weight.size(1); | ||
|
||
std::vector<int64_t> src_dims = {M, K}; | ||
std::vector<int64_t> dst_dims = {M, N}; | ||
Tensor qout = at::empty(dst_dims, device(c10::kXPU).dtype(c10::kByte)); | ||
|
||
quantized_matmul_pt2( | ||
qout, | ||
act, | ||
weight, | ||
b_raw, | ||
/*m2_trans=*/false, | ||
act_scale, | ||
act_zero_point, | ||
weight_scales, | ||
weight_zero_points, | ||
output_scale, | ||
output_zero_point, | ||
attr); | ||
act.contiguous(), | ||
act_scale.item().toDouble(), | ||
act_zero_point.item().toLong(), | ||
weight.contiguous(), | ||
weight_scales, | ||
weight_zero_points, | ||
b_raw, | ||
qout, | ||
output_scale, | ||
output_zero_point, | ||
output_dtype, | ||
/*other*/ std::nullopt, | ||
/*other scale*/ 1.0, | ||
/*other zp*/0, | ||
/*binary post op*/ "none", | ||
/*binary alpha*/1.0, | ||
post_op_name, | ||
post_op_args, | ||
post_op_algorithm | ||
); | ||
|
||
return qout; | ||
} | ||
|
@@ -57,14 +118,18 @@ Tensor q_linear_pointwise( | |
at::Tensor q_linear_prepack_onednn( | ||
at::Tensor weight, | ||
c10::optional<torch::List<int64_t>> input_shape) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but we need to keep this argument, as the Why |
||
return weight; | ||
at::Tensor weight_transposed = weight.transpose(0, 1); | ||
return weight_transposed; | ||
} | ||
|
||
|
||
TORCH_LIBRARY_IMPL(onednn, XPU, m) { | ||
m.impl( | ||
TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise"), | ||
TORCH_FN(q_linear_pointwise)); | ||
m.impl( | ||
TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.tensor"), | ||
TORCH_FN(q_linear_pointwise_tensor)); | ||
m.impl( | ||
TORCH_SELECTIVE_NAME("onednn::qlinear_prepack"), | ||
TORCH_FN(q_linear_prepack_onednn)); | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -1419,6 +1419,7 @@ def matcher_check_fn(): | |||||||||||||||
def _qlinear_cpu_test_helper( | ||||||||||||||||
self, | ||||||||||||||||
inputs, | ||||||||||||||||
device="cpu", | ||||||||||||||||
int8_mixed_bf16=False, | ||||||||||||||||
do_permute=False, | ||||||||||||||||
matcher_check_fn=None, | ||||||||||||||||
|
@@ -1438,7 +1439,7 @@ def forward(self, x): | |||||||||||||||
x = torch.reshape(torch.permute(x, (0, 2, 3, 1)), (2, 12, 4)) | ||||||||||||||||
return self.linear2(self.linear(x)) | ||||||||||||||||
|
||||||||||||||||
mod = M(bias, do_permute=do_permute).eval().xpu() | ||||||||||||||||
mod = M(bias, do_permute=do_permute).eval().to(device=device) | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. modified, thanks |
||||||||||||||||
|
||||||||||||||||
def _default_matcher_check_fn(): | ||||||||||||||||
self.assertEqual( | ||||||||||||||||
|
@@ -1459,12 +1460,12 @@ def _default_matcher_check_fn(): | |||||||||||||||
|
||||||||||||||||
@skipIfNoDynamoSupport | ||||||||||||||||
@skipIfNoONEDNN | ||||||||||||||||
def test_qlinear_cpu(self): | ||||||||||||||||
def test_qlinear_mkldnn(self, device="cpu"): | ||||||||||||||||
r""" | ||||||||||||||||
This testcase will quantize a single Linear Moduel. | ||||||||||||||||
""" | ||||||||||||||||
for bias in [True, False]: | ||||||||||||||||
self._qlinear_cpu_test_helper((torch.randn((2, 4)).xpu(),), bias=bias) | ||||||||||||||||
self._qlinear_cpu_test_helper((torch.randn((2, 4)).to(device=device),), device=device, bias=bias) | ||||||||||||||||
|
||||||||||||||||
@skipIfNoDynamoSupport | ||||||||||||||||
@skipIfNoONEDNN | ||||||||||||||||
|
@@ -1587,7 +1588,7 @@ def matcher_check_fn(): | |||||||||||||||
) | ||||||||||||||||
|
||||||||||||||||
def _qlinear_unary_cpu_test_helper( | ||||||||||||||||
self, inputs, unary_op=torch.nn.ReLU(), int8_mixed_bf16=False | ||||||||||||||||
self, inputs, unary_op=torch.nn.ReLU(), device="cpu", int8_mixed_bf16=False | ||||||||||||||||
): | ||||||||||||||||
class M(torch.nn.Module): | ||||||||||||||||
def __init__(self, use_bias): | ||||||||||||||||
|
@@ -1603,7 +1604,7 @@ def forward(self, x): | |||||||||||||||
|
||||||||||||||||
bias_list = [True, False] | ||||||||||||||||
for bias in bias_list: | ||||||||||||||||
mod = M(bias).eval() | ||||||||||||||||
mod = M(bias).eval().to(device=device) | ||||||||||||||||
|
||||||||||||||||
def matcher_check_fn(): | ||||||||||||||||
# 1. dequant-linear pattern matched in quantization weight prepack | ||||||||||||||||
|
@@ -1623,11 +1624,11 @@ def matcher_check_fn(): | |||||||||||||||
|
||||||||||||||||
@skipIfNoDynamoSupport | ||||||||||||||||
@skipIfNoONEDNN | ||||||||||||||||
def test_qlinear_relu_cpu(self): | ||||||||||||||||
def test_qlinear_relu_mkldnn(self, device="cpu"): | ||||||||||||||||
r""" | ||||||||||||||||
This testcase will quantize a Linear->ReLU pattern. | ||||||||||||||||
""" | ||||||||||||||||
self._qlinear_unary_cpu_test_helper((torch.randn((2, 4)),)) | ||||||||||||||||
self._qlinear_unary_cpu_test_helper((torch.randn((2, 4)).to(device=device),), device=device) | ||||||||||||||||
|
||||||||||||||||
@skipIfNoDynamoSupport | ||||||||||||||||
@skipIfNoONEDNNBF16 | ||||||||||||||||
|
@@ -1661,12 +1662,12 @@ def test_qlinear_relu_int8_mixed_bf16_input_dim_exceeds_2(self): | |||||||||||||||
|
||||||||||||||||
@skipIfNoDynamoSupport | ||||||||||||||||
@skipIfNoONEDNN | ||||||||||||||||
def test_qlinear_gelu_cpu(self): | ||||||||||||||||
def test_qlinear_gelu_mkldnn(self, device="cpu"): | ||||||||||||||||
r""" | ||||||||||||||||
This testcase will quantize a Linear->GELU pattern. | ||||||||||||||||
""" | ||||||||||||||||
for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]: | ||||||||||||||||
self._qlinear_unary_cpu_test_helper((torch.randn((2, 4)),), gelu) | ||||||||||||||||
self._qlinear_unary_cpu_test_helper((torch.randn((2, 4)).to(device=device),), gelu, device=device) | ||||||||||||||||
|
||||||||||||||||
@skipIfNoDynamoSupport | ||||||||||||||||
@skipIfNoONEDNNBF16 | ||||||||||||||||
|
@@ -2831,4 +2832,3 @@ def matcher_check_fn(): | |||||||||||||||
if __name__ == "__main__": | ||||||||||||||||
if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available(): | ||||||||||||||||
run_tests() | ||||||||||||||||
1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the file name align with other similar file names? If yes, pls. paste the path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the file name is changed, great appreciation for your reminding.