8000 [Cutlass] Integrate EVT into CUDACPPScheduling (#150906) · pytorch/pytorch@d483aef · GitHub
[go: up one dir, main page]

Skip to content

Commit d483aef

Browse files
mlazospytorchmergebot
authored andcommitted
[Cutlass] Integrate EVT into CUDACPPScheduling (#150906)
Previously merged: * #151713 * #151405 * #150905 * #152306 * #152305 Allow epilogue nodes in cuda combined scheduling Pull Request resolved: #150906 Approved by: https://github.com/eellison ghstack dependencies: #152733
1 parent 6b9d741 commit d483aef

File tree

5 files changed

+164
-21
lines changed

5 files changed

+164
-21
lines changed

torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py

Lines changed: 146 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,14 @@
88
from ...._dynamo.utils import counters
99
from ... import config
1010
from ...codecache import code_hash, get_path
11-
from ...ir import CUDATemplateBuffer
12-
from ...scheduler import BaseSchedulerNode, BaseScheduling, SchedulerNode
11+
from ...ir import Buffer, ComputedBuffer, CUDATemplateBuffer, Pointwise
12+
from ...scheduler import (
13+
BaseSchedulerNode,
14+
BaseScheduling,
15+
FusedSchedulerNode,
16+
SchedulerNode,
17+
WhyNoFuse,
18+
)
1319
from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product
1420
from ...virtualized import V
1521
from ..common import BackendFeature, IndentedBuffer
@@ -18,6 +24,12 @@
1824
log = logging.getLogger(__name__)
1925

2026

27+
class WhyNoFuseNames(WhyNoFuse):
28+
def __init__(self, name1: str, name2: str) -> None:
29+
self.name1 = name1
30+
self.name2 = name2
31+
32+
2133
class CUDACPPScheduling(BaseScheduling):
2234
"""
2335
Partial Scheduling implementation for CUDA C++ Kernels.
@@ -40,9 +52,32 @@ def is_cuda_cpp_template(node: BaseSchedulerNode) -> bool:
4052
node.node, CUDATemplateBuffer
4153
)
4254

55+
def is_cuda_cpp_fused_template(self, node: BaseSchedulerNode) -> bool:
56+
return isinstance(node, FusedSchedulerNode) and self.is_cuda_cpp_template(node)
57+
4358
def can_fuse_vertical(
4459
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
4560
) -> bool:
61+
if self.is_cuda_cpp_template(node1) and isinstance(node2, SchedulerNode):
62+
assert node1.node, "node1.node should not be None"
63+
assert node2.node, "node2.node should not be None"
64+
return self._can_fuse_epilogue_impl(
65+
cast(CUDATemplateBuffer, node1.node),
66+
[],
67+
node2, # type: ignore[arg-type]
68+
)
69+
elif self.is_cuda_cpp_fused_template(node1) and isinstance(
70+
node2, SchedulerNode
71+
):
72+
assert node1.node, "node1.node should not be None"
73+
assert node2.node, "node2.node should not be None"
74+
fnode1 = cast(FusedSchedulerNode, node1)
75+
return self._can_fuse_epilogue_impl(
76+
fnode1.get_template_node(), # type: ignore[arg-type]
77+
self._unwrap_epilogue_nodes(fnode1),
78+
node2, # type: ignore[arg-type]
79+
)
80+
4681
return False
4782

4883
def define_kernel(self, src_code: str, node_schedule) -> str:
@@ -94,13 +129,19 @@ def codegen_template(
94129
_, (_numel, rnumel) = template_node.group
95130
assert rnumel == 1
96131
ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node)
97-
kernel, render = ctb.make_kernel_render(ctb)
132+
epilogue_ir_nodes: list[Buffer] = [n.node for n in epilogue_nodes] # type: ignore[misc]
133+
assert all(isinstance(n, ComputedBuffer) for n in epilogue_ir_nodes), (
134+
"Epilogue nodes must all be instances of ir.ComputedBuffer"
135+
)
136+
kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_nodes)
137+
98138
with kernel:
99-
template_node.mark_run()
139+
for node in [template_node, *epilogue_nodes]:
140+
node.mark_run()
100141
src_code = render()
101142

102143
with V.set_kernel_handler(kernel):
103-
node_schedule = [template_node]
144+
node_schedule = [template_node, *epilogue_nodes]
104145
kernel_name = self.define_kernel(src_code, node_schedule)
105146

106147
# debug printing values of intermediate tensors
@@ -114,3 +155,103 @@ def codegen_template(
114155

115156
V.graph.removed_buffers |= kernel.removed_buffers
116157
self.free_buffers_in_scheduler()
158+
159+
@staticmethod
160+
def _unwrap_epilogue_nodes(
161+
fused_node: FusedSchedulerNode,
162+
) -> list[BaseSchedulerNode]:
163+
nodes = fused_node.get_nodes()
164+
template_node = fused_node.get_template_node()
165+
assert all(n.node is not None for n in nodes), (
166+
"All epilogue nodes should have an IRNode"
167+
)
168+
return cast(
169+
list[BaseSchedulerNode], [n for n in nodes if n.node is not template_node]
170+
)
171+
172+
def _can_fuse_epilogue_impl(
173+
self,
174+
cuda_template_buffer: CUDATemplateBuffer,
175+
existing_epilogue_nodes: list[BaseSchedulerNode],
176+
node_to_fuse: BaseSchedulerNode,
177+
) -> bool:
178+
"""
179+
Check if the given node can be fused with the epilogue. At the moment, Kernels
180+
support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes.
181+
182+
Args:
183+
cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer
184+
existing_epilogue_nodes : List[SchedulerNode]: The list of already fused epilogue nodes.
185+
node_to_fuse: The SchedulerNode node to be checked if it can be fused with the epilogue.
186+
Returns:
187+
- bool: True if the given node can be fused with the epilogue, False otherwise.
188+
189+
"""
190+
191+
why = WhyNoFuseNames(cuda_template_buffer.get_name(), node_to_fuse.get_name())
192+
193+
ir_node_to_fuse = node_to_fuse.node
194+
# for typing
195+
assert ir_node_to_fuse
196+
197+
assert isinstance(cuda_template_buffer, CUDATemplateBuffer)
198+
if not isinstance(ir_node_to_fuse, ComputedBuffer):
199+
return False
200+
if not isinstance(ir_node_to_fuse.data, Pointwise):
201+
return False
202+
# We can fuse a Pointwise op that depends on the last fused epilogue node
203+
# if any. If there is no epilogue node yet, it needs to depend on the template
204+
# node
205+
node_name = ir_node_to_fuse.get_computed_buffer_name() # type: ignore[attr-defined]
206+
if node_name is None:
207+
return False
208+
209+
assert (
210+
len(existing_epilogue_nodes)
211+
or cuda_template_buffer.get_name() in ir_node_to_fuse.get_read_names()
212+
), "First epilogue node must read from cuda template buffer"
213+
214+
# dtype can differ, and strides can differ as long as they are broadcastable
215+
if ir_node_to_fuse.get_size() != cuda_template_buffer.get_size():
216+
why(
217+
f"{cuda_template_buffer.get_name()}'s size: {cuda_template_buffer.get_size()} \
218+
differs from {node_name}'s size: {ir_node_to_fuse.get_size()}"
219+
)
220+
return False
221+
elif node_to_fuse.has_aliasing_or_mutation():
222+
why(f"{node_name} has aliasing or mutation")
223+
return False
224+
elif node_to_fuse.is_reduction():
225+
why(f"{node_name} is a reduction which is not yet supported by EVT")
226+
return False
227+
elif not config.epilogue_fusion:
228+
why("epilogue fusion is not enabled")
229+
return False
230+
231+
try:
232+
from torch._inductor.codegen.cuda.cutlass_python_evt import (
233+
CutlassEVTCodegen,
234+
)
235+
236+
CutlassEVTCodegen.ir_to_evt_python_code(
237+
cuda_template_buffer.get_name(),
238+
existing_epilogue_nodes + [node_to_fuse],
239+
)
240+
241+
except NotImplementedError as e:
242+
not_implemented_op = str(e)
243+
if not_implemented_op.startswith("_op_"):
244+
not_implemented_op = not_implemented_op[4:]
245+
why(
246+
f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}, \
247+
likely due to unsupported operation: {not_implemented_op}" # noqa: G004, B950
248+
)
249+
return False
250+
else: # Likely due to unsupported dtype.
251+
why(
252+
f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}. \
253+
Reason: {not_implemented_op}" # noqa: G004, B950
254+
)
255+
return False
256+
257+
return True

torch/_inductor/codegen/cuda/cuda_template.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import functools
33
import itertools
44
from dataclasses import dataclass
5-
from typing import Any, Optional
5+
from typing import Any, Optional, TYPE_CHECKING
66
from typing_extensions import override
77
from unittest.mock import patch
88

@@ -19,6 +19,12 @@
1919
from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel
2020

2121

22+
if TYPE_CHECKING:
23+
from ...scheduler import BaseSchedulerNode # noqa: TC004
24+
else:
25+
BaseSchedulerNode = Any
26+
27+
2228
autotuning_log = getArtifactLogger(__name__, "autotuning")
2329

2430

torch/_inductor/codegen/cuda_combined_scheduling.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def codegen_template(
8484
prologue_nodes: Sequence[BaseSchedulerNode],
8585
) -> Optional[str]:
8686
if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node):
87-
assert not epilogue_nodes
8887
assert not prologue_nodes
8988
return self._cuda_cpp_scheduling.codegen_template(
9089
template_node, epilogue_nodes, prologue_nodes

179B torch/_inductor/codegen/simd.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@
3434
from ..analyze_preserves_zero_mask import prologue_preserves_zero_mask
3535
from ..codecache import code_hash
3636
from ..dependencies import MemoryDep, StarDep, WeakDep
37-
from ..ir import IRNode, TritonTemplateBuffer
37+
38+
39+
if TYPE_CHECKING:
40+
from ..ir import IRNode
41+
3842
from ..optimize_indexing import indexing_dtype_strength_reduction
3943
from ..runtime.runtime_utils import green_text, yellow_text
4044
from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse
@@ -1155,16 +1159,9 @@ def can_fuse(self, node1, node2):
11551159
)
11561160
return False
11571161

1158-
for n, node_name in zip((node1, node2), ("node1", "node2")):
1162+
for n in (node1, node2):
11591163
if n.is_template():
1160-
# Only allow fusion for TritonTemplates for now.
1161-
# Fusion for CUDATemplates are not supported.
1162-
is_triton_template = isinstance(
1163-
n.get_template_node(), TritonTemplateBuffer
1164-
)
1165-
if not is_triton_template:
1166-
why(f"{node_name} is not TritonTemplateBuffer")
1167-
return is_triton_template
1164+
return True
11681165

11691166
# check for a bad combined tiling
11701167
tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1)

torch/_inductor/scheduler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -901,21 +901,21 @@ def get_prologue_template_epilogue(
901901
class WhyNoFuse:
902902
# TODO when we drop support for Python < 3.10, we can use
903903
# @dataclass(slots=True) instead of manually specifying __slots__.
904-
__slots__ = ["node1", "node2", "reason", "args"]
904+
__slots__ = ["name1", "name2", "reason", "args"]
905905
reason: str
906906
args: tuple[Any, ...]
907907

908908
def __init__(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> None:
909-
self.node1 = node1
910-
self.node2 = node2
909+
self.name1 = node1.get_name()
910+
self.name2 = node2.get_name()
911911

912912
def __call__(self, reason: str, *args: Any) -> None:
913913
self.reason = reason
914914
self.args = args
915915
fusion_log.debug(self)
916916

917917
def __str__(self) -> str:
918-
return f"cannot fuse {self.node1.get_name()} with {self.node2.get_name()}: " + (
918+
return f"cannot fuse {self.name1} with {self.name2}: " + (
919919
self.reason % self.args
920920
)
921921

0 commit comments

Comments
 (0)
0