@@ -107,6 +107,25 @@ static bool use_mkldnn_bf32_matmul() {
107
107
return use_mkldnn_bf16_matmul () && at::globalContext ().float32MatmulPrecision () == at::Float32MatmulPrecision::MEDIUM;
108
108
}
109
109
110
+ // returns an ideep::tensor
111
+ // - dims: shape e.g: {M,N}
112
+ // - idtype: ideep data type e.g: (f32, bf16, f16)
113
+ // - strides: Memory layout
114
+ // - data: data pointer
115
+ template <typename scalar_t >
116
+ inline ideep::tensor make_ideep_tensor (
117
+ std::vector<int64_t > dims,
118
+ ideep::tensor::data_type idtype,
119
+ ideep::tensor::dims& strides,
120
+ scalar_t *data){
121
+ ideep::tensor res ({
122
+ dims,
123
+ idtype,
124
+ strides
125
+ },
126
+ data);
127
+ return res;
128
+ }
110
129
111
130
template <typename scalar_t >
112
131
static inline typename std::enable_if_t <
@@ -155,35 +174,74 @@ mkldnn_gemm(
155
174
idtype = ideep::tensor::data_type::f32;
156
175
}
157
176
158
- ideep::tensor a ({
159
- /* sizes=*/ {k, m},
160
- idtype,
161
- /* strides=*/ a_strides},
162
- const_cast <scalar_t *>(a_data));
163
- ideep::tensor b ({
164
- /* sizes=*/ {n, k},
165
- idtype,
166
- /* strides=*/ b_strides},
167
- const_cast <scalar_t *>(b_data));
168
- ideep::tensor c ({
169
- /* sizes=*/ {n, m},
170
- idtype,
171
- /* strides=*/ c_strides},
172
- c_data);
177
+ ideep::tensor a = make_ideep_tensor<scalar_t >({k, m}, idtype, a_strides, const_cast <scalar_t *>(a_data));
178
+ ideep::tensor b = make_ideep_tensor<scalar_t >({n, k}, idtype, b_strides, const_cast <scalar_t *>(b_data));
179
+ ideep::tensor c = make_ideep_tensor<scalar_t >({n, m}, idtype, c_strides, c_data);
173
180
174
181
ideep::matmul_forward::compute (
175
182
b, a, c, alpha, beta,
176
183
ideep::scale_t (), ideep::scale_t (), ideep::scale_t (), op_attr);
177
184
178
185
if (c.get_data_handle () != c_data){
186
+ // ideep will query oneDNN expect format of output
187
+ // if given output format is not expected, ideep will re-init an output buffer
188
+ // under this case, we need copy the re-inited buffer back to given buffer
189
+ ideep::tensor real_output = make_ideep_tensor<scalar_t >({n,m}, idtype, c_strides, c_data);
190
+ c.reorder_to (real_output);
191
+ }
192
+ return true ;
193
+ }
194
+
195
+ template <typename scalar_t >
196
+ inline typename std::enable_if_t <
197
+ std::is_same_v<scalar_t , c10::BFloat16>,
198
+ bool >
199
+ mkldnn_gemm (
200
+ TransposeType transa, TransposeType transb,
201
+ int64_t m, int64_t n, int64_t k,
202
+ float alpha,
203
+ const scalar_t *a_data, int64_t lda,
204
+ const scalar_t *b_data, int64_t ldb,
205
+ float beta,
206
+ float * c_data, int64_t ldc) {
207
+ // introduce heuristic to validate dispatch to MKLDNN
208
+ // (m * n * k <= 16 * 16 * 16)
209
+ bool bf16_usable = use_mkldnn_bf16_matmul ();
210
+ if (!bf16_usable) {
211
+ return false ;
212
+ }
213
+
214
+ ideep::attr_t op_attr;
215
+ // Use mkldnn post ops to perform the add.
216
+ if (beta != 0 .0f ) {
217
+ op_attr = ideep::attr_t::fuse_sum ();
218
+ }
219
+
220
+ // NOTE: View as c-contiguous to avoid extra reordering in mkldnn
221
+ // Use identity: C = AB <=> C^T = B^T A^T
222
+ ideep::tensor::dims a_strides{{lda, 1 }}, b_strides{{ldb, 1 }}, c_strides{{ldc, 1 }};
223
+ if (transa != TransposeType::NoTranspose) {
224
+ std::swap (a_strides[0 ], a_strides[1 ]);
225
+ }
226
+ if (transb != TransposeType::NoTranspose) {
227
+ std::swap (b_strides[0 ], b_strides[1 ]);
228
+ }
229
+
230
+ auto idtype = ideep::tensor::data_type::bf16;
231
+
232
+ ideep::tensor a = make_ideep_tensor<scalar_t >({k, m}, idtype, a_strides, const_cast <scalar_t *>(a_data));
233
+ ideep::tensor b = make_ideep_tensor<scalar_t >({n, k}, idtype, b_strides, const_cast <scalar_t *>(b_data));
234
+ ideep::tensor c = make_ideep_tensor<float >({n, m}, ideep::tensor::data_type::f32, c_strides, c_data);
235
+
236
+ ideep::matmul_forward::compute (
237
+ b, a, c, alpha, beta,
238
+ ideep::scale_t (), ideep::scale_t (), ideep::scale_t (), op_attr);
239
+
240
+ if (c.get_data_handle () != c_data){
179
241
// ideep will query onednn expect format of output
180
242
// if given output format is not expected, ideep will re-init an output buffer
181
243
// under this case, we need copy the re-inited buffer back to given buffer
182
- ideep::tensor real_output ({
183
- /* sizes=*/ {n, m},
184
- idtype,
185
- /* strides=*/ c_strides},
186
- c_data);
244
+ ideep::tensor real_output = make_ideep_tensor<float >({n,m}, idtype, c_strides, c_data);
187
245
c.reorder_to (real_output);
188
246
}
189
247
@@ -201,6 +259,17 @@ bool mkldnn_bf16_gemm(
201
259
return mkldnn_gemm<c10::BFloat16>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
202
260
}
203
261
262
+ bool mkldnn_bf16f32_gemm (
263
+ TransposeType transa, TransposeType transb,
264
+ int64_t m, int64_t n, int64_t k,
265
+ float alpha,
266
+ const c10::BFloat16 *a, int64_t lda,
267
+ const c10::BFloat16 *b, int64_t ldb,
268
+ float beta,
269
+ float *c, int64_t ldc) {
270
+ return mkldnn_gemm<c10::BFloat16>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
271
+ }
272
+
204
273
bool mkldnn_fp16_gemm (
205
274
TransposeType transa, TransposeType transb,
206
275
int64_t m, int64_t n, int64_t k,
0 commit comments