8000 [Cutlass] Import cutlass python API for EVT · pytorch/pytorch@8a5f668 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8a5f668

Browse files
committed
[Cutlass] Import cutlass python API for EVT
ghstack-source-id: 7d6b727 Pull Request resolved: #150344
1 parent d6887f4 commit 8a5f668

File tree

1 file changed

+40
-22
lines changed

1 file changed

+40
-22
lines changed

torch/_inductor/codegen/cuda/cutlass_utils.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -75,28 +75,46 @@ def try_import_cutlass() -> bool:
7575
# This is a temporary hack to avoid CUTLASS module naming conflicts.
7676
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
7777

78-
cutlass_py_full_path = os.path.abspath(
79-
os.path.join(config.cuda.cutlass_dir, "python/cutlass_library")
80-
)
81-
tmp_cutlass_py_full_path = os.path.abspath(
82-
os.path.join(cache_dir(), "torch_cutlass_library")
83-
)
84-
dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library")
85-
86-
if os.path.isdir(cutlass_py_full_path):
87-
if tmp_cutlass_py_full_path not in sys.path:
88-
if os.path.exists(dst_link):
89-
assert os.path.islink(dst_link), (
90-
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
91-
)
92-
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
93-
cutlass_py_full_path
94-
), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
95-
else:
96-
os.makedirs(tmp_cutlass_py_full_path, exist_ok=True)
97-
os.symlink(cutlass_py_full_path, dst_link)
98-
sys.path.append(tmp_cutlass_py_full_path)
78+
# TODO(mlazos): epilogue visitor tree currently livers in python/cutlass,
79+
# but will be moved to python/cutlass_library in the future
80+
def path_join(path0, path1):
81+
return os.path.abspath(os.path.join(path0, path1))
82+
83+
# contains both cutlass and cutlass_library
84+
# we need cutlass for eVT
85+
cutlass_python_path = path_join(config.cuda.cutlass_dir, "python")
86+
87+
cutlass_library_src_path = path_join(cutlass_python_path, "cutlass_library")
88+
cutlass_src_path = path_join(cutlass_python_path, "cutlass")
89+
90+
tmp_cutlass_full_path = os.path.abspath(os.path.join(cache_dir(), "torch_cutlass"))
91+
92+
dst_link_library = path_join(tmp_cutlass_full_path, "cutlass_library")
93+
dst_link_cutlass = path_join(tmp_cutlass_full_path, "cutlass")
94+
95+
if os.path.isdir(cutlass_python_path):
96+
if tmp_cutlass_full_path not in sys.path:
97+
98+
def link_and_append(dst_link, src_path, parent_dir):
99+
if os.path.exists(dst_link):
100+
assert os.path.islink(dst_link), (
101+
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
102+
)
103+
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
104+
src_path,
105+
), f"Symlink at {dst_link} does not point to {src_path}"
106+
else:
107+
os.makedirs(parent_dir, exist_ok=True)
108+
os.symlink(src_path, dst_link)
109+
sys.path.append(parent_dir)
110+
111+
link_and_append(
112+
dst_link_library, cutlass_library_src_path, tmp_cutlass_full_path
113+
)
114+
link_and_append(dst_link_cutlass, cutlass_src_path, tmp_cutlass_full_path)
115+
99116
try:
117+
import cutlass # noqa: F401
100118
import cutlass_library.generator # noqa: F401
101119
import cutlass_library.library # noqa: F401
102120
import cutlass_library.manifest # noqa: F401
@@ -110,7 +128,7 @@ def try_import_cutlass() -> bool:
110128
else:
111129
log.debug(
112130
"Failed to import CUTLASS packages: CUTLASS repo does not exist: %s",
113-
cutlass_py_full_path,
131+
cutlass_python_path,
114132
)
115133
return False
116134

0 commit comments

Comments
 (0)
0