10000 fallback SDPA implict causal mask to explict for fp32 for performance · pytorch/pytorch@469b77c · GitHub
[go: up one dir, main page]

Skip to content

Commit 469b77c

Browse files
committed
fallback SDPA implict causal mask to explict for fp32 for performance
1 parent 817b225 commit 469b77c

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,27 @@ void gpu_float_sdpa(
329329
{c10::kXPU, c10::xpu::current_device()});
330330
auto strm = GpuStreamManager::Instance().get_stream();
331331

332+
const auto get_tril_mask = [&]() {
333+
auto opts = query.options();
334+
auto bool_tril =
335+
at::ones_symint(
336+
{query.sym_size(-2), key.sym_size(-2)}, opts.dtype(at::kBool))
337+
.tril();
338+
return at::where(
339+
bool_tril,
340+
0.f,
341+
at::scalar_tensor(-std::numeric_limits<float>::infinity(), opts));
342+
};
343+
344+
// OneDNN doesn't support fp32 ukernel for implicit causal mask,
345+
// and the reference implementation is worse than aten math + explict causal
346+
// mask. Fall back to explict causal mask until OneDNN v3.9 which has fp32
347+
// ukernel for implicit causal mask.
348+
if (is_causal && query.dtype() == at::kFloat) {
349+
attn_mask = get_tril_mask();
350+
is_causal = false;
351+
}
352+
332353
std::vector<dnnl::graph::logical_tensor> l_inputs, l_outputs;
333354
std::optional<dnnl::graph::compiled_partition> compiled_partition;
334355

0 commit comments

Comments
 (0)
0