8000 [Cutlass] Import cutlass python API for EVT · Divigroup-RAP/PYTORCH@a4db2a3 · GitHub
[go: up one dir, main page]

Skip to content

Commit a4db2a3

Browse files
committed
[Cutlass] Import cutlass python API for EVT
ghstack-source-id: 309a8ff Pull Request resolved: pytorch/pytorch#150344
1 parent c9aef50 commit a4db2a3

File tree

8 files changed

+101
-19
lines changed

8 files changed

+101
-19
lines changed

test/inductor/test_cutlass_backend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,15 @@ def mm(a, b):
135135
), "Cutlass Kernels should have been filtered, GEMM size is too small"
136136
torch.testing.assert_close(Y_compiled, Y)
137137

138+
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
139+
def test_import_cutlass(self):
140+
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
141+
142+
self.assertTrue(try_import_cutlass())
143+
144+
import cutlass # noqa: F401
145+
import cutlass_library # noqa: F401
146+
138147
@unittest.skipIf(not SM90OrLater, "need sm_90")
139148
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
140149
def test_cutlass_backend_subproc_mm(self):
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import torch
2+
3+
4+
__version__ = torch.version.cuda
5+
6+
from .cuda import * # noqa: F403
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# mypy: disable-error-code="no-untyped-def"
2+
# flake8: noqa
3+
class CUdeviceptr:
4+
pass
5+
6+
7+
class CUstream:
8+
def __init__(self, v):
9+
pass
10+
11+
12+
class CUresult:
13+
pass
14+
15+
16+
class nvrtc:
17+
pass

torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cudart.py

Whitespace-only changes.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# mypy: disable-error-code="var-annotated"
2+
Dot = None
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# typing: ignore
2+
# flake8: noqa
3+
from .special import *
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# mypy: disable-error-code="var-annotated"
2+
erf = None

torch/_inductor/codegen/cuda/cutlass_utils.py

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -78,31 +78,74 @@ def try_import_cutlass() -> bool:
7878
# This is a temporary hack to avoid CUTLASS module naming conflicts.
7979
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
8080

81-
cutlass_py_full_path = os.path.abspath(
82-
os.path.join(config.cuda.cutlass_dir, "python/cutlass_library")
81+
# TODO(mlazos): epilogue visitor tree currently lives in python/cutlass,
82+
# but will be moved to python/cutlass_library in the future
83+
8000 def path_join(path0, path1):
84+
return os.path.abspath(os.path.join(path0, path1))
85+
86+
# contains both cutlass and cutlass_library
87+
# we need cutlass for eVT
88+
cutlass_python_path = path_join(config.cuda.cutlass_dir, "python")
89+
torch_root = os.path.abspath(os.path.dirname(torch.__file__))
90+
mock_src_path = os.path.join(
91+
torch_root,
92+
"_inductor",
93+
"codegen",
94+
"cuda",
95+
"cutlass_lib_extensions",
96+
"cutlass_mock_imports",
8397
)
84-
tmp_cutlass_py_full_path = os.path.abspath(
85-
os.path.join(cache_dir(), "torch_cutlass_library")
86-
)
87-
dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library")
8898

89-
if os.path.isdir(cutlass_py_full_path):
90-
if tmp_cutlass_py_full_path not in sys.path:
91-
2D08 if os.path.exists(dst_link):
92-
assert os.path.islink(dst_link), (
93-
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
99+
cutlass_library_src_path = path_join(cutlass_python_path, "cutlass_library")
100+
cutlass_src_path = path_join(cutlass_python_path, "cutlass")
101+
pycute_src_path = path_join(cutlass_python_path, "pycute")
102+
103+
tmp_cutlass_full_path = os.path.abspath(os.path.join(cache_dir(), "torch_cutlass"))
104+
105+
dst_link_library = path_join(tmp_cutlass_full_path, "cutlass_library")
106+
dst_link_cutlass = path_join(tmp_cutlass_full_path, "cutlass")
107+
dst_link_pycute = path_join(tmp_cutlass_full_path, "pycute")
108+
109+
# mock modules to import cutlass
110+
mock_modules = ["cuda", "scipy", "pydot"]
111+
112+
if os.path.isdir(cutlass_python_path):
113+
if tmp_cutlass_full_path not in sys.path:
114+
115+
def link_and_append(dst_link, src_path, parent_dir):
116+
if os.path.exists(dst_link):
117+
assert os.path.islink(dst_link), (
118+
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
119+
)
120+
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
121+
src_path,
122+
), f"Symlink at {dst_link} does not point to {src_path}"
123+
else:
124+
os.makedirs(parent_dir, exist_ok=True)
125+
os.symlink(src_path, dst_link)
126+
127+
if parent_dir not in sys.path:
128+
sys.path.append(parent_dir)
129+
130+
link_and_append(
131+
dst_link_library, cutlass_library_src_path, tmp_cutlass_full_path
132+
)
133+
link_and_append(dst_link_cutlass, cutlass_src_path, tmp_cutlass_full_path)
134+
link_and_append(dst_link_pycute, pycute_src_path, tmp_cutlass_full_path)
135+
136+
for module in mock_modules:
137+
link_and_append(
138+
path_join(tmp_cutlass_full_path, module), # dst_link
139+
path_join(mock_src_path, module), # src_path
140+
tmp_cutlass_full_path, # parent
94141
)
95-
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
96-
cutlass_py_full_path
97-
), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
98-
else:
99-
os.makedirs(tmp_cutlass_py_full_path, exist_ok=True)
100-
os.symlink(cutlass_py_full_path, dst_link)
101-
sys.path.append(tmp_cutlass_py_full_path)
142+
102143
try:
144+
import cutlass # noqa: F401
103145
import cutlass_library.generator # noqa: F401
104146
import cutlass_library.library # noqa: F401
105147
import cutlass_library.manifest # noqa: F401
148+
import pycute # type: ignore[import-not-found] # noqa: F401
106149

107150
return True
108151
except ImportError as e:
@@ -113,7 +156,7 @@ def try_import_cutlass() -> bool:
113156
else:
114157
log.debug(
115158
"Failed to import CUTLASS packages: CUTLASS repo does not exist: %s",
116-
cutlass_py_full_path,
159+
cutlass_python_path,
117160
)
118161
return False
119162

0 commit comments

Comments
 (0)
0