@@ -75,28 +75,46 @@ def try_import_cutlass() -> bool:
75
75
# This is a temporary hack to avoid CUTLASS module naming conflicts.
76
76
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
77
77
78
- cutlass_py_full_path = os .path .abspath (
79
- os .path .join (config .cuda .cutlass_dir , "python/cutlass_library" )
80
- )
81
- tmp_cutlass_py_full_path = os .path .abspath (
82
- os .path .join (cache_dir (), "torch_cutlass_library" )
83
- )
84
- dst_link = os .path .join (tmp_cutlass_py_full_path , "cutlass_library" )
85
-
86
- if os .path .isdir (cutlass_py_full_path ):
87
- if tmp_cutlass_py_full_path not in sys .path :
88
- if os .path .exists (dst_link ):
89
- assert os .path .islink (dst_link ), (
90
- f"{ dst_link } is not a symlink. Try to remove { dst_link } manually and try again."
91
- )
92
- assert os .path .realpath (os .readlink (dst_link )) == os .path .realpath (
93
- cutlass_py_full_path
94
- ), f"Symlink at { dst_link } does not point to { cutlass_py_full_path } "
95
- else :
96
- os .makedirs (tmp_cutlass_py_full_path , exist_ok = True )
97
- os .symlink (cutlass_py_full_path , dst_link )
98
- sys .path .append (tmp_cutlass_py_full_path )
78
+ # TODO(mlazos): epilogue visitor tree currently livers in python/cutlass,
79
+ # but will be moved to python/cutlass_library in the future
80
+ def path_join (path0 , path1 ):
81
+ return os .path .abspath (os .path .join (path0 , path1 ))
82
+
83
+ # contains both cutlass and cutlass_library
84
+ # we need cutlass for eVT
85
+ cutlass_python_path = path_join (config .cuda .cutlass_dir , "python" )
86
+
87
+ cutlass_library_src_path = path_join (cutlass_python_path , "cutlass_library" )
88
+ cutlass_src_path = path_join (cutlass_python_path , "cutlass" )
89
+
90
+ tmp_cutlass_full_path = os .path .abspath (os .path .join (cache_dir (), "torch_cutlass" ))
91
+
92
+ dst_link_library = path_join (tmp_cutlass_full_path , "cutlass_library" )
93
+ dst_link_cutlass = path_join (tmp_cutlass_full_path , "cutlass" )
94
+
95
+ if os .path .isdir (cutlass_python_path ):
96
+ if tmp_cutlass_full_path not in sys .path :
97
+
98
+ def link_and_append (dst_link , src_path , parent_dir ):
99
+ if os .path .exists (dst_link ):
100
+ assert os .path .islink (dst_link ), (
101
+ f"{ dst_link } is not a symlink. Try to remove { dst_link } manually and try again."
102
+ )
103
+ assert os .path .realpath (os .readlink (dst_link )) == os .path .realpath (
104
+ src_path ,
105
+ ), f"Symlink at { dst_link } does not point to { src_path } "
106
+ else :
107
+ os .makedirs (parent_dir , exist_ok = True )
108
+ os .symlink (src_path , dst_link )
109
+ sys .path .append (parent_dir )
110
+
111
+ link_and_append (
112
+ dst_link_library , cutlass_library_src_path , tmp_cutlass_full_path
113
+ )
114
+ link_and_append (dst_link_cutlass , cutlass_src_path , tmp_cutlass_full_path )
115
+
99
116
try :
117
+ import cutlass # noqa: F401
100
118
import cutlass_library .generator # noqa: F401
101
119
import cutlass_library .library # noqa: F401
102
120
import cutlass_library .manifest # noqa: F401
@@ -110,7 +128,7 @@ def try_import_cutlass() -> bool:
110
128
else :
111
129
log .debug (
112
130
"Failed to import CUTLASS packages: CUTLASS repo does not exist: %s" ,
113
- cutlass_py_full_path ,
131
+ cutlass_python_path ,
114
132
)
115
133
return False
116
134
0 commit comments