8000 [WIP][Intel GPU][CI] Acceptance test for OneDNN v3.8.0 upgrading [DONT MERGE] by LuFinch · Pull Request #153228 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[WIP][Intel GPU][CI] Acceptance test for OneDNN v3.8.0 upgrading [DONT MERGE] #153228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: release/2.7
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions aten/src/ATen/native/mkldnn/xpu/Attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) {
}
return false;
}
if (query_size_last > 256) {
constexpr int MAX_HEAD_DIM = 576;
if (query_size_last > MAX_HEAD_DIM) {
if (debug) {
TORCH_WARN(
"OneDNN attention requires q,k,v to have head dimension less than 256.",
" Got ",
"OneDNN attention requires q,k,v to have head dimension less than ",
MAX_HEAD_DIM,
". Got ",
query_size_last,
" instead.");
}
Expand Down
90 changes: 45 additions & 45 deletions aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
#include <ATen/OpMathType.h>
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>

#include <oneapi/dnnl/dnnl.hpp>

namespace {

using namespace at::native::onednn;
using logical_tensor = dnnl::graph::logical_tensor;
using data_type = logical_tensor::data_type;
using dims = logical_tensor::dims;
using op = dnnl::graph::op;
using partition = dnnl::graph::partition;

namespace {
inline data_type to_logical_tensor_data_type(c10::ScalarType scalar_type) {
return scalar_type == c10::ScalarType::Float ? data_type::f32
: scalar_type == c10::ScalarType::Half ? data_type::f16
: scalar_type == c10::ScalarType::BFloat16 ? data_type::bf16
: data_type::undef;
}

struct SDPALogicalParams {
enum class TensorID {
query,
Expand Down Expand Up @@ -39,11 +47,7 @@ struct SDPALogicalParams {
const std::optional<at::Tensor>& attn_mask_,
const at::Tensor& output_,
bool is_causal) {
const data_type dtype = // to logical_tensor data type
query_.scalar_type() == c10::ScalarType::Float ? data_type::f32
: query_.scalar_type() == c10::ScalarType::Half ? data_type::f16
: query_.scalar_type() == c10::ScalarType::BFloat16 ? data_type::bf16
: data_type::undef;
const data_type dtype = to_logical_tensor_data_type(query_.scalar_type());
TORCH_INTERNAL_ASSERT(
(dtype != data_type::undef),
"Only FP16/BF16/FP32 datatypes are currently supported");
Expand All @@ -61,22 +65,27 @@ struct SDPALogicalParams {
key_.strides().vec()};
scale = {
static_cast<size_t>(TensorID::scale),
dtype,
to_logical_tensor_data_type(at::toOpMathType(query_.scalar_type())),
scalar_shape,
logical_tensor::layout_type::strided,
logical_tensor::property_type::constant};
if (is_causal) {
neg_inf = {
static_cast<size_t>(TensorID::neg_inf),
dtype,
to_logical_tensor_data_type(at::toOpMathType(query_.scalar_type())),
scalar_shape,
logical_tensor::layout_type::strided,
logical_tensor::property_type::constant};
}
if (attn_mask_.has_value()) {
const data_type mask_dtype =
to_logical_tensor_data_type(attn_mask_->scalar_type());
TORCH_INTERNAL_ASSERT(
(mask_dtype != data_type::undef),
"Only FP16/BF16/FP32 datatypes are currently supported for attn_mask");
attn_mask = {
static_cast<size_t>(TensorID::attn_mask),
dtype,
mask_dtype,
attn_mask_->sizes().vec(),
attn_mask_->strides().vec()};
}
Expand Down Expand Up @@ -124,7 +133,12 @@ partition create_sdpa_graph_partition(
size_t lt_id = static_cast<size_t>(SDPALogicalParams::TensorID::end);
size_t op_id = 0;

logical_tensor matmul_qk_out{lt_id++, dtype};
// OneDNN graph has optimized implementation for `f16` or `bf16` SDPA with
// `f32` intermediate data type on Intel Graphics Products with Intel(R) Xe
// Matrix Extensions (Intel(R) XMX) support, which means the
// Q/K/V tensors have bf16 or f16 data type while the output of the first
// MatMul, Scale, Mask, and the input of SoftMax are in f32 data type.
logical_tensor matmul_qk_out{lt_id++, data_type::f32};
op matmul_qk{
op_id++,
op::kind::MatMul,
Expand All @@ -133,7 +147,7 @@ partition create_sdpa_graph_partition(
"matmul_qk"};
matmul_qk.set_attr<bool>(op::attr::transpose_b, true);

logical_tensor scaled_qk_out{lt_id++, dtype};
logical_tensor scaled_qk_out{lt_id++, data_type::f32};
op scale_mul{
op_id++,
op::kind::Multiply,
Expand All @@ -158,7 +172,7 @@ partition create_sdpa_graph_partition(
if (params.attn_mask.has_value()) {
TORCH_INTERNAL_ASSERT(
!is_causal, "Additive mask cannot use with is_causal.");
masked_qk_out = {lt_id++, dtype};
masked_qk_out = {lt_id++, data_type::f32};
mask_add = {
op_id++,
op::kind::Add,
Expand Down Expand Up @@ -193,7 +207,7 @@ partition create_sdpa_graph_partition(
{mask_gt_out.value()},
"mask_gt"};

masked_qk_out = {lt_id++, dtype};
masked_qk_out = {lt_id++, data_type::f32};
mask_select = {
op_id++,
op::kind::Select,
Expand Down Expand Up @@ -327,24 +341,16 @@ void gpu_float_sdpa(
at::scalar_tensor(-std::numeric_limits<float>::infinity(), opts));
};

static bool driver_support_implict_causal = true;
if (attn_mask.has_value()) {
TORCH_INTERNAL_ASSERT(
!is_causal,
"scaled_dot_product_fused_attention_overrideable_xpu: "
"attn_mask cannot present with is_causal");
} else {
// Currenetly implict mask only supports square fp16 cases
const bool support_implict_causal = driver_support_implict_causal &&
(query.dtype() == at::kHalf || query.dtype() == at::kBFloat16) &&
seq_len_q == seq_len_k;
if (is_causal && !support_implict_causal) {
attn_mask = get_tril_mask();
is_causal = false;
}
// OneDNN doesn't support fp32 ukernel for implicit causal mask,
// and the reference implementation is worse than aten math + explict causal
// mask. Fall back to explict causal mask until OneDNN v3.9 which has fp32
// ukernel for implicit causal mask.
if (is_causal && query.dtype() == at::kFloat) {
attn_mask = get_tril_mask();
is_causal = false;
}

std::vector<logical_tensor> l_inputs, l_outputs;
std::vector<dnnl::graph::logical_tensor> l_inputs, l_outputs;
std::optional<dnnl::graph::compiled_partition> compiled_partition;

auto get_compiled_partition = [&]() {
Expand All @@ -366,24 +372,18 @@ void gpu_float_sdpa(
return compiled_partition;
};

// maybe retry without causal mask
try {
compiled_partition = get_compiled_partition();
} catch (std::exception& e) {
if (is_causal) {
attn_mask = get_tril_mask();
is_causal = false;
compiled_partition = get_compiled_partition();
driver_support_implict_causal = false;
} else {
throw e;
}
}
compiled_partition = get_compiled_partition();

Tensor softmax_scale1 = at::full({}, softmax_scale, query.options());
Tensor softmax_scale1 = at::full(
{},
softmax_scale,
query.options().dtype(at::toOpMathType(query.scalar_type())));
std::optional<at::Tensor> neg_inf;
if (is_causal) {
neg_inf = at::full({}, -INFINITY, query.options());
neg_inf = at::full(
{},
-INFINITY,
query.options().dtype(at::toOpMathType(query.scalar_type())));
}

std::vector<dnnl::graph::tensor> outputs = {
Expand Down
2 changes: 1 addition & 1 deletion cmake/Modules/FindMKLDNN.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ IF(NOT MKLDNN_FOUND)
endif()
ExternalProject_Add(xpu_mkldnn_proj
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN
GIT_TAG v3.7.1
GIT_TAG rls-v3.8
PREFIX ${XPU_MKLDNN_DIR_PREFIX}
BUILD_IN_SOURCE 0
CMAKE_ARGS -DCMAKE_C_COMPILER=icx
Expand Down
6 changes: 3 additions & 3 deletions test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3920,11 +3920,11 @@ def test_fused_attention_different_dk_dv(self, device):

self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)

def test_onednn_attention_fail_d256(self, device):
# Test that onednn graph attention dispatching correctly bails out on d > 256
def test_onednn_attention_fail_d576(self, device):
# Test that onednn graph attention dispatching correctly bails out on d > 576
b, h = 1, 2
s_q, s_kv = 128, 128
d_qk, d_v = 512, 512
d_qk, d_v = 1024, 1024

q = torch.randn(b, h, s_q, d_qk, device=device, dtype=torch.bfloat16)
k = torch.randn(b, h, s_kv, d_qk, device=device, dtype=torch.bfloat16)
Expand Down
Loading
0