8000 [SDPA-CUDNN] Make CuDNN Attention Opt in · pytorch/pytorch@1448700 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1448700

Browse files
committed
[SDPA-CUDNN] Make CuDNN Attention Opt in
ghstack-source-id: 30a6d89 Pull Request resolved: #138522
1 parent 7786869 commit 1448700

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,11 @@ bool check_prefer_cudnn_attention() {
6868
std::array<SDPBackend, num_backends> priority_order(sdp_params const& params) {
6969
constexpr std::array<SDPBackend, num_backends> default_order{
7070
SDPBackend::flash_attention,
71-
SDPBackend::cudnn_attention,
7271
SDPBackend::efficient_attention,
73-
SDPBackend::math};
74-
constexpr std::array<SDPBackend, num_backends> cudnn_order{
72+
SDPBackend::math,
7573
SDPBackend::cudnn_attention,
76-
SDPBackend::flash_attention,
77-
SDPBackend::efficient_attention,
78-
SDPBackend::math};
79-
static const bool prefer_cudnn = check_prefer_cudnn_attention();
80-
return prefer_cudnn ? cudnn_order : default_order;
74+
};
75+
return default_order;
8176
}
8277

8378
bool use_tensor_cores(sdp_params const& params, cudaDeviceProp* dprops, bool is_half) {

test/test_transformers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2808,8 +2808,12 @@ def test_fused_sdp_choice(self, device, type: str):
28082808
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
28092809
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
28102810

2811+
# TODO we are currently disabling this by default, lets assert that this returns
2812+
# FlashAttention, we need to change when we make remove opt-in for cudnn
28112813
if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater:
2812-
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value)
2814+
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value)
2815+
with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]):
2816+
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value)
28132817
elif PLATFORM_SUPPORTS_FLASH_ATTENTION:
28142818
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value)
28152819
elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION: # e.g., we're on Windows

0 commit comments

Comments
 (0)
0