8000 [Cutlass] Import cutlass python API for EVT · pytorch/pytorch@cb7fbd9 · GitHub
[go: up one dir, main page]

Skip to content

Commit cb7fbd9

Browse files
committed
[Cutlass] Import cutlass python API for EVT
ghstack-source-id: c723dc8 Pull Request resolved: #150344
1 parent 4d6ff6c commit cb7fbd9

File tree

5 files changed

+86
-21
lines changed

5 files changed

+86
-21
lines changed

test/inductor/test_cutlass_backend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,15 @@ def mm(a, b):
135135
), "Cutlass Kernels should have been filtered, GEMM size is too small"
136136
torch.testing.assert_close(Y_compiled, Y)
137137

138+
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
139+
def test_import_cutlass(self):
140+
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
141+
142+
self.assertTrue(try_import_cutlass())
143+
144+
import cutlass # noqa: F401
145+
import cutlass_library # noqa: F401
146+
138147
@unittest.skipIf(not SM90OrLater, "need sm_90")
139148
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
140149
def test_cutlass_backend_subproc_mm(self):
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
__version__ = "12.6"
2+
3+
from .cuda import *
4+
from .cudart import *
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
2+
class CUdeviceptr:
3+
pass
4+
5+
class CUstream:
6+
def __init__(self, v):
7+
pass
8+
9+
class CUresult:
10+
pass
11+
12+
class nvrtc:
13+
pass

torch/_inductor/codegen/cuda/cutlass_lib_extensions/mock_cuda_bindings/cuda/cudart.py

Whitespace-only changes.

torch/_inductor/codegen/cuda/cutlass_utils.py

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -75,28 +75,67 @@ def try_import_cutlass() -> bool:
7575
# This is a temporary hack to avoid CUTLASS module naming conflicts.
7676
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
7777

78-
cutlass_py_full_path = os.path.abspath(
79-
os.path.join(config.cuda.cutlass_dir, "python/cutlass_library")
78+
# TODO(mlazos): epilogue visitor tree currently lives in python/cutlass,
79+
# but will be moved to python/cutlass_library in the future
80+
def path_join(path0, path1):
81+
return os.path.abspath(os.path.join(path0, path1))
82+
83+
# contains both cutlass and cutlass_library
84+
# we need cutlass for eVT
85+
cutlass_python_path = path_join(config.cuda.cutlass_dir, "python")
86+
torch_root = os.path.abspath(os.path.dirname(torch.__file__))
87+
mock_cuda_src_path = os.path.join(
88+
torch_root,
89+
"_inductor",
90+
"codegen",
91+
"cuda",
92+
"cutlass_lib_extensions",
93+
"mock_cuda_bindings",
94+
"cuda",
8095
)
81-
tmp_cutlass_py_full_path = os.path.abspath(
82-
os.path.join(cache_dir(), "torch_cutlass_library")
83-
)
84-
dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library")
85-
86-
if os.path.isdir(cutlass_py_full_path):
87-
if tmp_cutlass_py_full_path not in sys.path:
88-
if os.path.exists(dst_link):
89-
assert os.path.islink(dst_link), (
90-
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
91-
)
92-
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
93-
cutlass_py_full_path
94-
), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
95-
else:
96-
os.makedirs(tmp_cutlass_py_full_path, exist_ok=True)
97-
os.symlink(cutlass_py_full_path, dst_link)
98-
sys.path.append(tmp_cutlass_py_full_path)
96+
97+
cutlass_library_src_path = path_join(cutlass_python_path, "cutlass_library")
98+
cutlass_src_path = path_join(cutlass_python_path, "cutlass")
99+
pycute_src_path = path_join(cutlass_python_path, "pycute")
100+
101+
tmp_cutlass_full_path = os.path.abspath(os.path.join(cache_dir(), "torch_cutlass"))
102+
103+
dst_link_library = path_join(tmp_cutlass_full_path, "cutlass_library")
104+
dst_link_cutlass = path_join(tmp_cutlass_full_path, "cutlass")
105+
# cuda bindings needed to import cutlass
106+
# pycute needed for EVT
107+
dst_link_pycute = path_join(tmp_cutlass_full_path, "pycute")
108+
dst_link_mock_cuda = path_join(tmp_cutlass_full_path, "cuda")
109+
110+
if os.path.isdir(cutlass_python_path):
111+
if tmp_cutlass_full_path not in sys.path:
112+
113+
def link_and_append(dst_link, src_path, parent_dir):
114+
if os.path.exists(dst_link):
115+
assert os.path.islink(dst_link), (
116+
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
117+
)
118+
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
119+
src_path,
120+
), f"Symlink at {dst_link} does not point to {src_path}"
121+
else:
122+
os.makedirs(parent_dir, exist_ok=True)
123+
os.symlink(src_path, dst_link)
124+
125+
if parent_dir not in sys.path:
126+
sys.path.append(parent_dir)
127+
128+
link_and_append(dst_link_pycute, pycute_src_path, tmp_cutlass_full_path)
129+
link_and_append(
130+
dst_link_library, cutlass_library_src_path, tmp_cutlass_full_path
131+
)
132+
link_and_append(dst_link_cutlass, cutlass_src_path, tmp_cutlass_full_path)
133+
link_and_append(
134+
dst_link_mock_cuda, mock_cuda_src_path, tmp_cutlass_full_path
135+
)
136+
99137
try:
138+
import cutlass # noqa: F401
100139
import cutlass_library.generator # noqa: F401
101140
import cutlass_library.library # noqa: F401
102141
import cutlass_library.manifest # noqa: F401
@@ -110,7 +149,7 @@ def try_import_cutlass() -> bool:
110149
else:
111150
log.debug(
112151
"Failed to import CUTLASS packages: CUTLASS repo does not exist: %s",
113-
cutlass_py_full_path,
152+
cutlass_python_path,
114153
)
115154
return False
116155

0 commit comments

Comments
 (0)
0