1
+
2
+ #include < ATen/Tensor.h>
3
+ #include < ATen/core/Tensor.h>
4
+
5
+ #include < ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
6
+ #include < ATen/native/mkldnn/xpu/detail/Attr.h>
7
+
8
+ #include < oneapi/dnnl/dnnl.hpp>
9
+
10
+ namespace at ::native::onednn {
11
+
12
+ void quantized_matmul_pt2 (
13
+ at::Tensor& result,
14
+ const at::Tensor& mat1,
15
+ const at::Tensor& mat2,
16
+ const at::Tensor& b_raw,
17
+ bool m2_trans,
18
+ double input_scale,
19
+ int64_t input_zero_point,
20
+ at::Tensor& weight_scales,
21
+ at::Tensor& weight_zero_points,
22
+ double output_scale,
23
+ int64_t output_zero_point,
24
+ Attr attr) {
25
+ size_t dims = result.dim ();
26
+ at::Device curDevice = at::Device (at::kXPU , c10::xpu::current_device ());
27
+ auto engine = GpuEngineManager::Instance ().get_engine (curDevice);
28
+ // engine index means the engine created on which device
29
+ auto engine_index = curDevice.index ();
30
+ auto strm = GpuStreamManager::Instance ().get_stream ();
31
+
32
+ at::Tensor m1 = is_onednn_matmul_strides (mat1)
33
+ ? mat1
34
+ : mat1.contiguous ();
35
+ at::Tensor m2 = is_onednn_matmul_strides (mat2)
36
+ ? mat2
37
+ : mat2.contiguous ();
38
+ at::Tensor dst = is_onednn_matmul_strides (result, true )
39
+ ? result
40
+ : result.contiguous ();
41
+
42
+ int64_t m = dst.size (-2 );
43
+ int64_t n = dst.size (-1 );
44
+ int64_t k = m1.size (-1 );
45
+ int64_t mb = 1 ;
46
+
47
+ if (dims == 3 ) {
48
+ mb = dst.size (0 );
49
+ TORCH_CHECK (
50
+ mb == m1.size (0 ) && mb == m2.size (0 ),
51
+ " batch size mismatch, dst mb: " ,
52
+ mb,
53
+ " m1 mb" ,
54
+ m1.size (0 ),
55
+ " m2 mb: " ,
56
+ m2.size (0 ));
57
+ }
58
+
59
+ bool with_bias = false ;
60
+ at::Tensor b = b_raw;
61
+ if (b.defined ()) {
62
+ with_bias = true ;
63
+ if (b.dim () == 1 ) {
64
+ TORCH_CHECK (
65
+ b.size (0 ) == n || b.size (0 ) == 1 ,
66
+ " matmul supports [n] or [1] when bias dim is 1 ..." );
67
+ if (b.size (0 ) == 0 ) {
68
+ with_bias = false ;
69
+ } else if (m1.dim () == 3 ) {
70
+ b = b.expand ({mb, m, n}).contiguous ();
71
+ } else if (m1.dim () == 2 ) {
72
+ b = b.expand ({1 , n}).contiguous ();
73
+ }
74
+ } else if (b.dim () == 2 ) {
75
+ TORCH_CHECK (
76
+ (b.size (0 ) == m && b.size (1 ) == n) ||
77
+ (b.size (0 ) == 1 && b.size (1 ) == n) ||
78
+ (b.size (0 ) == m && b.size (1 ) == 1 ) ||
79
+ (b.size (0 ) == 1 && b.size (1 ) == 1 ),
80
+ " matmul supports [m, n] or [1, n] or [m, 1] or [1, 1] when bias dim is 2 ..." );
81
+ if (b.size (0 ) == 1 && b.size (1 ) == 1 )
82
+ b = b.expand ({1 , n}).contiguous ();
83
+ } else if (b.dim () == 3 ) {
84
+ TORCH_CHECK (
85
+ are_expandable ({mb, m, n}, b.sizes ()),
86
+ " matmul bias must be expandable to:" ,
87
+ dst.sizes (),
88
+ " but got:" ,
89
+ b.sizes ());
90
+ b = b.expand ({mb, m, n}).contiguous ();
91
+ } else if (b.dim () == 0 ) {
92
+ TORCH_CHECK (
93
+ b.numel () == 1 , " matmul supports 1 numel when bias dim is [] ..." );
94
+ if (m1.dim () == 3 ) {
95
+ b = b.expand ({mb, m, n}).contiguous ();
96
+ } else {
97
+ b = b.expand ({1 , n}).contiguous ();
98
+ }
99
+ } else {
100
+ TORCH_CHECK (0 , " unsupported bias dim in matmul ..." );
101
+ }
102
+ }
103
+
104
+ // bias is fused in post-op for quantized path
105
+ b = b.contiguous (); // avoid reorder 2 times
106
+
107
+ // ipex matmul support both ab/ba shape for m2 at::Tensor, we don't check any more
108
+
109
+ auto m1_usr_dt = get_onednn_dtype (m1);
110
+ auto m2_usr_dt = get_onednn_dtype (m2);
111
+ auto dst_usr_dt = get_onednn_dtype (dst);
112
+
113
+ auto m1_dt = m1_usr_dt;
114
+ auto m2_dt = m2_usr_dt;
115
+ auto dst_dt = dst_usr_dt;
116
+ dnnl::memory::data_type bias_dt;
117
+
118
+ dnnl::memory::desc m1_md, m1_usr_md, m1_any_md;
119
+ dnnl::memory::desc m2_md, m2_usr_md, m2_any_md;
120
+ dnnl::memory::desc dst_md, dst_usr_md, dst_any_md;
121
+ dnnl::memory::desc b_md;
122
+
123
+ dnnl::memory::dims m1_dims, m2_dims, dst_dims, bias_dims;
124
+ dnnl::memory::dims m1_strides, m2_strides, dst_strides, bias_strides;
125
+ if (dims == 2 ) {
126
+ m1_dims = {m, k};
127
+ m2_dims = {k, n}; // (n, 1) (1, n)
128
+ dst_dims = {m, n};
129
+
130
+ m1_strides = {m1.stride (0 ), m1.stride (1 )};
131
+ if (m2_trans) {
132
+ m2_strides = {m2.stride (0 ), m2.stride (1 )};
133
+ } else {
134
+ m2_strides = {m2.stride (1 ), m2.stride (0 )};
135
+ }
136
+ dst_strides = {dst.stride (0 ), dst.stride (1 )};
137
+ } else {
138
+ m1_dims = {mb, m, k};
139
+ m2_dims = {mb, k, n};
140
+ dst_dims = {mb, m, n};
141
+
142
+ m1_strides = {m1.stride (0 ), m1.stride (1 ), m1.stride (2 )};
143
+ if (m2_trans) {
144
+ m2_strides = {m2.stride (0 ), m2.stride (1 ), m2.stride (2 )};
145
+ } else {
146
+ m2_strides = {m2.stride (0 ), m2.stride (2 ), m2.stride (1 )};
147
+ }
148
+ dst_strides = {dst.stride (0 ), dst.stride (1 ), dst.stride (2 )};
149
+ }
150
+
151
+ if (with_bias) {
152
+ bias_dims = get_onednn_dims (b);
153
+ bias_dt = get_onednn_dtype (b);
154
+ bias_strides = get_onednn_strides (b);
155
+ }
156
+
157
+ std::unordered_map<int , dnnl::memory> args;
158
+
159
+ dnnl::post_ops po;
160
+ // attr.extract_post_ops(dst, true);
161
+ bool m1_need_zp = (input_zero_point != 0 );
162
+ // wgh should never have zero point
163
+ bool wgh_is_per_channel = weight_scales.numel () > 1 ;
164
+
165
+ // STEP3: create primitive
166
+ dnnl::matmul matmul_p;
167
+ dnnl::matmul::primitive_desc matmul_pd;
168
+
169
+
170
+ m1_md = dnnl::memory::desc (m1_dims, m1_dt, m1_strides);
171
+ m2_md = dnnl::memory::desc (m2_dims, m2_dt, m2_strides);
172
+ dst_md = dnnl::memory::desc (dst_dims, dst_dt, dst_strides);
173
+ // STEP2: creat attribute
174
+ dnnl::primitive_attr pattr;
175
+ pattr.set_post_ops (po);
176
+
177
+ pattr.set_scratchpad_mode (dnnl::scratchpad_mode::user);
178
+
179
+ at::Tensor m2_sc;
180
+ if (!wgh_is_per_channel) {
181
+ pattr.set_scales_mask (DNNL_ARG_WEIGHTS, 0 );
182
+ } else {
183
+ pattr.set_scales_mask (DNNL_ARG_WEIGHTS, 1 << 1 );
184
+ }
185
+
186
+ at::Tensor m1_sc;
187
+ dnnl::memory::desc m1_sc_md =
188
+ dnnl::memory::desc ({1 }, dnnl::memory::data_type::f32 , dnnl::memory::format_tag::x);
189
+ int mask_ac = 0 ;
190
+ pattr.set_scales_mask (DNNL_ARG_SRC, mask_ac);
191
+ if (m1_need_zp) {
192
+ pattr.set_zero_points_mask (DNNL_ARG_SRC, mask_ac);
193
+ }
194
+
195
+ if (with_bias) {
196
+ b_md = dnnl::memory::desc (bias_dims, bias_dt, bias_strides);
197
+ matmul_pd =
198
+ dnnl::matmul::primitive_desc (engine, m1_md, m2_md, b_md, dst_md, pattr);
199
+ } else {
200
+ matmul_pd = dnnl::matmul::primitive_desc (engine, m1_md, m2_md, dst_md, pattr);
201
+ }
202
+
203
+ matmul_p = dnnl::matmul (matmul_pd);
204
+
205
+ m1_usr_md = dnnl::memory::desc (m1_dims, m1_usr_dt, m1_strides);
206
+ m2_usr_md = dnnl::memory::desc (m2_dims, m2_usr_dt, m2_strides);
207
+ dst_usr_md = dnnl::memory::desc (dst_dims, dst_usr_dt, dst_strides);
208
+ // STEP4: create memory
209
+ auto m1_usr_m = make_onednn_memory (m1_usr_md, engine, m1.data_ptr ());
210
+ auto m2_usr_m = make_onednn_memory (m2_usr_md, engine, m2.data_ptr ());
211
+ auto dst_usr_m = make_onednn_memory (dst_usr_md, engine, dst.data_ptr ());
212
+
213
+ auto expected_m1_md = matmul_pd.src_desc ();
214
+ auto expected_m2_md = matmul_pd.weights_desc ();
215
+ auto expected_dst_md = matmul_pd.dst_desc ();
216
+
217
+ dnnl::memory m1_m = m1_usr_m, m2_m = m2_usr_m, dst_m = dst_usr_m;
218
+ at::Tensor m1_, m2_, dst_;
219
+
220
+ int scratchpad_size = matmul_pd.scratchpad_desc ().get_size ();
221
+ at::Tensor scratchpad_tensor = at::empty (
222
+ {scratchpad_size}, m1.options ().dtype (at::kByte ), c10::nullopt);
223
+ auto scratchpad_memory = make_onednn_memory (
224
+ matmul_pd.scratchpad_desc (), engine, scratchpad_tensor.data_ptr ());
225
+ args.insert ({DNNL_ARG_SCRATCHPAD, scratchpad_memory});
226
+
227
+ // bias add for gen12hp platform
228
+ if (attr.with_binary ())
229
+ attr.construct_post_binary (matmul_pd, args);
230
+
231
+ args.insert ({DNNL_ARG_SRC, m1_m});
232
+ args.insert ({DNNL_ARG_WEIGHTS, m2_m});
233
+ args.insert ({DNNL_ARG_DST, dst_m});
234
+ if (b.defined ()) {
235
+ auto b_m = make_onednn_memory (b_md, engine, b.data_ptr ());
236
+ args.insert ({DNNL_ARG_BIAS, b_m});
237
+ }
238
+
239
+ // Add scale/zp md
240
+ dnnl::memory m2_sc_m, m2_zp_m;
241
+ dnnl::memory::desc m2_sc_md =
242
+ dnnl::memory::desc ({1 }, dnnl::memory::data_type::f32 , dnnl::memory::format_tag::x);
243
+ m2_sc_m = make_onednn_memory (m2_sc_md, engine, weight_scales.data_ptr ());
244
+ args.insert ({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, m2_sc_m});
245
+
246
+ dnnl::memory m1_sc_m, m1_zp_m;
247
+ Tensor m1_sc_tensor, m1_zp_tensor;
248
+ m1_sc_m = dnnl_memory_from_host_scalar (static_cast <float >(input_scale), m1_sc_tensor, engine);
249
+ m1_zp_m = dnnl_memory_from_host_scalar (static_cast <int32_t >(input_zero_point), m1_zp_tensor, engine);
250
+
251
+ args.insert ({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, m1_sc_m});
252
+ if (m1_need_zp) {
253
+ args.insert ({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, m1_zp_m});
254
+ }
255
+
256
+ auto qmatmul_event = dnnl::sycl_interop::execute (matmul_p, strm, args);
257
+
258
+ if (!dst.is_same (result))
259
+ result.copy_ (dst);
260
+ }
261
+
262
+
263
+ }
0 commit comments