|
| 1 | +#include <ATen/native/mkldnn/xpu/detail/oneDNN.h> |
| 2 | +#include <ATen/native/transformers/attention.h> |
| 3 | +#include <ATen/native/transformers/sdp_utils_cpp.h> |
| 4 | +#include <c10/util/Array.h> |
| 5 | +#include <torch/library.h> |
| 6 | + |
| 7 | +namespace { |
| 8 | +bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) { |
| 9 | + const auto query_size_last = params.query.sym_size(-1); |
| 10 | + const auto key_size_last = params.key.sym_size(-1); |
| 11 | + const auto value_size_last = params.value.sym_size(-1); |
| 12 | + if ((query_size_last != key_size_last) || |
| 13 | + (query_size_last != value_size_last)) { |
| 14 | + if (debug) { |
| 15 | + TORCH_WARN( |
| 16 | + "OneDNN attention requires q,k,v to have the same last dimension.", |
| 17 | + " Got Query.size(-1): ", |
| 18 | + query_size_last, |
| 19 | + ", Key.size(-1): ", |
| 20 | + key_size_last, |
| 21 | + ", Value.size(-1): ", |
| 22 | + value_size_last, |
| 23 | + " instead."); |
| 24 | + } |
| 25 | + return false; |
| 26 | + } |
| 27 | + if (query_size_last > 256) { |
| 28 | + if (debug) { |
| 29 | + TORCH_WARN( |
| 30 | + "OneDNN attention requires q,k,v to have head dimension less than 256.", |
| 31 | + " Got ", |
| 32 | + query_size_last, |
| 33 | + " instead."); |
| 34 | + } |
| 35 | + return false; |
| 36 | + } |
| 37 | + return true; |
| 38 | +} |
| 39 | + |
| 40 | +bool check_no_grad(sdp::sdp_params const& params, bool debug) { |
| 41 | + const bool any_inputs_require_grad = params.query.requires_grad() || |
| 42 | + params.key.requires_grad() || params.value.requires_grad(); |
| 43 | + const bool gradmode_enabled = at::GradMode::is_enabled(); |
| 44 | + if (debug && any_inputs_require_grad && gradmode_enabled) { |
| 45 | + TORCH_WARN("Backward or grad to be supported."); |
| 46 | + } |
| 47 | + return !any_inputs_require_grad || !gradmode_enabled; |
| 48 | +} |
| 49 | + |
| 50 | +bool use_overrideable_xpu(sdp::sdp_params const& params, bool debug) { |
| 51 | + constexpr auto supported_dtypes = c10::array_of<at::ScalarType>( |
| 52 | + at::kFloat, at::kBFloat16, at::kHalf); // double is not supported |
| 53 | + |
| 54 | + // Define gate functions that determine if a flash kernel can be run |
| 55 | + constexpr auto constraints = c10::array_of<bool (*)( |
| 56 | + sdp::sdp_params const&, bool)>( |
| 57 | + sdp::check_nested_tensor, |
| 58 | + sdp::check_for_dropout, |
| 59 | + sdp::check_tensor_shapes, |
| 60 | + sdp::check_batch_size_and_num_heads_dense<true /*supports GQA*/>, |
| 61 | + sdp::check_attn_mask_shape, |
| 62 | + sdp::check_nonzero_sequence_lengths_dense, |
| 63 | + sdp::check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim*/>, |
| 64 | + check_head_dim_size_xpu, |
| 65 | + check_no_grad); |
| 66 | + for (auto& constraint : constraints) { |
| 67 | + if (!constraint(params, debug)) { |
| 68 | + return false; |
| 69 | + } |
| 70 | + } |
| 71 | + return sdp::check_tensor_dtype(params, supported_dtypes, debug); |
| 72 | +} |
| 73 | + |
| 74 | +sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) { |
| 75 | + // This function defines the priority order of the different sdp backends |
| 76 | + // 1. Flash Attention |
| 77 | + // 2. Math fallback |
| 78 | + auto& ctx = at::globalContext(); |
| 79 | + // use overrideable linked to onednn as overrideable implementation |
| 80 | + if (!ctx.userEnabledMathSDP() && !ctx.userEnabledOverrideableSDP()) { |
| 81 | + return sdp::SDPBackend::error; |
| 82 | + } |
| 83 | + |
| 84 | + // Get ideal kernel ordering |
| 85 | + const std::array<sdp::SDPBackend, 2> priority_order{ |
| 86 | + sdp::SDPBackend::overrideable, |
| 87 | + sdp::SDPBackend::math, |
| 88 | + }; |
| 89 | + |
| 90 | + // Because TORCHCHECK checks if condition is true we negate debug so that |
| 91 | + // The statements will be printed when debug is true |
| 92 | + bool print_debug = false; |
| 93 | + for (auto& backend : priority_order) { |
| 94 | + switch (backend) { |
| 95 | + case sdp::SDPBackend::overrideable: |
| 96 | + if (ctx.userEnabledOverrideableSDP() && |
| 97 | + use_overrideable_xpu(kernel_params, print_debug)) { |
| 98 | + return sdp::SDPBackend::overrideable; |
| 99 | + } |
| 100 | + break; |
| 101 | + case sdp::SDPBackend::math: |
| 102 | + if (ctx.userEnabledMathSDP()) { |
| 103 | + return sdp::SDPBackend::math; |
| 104 | + } |
| 105 | + break; |
| 106 | + default: |
| 107 | + TORCH_CHECK(false, "Invalid backend"); |
| 108 | + } |
| 109 | + } |
| 110 | + // If we have gotten to this point then two things have happened: |
| 111 | + // 1. use_overrideable_xpu did not satisfy the constraints to be ran |
| 112 | + // 2. The user has explicitly disabled the math kernel |
| 113 | + // We then re-run the kernel checks with debug enabled to print out the |
| 114 | + // reason why the kernel was not selected |
| 115 | + |
| 116 | + print_debug = true; |
| 117 | + TORCH_WARN("OneDNN kernel not used because:"); |
| 118 | + use_overrideable_xpu(kernel_params, print_debug); |
| 119 | + TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.") |
| 120 | + return sdp::SDPBackend::error; |
| 121 | +} |
| 122 | +} // namespace |
| 123 | + |
| 124 | +namespace at::native { |
| 125 | +int64_t _fused_sdp_choice_xpu( |
| 126 | + const at::Tensor& query_, |
| 127 | + const at::Tensor& key, |
| 128 | + const at::Tensor& value, |
| 129 | + const std::optional<at::Tensor>& attn_mask_, |
| 130 | + double dropout_p, |
| 131 | + bool is_causal, |
| 132 | + std::optional<double> scale, |
| 133 | + bool enable_gqa) { |
| 134 | + sdp::sdp_params kernel_params{ |
| 135 | + query_, key, value, attn_mask_, dropout_p, is_causal, enable_gqa}; |
| 136 | + auto backend = select_sdp_backend_xpu(kernel_params); |
| 137 | + |
| 138 | + if (backend == sdp::SDPBackend::error) { |
| 139 | + TORCH_CHECK( |
| 140 | + false, |
| 141 | + "No viable backend for scaled_dot_product_attention was found. ", |
| 142 | + "This is likely due to turning off both the math kernel and the fused kernels."); |
| 143 | + } |
| 144 | + return static_cast<int64_t>(backend); |
| 145 | +} |
| 146 | + |
| 147 | +std::tuple< |
| 148 | + at::Tensor, |
| 149 | + at::Tensor, |
| 150 | + at::Tensor, |
| 151 | + at::Tensor, |
| 152 | + c10::SymInt, |
| 153 | + c10::SymInt, |
| 154 | + at::Tensor, |
| 155 | + at::Tensor, |
| 156 | + at::Tensor> |
| 157 | +_scaled_dot_product_fused_attention_overrideable_xpu( |
| 158 | + const at::Tensor& query, |
| 159 | + const at::Tensor& key, |
| 160 | + const at::Tensor& value, |
| 161 | + const std::optional<at::Tensor>& attn_bias, |
| 162 | + double dropout_p, |
| 163 | + bool is_causal, |
| 164 | + bool return_debug_mask, |
| 165 | + std::optional<double> scale) { |
| 166 | + TORCH_INTERNAL_ASSERT( |
| 167 | + query.dim() == 4 && key.dim() == 4 && value.dim() == 4, |
| 168 | + "scaled_dot_product_fused_attention_overrideable_xpu: Accept only 4 dims inputs shape of {(B), H, T, K}"); |
| 169 | + TORCH_INTERNAL_ASSERT( |
| 170 | + (key.size(0) == value.size(0)) && (key.size(1) == value.size(1)) && |
| 171 | + (key.size(2) == value.size(2)), |
| 172 | + "scaled_dot_product_fused_attention_overrideable_xpu: K/V should have the same batch / seq / num_head"); |
| 173 | + TORCH_INTERNAL_ASSERT( |
| 174 | + query.size(3) == key.size(3), |
| 175 | + "scaled_dot_product_fused_attention_overrideable_xpu: Q/K should have the same head_dim"); |
| 176 | + TORCH_INTERNAL_ASSERT( |
| 177 | + dropout_p == 0.0, |
| 178 | + "scaled_dot_product_fused_attention_overrideable_xpu: Currently do not support dropout > 0"); |
| 179 | + TORCH_INTERNAL_ASSERT( |
| 180 | + !(attn_bias.has_value() && is_causal), |
| 181 | + "scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot present with is_causal"); |
| 182 | + |
| 183 | + const int64_t batch_size = query.size(0); |
| 184 | + const int64_t num_head = query.size(1); |
| 185 | + const int64_t num_head_kv = key.size(1); |
| 186 | + const int64_t head_dim = query.size(3); |
| 187 | + const int64_t head_dim_v = value.size(3); |
| 188 | + const int64_t seq_len_q = query.size(2); |
| 189 | + const int64_t seq_len_kv = key.size(2); |
| 190 | + |
| 191 | + auto opts = query.options(); |
| 192 | + auto output = at::empty({batch_size, num_head, seq_len_q, head_dim}, opts); |
| 193 | + // auto logsumexp = |
| 194 | + // at::empty({batch_size, num_head, seq_len_q}, opts.dtype(at::kFloat)); |
| 195 | + auto logsumexp = at::empty({}, opts.dtype(at::kFloat)); |
| 196 | + |
| 197 | + at::native::onednn::gpu_float_sdpa( |
| 198 | + batch_size, |
| 199 | + seq_len_q, |
| 200 | + seq_len_kv, |
| 201 | + num_head, |
| 202 | + num_head_kv, |
| 203 | + head_dim, |
| 204 | + head_dim_v, |
| 205 | + query, |
| 206 | + key, |
| 207 | + value, |
| 208 | + attn_bias, |
| 209 | + is_causal, |
| 210 | + scale.has_value() ? scale.value() : (1.0 / std::sqrt(head_dim)), |
| 211 | + output); |
| 212 | + |
| 213 | + // rng and debug mask not used |
| 214 | + auto philox_seed = at::em
1C6A
pty({}, at::dtype(at::kLong)); |
| 215 | + auto philox_offset = at::empty({}, at::dtype(at::kLong)); |
| 216 | + auto debug_attn_mask = at::empty( |
| 217 | + {batch_size, num_head, seq_len_q, seq_len_kv}, at::dtype(at::kFloat)); |
| 218 | + |
| 219 | + return std::make_tuple( |
| 220 | + output, |
| 221 | + logsumexp, |
| 222 | + /* cum_seq_q */ at::Tensor(), |
| 223 | + /* cum_seq_k */ at::Tensor(), |
| 224 | + seq_len_q, |
| 225 | + seq_len_kv, |
| 226 | + philox_seed, |
| 227 | + philox_offset, |
| 228 | + debug_attn_mask); |
| 229 | +} |
| 230 | +} // namespace at::native |
0 commit comments