8000 [Cutlass] Integrate EVT into CUDACPPScheduling · pytorch/pytorch@e2b32ab · GitHub
[go: up one dir, main page]

Skip to content

Commit e2b32ab

Browse files
committed
[Cutlass] Integrate EVT into CUDACPPScheduling
Allow epilogue nodes in cuda combined scheduling ghstack-source-id: 0b9f5d8 Pull Request resolved: #150906
1 parent a4df6b0 commit e2b32ab

File tree

2 files changed

+125
-6
lines changed

2 files changed

+125
-6
lines changed

torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py

Lines changed: 125 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,13 @@
88
from ...._dynamo.utils import counters
99
from ... import config
1010
from ...codecache import code_hash, get_path
11-
from ...ir import CUDATemplateBuffer
12-
from ...scheduler import BaseSchedulerNode, BaseScheduling, SchedulerNode
11+
from ...ir import Buffer, ComputedBuffer, CUDATemplateBuffer, IRNode, Pointwise
12+
from ...scheduler import (
13+
BaseSchedulerNode,
14+
BaseScheduling,
15+
FusedSchedulerNode,
16+
SchedulerNode,
17+
)
1318
from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product
1419
from ...virtualized import V
1520
from ..common import BackendFeature, IndentedBuffer
@@ -40,9 +45,32 @@ def is_cuda_cpp_template(node: BaseSchedulerNode) -> bool:
4045
node.node, CUDATemplateBuffer
4146
)
4247

48+
def is_cuda_cpp_fused_template(self, node: BaseSchedulerNode) -> bool:
49+
return isinstance(node, FusedSchedulerNode) and self.is_cuda_cpp_template(node)
50+
4351
def can_fuse_vertical(
4452
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
4553
) -> bool:
54+
if self.is_cuda_cpp_template(node1) and isinstance(node2, SchedulerNode):
55+
assert node1.node, "node1.node should not be None"
56+
assert node2.node, "node2.node should not be None"
57+
return self._can_fuse_epilogue_impl(
58+
cast(CUDATemplateBuffer, node1.node),
59+
[],
60+
node2.node, # type: ignore[arg-type]
61+
)
62+
elif self.is_cuda_cpp_fused_template(node1) and isinstance(
63+
node2, SchedulerNode
64+
):
65+
assert node1.node, "node1.node should not be None"
66+
assert node2.node, "node2.node should not be None"
67+
fnode1 = cast(FusedSchedulerNode, node1)
68+
return self._can_fuse_epilogue_impl(
69+
fnode1.get_template_node(), # type: ignore[arg-type]
70+
self._unwrap_epilogue_nodes(fnode1),
71+
node2.node, # type: ignore[arg-type]
72+
)
73+
4674
return False
4775

4876
def define_kernel(self, src_code: str, node_schedule) -> str:
@@ -94,13 +122,19 @@ def codegen_template(
94122
_, (_numel, rnumel) = template_node.group
95123
assert rnumel == 1
96124
ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node)
97-
kernel, render = ctb.make_kernel_render(ctb)
125+
epilogue_ir_nodes: list[Buffer] = [n.node for n in epilogue_nodes] # type: ignore[misc]
126+
assert all(isinstance(n, ComputedBuffer) for n in epilogue_ir_nodes), (
127+
"Epilogue nodes must all be instances of ir.ComputedBuffer"
128+
)
129+
kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes)
130+
98131
with kernel:
99-
template_node.mark_run()
132+
for node in [template_node, *epilogue_nodes]:
133+
node.mark_run()
100134
src_code = render()
101135

102136
with V.set_kernel_handler(kernel):
103-
node_schedule = [template_node]
137+
node_schedule = [template_node, *epilogue_nodes]
104138
kernel_name = self.define_kernel(src_code, node_schedule)
105139

106140
# debug printing values of intermediate tensors
@@ -114,3 +148,89 @@ def codegen_template(
114148

115149
V.graph.removed_buffers |= kernel.removed_buffers
116150
self.free_buffers_in_scheduler()
151+
152+
@staticmethod
153+
def _unwrap_epilogue_nodes(fused_node: FusedSchedulerNode) -> list[IRNode]:
154+
nodes = list(fused_node.get_nodes())
155+
template_node = fused_node.get_template_node()
156+
assert all(n.node is not None for n in nodes), (
157+
"All epilogue nodes should have an IRNode"
158+
)
159+
return cast(
160+
list[IRNode], [n.node for n in nodes if n.node is not template_node]
161+
)
162+
163+
def _can_fuse_epilogue_impl(
164+
self,
165+
cuda_template_buffer: CUDATemplateBuffer,
166+
epilogue_nodes: list[IRNode],
167+
additional_node: IRNode,
168+
) -> bool:
169+
"""
170+
Check if the given node can be fused with the epilogue. At the moment, Kernels
171+
support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes.
172+
173+
Args:
174+
cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer
175+
epilogue_nodes : List[ir.Buffer]: The list of already fused epilogue nodes.
176+
additional_node: The ir.Buffer node to be checked if it can be fused with the epilogue.
177+
Returns:
178+
- bool: True if the given node can be fused with the epilogue, False otherwise.
179+
180+
"""
181+
if not isinstance(cuda_template_buffer, CUDATemplateBuffer):
182+
return False
183+
# if not cuda_template_buffer.template.can_fuse_epilogue:
184+
# # The used GEMM op does not support fusing epilogues
185+
# return False
186+
if not isinstance(additional_node, ComputedBuffer):
187+
return False
188+
if not isinstance(additional_node.data, Pointwise):
189+
return False
190+
# We can fuse a Pointwise op that depends on the last fused epilogue node
191+
# if any. If there is no epilogue node yet, it needs to depend on the template
192+
# node
193+
node_name = additional_node.get_computed_buffer_name() # type: ignore[attr-defined]
194+
if node_name is None:
195+
return False
196+
197+
if len(epilogue_nodes) == 0:
198+
if cuda_template_buffer.name not in additional_node.get_read_names():
199+
return False
200+
else:
201+
last_epilogue_node = epilogue_nodes[-1]
202+
assert isinstance(last_epilogue_node, ComputedBuffer) # for mypy
203+
last_epilogue_name = (
204+
last_epilogue_node.name
205+
if last_epilogue_node.name is not None
206+
else last_epilogue_node.data.name # type: ignore[attr-defined]
207+
)
208+
if last_epilogue_name not in additional_node.get_read_names():
209+
return False
210+
if additional_node.layout != cuda_template_buffer.layout:
211+
return False
212+
213+
try:
214+
from torch._inductor.codegen.cuda.cutlass_epilogue_visitor import (
215+
CutlassEVTCodegen,
216+
)
217+
218+
CutlassEVTCodegen.ir_to_evt_python_code(
219+
cast(str, cuda_template_buffer.name), epilogue_nodes + [additional_node]
220+
)
221+
222+
except NotImplementedError as e:
223+
not_implemented_op = str(e)
224+
if not_implemented_op.startswith("_op_"):
225+
not_implemented_op = not_implemented_op[4:]
226+
log.warning(
227+
f"Cannot fuse epilogue node {additional_node} into {cuda_template_buffer.name}, likely due to unsupported operation: {not_implemented_op}" # noqa: G004, B950
228+
)
229+
return False
230+
else: # Likely due to unsupported dtype.
231+
log.warning(
232+
f"Cannot fuse epilogue node {additional_node} into {cuda_template_buffer.name}. Reason: {not_implemented_op}" # noqa: G004, B950
233+
)
234+
return False
235+
236+
return True

torch/_inductor/codegen/cuda_combined_scheduling.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def codegen_template(
8484
prologue_nodes: Sequence[BaseSchedulerNode],
8585
) -> Optional[str]:
8686
if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node):
87-
assert not epilogue_nodes
8887
assert not prologue_nodes
8988
return self._cuda_cpp_scheduling.codegen_template(
9089
template_node, epilogue_nodes, prologue_nodes

0 commit comments

Comments
 (0)
0