8000 [Cutlass] Import cutlass python API for EVT by mlazos · Pull Request #150344 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Cutlass] Import cutlass python API for EVT #150344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,15 @@ def mm(a, b):
), "Cutlass Kernels should have been filtered, GEMM size is too small"
torch.testing.assert_close(Y_compiled, Y)

@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_import_cutlass(self):
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass

self.assertTrue(try_import_cutlass())

import cutlass # noqa: F401
import cutlass_library # noqa: F401

@unittest.skipIf(not SM90OrLater, "need sm_90")
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_cutlass_backend_subproc_mm(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch


__version__ = torch.version.cuda

from .cuda import * # noqa: F403
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# mypy: disable-error-code="no-untyped-def"
# flake8: noqa
class CUdeviceptr:
pass


class CUstream:
def __init__(self, v):
pass


class CUresult:
pass


class nvrtc:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# mypy: disable-error-code="var-annotated"
Dot = None
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# typing: ignore
# flake8: noqa
from .special import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# mypy: disable-error-code="var-annotated"
erf = None
81 changes: 62 additions & 19 deletions torch/_inductor/codegen/cuda/cutlass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,31 +78,74 @@ def try_import_cutlass() -> bool:
# This is a temporary hack to avoid CUTLASS module naming conflicts.
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.

cutlass_py_full_path = os.path.abspath(
os.path.join(config.cuda.cutlass_dir, "python/cutlass_library")
# TODO(mlazos): epilogue visitor tree currently lives in python/cutlass,
# but will be moved to python/cutlass_library in the future
def path_join(path0, path1):
return os.path.abspath(os.path.join(path0, path1))

# contains both cutlass and cutlass_library
# we need cutlass for eVT
cutlass_python_path = path_join(config.cuda.cutlass_dir, "python")
torch_root = os.path.abspath(os.path.dirname(torch.__file__))
mock_src_path = os.path.join(
torch_root,
"_inductor",
"codegen",
"cuda",
"cutlass_lib_extensions",
"cutlass_mock_imports",
)
tmp_cutlass_py_full_path = os.path.abspath(
os.path.join(cache_dir(), "torch_cutlass_library")
)
dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library")

if os.path.isdir(cutlass_py_full_path):
if tmp_cutlass_py_full_path not in sys.path:
if os.path.exists(dst_link):
assert os.path.islink(dst_link), (
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
cutlass_library_src_path = path_join(cutlass_python_path, "cutlass_library")
cutlass_src_path = path_join(cutlass_python_path, "cutlass")
pycute_src_path = path_join(cutlass_python_path, "pycute")

tmp_cutlass_full_path = os.path.abspath(os.path.join(cache_dir(), "torch_cutlass"))

dst_link_library = path_join(tmp_cutlass_full_path, "cutlass_library")
dst_link_cutlass = path_join(tmp_cutlass_full_path, "cutlass")
dst_link_pycute = path_join(tmp_cutlass_full_path, "pycute")

# mock modules to import cutlass
mock_modules = ["cuda", "scipy", "pydot"]

if os.path.isdir(cutlass_python_path):
if tmp_cutlass_full_path not in sys.path:

def link_and_append(dst_link, src_path, parent_dir):
if os.path.exists(dst_link):
assert os.path.islink(dst_link), (
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
)
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
src_path,
), f"Symlink at {dst_link} does not point to {src_path}"
else:
os.makedirs(parent_dir, exist_ok=True)
os.symlink(src_path, dst_link)

if parent_dir not in sys.path:
sys.path.append(parent_dir)

link_and_append(
dst_link_library, cutlass_library_src_path, tmp_cutlass_full_path
)
link_and_append(dst_link_cutlass, cutlass_src_path, tmp_cutlass_full_path)
link_and_append(dst_link_pycute, pycute_src_path, tmp_cutlass_full_path)

for module in mock_modules:
link_and_append(
path_join(tmp_cutlass_full_path, module), # dst_link
path_join(mock_src_path, module), # src_path
tmp_cutlass_full_path, # parent
)
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
cutlass_py_full_path
), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
else:
os.makedirs(tmp_cutlass_py_full_path, exist_ok=True)
os.symlink(cutlass_py_full_path, dst_link)
sys.path.append(tmp_cutlass_py_full_path)

try:
import cutlass # noqa: F401
import cutlass_library.generator # noqa: F401
import cutlass_library.library # noqa: F401
import cutlass_library.manifest # noqa: F401
import pycute # type: ignore[import-not-found] # noqa: F401

return True
except ImportError as e:
Expand All @@ -113,7 +156,7 @@ def try_import_cutlass() -> bool:
else:
log.debug(
"Failed to import CUTLASS packages: CUTLASS repo does not exist: %s",
cutlass_py_full_path,
cutlass_python_path,
)
return False

Expand Down
Loading
0