8000 [Intel GPU] qlinear at XPU backend · pytorch/pytorch@60389e8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 60389e8

Browse files
committed
[Intel GPU] qlinear at XPU backend
ghstack-source-id: e5eda45 Pull Request resolved: #133307
1 parent eff349b commit 60389e8

File tree

4 files changed

+355
-1
lines changed

4 files changed

+355
-1
lines changed
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
2+
#include <ATen/Tensor.h>
3+
#include <ATen/core/Tensor.h>
4+
5+
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
6+
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
7+
8+
#include <oneapi/dnnl/dnnl.hpp>
9+
10+
namespace at::native::onednn {
11+
12+
void quantized_matmul_pt2(
13+
at::Tensor& result,
14+
const at::Tensor& mat1,
15+
const at::Tensor& mat2,
16+
const at::Tensor& b_raw,
17+
bool m2_trans,
18+
double input_scale,
19+
int64_t input_zero_point,
20+
at::Tensor& weight_scales,
21+
at::Tensor& weight_zero_points,
22+
double output_scale,
23+
int64_t output_zero_point,
24+
Attr attr) {
25+
size_t dims = result.dim();
26+
at::Device curDevice = at::Device(at::kXPU, c10::xpu::current_device());
27+
auto engine = GpuEngineManager::Instance().get_engine(curDevice);
28+
// engine index means the engine created on which device
29+
auto engine_index = curDevice.index();
30+
auto strm = GpuStreamManager::Instance().get_stream();
31+
32+
at::Tensor m1 = is_onednn_matmul_strides(mat1)
33+
? mat1
34+
: mat1.contiguous();
35+
at::Tensor m2 = is_onednn_matmul_strides(mat2)
36+
? mat2
37+
: mat2.contiguous();
38+
at::Tensor dst = is_onednn_matmul_strides(result, true)
39+
? result
40+
: result.contiguous();
41+
42+
int64_t m = dst.size(-2);
43+
int64_t n = dst.size(-1);
44+
int64_t k = m1.size(-1);
45+
int64_t mb = 1;
46+
47+
if (dims == 3) {
48+
mb = dst.size(0);
49+
TORCH_CHECK(
50+
mb == m1.size(0) && mb == m2.size(0),
51+
"batch size mismatch, dst mb: ",
52+
mb,
53+
"m1 mb",
54+
m1.size(0),
55+
" m2 mb: ",
56+
m2.size(0));
57+
}
58+
59+
bool with_bias = false;
60+
at::Tensor b = b_raw;
61+
if (b.defined()) {
62+
with_bias = true;
63+
if (b.dim() == 1) {
64+
TORCH_CHECK(
65+
b.size(0) == n || b.size(0) == 1,
66+
"matmul supports [n] or [1] when bias dim is 1 ...");
67+
if (b.size(0) == 0) {
68+
with_bias = false;
69+
} else if (m1.dim() == 3) {
70+
b = b.expand({mb, m, n}).contiguous();
71+
} else if (m1.dim() == 2) {
72+
b = b.expand({1, n}).contiguous();
73+
}
74+
} else if (b.dim() == 2) {
75+
TORCH_CHECK(
76+
(b.size(0) == m && b.size(1) == n) ||
77+
(b.size(0) == 1 && b.size(1) == n) ||
78+
(b.size(0) == m && b.size(1) == 1) ||
79+
(b.size(0) == 1 && b.size(1) == 1),
80+
"matmul supports [m, n] or [1, n] or [m, 1] or [1, 1] when bias dim is 2 ...");
81+
if (b.size(0) == 1 && b.size(1) == 1)
82+
b = b.expand({1, n}).contiguous();
83+
} else if (b.dim() == 3) {
84+
TORCH_CHECK(
85+
are_expandable({mb, m, n}, b.sizes()),
86+
"matmul bias must be expandable to:",
87+
dst.sizes(),
88+
" but got:",
89+
b.sizes());
90+
b = b.expand({mb, m, n}).contiguous();
91+
} else if (b.dim() == 0) {
92+
TORCH_CHECK(
93+
b.numel() == 1, "matmul supports 1 numel when bias dim is [] ...");
94+
if (m1.dim() == 3) {
95+
b = b.expand({mb, m, n}).contiguous();
96+
} else {
97+
b = b.expand({1, n}).contiguous();
98+
}
99+
} else {
100+
TORCH_CHECK(0, "unsupported bias dim in matmul ...");
101+
}
102+
}
103+
104+
// bias is fused in post-op for quantized path
105+
b = b.contiguous(); // avoid reorder 2 times
106+
107+
// ipex matmul support both ab/ba shape for m2 at::Tensor, we don't check any more
108+
109+
auto m1_usr_dt = get_onednn_dtype(m1);
110+
auto m2_usr_dt = get_onednn_dtype(m2);
111+
auto dst_usr_dt = get_onednn_dtype(dst);
112+
113+
auto m1_dt = m1_usr_dt;
114+
auto m2_dt = m2_usr_dt;
115+
auto dst_dt = dst_usr_dt;
116+
dnnl::memory::data_type bias_dt;
117+
118+
dnnl::memory::desc m1_md, m1_usr_md, m1_any_md;
119+
dnnl::memory::desc m2_md, m2_usr_md, m2_any_md;
120+
dnnl::memory::desc dst_md, dst_usr_md, dst_any_md;
121+
dnnl::memory::desc b_md;
122+
123+
dnnl::memory::dims m1_dims, m2_dims, dst_dims, bias_dims;
124+
dnnl::memory::dims m1_strides, m2_strides, dst_strides, bias_strides;
125+
if (dims == 2) {
126+
m1_dims = {m, k};
127+
m2_dims = {k, n}; // (n, 1) (1, n)
128+
dst_dims = {m, n};
129+
130+
m1_strides = {m1.stride(0), m1.stride(1)};
131+
if (m2_trans) {
132+
m2_strides = {m2.stride(0), m2.stride(1)};
133+
} else {
134+
m2_strides = {m2.stride(1), m2.stride(0)};
135+
}
136+
dst_strides = {dst.stride(0), dst.stride(1)};
137+
} else {
138+
m1_dims = {mb, m, k};
139+
m2_dims = {mb, k, n};
140+
dst_dims = {mb, m, n};
141+
142+
m1_strides = {m1.stride(0), m1.stride(1), m1.stride(2)};
143+
if (m2_trans) {
144+
m2_strides = {m2.stride(0), m2.stride(1), m2.stride(2)};
145+
} else {
146+
m2_strides = {m2.stride(0), m2.stride(2), m2.stride(1)};
147+
}
148+
dst_strides = {dst.stride(0), dst.stride(1), dst.stride(2)};
149+
}
150+
151+
if (with_bias) {
152+
bias_dims = get_onednn_dims(b);
153+
bias_dt = get_onednn_dtype(b);
154+
bias_strides = get_onednn_strides(b);
155+
}
156+
157+
std::unordered_map<int, dnnl::memory> args;
158+
159+
dnnl::post_ops po;
160+
// attr.extract_post_ops(dst, true);
161+
bool m1_need_zp = (input_zero_point != 0);
162+
// wgh should never have zero point
163+
bool wgh_is_per_channel = weight_scales.numel() > 1;
164+
165+
// STEP3: create primitive
166+
dnnl::matmul matmul_p;
167+
dnnl::matmul::primitive_desc matmul_pd;
168+
169+
170+
m1_md = dnnl::memory::desc(m1_dims, m1_dt, m1_strides);
171+
m2_md = dnnl::memory::desc(m2_dims, m2_dt, m2_strides);
172+
dst_md = dnnl::memory::desc(dst_dims, dst_dt, dst_strides);
173+
// STEP2: creat attribute
174+
dnnl::primitive_attr pattr;
175+
pattr.set_post_ops(po);
176+
177+
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
178+
179+
at::Tensor m2_sc;
180+
if (!wgh_is_per_channel) {
181+
pattr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
182+
} else {
183+
pattr.set_scales_mask(DNNL_ARG_WEIGHTS, 1 << 1);
184+
}
185+
186+
at::Tensor m1_sc;
187+
dnnl::memory::desc m1_sc_md =
188+
dnnl::memory::desc({1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
189+
int mask_ac = 0;
190+
pattr.set_scales_mask(DNNL_ARG_SRC, mask_ac);
191+
if (m1_need_zp) {
192+
pattr.set_zero_points_mask(DNNL_ARG_SRC, mask_ac);
193+
}
194+
195+
if (with_bias) {
196+
b_md = dnnl::memory::desc(bias_dims, bias_dt, bias_strides);
197+
matmul_pd =
198+
dnnl::matmul::primitive_desc(engine, m1_md, m2_md, b_md, dst_md, pattr);
199+
} else {
200+
matmul_pd = dnnl::matmul::primitive_desc(engine, m1_md, m2_md, dst_md, pattr);
201+
}
202+
203+
matmul_p = dnnl::matmul(matmul_pd);
204+
205+
m1_usr_md = dnnl::memory::desc(m1_dims, m1_usr_dt, m1_strides);
206+
m2_usr_md = dnnl::memory::desc(m2_dims, m2_usr_dt, m2_strides);
207+
dst_usr_md = dnnl::memory::desc(dst_dims, dst_usr_dt, dst_strides);
208+
// STEP4: create memory
209+
auto m1_usr_m = make_onednn_memory(m1_usr_md, engine, m1.data_ptr());
210+
auto m2_usr_m = make_onednn_memory(m2_usr_md, engine, m2.data_ptr());
211+
auto dst_usr_m = make_onednn_memory(dst_usr_md, engine, dst.data_ptr());
212+
213+
auto expected_m1_md = matmul_pd.src_desc();
214+
auto expected_m2_md = matmul_pd.weights_desc();
215+
auto expected_dst_md = matmul_pd.dst_desc();
216+
217+
dnnl::memory m1_m = m1_usr_m, m2_m = m2_usr_m, dst_m = dst_usr_m;
218+
at::Tensor m1_, m2_, dst_;
219+
220+
int scratchpad_size = matmul_pd.scratchpad_desc().get_size();
221+
at::Tensor scratchpad_tensor = at::empty(
222+
{scratchpad_size}, m1.options().dtype(at::kByte), c10::nullopt);
223+
auto scratchpad_memory = make_onednn_memory(
224+
matmul_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr());
225+
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_memory});
226+
227+
// bias add for gen12hp platform
228+
if (attr.with_binary())
229+
attr.construct_post_binary(matmul_pd, args);
230+
231+
args.insert({DNNL_ARG_SRC, m1_m});
232+
args.insert({DNNL_ARG_WEIGHTS, m2_m});
233+
args.insert({DNNL_ARG_DST, dst_m});
234+
if (b.defined()) {
235+
auto b_m = make_onednn_memory(b_md, engine, b.data_ptr());
236+
args.insert({DNNL_ARG_BIAS, b_m});
237+
}
238+
239+
// Add scale/zp md
240+
dnnl::memory m2_sc_m, m2_zp_m;
241+
dnnl::memory::desc m2_sc_md =
242+
dnnl::memory::desc({1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
243+
m2_sc_m = make_onednn_memory(m2_sc_md, engine, weight_scales.data_ptr());
244+
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, m2_sc_m});
245+
246+
dnnl::memory m1_sc_m, m1_zp_m;
247+
Tensor m1_sc_tensor, m1_zp_tensor;
248+
m1_sc_m = dnnl_memory_from_host_scalar(static_cast<float>(input_scale), m1_sc_tensor, engine);
249+
m1_zp_m = dnnl_memory_from_host_scalar(static_cast<int32_t>(input_zero_point), m1_zp_tensor, engine);
250+
251+
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, m1_sc_m});
252+
if (m1_need_zp) {
253+
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, m1_zp_m});
254+
}
255+
256+
auto qmatmul_event = dnnl::sycl_interop::execute(matmul_p, strm, args);
257+
258+
if (!dst.is_same(result))
259+
result.copy_(dst);
260+
}
261+
262+
263+
}

aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,4 +133,18 @@ at::Tensor quantized_convolution_pt2(
133133
torch::List<c10::optional<at::Scalar>> unary_scalars,
134134
c10::optional<c10::string_view> unary_algorithm);
135135

136+
void quantized_matmul_pt2(
137+
at::Tensor& result,
138+
const at::Tensor& mat1,
139+
const at::Tensor& mat2,
140+
const at::Tensor& b_raw,
141+
bool m2_trans,
142+
double input_scale,
143+
int64_t input_zero_point,
144+
at::Tensor& weight_scales,
145+
at::Tensor& weight_zero_points,
146+
double output_scale,
147+
int64_t output_zero_point,
148+
Attr attr);
149+
136150
} // namespace at::native::onednn
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#include <torch/library.h>
2+
3+
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
4+
5+
using namespace at::native::onednn;
6+
7+
namespace at{
8+
namespace native{
9+
namespace xpu{
10+
// Operators for pt2.0
11+
Tensor q_linear_pointwise(
12+
Tensor act, // int8 cpu tensor
13+
double act_scale,
14+
int64_t act_zero_point,
15+
Tensor weight,
16+
Tensor weight_scales,
17+
Tensor weight_zero_points,
18+
c10::optional<Tensor> bias,
19+
double output_scale,
20+
int64_t output_zero_point,
21+
std::optional<c10::ScalarType> output_dtype,
22+
c10::string_view post_op_name,
23+
torch::List<std::optional<at::Scalar>> post_op_args,
24+
c10::string_view post_op_algorithm) {
25+
Tensor b_raw = bias.has_value() ? bias.value() : at::Tensor();
26+
27+
const int64_t dim = act.dim();
28+
int64_t K = act.size(dim - 1);
29+
int64_t M = act.numel() / K;
30+
// [M, K] x [K, N]
31+
int64_t N = weight.size(1);
32+
33+
std::vector<int64_t> src_dims = {M, K};
34+
std::vector<int64_t> dst_dims = {M, N};
35+
Tensor qout = at::empty(dst_dims, device(c10::kXPU).dtype(c10::kByte));
36+
37+
Attr attr = Attr();
38+
39+
quantized_matmul_pt2(
40+
qout,
41+
act,
42+
weight,
43+
b_raw,
44+
/*m2_trans=*/false,
45+
act_scale,
46+
act_zero_point,
47+
weight_scales,
48+
weight_zero_points,
49+
output_scale,
50+
output_zero_point,
51+
attr);
52+
53+
return qout;
54+
}
55+
56+
57+
at::Tensor q_linear_prepack_onednn(
58+
at::Tensor weight,
59+
c10::optional<torch::List<int64_t>> input_shape) {
60+
return weight;
61+
}
62+
63+
64+
TORCH_LIBRARY_IMPL(onednn, XPU, m) {
65+
m.impl(
66+
TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise"),
67+
TORCH_FN(q_linear_pointwise));
68+
m.impl(
69+
TORCH_SELECTIVE_NAME("onednn::qlinear_prepack"),
70+
TORCH_FN(q_linear_prepack_onednn));
71+
}
72+
73+
74+
}
75+
}
76+
}

torch/_inductor/mkldnn_ir.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ def _prepare_linear_fusion_create(
206206
req_stride_order = list(reversed(range(len(x.get_size()))))
207207

208208
x = cls.require_stride_order(x, req_stride_order)
209-
assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
209+
assert x.get_device().type in ["cpu", "xpu"] and weight.get_device().type in ["cpu", "xpu"]
210+
assert x.get_device().type == weight.get_device().type
210211
inputs = [x, weight]
211212

212213
output_stride = FlexibleLayout.contiguous_strides(output_size)

0 commit comments

Comments
 (0)
0