8000 [Cutlass] Codegen for EVT Epilogue · pytorch/pytorch@0855d04 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0855d04

Browse files
committed
[Cutlass] Codegen for EVT Epilogue
ghstack-source-id: 2f8fd1f Pull Request resolved: #150345
1 parent 8a5f668 commit 0855d04

File tree

2 files changed

+71
-2
lines changed

2 files changed

+71
-2
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# mypy: allow-untyped-defs
2+
3+
from ..cutlass_utils import try_import_cutlass
4+
5+
6+
if try_import_cutlass():
7+
import ast
8+
import textwrap
9+
10+
from cutlass.backend.evt import ( # type: ignore[import-untyped, import-not-found]
11+
EpilogueFunctorVisitor,
12+
)
13+
from cutlass.backend.evt.backend.emitter_base import ( # type: ignore[import-untyped, import-not-found]
14+
FusionCallbacks,
15+
)
16+
from cutlass.backend.evt.backend.sm90_emitter import ( # type: ignore[import-untyped, import-not-found]
17+
CollectiveEpilogue,
18+
)
19+
from cutlass.backend.evt.frontend import ( # type: ignore[import-untyped, import-not-found]
20+
PythonASTFrontend,
21+
)
22+
from cutlass.backend.evt.ir.tensor import ( # type: ignore[import-untyped, import-not-found]
23+
Tensor as CutlassTensor,
24+
)
25+
from cutlass_library import DataType, EpilogueScheduleType, TileDescription
26+
27+
def generate(
28+
fn_src: str,
29+
example_tensors: dict[str, CutlassTensor],
30+
accum_type: DataType,
31+
output_type: DataType,
32+
tile_description: TileDescription,
33+
epilogue_schedule: EpilogueScheduleType,
34+
**kwargs,
35+
):
36+
epilogue_functor = _trace(fn_src, example_tensors, **kwargs)
37+
visitor = EpilogueFunctorVisitor(90, epilogue_functor)
38+
fusion_callbacks = FusionCallbacks(visitor.graph, 90)
39+
collective_epilogue = CollectiveEpilogue(
40+
tile_description,
41+
epilogue_schedule,
42+
accum_type,
43+
output_type,
44+
fusion_callbacks,
45+
)
46+
return (
47+
"".join(fusion_callbacks.emit())
48+
+ "\n\n"
49+
+ "".join(collective_epilogue.emit())
50+
)
51+
52+
# Based off of
53+
# https://github.com/NVIDIA/cutlass/blob/df18f5e4f5de76bed8be1de8e4c245f2f5ec3020/python/cutlass/epilogue/epilogue.py#L117
54+
# This is modified to enable directly passing the source code of the epilogue vs getting it from a bona-fide python function
55+
# The reason for this is that inspect.getsource does not work with functions defined at runtime via exec/eval
56+
def _trace(fn_src, example_tensors, **kwargs):
57+
class EpilogueFunctor(PythonASTFrontend):
58+
def __init__(self, **kwargs):
59+
self.source = textwrap.dedent(fn_src)
60+
super().__init__(**kwargs)
61+
62+
def parse(self, example_inputs):
63+
self.example_inputs = example_inputs
64+
self.ast = ast.parse(self.source)
65+
self.visit(self.ast)
66+
67+
epilogue_functor = EpilogueFunctor(**kwargs)
68+
epilogue_functor.trace(example_tensors)
69+
return epilogue_functor

torch/_inductor/codegen/cuda/cutlass_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ def try_import_cutlass() -> bool:
7575
# This is a temporary hack to avoid CUTLASS module naming conflicts.
7676
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
7777

78-
# TODO(mlazos): epilogue visitor tree currently livers in python/cutlass,
79-
# but will be moved to python/cutlass_library in the future
78+
# TODO(mlazos): epilogue visitor tree currently lives in python/cutlass,
79+
# but will be moved to python/cutlass_library in the future (later 2025)
8080
def path_join(path0, path1):
8181
return os.path.abspath(os.path.join(path0, path1))
8282

0 commit comments

Comments
 (0)
0