10000 cpu: enable gemm-bf16f32 for SDPA BF16 (#140159) · pytorch/pytorch@cfee904 · GitHub
[go: up one dir, main page]

Skip to content

Commit cfee904

Browse files
aditew01pytorchmergebot
authored andcommitted
cpu: enable gemm-bf16f32 for SDPA BF16 (#140159)
This PR enables SDPA BF16: gemm:bf16f32 for aarch64. This will enable faster inference for models with attention layers for autocast mode (bf16). Benchmark results from [PyTorch CI HUD - branch](https://hud.pytorch.org/benchmark/huggingface/inductor_no_cudagraphs?dashboard=torchinductor&startTime=Fri%2C%2028%20Mar%202025%2021%3A26%3A20%20GMT&stopTime=Fri%2C%2004%20Apr%202025%2020%3A26%3A20%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cpu%20(aarch64)&lBranch=adi/gemm_bf16f32&lCommit=d5aeab452e4b1f0580a4636b15a604c77a02c57b&rBranch=main&rCommit=bc72420bcb37390af3fced885e019903e6e425bd) Overall Geometric mean speedup in HUD dashboard : for Huggingface: `[0.48x → 0.58x]` and for Blueberries: `[0.88x → 1.13x]` Benchmark numbers for `torch.nn.functional.scaled_dot_product_attention`on Neoverse™ V1. `batch_size = 1, num_attention_heads = 64, sequence_length = 512, attention_head_size = 128` `threads=16` <img width="319" alt="Screenshot 2024-12-20 at 16 23 22" src="https://github.com/user-attachments/assets/c863f97d-0761-4fb8-aa6c-fc67b22ac3f9" /> Script to benchmark & profile SDPA: import torch import torch.nn as nn import time import numpy as np from torch.profiler import profile, record_function, ProfilerActivity class SimpleAttentionModel(nn.Module): def __init__(self, query, key, value): super(SimpleAttentionModel, self).__init__() self.query = query self.key = key self.value = value def forward(self, attn_mask=None): torch.nn.functional.scaled_dot_product_attention( self.query, self.key, self.value, attn_mask=attn_mask) #batch_size = 1, num_attention_heads = 64, sequence_length = 512, hidden_size = 128 def bench_sdpa(batch_size = 1, num_attention_heads = 64, sequence_length = 512, query_sequence_length = 128 , hidden_size=128, precision=torch.float32): with torch.no_grad(): attention_head_size = int(hidden_size / num_attention_heads) query = torch.rand(size=(batch_size, num_attention_heads, query_sequence_length, attention_head_size), dtype=precision) key = torch.rand(size=(batch_size, num_attention_heads, sequence_length, attention_head_size), dtype=precision) value = torch.rand(size=(batch_size, num_attention_heads, sequence_length, attention_head_size), dtype=precision) model = SimpleAttentionModel(query, key, value) model.eval() for _ in range(10): model() times = [] n_iters = 100 for _ in range(n_iters): s = time.time_ns() model() times.append((time.time_ns() - s) / 1e3) min_times = np.min(times) mean_times = np.mean(times) print(f"Min Times = {min_times} us") print(f"Mean Times = {mean_times} us") print("Times = ", times) print("BF16 mode:") with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof: with record_function("model_inference"): bench_sdpa(precision=torch.bfloat16) profile_data = prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total") print(profile_data) Pull Request resolved: #140159 Approved by: https://github.com/jgong5, https://github.com/malfet, https://github.com/nikhil-arm, https://github.com/leslie-fang-intel, https://github.com/CaoE, https://github.com/cfRod, https://github.com/fadara01
1 parent 236b08c commit cfee904

File tree

3 files changed

+105
-20
lines changed

3 files changed

+105
-20
lines changed

aten/src/ATen/native/CPUBlas.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,13 @@ void gemm(
435435
return;
436436
}
437437
#endif
438+
#if AT_MKLDNN_ACL_ENABLED()
439+
// add heuristic based on shape to dispatch to sbgemm_ vs MKLDNN
440+
if (mkldnn_bf16f32_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) {
441+
return;
442+
}
443+
#endif //AT_MKLDNN_ACL_ENABLED
444+
438445
#ifdef MKL_HAS_SBGEMM
439446
if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
440447
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;

aten/src/ATen/native/mkldnn/Matmul.cpp

Lines changed: 89 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,25 @@ static bool use_mkldnn_bf32_matmul() {
107107
return use_mkldnn_bf16_matmul() && at::globalContext().float32MatmulPrecision() == at::Float32MatmulPrecision::MEDIUM;
108108
}
109109

110+
// returns an ideep::tensor
111+
// - dims: shape e.g: {M,N}
112+
// - idtype: ideep data type e.g: (f32, bf16, f16)
113+
// - strides: Memory layout
114+
// - data: data pointer
115+
template <typename scalar_t>
116+
inline ideep::tensor make_ideep_tensor(
117+
std::vector<int64_t> dims,
118+
ideep::tensor::data_type idtype,
119+
ideep::tensor::dims& strides,
120+
scalar_t *data){
121+
ideep::tensor res({
122+
dims,
123+
idtype,
124+
strides
125+
},
126+
data);
127+
return res;
128+
}
110129

111130
template<typename scalar_t>
112131
static inline typename std::enable_if_t<
@@ -155,35 +174,74 @@ mkldnn_gemm(
155174
idtype = ideep::tensor::data_type::f32;
156175
}
157176

158-
ideep::tensor a({
159-
/*sizes=*/{k, m},
160-
idtype,
161-
/*strides=*/a_strides},
162-
const_cast<scalar_t*>(a_data));
163-
ideep::tensor b({
164-
/*sizes=*/{n, k},
165-
idtype,
166-
/*strides=*/b_strides},
167-
const_cast<scalar_t*>(b_data));
168-
ideep::tensor c({
169-
/*sizes=*/{n, m},
170-
idtype,
171-
/*strides=*/c_strides},
172-
c_data);
177+
ideep::tensor a = make_ideep_tensor<scalar_t>({k, m}, idtype, a_strides, const_cast<scalar_t*>(a_data));
178+
ideep::tensor b = make_ideep_tensor<scalar_t>({n, k}, idtype, b_strides, const_cast<scalar_t*>(b_data));
179+
ideep::tensor c = make_ideep_tensor<scalar_t>({n, m}, idtype, c_strides, c_data);
173180

174181
ideep::matmul_forward::compute(
175182
b, a, c, alpha, beta,
176183
ideep::scale_t(), ideep::scale_t(), ideep::scale_t(), op_attr);
177184

178185
if (c.get_data_handle() != c_data){
186+
// ideep will query oneDNN expect format of output
187+
// if given output format is not expected, ideep will re-init an output buffer
188+
// under this case, we need copy the re-inited buffer back to given buffer
189+
ideep::tensor real_output = make_ideep_tensor<scalar_t>({n,m}, idtype, c_strides, c_data);
190+
c.reorder_to(real_output);
191+
}
192+
return true;
193+
}
194+
195+
template<typename scalar_t>
196+
inline typename std::enable_if_t<
197+
std::is_same_v<scalar_t, c10::BFloat16>,
198+
bool>
199+
mkldnn_gemm(
200+
TransposeType transa, TransposeType transb,
201+
int64_t m, int64_t n, int64_t k,
202+
float alpha,
203+
const scalar_t *a_data, int64_t lda,
204+
const scalar_t *b_data, int64_t ldb,
205+
float beta,
206+
float* c_data, int64_t ldc) {
207+
// introduce heuristic to validate dispatch to MKLDNN
208+
// (m * n * k <= 16 * 16 * 16)
209+
bool bf16_usable = use_mkldnn_bf16_matmul();
210+
if (!bf16_usable) {
211+
return false;
212+
}
213+
214+
ideep::attr_t op_attr;
215+
// Use mkldnn post ops to perform the add.
216+
if (beta != 0.0f) {
217+
op_attr = ideep::attr_t::fuse_sum();
218+
}
219+
220+
// NOTE: View as c-contiguous to avoid extra reordering in mkldnn
221+
// Use identity: C = AB <=> C^T = B^T A^T
222+
ideep::tensor::dims a_strides{{lda, 1}}, b_strides{{ldb, 1}}, c_strides{{ldc, 1}};
223+
if (transa != TransposeType::NoTranspose) {
224+
std::swap(a_strides[0], a_strides[1]);
225+
}
226+
if (transb != TransposeType::NoTranspose) {
227+
std::swap(b_strides[0], b_strides[1]);
228+
}
229+
230+
auto idtype = ideep::tensor::data_type::bf16;
231+
232+
ideep::tensor a = make_ideep_tensor<scalar_t>({k, m}, idtype, a_strides, const_cast<scalar_t*>(a_data));
233+
ideep::tensor b = make_ideep_tensor<scalar_t>({n, k}, idtype, b_strides, const_cast<scalar_t*>(b_data));
234+
ideep::tensor c = make_ideep_tensor<float>({n, m}, ideep::tensor::data_type::f32, c_strides, c_data);
235+
236+
ideep::matmul_forward::compute(
237+
b, a, c, alpha, beta,
238+
ideep::scale_t(), ideep::scale_t(), ideep::scale_t(), op_attr);
239+
240+
if(c.get_data_handle() != c_data){
179241
// ideep will query onednn expect format of output
180242
// if given output format is not expected, ideep will re-init an output buffer
181243
// under this case, we need copy the re-inited buffer back to given buffer
182-
ideep::tensor real_output({
183-
/*sizes=*/{n, m},
184-
idtype,
185-
/*strides=*/c_strides},
186-
c_data);
244+
ideep::tensor real_output = make_ideep_tensor<float>({n,m}, idtype, c_strides, c_data);
187245
c.reorder_to(real_output);
188246
}
189247

@@ -201,6 +259,17 @@ bool mkldnn_bf16_gemm(
201259
return mkldnn_gemm<c10::BFloat16>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
202260
}
203261

262+
bool mkldnn_bf16f32_gemm(
263+
TransposeType transa, TransposeType transb,
264+
int64_t m, int64_t n, int64_t k,
265+
float alpha,
266+
const c10::BFloat16 *a, int64_t lda,
267+
const c10::BFloat16 *b, int64_t ldb,
268+
float beta,
269+
float *c, int64_t ldc) {
270+
return mkldnn_gemm<c10::BFloat16>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
271+
}
272+
204273
bool mkldnn_fp16_gemm(
205274
TransposeType transa, TransposeType transb,
206275
int64_t m, int64_t n, int64_t k,

aten/src/ATen/native/mkldnn/Matmul.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@ bool mkldnn_bf16_gemm(
3939
float beta,
4040
c10::BFloat16 *c, int64_t ldc);
4141

42+
bool mkldnn_bf16f32_gemm(
43+
TransposeType transa, TransposeType transb,
44+
int64_t m, int64_t n, int64_t k,
45+
float alpha,
46+
const c10::BFloat16 *a, int64_t lda,
47+
const c10::BFloat16 *b, int64_t ldb,
48+
float beta,
49+
float *c, int64_t ldc);
50+
4251
bool mkldnn_fp16_gemm(
4352
TransposeType transa, TransposeType transb,
4453
int64_t m, int64_t n, int64_t k,

0 commit comments

Comments
 (0)
0