8000 [Intel GPU] qlinear at XPU backend (#133307) · pytorch/pytorch@59915b8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 59915b8

Browse files
ZhiweiYan-96guangyey
authored andcommitted
[Intel GPU] qlinear at XPU backend (#133307)
# Motivation The PR is intended to enable `onednn.qlinear` and `onednn.qlinear_unary` at Intel GPU. We register the qlinear ops at C++ backend via `TORCH_LIBRARY_IMPL`, the op this PR registers includes `onednn::qlinear_pointwise`, `onednn::qlinear_pointwise.tensor`, and `onednn::qlinear_prepack`. The prepack conduct transpose on weight for fitting oneDNN requirement on weight to acquire higher performance. Also, we remove the limitation of the corresponding annotation method in the `XPUInductorQuantizer` (`torch/ao/quantization/quantizer/xpu_inductor_quantizer.py`) to allow GPU linear conversion. We add the kChar(`torch.int8`) dtype in the `torch/_inductor/fx_passes/quantization` and `torch/_inductor/mkldnn_ir.py`, as signed int8 is the default INT8 data type at GPU side. We verified the op through UTs and e2e model testing like ResNet18, ResNet50. # UT verification ``` DNNL_VERBOSE=0 TORCH_COMPILE_DEBUG=0 python test/inductor/test_mkldnn_pattern_matcher.py -v \ -k test_qlinear_xpu \ -k test_qlinear_relu_xpu \ -k test_qlinear_gelu_xpu ``` # Runtime exemplification Here is the oneDNN verbose collected through running above UTs ``` //pure int8 gemm onednn_verbose,primitive,exec,gpu:0,matmul,jit:gemm:any,undef,src_s8::blocked:ab::f0 wei_s8::blocked:ab::f0 dst_s8::blocked:ab::f0,attr-scratchpad:user attr-scales:src0:0:f32+dst:0:f32+wei:2:f32 attr-zero-points:src0:0:s32+dst:0:s32,,2x4:4x3,0.187988 // post-relu fusion onednn_verbose,primitive,exec,gpu:0,matmul,jit:gemm:any,undef,src_s8::blocked:ab::f0 wei_s8::blocked:ab::f0 bia_f32::blocked:ab::f0_mask2 dst_f32::blocked:ab::f0,attr-scratchpad:user attr-scales:src0:0:f32+dst:0:f32+wei:2:f32 attr-zero-points:src0:0:s32 attr-post-ops:eltwise_relu,,2x4:4x4,0.115234 // post-gelu fusion onednn_verbose,primitive,exec,gpu:0,matmul,jit:gemm:any,undef,src_s8::blocked:ab::f0 wei_s8::blocked:ab::f0 dst_f32::blocked:ab::f0,attr-scratchpad:user attr-scales:src0:0:f32+dst:0:f32+wei:2:f32 attr-zero-points:src0:0:s32 attr-post-ops:eltwise_gelu_tanh,,2x4:4x4,0.170898 ```` Pull Request resolved: #133307 Approved by: https://github.com/liangan1, https://github.com/guangyey, https://github.com/EikanWang, https://github.com/jerryzh168 Co-authored-by: guangyey <guangye.yu@intel.com>
1 parent bb8c4ec commit 59915b8

File tree

9 files changed

+587
-38
lines changed

9 files changed

+587
-38
lines changed
Lines changed: 318 additions & 0 deletions
< 10000 td data-grid-cell-id="diff-47972c1ed4d0fe0a5cd26a6d22cf4a5087a3d5621d1ad55f149c5c552dfc1d4e-empty-138-2" data-line-anchor="diff-47972c1ed4d0fe0a5cd26a6d22cf4a5087a3d5621d1ad55f149c5c552dfc1d4eR138" data-selected="false" role="gridcell" style="background-color:var(--diffBlob-additionLine-bgColor, var(--diffBlob-addition-bgColor-line));padding-right:24px" tabindex="-1" valign="top" class="focusable-grid-cell diff-text-cell right-side-diff-cell left-side">+
int64_t k = m1.size(-1);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
#include <ATen/Tensor.h>
2+
#include <ATen/core/Tensor.h>
3+
#include <c10/core/ScalarType.h>
4+
5+
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
6+
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
7+
8+
#include <oneapi/dnnl/dnnl.hpp>
9+
10+
namespace at::native::onednn {
11+
12+
at::Tensor broadcast_bias2D(
13+
at::Tensor& dst,
14+
at::Tensor& bias,
15+
int64_t m,
16+
int64_t n) {
17+
switch (bias.dim()) {
18+
case 1:
19+
TORCH_CHECK(
20+
bias.size(0) == n || bias.size(0) == 1,
21+
"matmul supports [n] or [1] when bias dim is 1, but b.size() is:",
22+
bias.size(0));
23+
break;
24+
case 2:
25+
if ((bias.size(0) == m && bias.size(1) == n) ||
26+
(bias.size(0) == m && bias.size(1) == 1) ||
27+
(bias.size(0) == m && bias.size(1) == 1))
28+
return bias; // No need to broadcast
29+
TORCH_CHECK(
30+
bias.size(0) == 1 && bias.size(1) == 1,
31+
"matmul supports [m, n] or [1, n] or [m, 1] or [1, 1] when bias dim is 2 ...")
32+
break;
33+
case 0:
34+
TORCH_CHECK(
35+
bias.numel() == 1, "matmul supports 1 numel when bias dim is [] ...");
36+
break;
37+
default:
38+
TORCH_CHECK(0, "unsupported bias dim in matmul ...");
39+
}
40+
bias = bias.expand({1, n}).contiguous();
41+
return bias;
42+
}
43+
44+
at::Tensor broadcast_bias3D(
45+
at::Tensor& dst,
46+
at::Tensor bias,
47+
int64_t mb,
48+
int64_t m,
49+
int64_t n) {
50+
switch (bias.dim()) {
51+
case 1:
52+
TORCH_CHECK(
53+
bias.size(0) == n || bias.size(0) == 1,
54+
"matmul supports [n] or [1] when bias dim is 1, but b.size() is:",
55+
bias.size(0));
56+
break;
57+
case 3:
58+
TORCH_CHECK(
59+
are_expandable({mb, m, n}, bias.sizes()),
60+
"matmul bias must be expandable to:",
61+
dst.sizes(),
62+
" but got:",
63+
bias.sizes());
64+
break;
65+
case 0:
66+
TORCH_CHECK(
67+
bias.numel() == 1, "matmul supports 1 numel when bias dim is [] ...");
68+
break;
69+
default:
70+
TORCH_CHECK(0, "unsupported bias dim in matmul ...");
71+
}
72+
bias = bias.expand({mb, m, n}).contiguous();
73+
return bias;
74+
}
75+
76+
at::Tensor broadcast_bias(
77+
at::Tensor& dst,
78+
at::Tensor bias,
79+
int64_t mb,
80+
int64_t m,
81+
int64_t n) {
82+
if (dst.dim() == 2) {
83+
return broadcast_bias2D(dst, bias, m, n);
84+
} else {
85+
return broadcast_bias3D(dst, bias, mb, m, n);
86+
}
87+
}
88+
89+
void quantized_matmul(
90+
at::Tensor mat1, // act
91+
double input_scale,
92+
int64_t input_zero_point,
93+
at::Tensor mat2, // weight
94+
at::Tensor& weight_scales,
95+
at::Tensor& weight_zero_points,
96+
at::Tensor& bias,
97+
at::Tensor result, // output
98+
double output_scale,
99+
int64_t output_zero_point,
100+
std::optional<c10::ScalarType> output_dtype,
101+
std::optional<at::Tensor> other, // extra input for binary-post-op
102+
double other_scale,
103+
int64_t other_zero_point,
104+
const c10::string_view& binary_post_op,
105+
double binary_alpha,
106+
const c10::string_view& unary_post_op,
107+
torch::List<std::optional<at::Scalar>>& unary_post_op_args,
108+
c10::string_view unary_post_op_algorithm,
109+
bool m2_trans) {
110+
// [Note] Quantized Matrix Multiplication at XPU
111+
// The following code integrates oneDNN quantized gemm. The quantization
112+
// config we support:
113+
// activation: s8&u8; per tensor calibrated; symmetric&asymmetric
114+
// weight: s8; per_tensor/per_channel calibrated; symmetric
115+
auto attr = Attr(1.0 / output_scale, output_zero_point);
116+
construct_attr_by_post_op(
117+
binary_post_op,
118+
binary_alpha,
119+
input_scale,
120+
input_zero_point,
121+
unary_post_op,
122+
unary_post_op_args,
123+
unary_post_op_algorithm,
124+
attr);
125+
126+
size_t dims = result.dim();
127+
at::Device cur_device = at::Device(at::kXPU, c10::xpu::current_device());
128+
auto engine = GpuEngineManager::Instance().get_engine(cur_device);
129+
auto stream = GpuStreamManager::Instance().get_stream();
130+
131+
at::Tensor m1 = is_onednn_matmul_strides(mat1) ? mat1 : mat1.contiguous();
132+
at::Tensor m2 = is_onednn_matmul_strides(mat2) ? mat2 : mat2.contiguous();
133+
at::Tensor dst =
134+
is_onednn_matmul_strides(result, true) ? result : result.contiguous();
135+
136+
int64_t m = dst.size(-2);
137+
int64_t n = dst.size(-1);
138
139+
int64_t mb = 1;
140+
141+
if (dims == 3) {
142+
mb = dst.size(0);
143+
TORCH_CHECK(
144+
mb == m1.size(0) && mb == m2.size(0),
145+
"batch size mismatch, dst mb: ",
146+
mb,
147+
"m1 mb",
148+
m1.size(0),
149+
" m2 mb: ",
150+
m2.size(0));
151+
}
152+
153+
bool with_bias = false;
154+
at::Tensor b = bias;
155+
if (b.defined()) {
156+
with_bias = true;
157+
b = broadcast_bias(dst, b, mb, m, n);
158+
}
159+
// bias is fused in post-op for quantized path
160+
b = b.contiguous(); // avoid reorder 2 times
161+
162+
auto m1_usr_dt = get_onednn_dtype(m1);
163+
auto m2_usr_dt = get_onednn_dtype(m2);
164+
auto dst_usr_dt = get_onednn_dtype(dst);
165+
166+
auto m1_dt = m1_usr_dt;
167+
auto m2_dt = m2_usr_dt;
168+
auto dst_dt = dst_usr_dt;
169+
dnnl::memory::data_type bias_dt;
170+
171+
dnnl::memory::desc m1_md, m1_usr_md;
172+
dnnl::memory::desc m2_md, m2_usr_md;
173+
dnnl::memory::desc dst_md, dst_usr_md;
174+
dnnl::memory::desc b_md;
175+
176+
dnnl::memory::dims m1_dims, m2_dims, dst_dims, bias_dims;
177+
dnnl::memory::dims m1_strides, m2_strides, dst_strides, bias_strides;
178+
if (dims == 2) {
179+
m1_dims = {m, k};
180+
m2_dims = {k, n}; // (n, 1) (1, n)
181+
dst_dims = {m, n};
182+
183+
m1_strides = {m1.stride(0), m1.stride(1)};
184+
if (m2_trans) {
185+
m2_strides = {m2.stride(0), m2.stride(1)};
186+
} else {
187+
m2_strides = {m2.stride(1), m2.stride(0)};
188+
}
189+
dst_strides = {dst.stride(0), dst.stride(1)};
190+
} else {
191+
m1_dims = {mb, m, k};
192+
m2_dims = {mb, k, n};
193+
dst_dims = {mb, m, n};
194+
195+
m1_strides = {m1.stride(0), m1.stride(1), m1.stride(2)};
196+
if (m2_trans) {
197+
m2_strides = {m2.stride(0), m2.stride(1), m2.stride(2)};
198+
} else {
199+
m2_strides = {m2.stride(0), m2.stride(2), m2.stride(1)};
200+
}
201+
dst_strides = {dst.stride(0), dst.stride(1), dst.stride(2)};
202+
}
203+
204+
if (with_bias) {
205+
bias_dims = get_onednn_dims(b);
206+
bias_dt = get_onednn_dtype(b);
207+
bias_strides = get_onednn_strides(b);
208+
}
209+
210+
std::unordered_map<int, dnnl::memory> args;
211+
212+
dnnl::post_ops po;
213+
po = attr.extract_post_ops(
214+
dst,
215+
true,
216+
dst.scalar_type() == at::kByte || dst.scalar_type() == at::kChar);
217+
bool m1_need_zp = (input_zero_point != 0);
218+
bool wgh_is_per_channel = weight_scales.numel() > 1;
219+
220+
dnnl::matmul matmul_p;
221+
dnnl::matmul::primitive_desc matmul_pd;
222+
223+
m1_md = dnnl::memory::desc(m1_dims, m1_dt, m1_strides);
224+
m2_md = dnnl::memory::desc(m2_dims, m2_dt, m2_strides);
225+
dst_md = dnnl::memory::desc(dst_dims, dst_dt, dst_strides);
226+
dnnl::primitive_attr pattr;
227+
pattr.set_post_ops(po);
228+
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
229+
230+
at::Tensor m2_sc;
231+
if (!wgh_is_per_channel) {
232+
pattr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
233+
} else {
234+
pattr.set_scales_mask(DNNL_ARG_WEIGHTS, 1 << 1);
235+
}
236+
237+
at::Tensor m1_sc;
238+
dnnl::memory::desc m1_sc_md = dnnl::memory::desc(
239+
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
240+
int mask_ac = 0;
241+
pattr.set_scales_mask(DNNL_ARG_SRC, mask_ac);
242+
if (m1_need_zp) {
243+
pattr.set_zero_points_mask(DNNL_ARG_SRC, mask_ac);
244+
}
245+
246+
if (with_bias) {
247+
b_md = dnnl::memory::desc(bias_dims, bias_dt, bias_strides);
248+
matmul_pd =
249+
dnnl::matmul::primitive_desc(engine, m1_md, m2_md, b_md, dst_md, pattr);
250+
} else {
251+
matmul_pd =
252+
dnnl::matmul::primitive_desc(engine, m1_md, m2_md, dst_md, pattr);
253+
}
254+
255+
matmul_p = dnnl::matmul(matmul_pd);
256+
257+
m1_usr_md = dnnl::memory::desc(m1_dims, m1_usr_dt, m1_strides);
258+
m2_usr_md = dnnl::memory::desc(m2_dims, m2_usr_dt, m2_strides);
259+
dst_usr_md = dnnl::memory::desc(dst_dims, dst_usr_dt, dst_strides);
260+
261+
auto m1_usr_m = make_onednn_memory(m1_usr_md, engine, m1.data_ptr());
262+
auto m2_usr_m = make_onednn_memory(m2_usr_md, engine, m2.data_ptr());
263+
auto dst_usr_m = make_onednn_memory(dst_usr_md, engine, dst.data_ptr());
264+
265+
auto expected_m1_md = matmul_pd.src_desc();
266+
auto expected_m2_md = matmul_pd.weights_desc();
267+
auto expected_dst_md = matmul_pd.dst_desc();
268+
269+
dnnl::memory m1_m = m1_usr_m, m2_m = m2_usr_m, dst_m = dst_usr_m;
270+
at::Tensor m1_, m2_, dst_;
271+
272+
int scratchpad_size = matmul_pd.scratchpad_desc().get_size();
273+
at::Tensor scratchpad_tensor =
274+
at::empty({scratchpad_size}, m1.options().dtype(at::kByte), c10::nullopt);
275+
auto scratchpad_memory = make_onednn_memory(
276+
matmul_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr());
277+
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_memory});
278+
279+
if (attr.with_binary())
280+
attr.construct_post_binary(matmul_pd, args);
281+
282+
args.insert({DNNL_ARG_SRC, m1_m});
283+
args.insert({DNNL_ARG_WEIGHTS, m2_m});
284+
args.insert({DNNL_ARG_DST, dst_m});
285+
if (b.defined()) {
286+
auto b_m = make_onednn_memory(b_md, engine, b.data_ptr());
287+
args.insert({DNNL_ARG_BIAS, b_m});
288+
}
289+
290+
// Add scale/zp md
291+
weight_scales = weight_scales.to(at::kFloat);
292+
dnnl::memory m2_sc_m, m2_zp_m;
293+
dnnl::memory::desc m2_sc_md = dnnl::memory::desc(
294+
get_onednn_dims(weight_scales),
295+
dnnl::memory::data_type::f32,
296+
dnnl::memory::format_tag::x);
297+
m2_sc_m = make_onednn_memory(m2_sc_md, engine, weight_scales.data_ptr());
298+
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, m2_sc_m});
299+
300+
dnnl::memory m1_sc_m, m1_zp_m;
301+
Tensor m1_sc_tensor, m1_zp_tensor;
302+
m1_sc_m = dnnl_memory_from_host_scalar(
303+
static_cast<float>(input_scale), m1_sc_tensor, engine);
304+
m1_zp_m = dnnl_memory_from_host_scalar(
305+
static_cast<int32_t>(input_zero_point), m1_zp_tensor, engine);
306+
307+
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, m1_sc_m});
308+
if (m1_need_zp) {
309+
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, m1_zp_m});
310+
}
311+
312+
auto qmatmul_event = dnnl::sycl_interop::execute(matmul_p, stream, args);
313+
314+
if (!dst.is_same(result))
315+
result.copy_(dst);
316+
}
317+
318+
} // namespace at::native::onednn

aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,4 +133,26 @@ at::Tensor quantized_convolution(
133133
torch::List<std::optional<at::Scalar>> unary_scalars,
134134
std::optional<std::string_view> unary_algorithm);
135135

136+
void quantized_matmul(
137+
at::Tensor mat1, // act
138+
double input_scale,
139+
int64_t input_zero_point,
140+
at::Tensor mat2, // weight
141+
at::Tensor& weight_scales,
142+
at::Tensor& weight_zero_points,
143+
at::Tensor& b_raw,
144+
at::Tensor result, // output
145+
double output_scale,
146+
int64_t output_zero_point,
147+
std::optional<c10::ScalarType> output_dtype,
148+
std::optional<at::Tensor> other, // extra input for binary-post-op
149+
double other_scale,
150+
int64_t other_zero_point,
151+
const c10::string_view& binary_post_op,
152+
double binary_alpha,
153+
const c10::string_view& unary_post_op,
154+
torch::List<std::optional<at::Scalar>>& unary_post_op_args,
155+
c10::string_view unary_post_op_algorithm,
156+
bool m2_trnas);
157+
136158
} // namespace at::native::onednn

0 commit comments

Comments
 (0)
0