8
8
from ...._dynamo .utils import counters
9
9
from ... import config
10
10
from ...codecache import code_hash , get_path
11
- from ...ir import CUDATemplateBuffer
12
- from ...scheduler import BaseSchedulerNode , BaseScheduling , SchedulerNode
11
+ from ...ir import Buffer , ComputedBuffer , CUDATemplateBuffer , IRNode , Pointwise
12
+ from ...scheduler import (
13
+ BaseSchedulerNode ,
14
+ BaseScheduling ,
15
+ FusedSchedulerNode ,
16
+ SchedulerNode ,
17
+ )
13
18
from ...utils import get_fused_kernel_name , get_kernel_metadata , sympy_product
14
19
from ...virtualized import V
15
20
from ..common import BackendFeature , IndentedBuffer
@@ -40,9 +45,32 @@ def is_cuda_cpp_template(node: BaseSchedulerNode) -> bool:
40
45
node .node , CUDATemplateBuffer
41
46
)
42
47
48
+ def is_cuda_cpp_fused_template (self , node : BaseSchedulerNode ) -> bool :
49
+ return isinstance (node , FusedSchedulerNode ) and self .is_cuda_cpp_template (node )
50
+
43
51
def can_fuse_vertical (
44
52
self , node1 : BaseSchedulerNode , node2 : BaseSchedulerNode
45
53
) -> bool :
54
+ if self .is_cuda_cpp_template (node1 ) and isinstance (node2 , SchedulerNode ):
55
+ assert node1 .node , "node1.node should not be None"
56
+ assert node2 .node , "node2.node should not be None"
57
+ return self ._can_fuse_epilogue_impl (
58
+ cast (CUDATemplateBuffer , node1 .node ),
59
+ [],
60
+ node2 .node , # type: ignore[arg-type]
61
+ )
62
+ elif self .is_cuda_cpp_fused_template (node1 ) and isinstance (
63
+ node2 , SchedulerNode
64
+ ):
65
+ assert node1 .node , "node1.node should not be None"
66
+ assert node2 .node , "node2.node should not be None"
67
+ fnode1 = cast (FusedSchedulerNode , node1 )
68
+ return self ._can_fuse_epilogue_impl (
69
+ fnode1 .get_template_node (), # type: ignore[arg-type]
70
+ self ._unwrap_epilogue_nodes (fnode1 ),
71
+ node2 .node , # type: ignore[arg-type]
72
+ )
73
+
46
74
return False
47
75
48
76
def define_kernel (self , src_code : str , node_schedule ) -> str :
@@ -94,13 +122,19 @@ def codegen_template(
94
122
_ , (_numel , rnumel ) = template_node .group
95
123
assert rnumel == 1
96
124
ctb : CUDATemplateBuffer = cast (CUDATemplateBuffer , template_node .node )
97
- kernel , render = ctb .make_kernel_render (ctb )
125
+ epilogue_ir_nodes : list [Buffer ] = [n .node for n in epilogue_nodes ] # type: ignore[misc]
126
+ assert all (isinstance (n , ComputedBuffer ) for n in epilogue_ir_nodes ), (
127
+ "Epilogue nodes must all be instances of ir.ComputedBuffer"
128
+ )
129
+ kernel , render = ctb .make_kernel_render (ctb , epilogue_nodes = epilogue_ir_nodes )
130
+
98
131
with kernel :
99
- template_node .mark_run ()
132
+ for node in [template_node , * epilogue_nodes ]:
133
+ node .mark_run ()
100
134
src_code = render ()
101
135
102
136
with V .set_kernel_handler (kernel ):
103
- node_schedule = [template_node ]
137
+ node_schedule = [template_node , * epilogue_nodes ]
104
138
kernel_name = self .define_kernel (src_code , node_schedule )
105
139
106
140
# debug printing values of intermediate tensors
@@ -114,3 +148,89 @@ def codegen_template(
114
148
115
149
V .graph .removed_buffers |= kernel .removed_buffers
116
150
self .free_buffers_in_scheduler ()
151
+
152
+ @staticmethod
153
+ def _unwrap_epilogue_nodes (fused_node : FusedSchedulerNode ) -> list [IRNode ]:
154
+ nodes = list (fused_node .get_nodes ())
155
+ template_node = fused_node .get_template_node ()
156
+ assert all (n .node is not None for n in nodes ), (
157
+ "All epilogue nodes should have an IRNode"
158
+ )
159
+ return cast (
160
+ list [IRNode ], [n .node for n in nodes if n .node is not template_node ]
161
+ )
162
+
163
+ def _can_fuse_epilogue_impl (
164
+ self ,
165
+ cuda_template_buffer : CUDATemplateBuffer ,
166
+ epilogue_nodes : list [IRNode ],
167
+ additional_node : IRNode ,
168
+ ) -> bool :
169
+ """
170
+ Check if the given node can be fused with the epilogue. At the moment, Kernels
171
+ support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes.
172
+
173
+ Args:
174
+ cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer
175
+ epilogue_nodes : List[ir.Buffer]: The list of already fused epilogue nodes.
176
+ additional_node: The ir.Buffer node to be checked if it can be fused with the epilogue.
177
+ Returns:
178
+ - bool: True if the given node can be fused with the epilogue, False otherwise.
179
+
180
+ """
181
+ if not isinstance (cuda_template_buffer , CUDATemplateBuffer ):
182
+ return False
183
+ # if not cuda_template_buffer.template.can_fuse_epilogue:
184
+ # # The used GEMM op does not support fusing epilogues
185
+ # return False
186
+ if not isinstance (additional_node , ComputedBuffer ):
187
+ return False
188
+ if not isinstance (additional_node .data , Pointwise ):
189
+ return False
190
+ # We can fuse a Pointwise op that depends on the last fused epilogue node
191
+ # if any. If there is no epilogue node yet, it needs to depend on the template
192
+ # node
193
+ node_name = additional_node .get_computed_buffer_name () # type: ignore[attr-defined]
194
+ if node_name is None :
195
+ return False
196
+
197
+ if len (epilogue_nodes ) == 0 :
198
+ if cuda_template_buffer .name not in additional_node .get_read_names ():
199
+ return False
200
+ else :
201
+ last_epilogue_node = epilogue_nodes [- 1 ]
202
+ assert isinstance (last_epilogue_node , ComputedBuffer ) # for mypy
203
+ last_epilogue_name = (
204
+ last_epilogue_node .name
205
+ if last_epilogue_node .name is not None
206
+ else last_epilogue_node .data .name # type: ignore[attr-defined]
207
+ )
208
+ if last_epilogue_name not in additional_node .get_read_names ():
209
+ return False
210
+ if additional_node .layout != cuda_template_buffer .layout :
211
+ return False
212
+
213
+ try :
214
+ from torch ._inductor .codegen .cuda .cutlass_epilogue_visitor import (
215
+ CutlassEVTCodegen ,
216
+ )
217
+
218
+ CutlassEVTCodegen .ir_to_evt_python_code (
219
+ cast (str , cuda_template_buffer .name ), epilogue_nodes + [additional_node ]
220
+ )
221
+
222
+ except NotImplementedError as e :
223
+ not_implemented_op = str (e )
224
+ if not_implemented_op .startswith ("_op_" ):
225
+ not_implemented_op = not_implemented_op [4 :]
226
+ log .warning (
227
+ f"Cannot fuse epilogue node { additional_node } into { cuda_template_buffer .name } , likely due to unsupported operation: { not_implemented_op } " # noqa: G004, B950
228
+ )
229
+ return False
230
+ else : # Likely due to unsupported dtype.
231
+ log .warning (
232
+ f"Cannot fuse epilogue node { additional_node } into { cuda_template_buffer .name } . Reason: { not_implemented_op } " # noqa: G004, B950
233
+ )
234
+ return False
235
+
236
+ return True
0 commit comments