10000 [Cutlass] Import cutlass python API for EVT · pytorch/pytorch@9afdc51 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9afdc51

Browse files
committed
[Cutlass] Import cutlass python API for EVT
ghstack-source-id: 780a382 Pull Request resolved: #150344
1 parent 4d6ff6c commit 9afdc51

File tree

8 files changed

+95
-21
lines changed

8 files changed

+95
-21
lines changed

test/inductor/test_cutlass_backend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,15 @@ def mm(a, b):
135135
), "Cutlass Kernels should have been filtered, GEMM size is too small"
136136
torch.testing.assert_close(Y_compiled, Y)
137137

138+
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
139+
def test_import_cutlass(self):
140+
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
141+
142+
self.assertTrue(try_import_cutlass())
143+
144+
import cutlass # noqa: F401
145+
import cutlass_library # noqa: F401
146+
138147
@unittest.skipIf(not SM90OrLater, "need sm_90")
139148
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
140149
def test_cutlass_backend_subproc_mm(self):
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
__version__ = "12.6"
2+
3+
from .cuda import *
4+
from .cudart import *
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
2+
class CUdeviceptr:
3+
pass
4+
5+
class CUstream:
6+
def __init__(self, v):
7+
pass
8+
9+
class CUresult:
10+
pass
11+
12+
class nvrtc:
13+
pass

torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cudart.py

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Dot = None
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .special import *
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
erf = None

torch/_inductor/codegen/cuda/cutlass_utils.py

Lines changed: 66 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mypy: allow-untyped-defs
22
import functools
3+
import importlib.util
34
import logging
45
import os
56
import sys
@@ -75,28 +76,72 @@ def try_import_cutlass() -> bool:
7576
# This is a temporary hack to avoid CUTLASS module naming conflicts.
7677
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
7778

78-
cutlass_py_full_path = os.path.abspath(
79-
os.path.join(config.cuda.cutlass_dir, "python/cutlass_library")
79+
# TODO(mlazos): epilogue visitor tree currently lives in python/cutlass,
80+
# but will be moved to python/cutlass_library in the future
81+
def path_join(path0, path1):
82+
return os.path.abspath(os.path.join(path0, path1))
83+
84+
# contains both cutlass and cutlass_library
85+
# we need cutlass for eVT
86+
cutlass_python_path = path_join(config.cuda.cutlass_dir, "python")
87+
torch_root = os.path.abspath(os.path.dirname(torch.__file__))
88+
mock_src_path = os.path.join(
89+
torch_root,
90+
"_inductor",
91+
"codegen",
92+
"cuda",
93+
"cutlass_lib_extensions",
94+
"cutlass_mock_imports",
8095
)
81-
tmp_cutlass_py_full_path = os.path.abspath(
82-
os.path.join(cache_dir(), "torch_cutlass_library")
83-
)
84-
dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library")
85-
86-
if os.path.isdir(cutlass_py_full_path):
87-
if tmp_cutlass_py_full_path not in sys.path:
88-
if os.path.exists(dst_link):
89-
assert os.path.islink(dst_link), (
90-
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
91-
)
92-
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
93-
cutlass_py_full_path
94-
), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
95-
else:
96-
os.makedirs(tmp_cutlass_py_full_path, exist_ok=True)
97-
os.symlink(cutlass_py_full_path, dst_link)
98-
sys.path.append(tmp_cutlass_py_full_path)
96+
97+
cutlass_library_src_path = path_join(cutlass_python_path, "cutlass_library")
98+
cutlass_src_path = path_join(cutlass_python_path, "cutlass")
99+
pycute_src_path = path_join(cutlass_python_path, "pycute")
100+
101+
tmp_cutlass_full_path = os.path.abspath(os.path.join(cache_dir(), "torch_cutlass"))
102+
103+
dst_link_library = path_join(tmp_cutlass_full_path, "cutlass_library")
104+
dst_link_cutlass = path_join(tmp_cutlass_full_path, "cutlass")
105+
dst_link_pycute = path_join(tmp_cutlass_full_path, "pycute")
106+
107+
# mock modules to import cutlass
108+
mock_modules = ["cuda", "scipy", "pydot"]
109+
110+
if os.path.isdir(cutlass_python_path):
111+
if tmp_cutlass_full_path not in sys.path:
112+
113+
def link_and_append(dst_link, src_path, parent_dir):
114+
if os.path.exists(dst_link):
115+
assert os.path.islink(dst_link), (
116+
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
117+
)
118+
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
119+
src_path,
120+
), f"Symlink at {dst_link} does not point to {src_path}"
121+
else:
122+
os.makedirs(parent_dir, exist_ok=True)
123+
os.symlink(src_path, dst_link)
124+
125+
if parent_dir not in sys.path:
126+
sys.path.append(parent_dir)
127+
128+
link_and_append(
129+
dst_link_library, cutlass_library_src_path, tmp_cutlass_full_path
130+
)
131+
link_and_append(dst_link_cutlass, cutlass_src_path, tmp_cutlass_full_path)
132+
link_and_append(dst_link_pycute, pycute_src_path, tmp_cutlass_full_path)
133+
134+
for module in mock_modules:
135+
if not importlib.util.find_spec(module):
136+
link_and_append(
137+
path_join(tmp_cutlass_full_path, module), # dst_link
138+
path_join(mock_src_path, module), # src_path
139+
tmp_cutlass_full_path, # parent
140+
)
141+
99142
try:
143+
breakpoint()
144+
import cutlass # noqa: F401
100145
import cutlass_library.generator # noqa: F401
101146
import cutlass_library.library # noqa: F401
102147
import cutlass_library.manifest # noqa: F401
@@ -110,7 +155,7 @@ def try_import_cutlass() -> bool:
110155
else:
111156
log.debug(
112157
"Failed to import CUTLASS packages: CUTLASS repo does not exist: %s",
113-
cutlass_py_full_path,
158+
cutlass_python_path,
114159
)
115160
return False
116161

0 commit comments

Comments
 (0)
0