8000 Update on "[SDPA-CUDNN] Make CuDNN Attention Opt in" · pytorch/pytorch@132ea4a · GitHub
[go: up one dir, main page]

Skip to content

Commit 132ea4a

Browse files
committed
Update on "[SDPA-CUDNN] Make CuDNN Attention Opt in"
# Summary Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5: 1. #138529 2. huggingface/diffusers#9704 3. #138354 In light of the above we are going to make the CuDNN backend Opt-in by default. This can be done easily with the context manager for choosing backends I.e.: ``` Python from torch.nn.attention import sdpa_kernel, SDPBackend with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): out = F.scaled_dot_product_attention(q, k, v) ``` This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager). Cc atalman cc mikaylagawarecki [ghstack-poisoned]
1 parent aecd92d commit 132ea4a

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

test/test_transformers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2808,8 +2808,12 @@ def test_fused_sdp_choice(self, device, type: str):
28082808
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
28092809
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
28102810

2811+
# TODO we are currently disabling this by default, lets assert that this returns
2812+
# FlashAttention, we need to change when we make remove opt-in for cudnn
28112813
if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater:
2812-
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value)
2814+
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value)
2815+
with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]):
2816+
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value)
28132817
elif PLATFORM_SUPPORTS_FLASH_ATTENTION:
28142818
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value)
28152819
elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION: # e.g., we're on Windows

0 commit comments

Comments
 (0)
0