|
| 1 | +import functools |
1 | 2 | import logging
|
2 |
| -from typing import Any, Dict, List |
| 3 | +from typing import Any, Dict, List, Optional |
3 | 4 |
|
4 | 5 | import torch
|
5 | 6 | from torch._inductor.virtualized import V
|
@@ -259,11 +260,19 @@ def fallback_mixed_mm(mat1, mat2, *, out):
|
259 | 260 | aten_fallback_mixed_mm = ExternKernelChoice(fallback_mixed_mm, None)
|
260 | 261 |
|
261 | 262 |
|
| 263 | +@functools.lru_cache(None) |
| 264 | +def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool: |
| 265 | + props = torch.cuda.get_device_properties(index or 0) |
| 266 | + return props.major <= 7 |
| 267 | + |
| 268 | + |
262 | 269 | def tuned_mixed_mm(mat1, mat2, mat2_dtype):
|
263 | 270 | m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None)
|
264 | 271 | choices = [aten_fallback_mixed_mm.bind((mat1, mat2), layout)]
|
265 |
| - if mat1.layout.dtype != torch.float32 and not mat2.layout.is_contiguous(): |
266 |
| - # can't use triton kernel unless one of these is true |
| 272 | + if ( |
| 273 | + mat1.layout.dtype != torch.float32 and not mat2.layout.is_contiguous() |
| 274 | + ) or _is_sm7x_or_older_gpu(layout.device.index): |
| 275 | + # can't use triton kernel unless one of these is true or if running on v100 (numerical issues) |
267 | 276 | return autotune_select_algorithm("mixed_mm", choices, [mat1, mat2], layout)
|
268 | 277 | if inductor_config.force_mixed_mm:
|
269 | 278 | choices = []
|
|
0 commit comments