8000 [Inductor] Skip triton templates for mixedmm on SM70- (#118591) (#119… · pytorch/pytorch@eef51a6 · GitHub
[go: up one dir, main page]

Skip to content

Commit eef51a6

Browse files
atalmanmalfet
andauthored
[Inductor] Skip triton templates for mixedmm on SM70- (#118591) (#119894)
As it results in numerical errors, see #117144 Fixes #117144 Pull Request resolved: #118591 Approved by: https://github.com/jansel Co-authored-by: Nikita Shulga <nshulga@meta.com>
1 parent 940358f commit eef51a6

File tree

1 file changed

+12
-3
lines changed
  • torch/_inductor/kernel

1 file changed

+12
-3
lines changed

torch/_inductor/kernel/mm.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import functools
12
import logging
2-
from typing import Any, Dict, List
3+
from typing import Any, Dict, List, Optional
34

45
import torch
56
from torch._inductor.virtualized import V
@@ -259,11 +260,19 @@ def fallback_mixed_mm(mat1, mat2, *, out):
259260
aten_fallback_mixed_mm = ExternKernelChoice(fallback_mixed_mm, None)
260261

261262

263+
@functools.lru_cache(None)
264+
def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool:
265+
props = torch.cuda.get_device_properties(index or 0)
266+
return props.major <= 7
267+
268+
262269
def tuned_mixed_mm(mat1, mat2, mat2_dtype):
263270
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None)
264271
choices = [aten_fallback_mixed_mm.bind((mat1, mat2), layout)]
265-
if mat1.layout.dtype != torch.float32 and not mat2.layout.is_contiguous():
266-
# can't use triton kernel unless one of these is true
272+
if (
273+
mat1.layout.dtype != torch.float32 and not mat2.layout.is_contiguous()
274+
) or _is_sm7x_or_older_gpu(layout.device.index):
275+
# can't use triton kernel unless one of these is true or if running on v100 (numerical issues)
267276
return autotune_select_algorithm("mixed_mm", choices, [mat1, mat2], layout)
268277
if inductor_config.force_mixed_mm:
269278
choices = []

0 commit comments

Comments
 (0)
0