@@ -115,8 +115,9 @@ def trace(
115
115
output_type : DataType ,
116
116
tile_description : TileDescription ,
117
117
epilogue_schedule : EpilogueScheduleType ,
118
+ name_to_buffer : dict [str , Buffer ],
118
119
** kwargs : dict [str , Any ],
119
- ) -> tuple [str , str ]:
120
+ ) -> tuple [str , str , str ]:
120
121
cuda_arch = int (cuda_env .get_cuda_arch ()) # type: ignore[arg-type]
121
122
assert cuda_arch >= 90 , "Only SM90+ is supported for EVT"
122
123
epilogue_functor = _trace (fn_src , example_tensors , ** kwargs )
@@ -129,8 +130,9 @@ def trace(
129
130
output_type ,
130
131
fusion_callbacks ,
131
132
)
132
-
133
- return collective_epilogue .emit ()
133
+ evt_name , evt_code = collective_epilogue .emit ()
134
+ evt_args = _render_argument_type (epilogue_functor , name_to_buffer )
135
+ return evt_name , evt_args , evt_code
134
136
135
137
# Based off of
136
138
# https://github.com/NVIDIA/cutlass/blob/df18f5e4f5de76bed8be1de8e4c245f2f5ec3020/python/cutlass/epilogue/epilogue.py#L117
@@ -167,33 +169,42 @@ def is_nested_visitor_type(t: type) -> bool:
167
169
)
168
170
169
171
buffer = IndentedBuffer ()
172
+ with buffer .set_tabwidth (2 ):
170
173
171
- def render_argument_type (name : str , t : CutlassArgType ) -> None :
172
- if issubclass (t , ctypes .c_byte ):
173
- buffer .writeline (f"{{}}, /* { name } */" )
174
- else :
175
- fields = [
176
- (fname , _get_arg_from_node (ty , name_to_buffer [name ]))
177
- for fname , ty in t ._fields_
178
- ]
179
- field_strs = [f"/* { fname } */ { str (field )} " for fname , field in fields ]
180
- buffer .writeline (f"{{{ ', ' .join (field_strs )} }}, /* { name } */" )
181
-
182
- def render_thread_type (name : str , t : CutlassArgType ) -> None :
183
- if is_nested_visitor_type (t ):
184
- buffer .writeline (f"{{ /* { name } */" )
185
- with buffer .indent ():
186
- for name , inner_t in t ._fields_ :
187
- render_thread_type (name , inner_t )
188
- buffer .writeline ("}," )
189
- else :
190
- render_argument_type (name , t )
191
-
192
- buffer .writeline ("{{" )
193
- with buffer .indent ():
194
- render_thread_type ("thread" , epilogue_thread_type )
195
-
196
- buffer .writeline ("}};" )
174
+ def render_argument_type (name : str , t : CutlassArgType ) -> None :
175
+ if issubclass (t , ctypes .c_byte ):
176
+ buffer .writeline (f"{{}}, /* { name } */" )
177
+ else :
178
+ fields = [
179
+ (fname , _get_arg_from_node (ty , name_to_buffer [name ]))
180
+ for fname , ty in t ._fields_
181
+ ]
182
+ field_strs = [
183
+ f"/* { fname } */ { str (field )} " for fname , field in fields
184
+ ]
185
+ buffer .writeline (f"{{{ ', ' .join (field_strs )} }}, /* { name } */" )
186
+
187
+ def render_thread_type (name : str , t : CutlassArgType ) -> None :
188
+ if is_nested_visitor_type (t ):
189
+ buffer .writeline (f"{{ /* { name } */" )
190
+ with buffer .indent ():
191
+ for name , inner_t in t ._fields_ :
192
+ render_thread_type (name , inner_t )
193
+ buffer .writeline ("}," )
194
+ else :
195
+ render_argument_type (name , t )
196
+
197
+ # unroll the recursion once to address special case formatting
198
+ # namely, no ending comma and no indentation for the outermost thread type
199
+ buffer .writeline ("{ /* thread */" )
200
+ with buffer .indent (3 ):
201
+ if is_nested_visitor_type (epilogue_thread_type ):
202
+ with buffer .indent ():
203
+ for name , inner_t in epilogue_thread_type ._fields_ :
204
+ render_thread_type (name , inner_t )
205
+ else :
206
+ render_argument_type ("thread" , epilogue_thread_type )
207
+ buffer .writeline ("}" )
197
208
198
209
return buffer .getvalue ()
199
210
@@ -225,11 +236,11 @@ def render_stride(x: int) -> str:
225
236
return f"{{{ ', ' .join ([render_stride (x ) for x in stride ])} }}"
226
237
227
238
elif issubclass (arg_ty , ctypes .c_void_p ):
228
- return f"{ node .get_name () } .get() "
239
+ return f"( { CUTLASSTemplate . _DTYPE_TO_CUTLASS [ node .get_layout (). dtype ] } *) { node . get_name () } "
229
240
elif (
230
241
arg_ty in _CUTLASS_C_DTYPES
231
242
): # Assumption: this is the element dtype, this holds for all cutlass ir nodes currently
232
- return CUTLASSTemplate ._DTYPE_TO_CUTLASS [node .get_layout ().dtype ]
243
+ return f" { CUTLASSTemplate ._DTYPE_TO_CUTLASS [node .get_layout ().dtype ]} (0)"
233
244
elif issubclass (arg_ty , EmptyByte ):
234
245
return "{}"
235
246
0 commit comments