10000 Update · pytorch/pytorch@7703789 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7703789

Browse files
committed
Update
[ghstack-poisoned]
1 parent 1941cf1 commit 7703789

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

test/inductor/test_flex_attention.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@ def skip_on_cuda(test_func):
106106
return decorated_func
107107

108108

109+
def skip_on_rocm(test_func):
110+
"""Decorator to skip tests that are not supported on CUDA."""
111+
IS_ROCM = torch.cuda.is_available() and torch.version.hip
112+
decorated_func = skipCUDAIf(IS_ROCM, "Not supported on ROCM")(test_func)
113+
return decorated_func
114+
115+
109116
def rmse(ref, res):
110117
"""
111118
Calculate root mean squared error
@@ -1398,6 +1405,7 @@ def mask_mod(b, h, q, kv):
13981405
@dtypes(*device_configs["cpu"].dtypes_fast)
13991406
@dtypesIfCUDA(*device_configs["cuda"].dtypes_fast)
14001407
@common_utils.parametrize("score_mod", test_score_mods)
1408+
@skip_on_rocm # TODO: NaNs on ROCM
14011409
def test_GQA(self, device, dtype: torch.dtype, score_mod: Callable):
14021410
inputs = (
14031411
score_mod,
@@ -1899,6 +1907,7 @@ def f(q, k1, k2, v1, v2):
18991907

19001908
@supported_platform
19011909
@skip_on_cpu
1910+
@skip_on_rocm # TODO: Investigate
19021911
def test_multiple_mask_calls(self, device):
19031912
make_tensor = functools.partial(
19041913
torch.randn,

0 commit comments

Comments
 (0)
0