|
13 | 13 | from torch._inductor.ir import ComputedBuffer, FixedLayout, PermuteView, Pointwise
|
14 | 14 | from torch._inductor.scheduler import BaseSchedulerNode
|
15 | 15 | from torch._inductor.utils import OrderedSet
|
| 16 | +from torch.testing._internal.common_cuda import SM90OrLater |
16 | 17 | from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
17 | 18 |
|
18 | 19 |
|
@@ -108,6 +109,7 @@ def __init__(self, name_to_buffer):
|
108 | 109 |
|
109 | 110 |
|
110 | 111 | class TestCutlassEVT(TestCase):
|
| 112 | + @unittest.skipIf(not SM90OrLater, "need sm_90") |
111 | 113 | @unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
112 | 114 | def test_py_codegen_accumulator_return(self):
|
113 | 115 | from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
|
@@ -162,6 +164,7 @@ def fn(accum, buf1, buf2):
|
162 | 164 | return D, tmp_1, tmp_2""",
|
163 | 165 | )
|
164 | 166 |
|
| 167 | + @unittest.skipIf(not SM90OrLater, "need sm_90") |
165 | 168 | @unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
166 | 169 | def test_py_codegen_disjoint_read_indexing(self):
|
167 | 170 | from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
|
@@ -207,6 +210,7 @@ def inner_fn_buf4(index):
|
207 | 210 | """Unsupported indexing for buf0 with index 200*i0 + 60000*i1 + i2 and strides [200, 60000, 1]""",
|
208 | 211 | )
|
209 | 212 |
|
| 213 | + @unittest.skipIf(not SM90OrLater, "need sm_90") |
210 | 214 | @unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
211 | 215 | def test_py_codegen(self):
|
212 | 216 | from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
|
@@ -261,6 +265,7 @@ def fn(accum, buf1, buf2):
|
261 | 265 | return D, tmp_2""",
|
262 | 266 | )
|
263 | 267 |
|
| 268 | + @unittest.skipIf(not SM90OrLater, "need sm_90") |
264 | 269 | @unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
265 | 270 | def test_example_tensor_creation(self):
|
266 | 271 | from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
|
@@ -292,6 +297,7 @@ def test_example_tensor_creation(self):
|
292 | 297 | result["buf1"].element, torch_dtype_to_cutlass_type(torch.float32)
|
293 | 298 | )
|
294 | 299 |
|
| 300 | + @unittest.skipIf(not SM90OrLater, "need sm_90") |
295 | 301 | @unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
296 | 302 | def test_evt_argument_codegen(self):
|
297 | 303 | epilogue_functor = _trace(BIAS_CODE, EXAMPLE_TENSORS)
|
@@ -322,6 +328,7 @@ def test_evt_argument_codegen(self):
|
322 | 328 | """,
|
323 | 329 | )
|
324 | 330 |
|
| 331 | + @unittest.skipIf(not SM90OrLater, "need sm_90") |
325 | 332 | @unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
326 | 333 | def test_evt_codegen(self):
|
327 | 334 | _, _, code = trace(
|
|
0 commit comments