8
8
from ...._dynamo .utils import counters
9
9
from ... import config
10
10
from ...codecache import code_hash , get_path
11
- from ...ir import CUDATemplateBuffer
12
- from ...scheduler import BaseSchedulerNode , BaseScheduling , SchedulerNode
11
+ from ...ir import Buffer , ComputedBuffer , CUDATemplateBuffer , Pointwise
12
+ from ...scheduler import (
13
+ BaseSchedulerNode ,
14
+ BaseScheduling ,
15
+ FusedSchedulerNode ,
16
+ SchedulerNode ,
17
+ )
13
18
from ...utils import get_fused_kernel_name , get_kernel_metadata , sympy_product
14
19
from ...virtualized import V
15
20
from ..common import BackendFeature , IndentedBuffer
@@ -40,9 +45,32 @@ def is_cuda_cpp_template(node: BaseSchedulerNode) -> bool:
40
45
node .node , CUDATemplateBuffer
41
46
)
42
47
48
+ def is_cuda_cpp_fused_template (self , node : BaseSchedulerNode ) -> bool :
49
+ return isinstance (node , FusedSchedulerNode ) and self .is_cuda_cpp_template (node )
50
+
43
51
def can_fuse_vertical (
44
52
self , node1 : BaseSchedulerNode , node2 : BaseSchedulerNode
45
53
) -> bool :
54
+ if self .is_cuda_cpp_template (node1 ) and isinstance (node2 , SchedulerNode ):
55
+ assert node1 .node , "node1.node should not be None"
56
+ assert node2 .node , "node2.node should not be None"
57
+ return self ._can_fuse_epilogue_impl (
58
+ cast (CUDATemplateBuffer , node1 .node ),
59
+ [],
60
+ node2 , # type: ignore[arg-type]
61
+ )
62
+ elif self .is_cuda_cpp_fused_template (node1 ) and isinstance (
63
+ node2 , SchedulerNode
64
+ ):
65
+ assert node1 .node , "node1.node should not be None"
66
+ assert node2 .node , "node2.node should not be None"
67
+ fnode1 = cast (FusedSchedulerNode , node1 )
68
+ return self ._can_fuse_epilogue_impl (
69
+ fnode1 .get_template_node (), # type: ignore[arg-type]
70
+ self ._unwrap_epilogue_nodes (fnode1 ),
71
+ node2 , # type: ignore[arg-type]
72
+ )
73
+
46
74
return False
47
75
48
76
def define_kernel (self , src_code : str , node_schedule ) -> str :
@@ -94,13 +122,19 @@ def codegen_template(
94
122
_ , (_numel , rnumel ) = template_node .group
95
123
assert rnumel == 1
96
124
ctb : CUDATemplateBuffer = cast (CUDATemplateBuffer , template_node .node )
97
- kernel , render = ctb .make_kernel_render (ctb )
125
+ epilogue_ir_nodes : list [Buffer ] = [n .node for n in epilogue_nodes ] # type: ignore[misc]
126
+ assert all (isinstance (n , ComputedBuffer ) for n in epilogue_ir_nodes ), (
127
+ "Epilogue nodes must all be instances of ir.ComputedBuffer"
128
+ )
129
+ kernel , render = ctb .make_kernel_render (ctb , epilogue_nodes = epilogue_nodes )
130
+
98
131
with kernel :
99
- template_node .mark_run ()
132
+ for node in [template_node , * epilogue_nodes ]:
133
+ node .mark_run ()
100
134
src_code = render ()
101
135
102
136
with V .set_kernel_handler (kernel ):
103
- node_schedule = [template_node ]
137
+ node_schedule = [template_node , * epilogue_nodes ]
104
138
kernel_name = self .define_kernel (src_code , node_schedule )
105
139
106
140
# debug printing values of intermediate tensors
@@ -114,3 +148,93 @@ def codegen_template(
114
148
115
149
V .graph .removed_buffers |= kernel .removed_buffers
116
150
self .free_buffers_in_scheduler ()
151
+
152
+ @staticmethod
153
+ def _unwrap_epilogue_nodes (
154
+ fused_node : FusedSchedulerNode ,
155
+ ) -> list [BaseSchedulerNode ]:
156
+ nodes = list (fused_node .get_nodes ())
157
+ template_node = fused_node .get_template_node ()
158
+ assert all (n .node is not None for n in nodes ), (
159
+ "All epilogue nodes should have an IRNode"
160
+ )
161
+ return cast (
162
+ list [BaseSchedulerNode ], [n for n in nodes if n .node is not template_node ]
163
+ )
164
+
165
+ def _can_fuse_epilogue_impl (
166
+ self ,
167
+ cuda_template_buffer : CUDATemplateBuffer ,
168
+ epilogue_nodes : list [BaseSchedulerNode ],
169
+ additional_node : BaseSchedulerNode ,
170
+ ) -> bool :
171
+ """
172
+ Check if the given node can be fused with the epilogue. At the moment, Kernels
173
+ support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes.
174
+
175
+ Args:
176
+ cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer
177
+ epilogue_nodes : List[ir.Buffer]: The list of already fused epilogue nodes.
178
+ additional_node: The ir.Buffer node to be checked if it can be fused with the epilogue.
179
+ Returns:
180
+ - bool: True if the given node can be fused with the epilogue, False otherwise.
181
+
182
+ """
183
+ additional_ir_node = additional_node .node
184
+
185 + if not isinstance (cuda_template_buffer , CUDATemplateBuffer ):
186
+ return False
187
+ # if not cuda_template_buffer.template.can_fuse_epilogue:
188
+ # # The used GEMM op does not support fusing epilogues
189
+ # return False
190
+ if not isinstance (additional_ir_node , ComputedBuffer ):
191
+ return False
192
+ if not isinstance (additional_ir_node .data , Pointwise ):
193
+ return False
194
+ # We can fuse a Pointwise op that depends on the last fused epilogue node
195
+ # if any. If there is no epilogue node yet, it needs to depend on the template
196
+ # node
197
+ node_name = additional_ir_node .get_computed_buffer_name () # type: ignore[attr-defined]
198
+ if node_name is None :
199
+ return False
200
+
201
+ if len (epilogue_nodes ) == 0 :
202
+ if cuda_template_buffer .name not in additional_ir_node .get_read_names ():
203
+ return False
204
+ else :
205
+ last_epilogue_node = epilogue_nodes [- 1 ].node
206
+ assert isinstance (last_epilogue_node , ComputedBuffer ) # for mypy
207
+ last_epilogue_name = (
208
+ last_epilogue_node .name
209
+ if last_epilogue_node .name is not None
210
+ else last_epilogue_node .data .name # type: ignore[attr-defined]
211
+ )
212
+ if last_epilogue_name not in additional_ir_node .get_read_names ():
213
+ return False
214
+ if additional_node .layout != cuda_template_buffer .layout :
215
+ return False
216
+
217
+ try :
218
+ from torch ._inductor .codegen .cuda .cutlass_python_evt import (
219
+ CutlassEVTCodegen ,
220
+ )
221
+
222
+ CutlassEVTCodegen .ir_to_evt_python_code (
223
+ cast (str , cuda_template_buffer .name ), epilogue_nodes + [additional_node ]
224
+ )
225
+
226
+ except NotImplementedError as e :
227
+ not_implemented_op = str (e )
228
+ if not_implemented_op .startswith ("_op_" ):
229
+ not_implemented_op = not_implemented_op [4 :]
230
+ log .warning (
231
+ f"Cannot fuse epilogue node { additional_node } into { cuda_template_buffer .name } , likely due to unsupported operation: { not_implemented_op } " # noqa: G004, B950
232
+ )
233
+ return False
234
+ else : # Likely due to unsupported dtype.
235
+ log .warning (
236
+ f"Cannot fuse epilogue node { additional_node } into { cuda_template_buffer .name } . Reason: { not_implemented_op } " # noqa: G004, B950
237
+ )
238
+ return False
239
+
240
+ return True
0 commit comments