@@ -78,31 +78,74 @@ def try_import_cutlass() -> bool:
78
78
# This is a temporary hack to avoid CUTLASS module naming conflicts.
79
79
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
80
80
81
- cutlass_py_full_path = os .path .abspath (
82
- os .path .join (config .cuda .cutlass_dir , "python/cutlass_library" )
81
+ # TODO(mlazos): epilogue visitor tree currently lives in python/cutlass,
82
+ # but will be moved to python/cutlass_library in the future
83
+
8000
def path_join (path0 , path1 ):
84
+ return os .path .abspath (os .path .join (path0 , path1 ))
85
+
86
+ # contains both cutlass and cutlass_library
87
+ # we need cutlass for eVT
88
+ cutlass_python_path = path_join (config .cuda .cutlass_dir , "python" )
89
+ torch_root = os .path .abspath (os .path .dirname (torch .__file__ ))
90
+ mock_src_path = os .path .join (
91
+ torch_root ,
92
+ "_inductor" ,
93
+ "codegen" ,
94
+ "cuda" ,
95
+ "cutlass_lib_extensions" ,
96
+ "cutlass_mock_imports" ,
83
97
)
84
- tmp_cutlass_py_full_path = os .path .abspath (
85
- os .path .join (cache_dir (), "torch_cutlass_library" )
86
- )
87
- dst_link = os .path .join (tmp_cutlass_py_full_path , "cutlass_library" )
88
98
89
- if os .path .isdir (cutlass_py_full_path ):
90
- if tmp_cutlass_py_full_path not in sys .path :
91
-
2D08
if os .path .exists (dst_link ):
92
- assert os .path .islink (dst_link ), (
93
- f"{ dst_link } is not a symlink. Try to remove { dst_link } manually and try again."
99
+ cutlass_library_src_path = path_join (cutlass_python_path , "cutlass_library" )
100
+ cutlass_src_path = path_join (cutlass_python_path , "cutlass" )
101
+ pycute_src_path = path_join (cutlass_python_path , "pycute" )
102
+
103
+ tmp_cutlass_full_path = os .path .abspath (os .path .join (cache_dir (), "torch_cutlass" ))
104
+
105
+ dst_link_library = path_join (tmp_cutlass_full_path , "cutlass_library" )
106
+ dst_link_cutlass = path_join (tmp_cutlass_full_path , "cutlass" )
107
+ dst_link_pycute = path_join (tmp_cutlass_full_path , "pycute" )
108
+
109
+ # mock modules to import cutlass
110
+ mock_modules = ["cuda" , "scipy" , "pydot" ]
111
+
112
+ if os .path .isdir (cutlass_python_path ):
113
+ if tmp_cutlass_full_path not in sys .path :
114
+
115
+ def link_and_append (dst_link , src_path , parent_dir ):
116
+ if os .path .exists (dst_link ):
117
+ assert os .path .islink (dst_link ), (
118
+ f"{ dst_link } is not a symlink. Try to remove { dst_link } manually and try again."
119
+ )
120
+ assert os .path .realpath (os .readlink (dst_link )) == os .path .realpath (
121
+ src_path ,
122
+ ), f"Symlink at { dst_link } does not point to { src_path } "
123
+ else :
124
+ os .makedirs (parent_dir , exist_ok = True )
125
+ os .symlink (src_path , dst_link )
126
+
127
+ if parent_dir not in sys .path :
128
+ sys .path .append (parent_dir )
129
+
130
+ link_and_append (
131
+ dst_link_library , cutlass_library_src_path , tmp_cutlass_full_path
132
+ )
133
+ link_and_append (dst_link_cutlass , cutlass_src_path , tmp_cutlass_full_path )
134
+ link_and_append (dst_link_pycute , pycute_src_path , tmp_cutlass_full_path )
135
+
136
+ for module in mock_modules :
137
+ link_and_append (
138
+ path_join (tmp_cutlass_full_path , module ), # dst_link
139
+ path_join (mock_src_path , module ), # src_path
140
+ tmp_cutlass_full_path , # parent
94
141
)
95
- assert os .path .realpath (os .readlink (dst_link )) == os .path .realpath (
96
- cutlass_py_full_path
97
- ), f"Symlink at { dst_link } does not point to { cutlass_py_full_path } "
98
- else :
99
- os .makedirs (tmp_cutlass_py_full_path , exist_ok = True )
100
- os .symlink (cutlass_py_full_path , dst_link )
101
- sys .path .append (tmp_cutlass_py_full_path )
142
+
102
143
try :
144
+ import cutlass # noqa: F401
103
145
import cutlass_library .generator # noqa: F401
104
146
import cutlass_library .library # noqa: F401
105
147
import cutlass_library .manifest # noqa: F401
148
+ import pycute # type: ignore[import-not-found] # noqa: F401
106
149
107
150
return True
108
151
except ImportError as e :
@@ -113,7 +156,7 @@ def try_import_cutlass() -> bool:
113
156
else :
114
157
log .debug (
115
158
"Failed to import CUTLASS packages: CUTLASS repo does not exist: %s" ,
116
- cutlass_py_full_path ,
159
+ cutlass_python_path ,
117
160
)
118
161
return False
119
162
0 commit comments