8000 [ROCm] Update meta_registration for efficient attention (#146979) · pytorch/pytorch@83bb921 · GitHub
[go: up one dir, main page]

Skip to content

Commit 83bb921

Browse files
AmdSampsapytorchmergebot
authored andcommitted
[ROCm] Update meta_registration for efficient attention (#146979)
Fixes a series of failing and skipped unit tests. For nvidia hw, the longsumexp last dimension is required to be a multiple of 32. This is not the case for rocm. A related issue: #146848 The unit tests in question: ```bash inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_prev_13_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_prev_14_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_prev_15_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_rewriter_11_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_rewriter_14_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_rewriter_15_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_rewriter_17_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_rewriter_1_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_rewriter_1_freezing inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_rewriter_2_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_rewriter_3_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_rewriter_4_cuda inductor.test_fused_attention SDPAPatternRewriterCudaDynamicTests test_sdpa_rewriter_6_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_prev_13_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_prev_14_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_prev_15_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_rewriter_11_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_rewriter_14_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_rewriter_15_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_rewriter_17_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_rewriter_1_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_rewriter_1_freezing inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_rewriter_2_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_rewriter_3_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_rewriter_4_cuda inductor.test_fused_attention SDPAPatternRewriterCudaTests test_sdpa_rewriter_6_cuda ``` Pull Request resolved: #146979 Approved by: https://github.com/shunting314
1 parent 382fbcc commit 83bb921

File tree

4 files changed

+10
-30
lines changed

4 files changed

+10
-30
lines changed

test/inductor/test_fused_attention.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def _check_common(
105105
):
106106
self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol)
107107

108-
@skipIfRocm
109108
def _test_sdpa_rewriter_1(self):
110109
def dot_prod_attention(
111110
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
@@ -132,7 +131,6 @@ def dot_prod_attention(
132131
rtol=rtol,
133132
)
134133

135-
@skipIfRocm
136134
@torch._inductor.config.patch("freezing", True)
137135
def _test_sdpa_rewriter_1_freezing(self):
138136
def dot_prod_attention(
@@ -161,7 +159,6 @@ def dot_prod_attention(
161159
check_train=False,
162160
)
163161

164-
@skipIfRocm # https://github.com/pytorch/pytorch/issues/146848
165162
def _test_insignificant_strides(self):
166163
f32 = torch.float32
167164

@@ -265,7 +262,6 @@ def dot_prod_attention(
265262
_, (source_code,) = run_and_get_code(dot_prod_attention, *args)
266263
self.assertNotIn("aten._scaled_dot_product_efficient_attention", source_code)
267264

268-
@skipIfRocm
269265
def _test_sdpa_rewriter_2(self):
270266
def dot_prod_attention(
271267
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
@@ -280,7 +276,6 @@ def dot_prod_attention(
280276
self._check_common(dot_prod_attention)
281277
self._check_common(checkpoint_wrapper(dot_prod_attention))
282278

283-
@skipIfRocm # AssertionError: expected size 4==4, stride 32==64 at dim=0
284279
def _test_sdpa_rewriter_3(self):
285280
def dot_prod_attention(
286281
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training: bool
@@ -297,7 +292,6 @@ def dot_prod_attention(
297292
checkpoint_wrapper(dot_prod_attention), contains=False, has_dropout=True
298293
)
299294

300-
@skipIfRocm # AssertionError: expected size 4==4, stride 32==64 at dim=0
301295
def _test_sdpa_rewriter_4(self):
302296
def dot_prod_attention(
303297
query: torch.Tensor,
@@ -347,7 +341,6 @@ def sfdp_pattern_5_v2(query, key, value):
347341
self._check_common(sfdp_pattern_5_v2, contains=False)
348342
self._check_common(checkpoint_wrapper(sfdp_pattern_5_v2), contains=False)
349343

350-
@skipIfRocm
351344
def _test_sdpa_rewriter_6(self):
352345
def sfdp_pattern_6(query, key, value, training):
353346
attn_mask = torch.ones(
@@ -571,7 +564,6 @@ def forward(self, query, key, value, attn_mask) -> torch.Tensor:
571564
model, args1=args, contains=False, atol=1e-4, has_fuse_pattern=False
572565
)
573566

574-
@skipIfRocm
575567
def _test_sdpa_rewriter_11(self):
576568
def dot_prod_attention(
577569
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
@@ -612,7 +604,6 @@ def dot_prod_attention(
612604

613605
self._check_common(dot_prod_attention, contains=False, has_dropout=True)
614606

615-
@skipIfRocm
616607
def _test_sdpa_prev_13(self):
617608
def dot_prod_attention(
618609
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
@@ -629,7 +620,6 @@ def dot_prod_attention(
629620
self._check_common(dot_prod_attention, check_train=False)
630621
self._check_common(checkpoint_wrapper(dot_prod_attention), check_train=False)
631622

632-
@skipIfRocm
633623
def _test_sdpa_prev_14(self):
634624
def dot_prod_attention(
635625
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
@@ -645,7 +635,6 @@ def dot_prod_attention(
645635
self._check_common(dot_prod_attention, check_train=False)
646636
self._check_common(checkpoint_wrapper(dot_prod_attention), check_train=False)
647637

648-
@skipIfRocm
649638
def _test_sdpa_prev_15(self):
650639
def dot_prod_attention(
651640
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
@@ -695,7 +684,6 @@ def dot_prod_attention(
695684
rtol=1e-2,
696685
)
697686

698-
@skipIfRocm
699687
def _test_sdpa_rewriter_14(self):
700688
def dot_prod_attention(
701689
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
@@ -718,7 +706,6 @@ def dot_prod_attention(
718706

719707
self._check_common(dot_prod_attention)
720708

721-
@skipIfRocm
722709
def _test_sdpa_rewriter_15(self):
723710
def dot_prod_attention(
724711
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
@@ -811,7 +798,6 @@ def dot_prod_attention(
811798
dot_prod_attention, args1=args, contains=False, has_dropout=True
812799
)
813800

814-
@skipIfRocm
815801
def _test_sdpa_rewriter_17(self):
816802
def dot_prod_attention(
817803
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training

test/inductor/test_torchinductor.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11091,10 +11091,6 @@ def fn(z):
1109111091
def test_scaled_dot_product_attention(self):
1109211092
if self.device == "cuda" and not PLATFORM_SUPPORTS_FLASH_ATTENTION:
1109311093
raise unittest.SkipTest("Can't run flash attention on this platform")
11094-
if self.device == "cuda" and TEST_WITH_ROCM:
11095-
raise unittest.SkipTest(
11096-
"Flash attention support is incomplete on this platform"
11097-
)
1109811094

1109911095
def fn(q, k, v):
1110011096
return torch.nn.functional.scaled_dot_product_attention(

torch/_inductor/fx_passes/fuse_attention.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import math
66

77
import torch
8-
from torch.nn.attention import sdpa_kernel, SDPBackend
98

109
from ..._dynamo.utils import counters
1110
from ..pattern_matcher import (
@@ -20,14 +19,7 @@
2019
aten = torch.ops.aten
2120

2221

23-
if torch.version.hip:
24-
25-
def _scaled_dot_product_attention(*args, **kwargs):
26-
with sdpa_kernel(backends=[SDPBackend.MATH, SDPBackend.FLASH_ATTENTION]):
27-
return aten.scaled_dot_product_attention(*args, **kwargs)
28-
29-
else:
30-
_scaled_dot_product_attention = aten.scaled_dot_product_attention
22+
_scaled_dot_product_attention = aten.scaled_dot_product_attention
3123

3224

3325
def _sfdp_pattern_1(query, key, value, inv_scale):

torch/_meta_registrations.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4448,8 +4448,7 @@ def pool3d_shape_check(
44484448
torch._check(
44494449
dT > 0 and dW > 0 and dH > 0,
44504450
lambda: (
4451-
f"stride should be greater than zero, but got "
4452-
f"dT: {dT}, dH: {dH}, dW: {dW}"
4451+
f"stride should be greater than zero, but got dT: {dT}, dH: {dH}, dW: {dW}"
44534452
),
44544453
)
44554454
torch._check(
@@ -5724,7 +5723,14 @@ def meta__scaled_dot_product_efficient_attention(
57245723

57255724
res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
57265725

5727-
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
5726+
if torch.version.hip and torch.cuda.is_available():
5727+
"""Please see: https://github.com/pytorch/pytorch/issues/146848
5728+
longsumexp last dim should be seq length
5729+
"""
5730+
logsumexp_dim = M if compute_log_sumexp else 0
5731+
else:
5732+
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
5733+
57285734
logsum_exp = torch.empty(
57295735
(B, num_heads, logsumexp_dim),
57305736
dtype=torch.float,

0 commit comments

Comments
 (0)
0