8000 [Intel GPU] Add SDPA implementation on XPU with OneDNN (#147612) · pytorch/pytorch@cde1220 · GitHub
[go: up one dir, main page]

Skip to content

Commit cde1220

Browse files
DDElepytorchmergebot
authored andcommitted
[Intel GPU] Add SDPA implementation on XPU with OneDNN (#147612)
Add XPU implementation of OneDNN based SDPA operator. Will be integrated and enabled later. Depends on BUILD_GRAPH switch in #147608 Pull Request resolved: #147612 Approved by: https://github.com/EikanWang
1 parent 576ed1e commit cde1220

File tree

6 files changed

+717
-9
lines changed

6 files changed

+717
-9
lines changed
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
2+
#include <ATen/native/transformers/attention.h>
3+
#include <ATen/native/transformers/sdp_utils_cpp.h>
4+
#include <c10/util/Array.h>
5+
#include <torch/library.h>
6+
7+
namespace {
8+
bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) {
9+
const auto query_size_last = params.query.sym_size(-1);
10+
const auto key_size_last = params.key.sym_size(-1);
11+
const auto value_size_last = params.value.sym_size(-1);
12+
if ((query_size_last != key_size_last) ||
13+
(query_size_last != value_size_last)) {
14+
if (debug) {
15+
TORCH_WARN(
16+
"OneDNN attention requires q,k,v to have the same last dimension.",
17+
" Got Query.size(-1): ",
18+
query_size_last,
19+
", Key.size(-1): ",
20+
key_size_last,
21+
", Value.size(-1): ",
22+
value_size_last,
23+
" instead.");
24+
}
25+
return false;
26+
}
27+
if (query_size_last > 256) {
28+
if (debug) {
29+
TORCH_WARN(
30+
"OneDNN attention requires q,k,v to have head dimension less than 256.",
31+
" Got ",
32+
query_size_last,
33+
" instead.");
34+
}
35+
return false;
36+
}
37+
return true;
38+
}
39+
40+
bool check_no_grad(sdp::sdp_params const& params, bool debug) {
41+
const bool any_inputs_require_grad = params.query.requires_grad() ||
42+
params.key.requires_grad() || params.value.requires_grad();
43+
const bool gradmode_enabled = at::GradMode::is_enabled();
44+
if (debug && any_inputs_require_grad && gradmode_enabled) {
45+
TORCH_WARN("Backward or grad to be supported.");
46+
}
47+
return !any_inputs_require_grad || !gradmode_enabled;
48+
}
49+
50+
bool use_overrideable_xpu(sdp::sdp_params const& params, bool debug) {
51+
constexpr auto supported_dtypes = c10::array_of<at::ScalarType>(
52+
at::kFloat, at::kBFloat16, at::kHalf); // double is not supported
53+
54+
// Define gate functions that determine if a flash kernel can be run
55+
constexpr auto constraints = c10::array_of<bool (*)(
56+
sdp::sdp_params const&, bool)>(
57+
sdp::check_nested_tensor,
58+
sdp::check_for_dropout,
59+
sdp::check_tensor_shapes,
60+
sdp::check_batch_size_and_num_heads_dense<true /*supports GQA*/>,
61+
sdp::check_attn_mask_shape,
62+
sdp::check_nonzero_sequence_lengths_dense,
63+
sdp::check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim*/>,
64+
check_head_dim_size_xpu,
65+
check_no_grad);
66+
for (auto& constraint : constraints) {
67+
if (!constraint(params, debug)) {
68+
return false;
69+
}
70+
}
71+
return sdp::check_tensor_dtype(params, supported_dtypes, debug);
72+
}
73+
74+
sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
75+
// This function defines the priority order of the different sdp backends
76+
// 1. Flash Attention
77+
// 2. Math fallback
78+
auto& ctx = at::globalContext();
79+
// use overrideable linked to onednn as overrideable implementation
80+
if (!ctx.userEnabledMathSDP() && !ctx.userEnabledOverrideableSDP()) {
81+
return sdp::SDPBackend::error;
82+
}
83+
84+
// Get ideal kernel ordering
85+
const std::array<sdp::SDPBackend, 2> priority_order{
86+
sdp::SDPBackend::overrideable,
87+
sdp::SDPBackend::math,
88+
};
89+
90+
// Because TORCHCHECK checks if condition is true we negate debug so that
91+
// The statements will be printed when debug is true
92+
bool print_debug = false;
93+
for (auto& backend : priority_order) {
94+
switch (backend) {
95+
case sdp::SDPBackend::overrideable:
96+
if (ctx.userEnabledOverrideableSDP() &&
97+
use_overrideable_xpu(kernel_params, print_debug)) {
98+
return sdp::SDPBackend::overrideable;
99+
}
100+
break;
101+
case sdp::SDPBackend::math:
102+
if (ctx.userEnabledMathSDP()) {
103+
return sdp::SDPBackend::math;
104+
}
105+
break;
106+
default:
107+
TORCH_CHECK(false, "Invalid backend");
108+
}
109+
}
110+
// If we have gotten to this point then two things have happened:
111+
// 1. use_overrideable_xpu did not satisfy the constraints to be ran
112+
// 2. The user has explicitly disabled the math kernel
113+
// We then re-run the kernel checks with debug enabled to print out the
114+
// reason why the kernel was not selected
115+
116+
print_debug = true;
117+
TORCH_WARN("OneDNN kernel not used because:");
118+
use_overrideable_xpu(kernel_params, print_debug);
119+
TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.")
120+
return sdp::SDPBackend::error;
121+
}
122+
} // namespace
123+
124+
namespace at::native {
125+
int64_t _fused_sdp_choice_xpu(
126+
const at::Tensor& query_,
127+
const at::Tensor& key,
128+
const at::Tensor& value,
129+
const std::optional<at::Tensor>& attn_mask_,
130+
double dropout_p,
131+
bool is_causal,
132+
std::optional<double> scale,
133+
bool enable_gqa) {
134+
sdp::sdp_params kernel_params{
135+
query_, key, value, attn_mask_, dropout_p, is_causal, enable_gqa};
136+
auto backend = select_sdp_backend_xpu(kernel_params);
137+
138+
if (backend == sdp::SDPBackend::error) {
139+
TORCH_CHECK(
140+
false,
141+
"No viable backend for scaled_dot_product_attention was found. ",
142+
"This is likely due to turning off both the math kernel and the fused kernels.");
143+
}
144+
return static_cast<int64_t>(backend);
145+
}
146+
147+
std::tuple<
148+
at::Tensor,
149+
at::Tensor,
150+
at::Tensor,
151+
at::Tensor,
152+
c10::SymInt,
153+
c10::SymInt,
154+
at::Tensor,
155+
at::Tensor,
156+
at::Tensor>
157+
_scaled_dot_product_fused_attention_overrideable_xpu(
158+
const at::Tensor& query,
159+
const at::Tensor& key,
160+
const at::Tensor& value,
161+
const std::optional<at::Tensor>& attn_bias,
162+
double dropout_p,
163+
bool is_causal,
164+
bool return_debug_mask,
165+
std::optional<double> scale) {
166+
TORCH_INTERNAL_ASSERT(
167+
query.dim() == 4 && key.dim() == 4 && value.dim() == 4,
168+
"scaled_dot_product_fused_attention_overrideable_xpu: Accept only 4 dims inputs shape of {(B), H, T, K}");
169+
TORCH_INTERNAL_ASSERT(
170+
(key.size(0) == value.size(0)) && (key.size(1) == value.size(1)) &&
171+
(key.size(2) == value.size(2)),
172+
"scaled_dot_product_fused_attention_overrideable_xpu: K/V should have the same batch / seq / num_head");
173+
TORCH_INTERNAL_ASSERT(
174+
query.size(3) == key.size(3),
175+
"scaled_dot_product_fused_attention_overrideable_xpu: Q/K should have the same head_dim");
176+
TORCH_INTERNAL_ASSERT(
177+
dropout_p == 0.0,
178+
"scaled_dot_product_fused_attention_overrideable_xpu: Currently do not support dropout > 0");
179+
TORCH_INTERNAL_ASSERT(
180+
!(attn_bias.has_value() && is_causal),
181+
"scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot present with is_causal");
182+
183+
const int64_t batch_size = query.size(0);
184+
const int64_t num_head = query.size(1);
185+
const int64_t num_head_kv = key.size(1);
186+
const int64_t head_dim = query.size(3);
187+
const int64_t head_dim_v = value.size(3);
188+
const int64_t seq_len_q = query.size(2);
189+
const int64_t seq_len_kv = key.size(2);
190+
191+
auto opts = query.options();
192+
auto output = at::empty({batch_size, num_head, seq_len_q, head_dim}, opts);
193+
// auto logsumexp =
194+
// at::empty({batch_size, num_head, seq_len_q}, opts.dtype(at::kFloat));
195+
auto logsumexp = at::empty({}, opts.dtype(at::kFloat));
196+
197+
at::native::onednn::gpu_float_sdpa(
198+
batch_size,
199+
seq_len_q,
200+
seq_len_kv,
201+
num_head,
202+
num_head_kv,
203+
head_dim,
204+
head_dim_v,
205+
query,
206+
key,
207+
value,
208+
attn_bias,
209+
is_causal,
210+
scale.has_value() ? scale.value() : (1.0 / std::sqrt(head_dim)),
211+
output);
212+
213+
// rng and debug mask not used
214+
auto philox_seed = at::em 1C6A pty({}, at::dtype(at::kLong));
215+
auto philox_offset = at::empty({}, at::dtype(at::kLong));
216+
auto debug_attn_mask = at::empty(
217+
{batch_size, num_head, seq_len_q, seq_len_kv}, at::dtype(at::kFloat));
218+
219+
return std::make_tuple(
220+
output,
221+
logsumexp,
222+
/* cum_seq_q */ at::Tensor(),
223+
/* cum_seq_k */ at::Tensor(),
224+
seq_len_q,
225+
seq_len_kv,
226+
philox_seed,
227+
philox_offset,
228+
debug_attn_mask);
229+
}
230+
} // namespace at::native

0 commit comments

Comments
 (0)
0