@@ -105,6 +105,7 @@ def __init__(self, name_to_buffer):
105
105
106
106
self .sizevars = torch ._inductor .sizevars .SizeVarAllocator ()
107
107
self .name_to_buffer = name_to_buffer
108
+ self .graph_inputs = dict ()
108
109
self .mutated_buffers = OrderedSet ()
109
110
110
111
@@ -210,6 +211,64 @@ def inner_fn_buf4(index):
210
211
"""Unsupported indexing for buf0 with index 200*i0 + 60000*i1 + i2 and strides [200, 60000, 1]""" ,
211
212
)
212
213
214
+ @unittest .skipIf (not SM90OrLater , "need sm_90" )
215
+ @unittest .skipIf (not try_import_cutlass (), "requires cutlass" )
216
+ def test_py_codegen_broadcasting (self ):
217
+ from torch ._inductor .codegen .cuda .cutlass_python_evt import CutlassEVTCodegen
218
+ from torch ._inductor .virtualized import V
219
+
220
+ size = (100 , 300 , 200 )
221
+ buf0 = MockComputedBuffer ("buf0" , None , torch .float32 , size )
222
+ buf1 = MockComputedBuffer ("buf1" , None , torch .float32 , size )
223
+ buf2 = MockComputedBuffer ("buf2" , None , torch .float32 , size )
224
+
225
+ # buf0 is acc
226
+ # buf1 is external
227
+ def inner_fn_buf3 (index ):
228
+ tmp0 = buf0 .make_loader ()(index )
229
+ tmp1 = buf1 .make_loader ()(index )
230
+ tmp2 = buf2 .make_loader ()(index )
231
+ return tmp0 * tmp1 + tmp2
232
+
233
+ def inner_fn_buf4 (index ):
234
+ tmp0 = buf0 .make_loader ()(index )
235
+ tmp3 = buf3 .make_loader ()(index )
236
+ return tmp0 + tmp3 * tmp3
237
+
238
+ buf3 = MockComputedBuffer ("buf3"
10000
span>, inner_fn_buf3 , torch .float32 , size )
239
+ buf4 = MockComputedBuffer (
240
+ "buf4" , inner_fn_buf4 , torch .float32 , (100 , 300 , 1 )
241
+ ) # broadcast
242
+ with V .set_graph_handler (
243
+ MockGraphHandler (
244
+ {"buf0" : buf0 , "buf1" : buf1 , "buf2" : buf2 , "buf3" : buf3 , "buf4" : buf4 }
245
+ )
246
+ ):
247
+ reads , writes , renames , code = CutlassEVTCodegen .ir_to_evt_python_code (
248
+ "buf0" ,
249
+ [
250
+ MockSchedulerNode (buf3 ),
251
+ MockSchedulerNode (buf4 , last_usage = OrderedSet (["buf0" ])),
252
+ ],
253
+ )
254
+ self .assertExpectedInline (reads , """['buf0', 'buf1', 'buf2']""" )
255
+ self .assertExpectedInline (writes , """['buf3', 'buf4']""" )
256
+ self .assertExpectedInline (
257
+ renames , """{'buf3': 'D', 'buf4': 'tmp_3', 'buf0': 'accum'}"""
258
+ )
259
+ self .assertExpectedInline (
260
+ code ,
261
+ """\
262
+ def fn(accum, buf1, buf2):
263
+ tmp_0 = accum * buf1
264
+ tmp_1 = tmp_0 + buf2
265
+ D = tmp_1 # cutlass evt requirement
266
+ tmp_2 = D * D
267
+ tmp_3 = accum + tmp_2
268
+
269
+ return D, tmp_3""" ,
270
+ )
271
+
213
272
@unittest .skipIf (not SM90OrLater , "need sm_90" )
214
273
@unittest .skipIf (not try_import_cutlass (), "requires cutlass" )
215
274
def test_py_codegen (self ):
0 commit comments