@@ -75,28 +75,67 @@ def try_import_cutlass() -> bool:
75
75
# This is a temporary hack to avoid CUTLASS module naming conflicts.
76
76
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
77
77
78
- cutlass_py_full_path = os .path .abspath (
79
- os .path .join (config .cuda .cutlass_dir , "python/cutlass_library" )
78
+ # TODO(mlazos): epilogue visitor tree currently lives in python/cutlass,
79
+ # but will be moved to python/cutlass_library in the future
80
+ def path_join (path0 , path1 ):
81
+ return os .path .abspath (os .path .join (path0 , path1 ))
82
+
83
+ # contains both cutlass and cutlass_library
84
+ # we need cutlass for eVT
85
+ cutlass_python_path = path_join (config .cuda .cutlass_dir , "python" )
86
+ torch_root = os .path .abspath (os .path .dirname (torch .__file__ ))
87
+ mock_cuda_src_path = os .path .join (
88
+ torch_root ,
89
+ "_inductor" ,
90
+ "codegen" ,
91
+ "cuda" ,
92
+ "cutlass_lib_extensions" ,
93
+ "mock_cuda_bindings" ,
94
+ "cuda" ,
80
95
)
81
- tmp_cutlass_py_full_path = os .path .abspath (
82
- os .path .join (cache_dir (), "torch_cutlass_library" )
83
- )
84
- dst_link = os .path .join (tmp_cutlass_py_full_path , "cutlass_library" )
85
-
86
- if os .path .isdir (cutlass_py_full_path ):
87
- if tmp_cutlass_py_full_path not in sys .path :
88
- if os .path .exists (dst_link ):
89
- assert os .path .islink (dst_link ), (
90
- f"{ dst_link } is not a symlink. Try to remove { dst_link } manually and try again."
91
- )
92
- assert os .path .realpath (os .readlink (dst_link )) == os .path .realpath (
93
- cutlass_py_full_path
94
- ), f"Symlink at { dst_link } does not point to { cutlass_py_full_path } "
95
- else :
96
- os .makedirs (tmp_cutlass_py_full_path , exist_ok = True )
97
- os .symlink (cutlass_py_full_path , dst_link )
98
- sys .path .append (tmp_cutlass_py_full_path )
96
+
97
+ cutlass_library_src_path = path_join (cutlass_python_path , "cutlass_library" )
98
+ cutlass_src_path = path_join (cutlass_python_path , "cutlass" )
99
+ pycute_src_path = path_join (cutlass_python_path , "pycute" )
100
+
101
+ tmp_cutlass_full_path = os .path .abspath (os .path .join (cache_dir (), "torch_cutlass" ))
102
+
103
+ dst_link_library = path_join (tmp_cutlass_full_path , "cutlass_library" )
104
+ dst_link_cutlass = path_join (tmp_cutlass_full_path , "cutlass" )
105
+ # cuda bindings needed to import cutlass
106
+ # pycute needed for EVT
107
+ dst_link_pycute = path_join (tmp_cutlass_full_path , "pycute" )
108
+ dst_link_mock_cuda = path_join (tmp_cutlass_full_path , "cuda" )
109
+
110
+ if os .path .isdir (cutlass_python_path ):
111
+ if tmp_cutlass_full_path not in sys .path :
112
+
113
+ def link_and_append (dst_link , src_path , parent_dir ):
114
+ if os .path .exists (dst_link ):
115
+ assert os .path .islink (dst_link ), (
116
+ f"{ dst_link } is not a symlink. Try to remove { dst_link } manually and try again."
117
+ )
118
+ assert os .path .realpath (os .readlink (dst_link )) == os .path .realpath (
119
+ src_path ,
120
+ ), f"Symlink at { dst_link } does not point to { src_path } "
121
+ else :
122
+ os .makedirs (parent_dir , exist_ok = True )
123
+ os .symlink (src_path , dst_link )
124
+
125
+ if parent_dir not in sys .path :
126
+ sys .path .append (parent_dir )
127
+
128
+ link_and_append (dst_link_pycute , pycute_src_path , tmp_cutlass_full_path )
129
+ link_and_append (
130
+ dst_link_library , cutlass_library_src_path , tmp_cutlass_full_path
131
+ )
132
+ link_and_append (dst_link_cutlass , cutlass_src_path , tmp_cutlass_full_path )
133
+ link_and_append (
134
+ dst_link_mock_cuda , mock_cuda_src_path , tmp_cutlass_full_path
135
+ )
136
+
99
137
try :
138
+ import cutlass # noqa: F401
100
139
import cutlass_library .generator # noqa: F401
101
140
import cutlass_library .library # noqa: F401
102
141
import cutlass_library .manifest # noqa: F401
@@ -110,7 +149,7 @@ def try_import_cutlass() -> bool:
110
149
else :
111
150
log .debug (
112
151
"Failed to import CUTLASS packages: CUTLASS repo does not exist: %s" ,
113
- cutlass_py_full_path ,
152
+ cutlass_python_path ,
114
153
)
115
154
return False
116
155
0 commit comments