1
1
# mypy: allow-untyped-defs
2
2
import functools
3
+ import importlib .util
3
4
import logging
4
5
import os
5
6
import sys
@@ -75,28 +76,72 @@ def try_import_cutlass() -> bool:
75
76
# This is a temporary hack to avoid CUTLASS module naming conflicts.
76
77
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
77
78
78
- cutlass_py_full_path = os .path .abspath (
79
- os .path .join (config .cuda .cutlass_dir , "python/cutlass_library" )
79
+ # TODO(mlazos): epilogue visitor tree currently lives in python/cutlass,
80
+ # but will be moved to python/cutlass_library in the future
81
+ def path_join (path0 , path1 ):
82
+ return os .path .abspath (os .path .join (path0 , path1 ))
83
+
84
+ # contains both cutlass and cutlass_library
85
+ # we need cutlass for eVT
86
+ cutlass_python_path = path_join (config .cuda .cutlass_dir , "python" )
87
+ torch_root = os .path .abspath (os .path .dirname (torch .__file__ ))
88
+ mock_src_path = os .path .join (
89
+ torch_root ,
90
+ "_inductor" ,
91
+ "codegen" ,
92
+ "cuda" ,
93
+ "cutlass_lib_extensions" ,
94
+ "cutlass_mock_imports" ,
80
95
)
81
- tmp_cutlass_py_full_path = os .path .abspath (
82
- os .path .join (cache_dir (), "torch_cutlass_library" )
83
- )
84
- dst_link = os .path .join (tmp_cutlass_py_full_path , "cutlass_library" )
85
-
86
- if os .path .isdir (cutlass_py_full_path ):
87
- if tmp_cutlass_py_full_path not in sys .path :
88
- if os .path .exists (dst_link ):
89
- assert os .path .islink (dst_link ), (
90
- f"{ dst_link } is not a symlink. Try to remove { dst_link } manually and try again."
91
- )
92
- assert os .path .realpath (os .readlink (dst_link )) == os .path .realpath (
93
- cutlass_py_full_path
94
- ), f"Symlink at { dst_link } does not point to { cutlass_py_full_path } "
95
- else :
96
- os .makedirs (tmp_cutlass_py_full_path , exist_ok = True )
97
- os .symlink (cutlass_py_full_path , dst_link )
98
- sys .path .append (tmp_cutlass_py_full_path )
96
+
97
+ cutlass_library_src_path = path_join (cutlass_python_path , "cutlass_library" )
98
+ cutlass_src_path = path_join (cutlass_python_path , "cutlass" )
99
+ pycute_src_path = path_join (cutlass_python_path , "pycute" )
100
+
101
+ tmp_cutlass_full_path = os .path .abspath (os .path .join (cache_dir (), "torch_cutlass" ))
102
+
103
+ dst_link_library = path_join (tmp_cutlass_full_path , "cutlass_library" )
104
+ dst_link_cutlass = path_join (tmp_cutlass_full_path , "cutlass" )
105
+ dst_link_pycute = path_join (tmp_cutlass_full_path , "pycute" )
106
+
107
+ # mock modules to import cutlass
108
+ mock_modules = ["cuda" , "scipy" , "pydot" ]
109
+
110
+ if os .path .isdir (cutlass_python_path ):
111
+ if tmp_cutlass_full_path not in sys .path :
112
+
113
+ def link_and_append (dst_link , src_path , parent_dir ):
114
+ if os .path .exists (dst_link ):
115
+ assert os .path .islink (dst_link ), (
116
+ f"{ dst_link } is not a symlink. Try to remove { dst_link } manually and try again."
117
+ )
118
+ assert os .path .realpath (os .readlink (dst_link )) == os .path .realpath (
119
+ src_path ,
120
+ ), f"Symlink at { dst_link } does not point to { src_path } "
121
+ else :
122
+ os .makedirs (parent_dir , exist_ok = True )
123
+ os .symlink (src_path , dst_link )
124
+
125
+ if parent_dir not in sys .path :
126
+ sys .path .append (parent_dir )
127
+
128
+ link_and_append (
129
+ dst_link_library , cutlass_library_src_path , tmp_cutlass_full_path
130
+ )
131
+ link_and_append (dst_link_cutlass , cutlass_src_path , tmp_cutlass_full_path )
132
+ link_and_append (dst_link_pycute , pycute_src_path , tmp_cutlass_full_path )
133
+
134
+ for module in mock_modules :
135
+ if not importlib .util .find_spec (module ):
136
+ link_and_append (
137
+ path_join (tmp_cutlass_full_path , module ), # dst_link
138
+ path_join (mock_src_path , module ), # src_path
139
+ tmp_cutlass_full_path , # parent
140
+ )
141
+
99
142
try :
143
+ breakpoint ()
144
+ import cutlass # noqa: F401
100
145
import cutlass_library .generator # noqa: F401
101
146
import cutlass_library .library # noqa: F401
102
147
import cutlass_library .manifest # noqa: F401
@@ -110,7 +155,7 @@ def try_import_cutlass() -> bool:
110
155
else :
111
156
log .debug (
112
157
"Failed to import CUTLASS packages: CUTLASS repo does not exist: %s" ,
113
- cutlass_py_full_path ,
158
+ cutlass_python_path ,
114
159
)
115
160
return False
116
161
0 commit comments