@@ -147,22 +147,25 @@ def inner_fn_buf4(index):
147
147
MockSchedulerNode (buf3 ),
148
148
MockSchedulerNode (buf4 , last_usage = OrderedSet (["buf3" ])),
149
149
],
150
+ OrderedSet ([]),
150
151
)
151
- self .assertExpectedInline (reads , """['buf0', ' buf1', 'buf2']""" )
152
+ self .assertExpectedInline (reads , """['buf1', 'buf2']""" )
152
153
self .assertExpectedInline (writes , """['buf0', 'buf3', 'buf4']""" )
153
154
self .assertExpectedInline (
154
- renames , """{'buf0': 'accum', 'buf3': 'tmp_1', 'buf4': 'tmp_2'}"""
155
+ renames ,
156
+ """{'accum': 'buf0', 'tmp_0': 'buf0', 'buf1': 'buf1', 'buf2': 'buf2', 'D': 'buf3', 'tmp_3': 'buf4'}""" ,
155
157
)
156
158
self .assertExpectedInline (
157
159
code ,
158
160
"""\
159
161
def fn(accum, buf1, buf2):
160
- D = accum # cutlass evt requirement
161
- tmp_0 = accum * buf1
162
- tmp_1 = tmp_0 + buf2
163
- tmp_2 = accum + tmp_1
162
+ tmp_0 = accum
163
+ tmp_1 = tmp_0 * buf1
164
+ tmp_2 = tmp_1 + buf2
165
+ D = tmp_2 # cutlass evt requirement
166
+ tmp_3 = tmp_0 + D
164
167
165
- return D, tmp_1, tmp_2 """ ,
168
+ return tmp_0, D, tmp_3 """ ,
166
169
)
167
170
168
171
@unittest .skipIf (not SM90OrLater , "need sm_90" )
@@ -201,7 +204,9 @@ def inner_fn_buf4(index):
201
204
result = None
202
205
try :
203
206
CutlassEVTCodegen .ir_to_evt_python_code (
204
- "buf0" , [MockSchedulerNode (buf3 ), MockSchedulerNode (buf4 )]
207
+ "buf0" ,
208
+ [MockSchedulerNode (buf3 ), MockSchedulerNode (buf4 )],
209
+ OrderedSet ([]),
205
210
)
206
211
except NotImplementedError as e :
207
212
result = e
@@ -251,23 +256,26 @@ def inner_fn_buf4(index):
251
256
MockSchedulerNode (buf3 ),
252
257
MockSchedulerNode (buf4 , last_usage = OrderedSet (["buf0" ])),
253
258
],
259
+ OrderedSet ([]),
254
260
)
255
- self .assertExpectedInline (reads , """['buf0', ' buf1', 'buf2']""" )
256
- self .assertExpectedInline (writes , """['buf3', 'buf4']""" )
261
+ self .assertExpectedInline (reads , """['buf1', 'buf2']""" )
262
+ self .assertExpectedInline (writes , """['buf0', ' buf3', 'buf4']""" )
257
263
self .assertExpectedInline (
258
- renames , """{'buf3': 'D', 'buf4': 'tmp_3', 'buf0': 'accum'}"""
264
+ renames ,
265
+ """{'accum': 'buf0', 'tmp_0': 'buf0', 'buf1': 'buf1', 'buf2': 'buf2', 'D': 'buf3', 'tmp_4': 'buf4'}""" ,
259
266
)
260
267
self .assertExpectedInline (
261
268
code ,
262
269
"""\
263
270
def fn(accum, buf1, buf2):
264
- tmp_0 = accum * buf1
265
- tmp_1 = tmp_0 + buf2
266
- D = tmp_1 # cutlass evt requirement
267
- tmp_2 = D * D
268
- tmp_3 = accum + tmp_2
269
-
270
- return D, tmp_3""" ,
271
+ tmp_0 = accum
272
+ tmp_1 = tmp_0 * buf1
273
+ tmp_2 = tmp_1 + buf2
274
+ D = tmp_2 # cutlass evt requirement
275
+ tmp_3 = D * D
276
+ tmp_4 = tmp_0 + tmp_3
277
+
278
+ return tmp_0, D, tmp_4""" ,
271
279
)
272
280
273
281
@unittest .skipIf (not SM90OrLater , "need sm_90" )
@@ -305,13 +313,15 @@ def inner_fn_buf4(index):
305
313
"buf0" ,
306
314
[
307
315
MockSchedulerNode (buf3 ),
308
- MockSchedulerNode (buf4 , last_usage = OrderedSet ([ "buf0" ]) ),
316
+ MockSchedulerNode (buf4 ),
309
317
],
318
+ OrderedSet (["buf0" ]),
310
319
)
311
- self .assertExpectedInline (reads , """['buf0', ' buf1', 'buf2']""" )
320
+ self .assertExpectedInline (reads , """['buf1', 'buf2']""" )
312
321
self .assertExpectedInline (writes , """['buf3', 'buf4']""" )
313
322
self .assertExpectedInline (
314
- renames , """{'buf3': 'D', 'buf4': 'tmp_2', 'buf0': 'accum'}"""
323
+ renames ,
324
+ """{'accum': 'buf0', 'buf1': 'buf1', 'buf2': 'buf2', 'D': 'buf3', 'tmp_2': 'buf4'}""" ,
315
325
)
316
326
self .assertExpectedInline (
317
327
code ,
@@ -338,13 +348,9 @@ def test_example_tensor_creation(self):
338
348
col_major_buf1 = MockComputedBuffer (
339
349
"buf1" , None , torch .float32 , (3 , 2 , 1 ), (1 , 3 , 0 )
340
350
)
341
- read_names = ["buf0" ]
342
- write_names = ["buf1" ]
343
- buffer_renames = {"buf0" : "acc" }
351
+ buffer_renames = {"buf0" : "buf0" , "buf1" : "buf1" , "acc" : "buf0" }
344
352
name_to_buffer = {"buf0" : row_major_buf0 , "buf1" : col_major_buf1 }
345
- result = create_example_tensors (
346
- read_names , write_names , buffer_renames , name_to_buffer
347
- )
353
+ result = create_example_tensors (buffer_renames , name_to_buffer )
348
354
self .assertEqual (result ["acc" ].shape , (3 , 4 , 1 ))
349
355
self .assertEqual (result ["acc" ].stride , (4 , 1 , 0 ))
350
356
self .assertEqual (
@@ -360,7 +366,10 @@ def test_example_tensor_creation(self):
360
366
@unittest .skipIf (not SM90OrLater , "need sm_90" )
361
367
@unittest .skipIf (not try_import_cutlass (), "requires cutlass" )
362
368
def test_evt_argument_codegen (self ):
363
- epilogue_functor = _trace (BIAS_CODE , EXAMPLE_TENSORS )
369
+ from torch ._inductor .codegen .cuda .cuda_env import get_cuda_arch
370
+
371
+ cuda_arch = int (get_cuda_arch ()) # type: ignore[arg-type]
372
+ epilogue_functor = _trace (BIAS_CODE , EXAMPLE_TENSORS , cuda_arch )
364
373
365
374
self .assertExpectedInline (
366
375
_render_argument_type (
@@ -388,6 +397,51 @@ def test_evt_argument_codegen(self):
388
397
""" ,
389
398
)
390
399
400
+ @unittest .skipIf (not SM90OrLater , "need sm_90" )
401
+ @unittest .skipIf (not try_import_cutlass (), "requires cutlass" )
402
+ def test_evt_argument_codegen_return_accumulator (self ):
403
+ from torch ._inductor .codegen .cuda .cuda_env import get_cuda_arch
404
+
405
+ code = """
406
+ def fn(accum, bias):
407
+ E = accum
408
+ D = E + bias
409
+ return D, E
410
+ """
411
+ example_tensors = {
412
+ "accum" : CutlassTensor (
413
+ element = DataType .f32 , shape = (M , N ), layout_tag = LayoutType .RowMajor
414
+ ),
415
+ "bias" : BIAS ,
416
+ # "beta": 0.5, TODO: mlazos support scalars
417
+ # "alpha": 0.5, TODO: mlazos support scalars
418
+ "D" : CutlassTensor (
419
+ element = DataType .f32 , shape = (M , N ), layout_tag = LayoutType .RowMajor
420
+ ),
421
+ "E" : CutlassTensor (
422
+ element = DataType .f32 , shape = (M , N ), layout_tag = LayoutType .RowMajor
423
+ ),
424
+ }
425
+
426
+ cuda_arch = int (get_cuda_arch ()) # type: ignore[arg-type]
427
+ epilogue_functor = _trace (code , example_tensors , cuda_arch )
428
+
429
+ self .assertExpectedInline (
430
+ _render_argument_type (
431
+ epilogue_functor , _create_mock_buffer_name_map (example_tensors )
432
+ ),
433
+ """\
434
+ { /* thread */
435
+ { /* E */
436
+ {}, /* accum */
437
+ {/* ptr_aux */ (float*) E, /* dAux */ {2048, _1{}, _0{}}}, /* E */
438
+ },
439
+ {/* ptr_col */ (float*) bias, /* null_default */ float(0), /* dCol */ {}}, /* bias */
440
+ {}, /* compute_0 */
441
+ }
442
+ """ ,
443
+ )
444
+
391
445
@unittest .skipIf (not SM90OrLater , "need sm_90" )
392
446
@unittest .skipIf (not try_import_cutlass (), "requires cutlass" )
393
447
def test_evt_codegen (self ):
0 commit comments