8000 [Cutlass] Integrate EVT into CUDACPPScheduling · pytorch/pytorch@ad7cc67 · GitHub
[go: up one dir, main page]

Skip to content

Commit ad7cc67

Browse files
committed
[Cutlass] Integrate EVT into CUDACPPScheduling
Allow epilogue nodes in cuda combined scheduling ghstack-source-id: 06f94a7 Pull Request resolved: #150906 Updates to cuda_cpp_scheduling cuda_cpp_scheduling
1 parent 3c26dcd commit ad7cc67

File tree

2 files changed

+129
-6
lines changed

2 files changed

+129
-6
lines changed

torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py

Lines changed: 129 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,13 @@
88
from ...._dynamo.utils import counters
99
from ... import config
1010
from ...codecache import code_hash, get_path
11-
from ...ir import CUDATemplateBuffer
12-
from ...scheduler import BaseSchedulerNode, BaseScheduling, SchedulerNode
11+
from ...ir import Buffer, ComputedBuffer, CUDATemplateBuffer, Pointwise
12+
from ...scheduler import (
13+
BaseSchedulerNode,
14+
BaseScheduling,
15+
FusedSchedulerNode,
16+
SchedulerNode,
17+
)
1318
from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product
1419
from ...virtualized import V
1520
from ..common import BackendFeature, IndentedBuffer
@@ -40,9 +45,32 @@ def is_cuda_cpp_template(node: BaseSchedulerNode) -> bool:
4045
node.node, CUDATemplateBuffer
4146
)
4247

48+
def is_cuda_cpp_fused_template(self, node: BaseSchedulerNode) -> bool:
49+
return isinstance(node, FusedSchedulerNode) and self.is_cuda_cpp_template(node)
50+
4351
def can_fuse_vertical(
4452
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
4553
) -> bool:
54+
if self.is_cuda_cpp_template(node1) and isinstance(node2, SchedulerNode):
55+
assert node1.node, "node1.node should not be None"
56+
assert node2.node, "node2.node should not be None"
57+
return self._can_fuse_epilogue_impl(
58+
cast(CUDATemplateBuffer, node1.node),
59+
[],
60+
node2, # type: ignore[arg-type]
61+
)
62+
elif self.is_cuda_cpp_fused_template(node1) and isinstance(
63+
node2, SchedulerNode
64+
):
65+
assert node1.node, "node1.node should not be None"
66+
assert node2.node, "node2.node should not be None"
67+
fnode1 = cast(FusedSchedulerNode, node1)
68+
return self._can_fuse_epilogue_impl(
69+
fnode1.get_template_node(), # type: ignore[arg-type]
70+
self._unwrap_epilogue_nodes(fnode1),
71+
node2, # type: ignore[arg-type]
72+
)
73+
4674
return False
4775

4876
def define_kernel(self, src_code: str, node_schedule) -> str:
@@ -94,13 +122,19 @@ def codegen_template(
94122
_, (_numel, rnumel) = template_node.group
95123
assert rnumel == 1
96124
ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node)
97-
kernel, render = ctb.make_kernel_render(ctb)
125+
epilogue_ir_nodes: list[Buffer] = [n.node for n in epilogue_nodes] # type: ignore[misc]
126+
assert all(isinstance(n, ComputedBuffer) for n in epilogue_ir_nodes), (
127+
"Epilogue nodes must all be instances of ir.ComputedBuffer"
128+
)
129+
kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_nodes)
130+
98131
with kernel:
99-
template_node.mark_run()
132+
for node in [template_node, *epilogue_nodes]:
133+
node.mark_run()
100134
src_code = render()
101135

102136
with V.set_kernel_handler(kernel):
103-
node_schedule = [template_node]
137+
node_schedule = [template_node, *epilogue_nodes]
104138
kernel_name = self.define_kernel(src_code, node_schedule)
105139

106140
# debug printing values of intermediate tensors
@@ -114,3 +148,93 @@ def codegen_template(
114148

115149
V.graph.removed_buffers |= kernel.removed_buffers
116150
self.free_buffers_in_scheduler()
151+
152+
@staticmethod
153+
def _unwrap_epilogue_nodes(
154+
fused_node: FusedSchedulerNode,
155+
) -> list[BaseSchedulerNode]:
156+
nodes = list(fused_node.get_nodes())
157+
template_node = fused_node.get_template_node()
158+
assert all(n.node is not None for n in nodes), (
159+
"All epilogue nodes should have an IRNode"
160+
)
161+
return cast(
162+
list[BaseSchedulerNode], [n for n in nodes if n.node is not template_node]
163+
)
164+
165+
def _can_fuse_epilogue_impl(
166+
self,
167+
cuda_template_buffer: CUDATemplateBuffer,
168+
epilogue_nodes: list[BaseSchedulerNode],
169+
additional_node: BaseSchedulerNode,
170+
) -> bool:
171+
"""
172+
Check if the given node can be fused with the epilogue. At the moment, Kernels
173+
support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes.
174+
175+
Args:
176+
cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer
177+
epilogue_nodes : List[ir.Buffer]: The list of already fused epilogue nodes.
178+
additional_node: The ir.Buffer node to be checked if it can be fused with the epilogue.
179+
Returns:
180+
- bool: True if the given node can be fused with the epilogue, False otherwise.
181+
182+
"""
183+
additional_ir_node = additional_node.node
184+
185+
if not isinstance(cuda_template_buffer, CUDATemplateBuffer):
186+
return False
187+
# if not cuda_template_buffer.template.can_fuse_epilogue:
188+
# # The used GEMM op does not support fusing epilogues
189+
# return False
190+
if not isinstance(additional_ir_node, ComputedBuffer):
191+
return False
192+
if not isinstance(additional_ir_node.data, Pointwise):
193+
return False
194+
# We can fuse a Pointwise op that depends on the last fused epilogue node
195+
# if any. If there is no epilogue node yet, it needs to depend on the template
196+
# node
197+
node_name = additional_ir_node.get_computed_buffer_name() # type: ignore[attr-defined]
198+
if node_name is None:
199+
return False
200+
201+
if len(epilogue_nodes) == 0:
202+
if cuda_template_buffer.name not in additional_ir_node.get_read_names():
203+
return False
204+
else:
205+
last_epilogue_node = epilogue_nodes[-1].node
206+
assert isinstance(last_epilogue_node, ComputedBuffer) # for mypy
207+
last_epilogue_name = (
208+
last_epilogue_node.name
209+
if last_epilogue_node.name is not None
210+
else last_epilogue_node.data.name # type: ignore[attr-defined]
211+
)
212+
if last_epilogue_name not in additional_ir_node.get_read_names():
213+
return False
214+
if additional_node.layout != cuda_template_buffer.layout:
215+
return False
216+
217+
try:
218+
from torch._inductor.codegen.cuda.cutlass_python_evt import (
219+
CutlassEVTCodegen,
220+
)
221+
222+
CutlassEVTCodegen.ir_to_evt_python_code(
223+
cast(str, cuda_template_buffer.name), epilogue_nodes + [additional_node]
224+
)
225+
226+
except NotImplementedError as e:
227+
not_implemented_op = str(e)
228+
if not_implemented_op.startswith("_op_"):
229+
not_implemented_op = not_implemented_op[4:]
230+
log.warning(
231+
f"Cannot fuse epilogue node {additional_node} into {cuda_template_buffer.name}, likely due to unsupported operation: {not_implemented_op}" # noqa: G004, B950
232+
)
233+
return False
234+
else: # Likely due to unsupported dtype.
235+
log.warning(
236+
f"Cannot fuse epilogue node {additional_node} into {cuda_template_buffer.name}. Reason: {not_implemented_op}" # noqa: G004, B950
237+
)
238+
return False
239+
240+
return True

torch/_inductor/codegen/cuda_combined_scheduling.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def codegen_template(
8484
prologue_nodes: Sequence[BaseSchedulerNode],
8585
) -> Optional[str]:
8686
if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node):
87-
assert not epilogue_nodes
8887
assert not prologue_nodes
8988
return self._cuda_cpp_scheduling.codegen_template(
9089
template_node, epilogue_nodes, prologue_nodes

0 commit comments

Comments
 (0)
0