|
53 | 53 |
|
54 | 54 | # Use this decorator only when hitting Triton bugs on H100
|
55 | 55 | running_on_a100_only = skipUnless(
|
56 |
| - (torch.cuda.is_available() and has_triton()) |
57 |
| - and (torch.cuda.get_device_capability() == (8, 0) or torch.version.hip), |
58 |
| - "Requires Triton + A100 or Triton + ROCm", |
| 56 | + ( |
| 57 | + (torch.cuda.is_available() and has_triton()) |
| 58 | + and (torch.cuda.get_device_capability() == (8, 0) or torch.version.hip) |
| 59 | + ) |
| 60 | + or (torch.xpu.is_available() and has_triton()), |
| 61 | + "Requires Triton + A100 or Triton + ROCm or Triton + XPU", |
59 | 62 | )
|
60 | 63 |
|
61 | 64 | Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
|
@@ -4975,9 +4978,12 @@ def get_params(dtypes: list[torch.dtype]) -> list[Params]:
|
4975 | 4978 |
|
4976 | 4979 |
|
4977 | 4980 | supports_learnable_bias = unittest.skipUnless(
|
4978 |
| - (torch.cuda.is_available() and has_triton()) |
4979 |
| - and (torch.cuda.get_device_capability() >= (8, 0) or torch.version.hip), |
4980 |
| - "Requires Triton + A100 or Triton + ROCm", |
| 4981 | + ( |
| 4982 | + (torch.cuda.is_available() and has_triton()) |
| 4983 | + and (torch.cuda.get_device_capability() >= (8, 0) or torch.version.hip) |
| 4984 | + ) |
| 4985 | + or (torch.xpu.is_available() and has_triton()), |
| 4986 | + "Requires Triton + A100 or Triton + ROCm or Triton + XPU", |
4981 | 4987 | )
|
4982 | 4988 |
|
4983 | 4989 |
|
|
0 commit comments