File tree Expand file tree Collapse file tree 1 file changed +21
-0
lines changed
aten/src/ATen/native/mkldnn/xpu/detail Expand file tree Collapse file tree 1 file changed +21
-0
lines changed Original file line number Diff line number Diff line change @@ -329,6 +329,27 @@ void gpu_float_sdpa(
329
329
{c10::kXPU , c10::xpu::current_device ()});
330
330
auto strm = GpuStreamManager::Instance ().get_stream ();
331
331
332
+ const auto get_tril_mask = [&]() {
333
+ auto opts = query.options ();
334
+ auto bool_tril =
335
+ at::ones_symint (
336
+ {query.sym_size (-2 ), key.sym_size (-2 )}, opts.dtype (at::kBool ))
337
+ .tril ();
338
+ return at::where (
339
+ bool_tril,
340
+ 0 .f ,
341
+ at::scalar_tensor (-std::numeric_limits<float >::infinity (), opts));
342
+ };
343
+
344
+ // OneDNN doesn't support fp32 ukernel for implicit causal mask,
345
+ // and the reference implementation is worse than aten math + explict causal
346
+ // mask. Fall back to explict causal mask until OneDNN v3.9 which has fp32
347
+ // ukernel for implicit causal mask.
348
+ if (is_causal && query.dtype () == at::kFloat ) {
349
+ attn_mask = get_tril_mask ();
350
+ is_causal = false ;
351
+ }
352
+
332
353
std::vector<dnnl::graph::logical_tensor> l_inputs, l_outputs;
333
354
std::optional<dnnl::graph::compiled_partition> compiled_partition;
334
355
You can’t perform that action at this time.
0 commit comments