8000 [cuDNN][SDPA] Bail out of cuDNN SDPA for seqlen 1 inputs (#138531) · rahulsingh-intel/pytorch@cc59e91 · GitHub
[go: up one dir, main page]

Skip to content

Commit cc59e91

Browse files
eqymalfet
authored andcommitted
[cuDNN][SDPA] Bail out of cuDNN SDPA for seqlen 1 inputs (pytorch#138531)
Forwarded pytorch#138529 to the cuDNN team but for now but we want to avoid dispatching to unsupported cases Pull Request resolved: pytorch#138531 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
1 parent 611a307 commit cc59e91

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,12 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) {
409409
return false;
410410
}
411411
}
412+
if (s_q == 1 || s_k == 1) {
413+
if (debug) {
414+
TORCH_WARN_ONCE("cudnn SDPA does not support sequence length 1.");
415+
}
416+
return false;
417+
}
412418
return true;
413419
}
414420

test/test_transformers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2463,6 +2463,31 @@ def test_cudnn_attention_different_dk_dv(self, device):
24632463

24642464
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
24652465

2466+
@skipIfRocm # No cuDNN Attention
2467+
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
2468+
def test_fused_attention_different_dk_dv(self, device):
2469+
dtype = torch.bfloat16
2470+
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
2471+
batch, num_heads, head_dim_k, head_dim_v = 32, 16, 128, 64
2472+
seq_len = 640
2473+
q_shape = SdpaShape(batch, num_heads, 1, head_dim_k)
2474+
k_shape = SdpaShape(batch, num_heads, 2, head_dim_k)
2475+
v_shape = SdpaShape(batch, num_heads, 2, head_dim_v)
2476+
query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape)
2477+
2478+
# test that we do not dispatch to cuDNN for an unsupported case
2479+
actual = torch.nn.functional.scaled_dot_product_attention(
2480+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
2481+
with sdpa_kernel(backends=[SDPBackend.MATH]):
2482+
math_ref = torch.nn.functional.scaled_dot_product_attention(
2483+
query.contiguous().to(torch.float32),
2484+
key.contiguous().to(torch.float32),
2485+
value.contiguous().to(torch.float32),
2486+
attn_mask=None, dropout_p=0.0, is_causal=False)
2487+
2488+
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
2489+
2490+
24662491
@skipIfRocm # No cuDNN Attention
24672492
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
24682493
def test_cudnn_attention_fail_d128(self, device):

0 commit comments

Comments
 (0)
0