8
8
from ...._dynamo .utils import counters
9
9
from ... import config
10
10
from ...codecache import code_hash , get_path
11
- from ...ir import CUDATemplateBuffer
12
- from ...scheduler import BaseSchedulerNode , BaseScheduling , SchedulerNode
11
+ from ...ir import Buffer , ComputedBuffer , CUDATemplateBuffer , Pointwise
12
+ from ...scheduler import (
13
+ BaseSchedulerNode ,
14
+ BaseScheduling ,
15
+ FusedSchedulerNode ,
16
+ SchedulerNode ,
17
+ WhyNoFuse ,
18
+ )
13
19
from ...utils import get_fused_kernel_name , get_kernel_metadata , sympy_product
14
20
from ...virtualized import V
15
21
from ..common import BackendFeature , IndentedBuffer
18
24
log = logging .getLogger (__name__ )
19
25
20
26
27
+ class WhyNoFuseNames (WhyNoFuse ):
28
+ def __init__ (self , name1 : str , name2 : str ) -> None :
29
+ self .name1 = name1
30
+ self .name2 = name2
31
+
32
+
21
33
class CUDACPPScheduling (BaseScheduling ):
22
34
"""
23
35
Partial Scheduling implementation for CUDA C++ Kernels.
@@ -40,9 +52,32 @@ def is_cuda_cpp_template(node: BaseSchedulerNode) -> bool:
40
52
node .node , CUDATemplateBuffer
41
53
)
42
54
55
+ def is_cuda_cpp_fused_template (self , node : BaseSchedulerNode ) -> bool :
56
+ return isinstance (node , FusedSchedulerNode ) and self .is_cuda_cpp_template (node )
57
+
43
58
def can_fuse_vertical (
44
59
self , node1 : BaseSchedulerNode , node2 : BaseSchedulerNode
45
60
) -> bool :
61
+ if self .is_cuda_cpp_template (node1 ) and isinstance (node2 , SchedulerNode ):
62
+ assert node1 .node , "node1.node should not be None"
63
+ assert node2 .node , "node2.node should not be None"
64
+ return self ._can_fuse_epilogue_impl (
65
+ cast (CUDATemplateBuffer , node1 .node ),
66
+ [],
67
+ node2 , # type: ignore[arg-type]
68
+ )
69
+ elif self .is_cuda_cpp_fused_template (node1 ) and isinstance (
70
+ node2 , SchedulerNode
71
+ ):
72
+ assert node1 .node , "node1.node should not be None"
73
+ assert node2 .node , "node2.node should not be None"
74
+ fnode1 = cast (FusedSchedulerNode , node1 )
75
+ return self ._can_fuse_epilogue_impl (
76
+ fnode1 .get_template_node (), # type: ignore[arg-type]
77
+ self ._unwrap_epilogue_nodes (fnode1 ),
78
+ node2 , # type: ignore[arg-type]
79
+ )
80
+
46
81
return False
47
82
48
83
def define_kernel (self , src_code : str , node_schedule ) -> str :
@@ -94,13 +129,19 @@ def codegen_template(
94
129
_ , (_numel , rnumel ) = template_node .group
95
130
assert rnumel == 1
96
131
ctb : CUDATemplateBuffer = cast (CUDATemplateBuffer , template_node .node )
97
- kernel , render = ctb .make_kernel_render (ctb )
132
+ epilogue_ir_nodes : list [Buffer ] = [n .node for n in epilogue_nodes ] # type: ignore[misc]
133
+ assert all (isinstance (n , ComputedBuffer ) for n in epilogue_ir_nodes ), (
134
+ "Epilogue nodes must all be instances of ir.ComputedBuffer"
135
+ )
136
+ kernel , render = ctb .make_kernel_render (ctb , epilogue_nodes = epilogue_nodes )
137
+
98
138
with kernel :
99
- template_node .mark_run ()
139
+ for node in [template_node , * epilogue_nodes ]:
140
+ node .mark_run ()
100
141
src_code = render ()
101
142
102
143
with V .set_kernel_handler (kernel ):
103
- node_schedule = [template_node ]
144
+ node_schedule = [template_node , * epilogue_nodes ]
104
145
kernel_name = self .define_kernel (src_code , node_schedule )
105
146
106
147
# debug printing values of intermediate tensors
@@ -114,3 +155,103 @@ def codegen_template(
114
155
115
156
V .graph .removed_buffers |= kernel .removed_buffers
116
157
self .free_buffers_in_scheduler ()
158
+
159
+ @staticmethod
160
+ def _unwrap_epilogue_nodes (
161
+ fused_node : FusedSchedulerNode ,
162
+ ) -> list [BaseSchedulerNode ]:
163
+ nodes = fused_node .get_nodes ()
164
+ template_node = fused_node .get_template_node ()
165
+ assert all (n .node is not None for n in nodes ), (
166
+ "All epilogue nodes should have an IRNode"
167
+ )
168
+ return cast (
169
+ list [BaseSchedulerNode ], [n for n in nodes if n .node is not template_node ]
170
+ )
171
+
172
+ def _can_fuse_epilogue_impl (
173
+ self ,
174
+ cuda_template_buffer : CUDATemplateBuffer ,
175
+ existing_epilogue_nodes : list [BaseSchedulerNode ],
176
+ node_to_fuse : BaseSchedulerNode ,
177
+ ) -> bool :
178
+ """
179
+ Check if the given node can be fused with the epilogue. At the moment, Kernels
180
+ support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes.
181
+
182
+ Args:
183
+ cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer
184
+ existing_epilogue_nodes : List[SchedulerNode]: The list of already fused epilogue nodes.
185
+ node_to_fuse: The SchedulerNode node to be checked if it can be fused with the epilogue.
186
+ Returns:
187
+ - bool: True if the given node can be fused with the epilogue, False otherwise.
188
+
189
+ """
190
+
191
+ why = WhyNoFuseNames (cuda_template_buffer .get_name (), node_to_fuse .get_name ())
192
+
193
+ ir_node_to_fuse = node_to_fuse .node
194
+ # for typing
195
+ assert ir_node_to_fuse
196
+
197
+ assert isinstance (cuda_template_buffer , CUDATemplateBuffer )
198
+ if not isinstance (ir_node_to_fuse , ComputedBuffer ):
199
+ return False
200
+ if not isinstance (ir_node_to_fuse .data , Pointwise ):
201
+ return False
202
+ # We can fuse a Pointwise op that depends on the last fused epilogue node
203
+ # if any. If there is no epilogue node yet, it needs to depend on the template
204
+ # node
205
+ node_name = ir_node_to_fuse .get_computed_buffer_name () # type: ignore[attr-defined]
206
+ if node_name is None :
207
+ return False
208
+
209
+ assert (
210
+ len (existing_epilogue_nodes )
211
+ or cuda_template_buffer .get_name () in ir_node_to_fuse .get_read_names ()
212
+ ), "First epilogue node must read from cuda template buffer"
213
+
214
+ # dtype can differ, and strides can differ as long as they are broadcastable
215
+ if ir_node_to_fuse .get_size () != cuda_template_buffer .get_size ():
216
+ why (
217
+ f"{ cuda_template_buffer .get_name ()} 's size: { cuda_template_buffer .get_size ()} \
218
+ differs from { node_name } 's size: { ir_node_to_fuse .get_size ()} "
219
+ )
220
+ return False
221
+ elif node_to_fuse .has_aliasing_or_mutation ():
222
+ why (f"{ node_name } has aliasing or mutation" )
223
+ return False
224
+ elif node_to_fuse .is_reduction ():
225
+ why (f"{ node_name } is a reduction which is not yet supported by EVT" )
226
+ return False
227
+ elif not config .epilogue_fusion :
228
+ why ("epilogue fusion is not enabled" )
229
+ return False
230
+
231
+ try :
232
+ from torch ._inductor .codegen .cuda .cutlass_python_evt import (
233
+ CutlassEVTCodegen ,
234
+ )
235
+
236
+ CutlassEVTCodegen .ir_to_evt_python_code (
237
+ cuda_template_buffer .get_name (),
238
+ existing_epilogue_nodes + [node_to_fuse ],
239
+ )
240
+
241
+ except NotImplementedError as e :
242
+ not_implemented_op = str (e )
243
+ if not_implemented_op .startswith ("_op_" ):
244
+ not_implemented_op = not_implemented_op [4 :]
245
+ why (
246
+ f"Cannot fuse epilogue node { node_to_fuse } into { cuda_template_buffer .name } , \
247
+ likely due to unsupported operation: { not_implemented_op } " # noqa: G004, B950
248
+ )
249
+ return False
250
+ else : # Likely due to unsupported dtype.
251
+ why (
252
+ f"Cannot fuse epilogue node { node_to_fuse } into { cuda_template_buffer .name } . \
253
+ Reason: { not_implemented_op } " # noqa: G004, B950
254
+ )
255
+ return False
256
+
257
+ return True
0 commit comments