8000 enable skipped cases · pytorch/pytorch@ffdb463 · GitHub
[go: up one dir, main page]

Skip to content

Commit ffdb463

Browse files
committed
enable skipped cases
1 parent 685bfda commit ffdb463

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

test/inductor/test_flex_attention.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,12 @@
5353

5454
# Use this decorator only when hitting Triton bugs on H100
5555
running_on_a100_only = skipUnless(
56-
(torch.cuda.is_available() and has_triton())
57-
and (torch.cuda.get_device_capability() == (8, 0) or torch.version.hip),
58-
"Requires Triton + A100 or Triton + ROCm",
56+
(
57+
(torch.cuda.is_available() and has_triton())
58+
and (torch.cuda.get_device_capability() == (8, 0) or torch.version.hip)
59+
)
60+
or (torch.xpu.is_available() and has_triton()),
61+
"Requires Triton + A100 or Triton + ROCm or Triton + XPU",
5962
)
6063

6164
Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
@@ -4975,9 +4978,12 @@ def get_params(dtypes: list[torch.dtype]) -> list[Params]:
49754978

49764979

49774980
supports_learnable_bias = unittest.skipUnless(
4978-
(torch.cuda.is_available() and has_triton())
4979-
and (torch.cuda.get_device_capability() >= (8, 0) or torch.version.hip),
4980-
"Requires Triton + A100 or Triton + ROCm",
4981+
(
4982+
(torch.cuda.is_available() and has_triton())
4983+
and (torch.cuda.get_device_capability() >= (8, 0) or torch.version.hip)
4984+
)
4985+
or (torch.xpu.is_available() and has_triton()),
4986+
"Requires Triton + A100 or Triton + ROCm or Triton + XPU",
49814987
)
49824988

49834989

0 commit comments

Comments
 (0)
0