-
Notifications
You must be signed in to change notification settings - Fork 24.8k
[cuDNN][SDPA] Match query
's memory layout ordering for output
in cuDNN SDPA
#138354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
c719766
check in
eqy f9ea122
turn off default
eqy 5c0c0d7
fix
eqy a098629
lint
eqy 571f008
lint
eqy cfaf426
lint
eqy dd68a10
Update sdp_utils.cpp
eqy 5844d6f
update
eqy c1ef029
rework test
eqy 8351fb7
lint
eqy 9c2e329
Update MHA.cpp
eqy d29c855
cleanup permute
eqy a55d601
rework stride matching
eqy 13a521d
check in
eqy 4d53d98
Update MHA.cpp
eqy 447d63b
Update MHA.cpp
eqy 13957c8
fix
eqy 27360a9
lint
eqy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2529,9 +2529,9 @@ def test_cudnn_attention_trivial_output_transpose(self, device): | |
def test_cudnn_attention_nonmodulo64seqlen(self, device): | ||
# see also: https://github.com/pytorch/pytorch/issues/137347 | ||
mask = torch.randint(0, 2, (2, 1, 157, 6404)).to(device="cuda", dtype=torch.bool) | ||
q = torch.randn(2, 32, 157, 128, device='cuda', dtype=torch.bfloat16, requires_grad=True) | ||
k = torch.randn(2, 32, 6404, 128, device='cuda', dtype=torch.bfloat16, requires_grad=True) | ||
v = torch.randn(2, 32, 6404, 128, device='cuda', dtype=torch.bfloat16, requires_grad=True) | ||
q = torch.randn(2, 32, 157, 128, device='cuda', dtype=torch.float16, requires_grad=True) | ||
k = torch.randn(2, 32, 6404, 128, device='cuda', dtype=torch.float16, requires_grad=True) | ||
v = torch.randn(2, 32, 6404, 128, device='cuda', dtype=torch.float16, requires_grad=True) | ||
q_cpu = q.detach().clone().cpu() | ||
k_cpu = k.detach().clone().cpu() | ||
v_cpu = v.detach().clone().cpu() | ||
|
@@ -2564,6 +2564,36 @@ def test_cudnn_attention_nonmodulo64seqlen(self, device): | |
torch.testing.assert_close(k.grad, k_cpu.grad.cuda(), atol=3e-3, rtol=2e-3) | ||
torch.testing.assert_close(v.grad, v_cpu.grad.cuda(), atol=3e-3, rtol=2e-3) | ||
|
||
@skipIfRocm | ||
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") | ||
def test_cudnn_attention_preserves_query_layout(self, device): | ||
|
||
def test_attention(backend: SDPBackend, permute_order: List[List[int]]): | ||
BHSqD = [4, 16, 256, 64] | ||
BHSkvD = [4, 16, 512, 64] | ||
|
||
shape_q = [BHSqD[idx] for idx in permute_order] | ||
shape_kv = [BHSkvD[idx] for idx in permute_order] | ||
reverse = [permute_order.index(idx) for idx in range(4)] | ||
q = torch.randn(*shape_q, dtype=torch.bfloat16, device='cuda', requires_grad=True).permute(reverse) | ||
k = torch.randn(*shape_kv, dtype=torch.bfloat16, device='cuda', requires_grad=True).permute(reverse) | ||
v = torch.randn(*shape_kv, dtype=torch.bfloat16, device='cuda', requires_grad=True).permute(reverse) | ||
self.assertEqual(q.shape, BHSqD) | ||
self.assertEqual(k.shape, BHSkvD) | ||
self.assertEqual(v.shape, BHSkvD) | ||
|
||
with sdpa_kernel(backend): | ||
out = F.scaled_dot_product_attention(q, k, v) | ||
self.assertTrue(out.permute(permute_order).is_contiguous()) | ||
out.sum().backward() | ||
|
||
permute_orders = list() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit I think itertols has a permuatations func:
|
||
permutable = [0, 1, 2] | ||
permute_orders = itertools.permutations(permutable) | ||
|
||
for permute_order in permute_orders: | ||
test_attention(SDPBackend.CUDNN_ATTENTION, list(permute_order) + [3]) | ||
|
||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") | ||
@parametrize("mask_dim", [1, 2, 3, 4]) | ||
def test_mem_efficient_attention_mask_variants(self, device, mask_dim: List[int]): | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.