diff --git a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp index 89b69ec70ddb7..e27978e683f19 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp @@ -24,11 +24,13 @@ bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) { } return false; } - if (query_size_last > 256) { + constexpr int MAX_HEAD_DIM = 576; + if (query_size_last > MAX_HEAD_DIM) { if (debug) { TORCH_WARN( - "OneDNN attention requires q,k,v to have head dimension less than 256.", - " Got ", + "OneDNN attention requires q,k,v to have head dimension less than ", + MAX_HEAD_DIM, + ". Got ", query_size_last, " instead."); } diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp index 32394a200fca7..b9a7712cb5172 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp @@ -1,9 +1,11 @@ +#include #include #include #include - #include +namespace { + using namespace at::native::onednn; using logical_tensor = dnnl::graph::logical_tensor; using data_type = logical_tensor::data_type; @@ -11,7 +13,13 @@ using dims = logical_tensor::dims; using op = dnnl::graph::op; using partition = dnnl::graph::partition; -namespace { +inline data_type to_logical_tensor_data_type(c10::ScalarType scalar_type) { + return scalar_type == c10::ScalarType::Float ? data_type::f32 + : scalar_type == c10::ScalarType::Half ? data_type::f16 + : scalar_type == c10::ScalarType::BFloat16 ? data_type::bf16 + : data_type::undef; +} + struct SDPALogicalParams { enum class TensorID { query, @@ -39,11 +47,7 @@ struct SDPALogicalParams { const std::optional& attn_mask_, const at::Tensor& output_, bool is_causal) { - const data_type dtype = // to logical_tensor data type - query_.scalar_type() == c10::ScalarType::Float ? data_type::f32 - : query_.scalar_type() == c10::ScalarType::Half ? data_type::f16 - : query_.scalar_type() == c10::ScalarType::BFloat16 ? data_type::bf16 - : data_type::undef; + const data_type dtype = to_logical_tensor_data_type(query_.scalar_type()); TORCH_INTERNAL_ASSERT( (dtype != data_type::undef), "Only FP16/BF16/FP32 datatypes are currently supported"); @@ -61,22 +65,27 @@ struct SDPALogicalParams { key_.strides().vec()}; scale = { static_cast(TensorID::scale), - dtype, + to_logical_tensor_data_type(at::toOpMathType(query_.scalar_type())), scalar_shape, logical_tensor::layout_type::strided, logical_tensor::property_type::constant}; if (is_causal) { neg_inf = { static_cast(TensorID::neg_inf), - dtype, + to_logical_tensor_data_type(at::toOpMathType(query_.scalar_type())), scalar_shape, logical_tensor::layout_type::strided, logical_tensor::property_type::constant}; } if (attn_mask_.has_value()) { + const data_type mask_dtype = + to_logical_tensor_data_type(attn_mask_->scalar_type()); + TORCH_INTERNAL_ASSERT( + (mask_dtype != data_type::undef), + "Only FP16/BF16/FP32 datatypes are currently supported for attn_mask"); attn_mask = { static_cast(TensorID::attn_mask), - dtype, + mask_dtype, attn_mask_->sizes().vec(), attn_mask_->strides().vec()}; } @@ -124,7 +133,12 @@ partition create_sdpa_graph_partition( size_t lt_id = static_cast(SDPALogicalParams::TensorID::end); size_t op_id = 0; - logical_tensor matmul_qk_out{lt_id++, dtype}; + // OneDNN graph has optimized implementation for `f16` or `bf16` SDPA with + // `f32` intermediate data type on Intel Graphics Products with Intel(R) Xe + // Matrix Extensions (Intel(R) XMX) support, which means the + // Q/K/V tensors have bf16 or f16 data type while the output of the first + // MatMul, Scale, Mask, and the input of SoftMax are in f32 data type. + logical_tensor matmul_qk_out{lt_id++, data_type::f32}; op matmul_qk{ op_id++, op::kind::MatMul, @@ -133,7 +147,7 @@ partition create_sdpa_graph_partition( "matmul_qk"}; matmul_qk.set_attr(op::attr::transpose_b, true); - logical_tensor scaled_qk_out{lt_id++, dtype}; + logical_tensor scaled_qk_out{lt_id++, data_type::f32}; op scale_mul{ op_id++, op::kind::Multiply, @@ -158,7 +172,7 @@ partition create_sdpa_graph_partition( if (params.attn_mask.has_value()) { TORCH_INTERNAL_ASSERT( !is_causal, "Additive mask cannot use with is_causal."); - masked_qk_out = {lt_id++, dtype}; + masked_qk_out = {lt_id++, data_type::f32}; mask_add = { op_id++, op::kind::Add, @@ -193,7 +207,7 @@ partition create_sdpa_graph_partition( {mask_gt_out.value()}, "mask_gt"}; - masked_qk_out = {lt_id++, dtype}; + masked_qk_out = {lt_id++, data_type::f32}; mask_select = { op_id++, op::kind::Select, @@ -327,24 +341,16 @@ void gpu_float_sdpa( at::scalar_tensor(-std::numeric_limits::infinity(), opts)); }; - static bool driver_support_implict_causal = true; - if (attn_mask.has_value()) { - TORCH_INTERNAL_ASSERT( - !is_causal, - "scaled_dot_product_fused_attention_overrideable_xpu: " - "attn_mask cannot present with is_causal"); - } else { - // Currenetly implict mask only supports square fp16 cases - const bool support_implict_causal = driver_support_implict_causal && - (query.dtype() == at::kHalf || query.dtype() == at::kBFloat16) && - seq_len_q == seq_len_k; - if (is_causal && !support_implict_causal) { - attn_mask = get_tril_mask(); - is_causal = false; - } + // OneDNN doesn't support fp32 ukernel for implicit causal mask, + // and the reference implementation is worse than aten math + explict causal + // mask. Fall back to explict causal mask until OneDNN v3.9 which has fp32 + // ukernel for implicit causal mask. + if (is_causal && query.dtype() == at::kFloat) { + attn_mask = get_tril_mask(); + is_causal = false; } - std::vector l_inputs, l_outputs; + std::vector l_inputs, l_outputs; std::optional compiled_partition; auto get_compiled_partition = [&]() { @@ -366,24 +372,18 @@ void gpu_float_sdpa( return compiled_partition; }; - // maybe retry without causal mask - try { - compiled_partition = get_compiled_partition(); - } catch (std::exception& e) { - if (is_causal) { - attn_mask = get_tril_mask(); - is_causal = false; - compiled_partition = get_compiled_partition(); - driver_support_implict_causal = false; - } else { - throw e; - } - } + compiled_partition = get_compiled_partition(); - Tensor softmax_scale1 = at::full({}, softmax_scale, query.options()); + Tensor softmax_scale1 = at::full( + {}, + softmax_scale, + query.options().dtype(at::toOpMathType(query.scalar_type()))); std::optional neg_inf; if (is_causal) { - neg_inf = at::full({}, -INFINITY, query.options()); + neg_inf = at::full( + {}, + -INFINITY, + query.options().dtype(at::toOpMathType(query.scalar_type()))); } std::vector outputs = { diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index ae0a512b16583..9ec3418e924e1 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -49,7 +49,7 @@ IF(NOT MKLDNN_FOUND) endif() ExternalProject_Add(xpu_mkldnn_proj GIT_REPOSITORY https://github.com/oneapi-src/oneDNN - GIT_TAG v3.7.1 + GIT_TAG rls-v3.8 PREFIX ${XPU_MKLDNN_DIR_PREFIX} BUILD_IN_SOURCE 0 CMAKE_ARGS -DCMAKE_C_COMPILER=icx diff --git a/test/test_transformers.py b/test/test_transformers.py index 798095e065785..5164c27a3653d 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -3920,11 +3920,11 @@ def test_fused_attention_different_dk_dv(self, device): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) - def test_onednn_attention_fail_d256(self, device): - # Test that onednn graph attention dispatching correctly bails out on d > 256 + def test_onednn_attention_fail_d576(self, device): + # Test that onednn graph attention dispatching correctly bails out on d > 576 b, h = 1, 2 s_q, s_kv = 128, 128 - d_qk, d_v = 512, 512 + d_qk, d_v = 1024, 1024 q = torch.randn(b, h, s_q, d_qk, device=device, dtype=torch.bfloat16) k = torch.randn(b, h, s_kv, d_qk, device=device, dtype=torch.bfloat16)