8000 [Cutlass] E2E Tests for EVT · pytorch/pytorch@59700b7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 59700b7

Browse files
committed
[Cutlass] E2E Tests for EVT
ghstack-source-id: cea63c9 Pull Request resolved: #152815
1 parent 0104ac0 commit 59700b7

File tree

11 files changed

+379
-114
lines changed

11 files changed

+379
-114
lines changed

test/inductor/test_cutlass_backend.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,36 @@ def _get_path_without_sccache() -> str:
6262
return ":".join(path_envs)
6363

6464

65+
un_ops_under_test = [torch.relu]
66+
bin_ops_under_test = [torch.add, torch.mul, torch.sub, torch.div]
67+
68+
evt_all_ops = parametrize(
69+
"op", un_ops_under_test + bin_ops_under_test, name_fn=lambda f: f.__name__
70+
)
71+
72+
evt_bin_ops = parametrize("op", bin_ops_under_test, name_fn=lambda f: f.__name__)
73+
74+
75+
def gen_args(op, shape):
76+
if op in bin_ops_under_test:
77+
return (torch.rand(*shape, device="cuda:0").half(),)
78+
else:
79+
return ()
80+
81+
82+
use_evt_config = config.patch(
83+
{
84+
"max_autotune": True,
85+
"max_autotune_gemm_backends": "CUTLASS",
86+
"cuda.cutlass_max_profiling_configs": 1,
87+
"autotune_fallback_to_aten": False,
88+
"benchmark_epilogue_fusion": False,
89+
"cuda.cutlass_tma_only": True, # EVT doesn't support benchmark fusion yet
90+
"cuda.cutlass_epilogue_fusion_enabled": True,
91+
}
92+
)
93+
94+
6595
@instantiate_parametrized_tests
6696
class TestCutlassBackend(TestCase):
6797
def setUp(self):
@@ -91,6 +121,22 @@ def tearDown(self):
91121
super().tearDown()
92122
clear_inductor_caches()
93123

124+
def run_evt_test(self, model, op, shape, num_fusions=1):
125+
M, N = shape
126+
a = torch.ones(M, N).cuda().half()
127+
b = torch.ones(N, N).cuda().half()
128+
extra_args = gen_args(op, (M, N))
129+
model = model.cuda()
130+
131+
result = torch.compile(model)(a, b, extra_args)
132+
ref_result = model(a, b, extra_args)
133+
134+
self.assertEqual(
135+
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"],
136+
num_fusions,
137+
)
138+
torch.testing.assert_close(result, ref_result)
139+
94140
@unittest.skipIf(not SM90OrLater, "need sm_90")
95141
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
96142
def test_max_autotune_cutlass_threshold(self):
@@ -1316,6 +1362,33 @@ def forward(self, B):
13161362
):
13171363
_ = torch.compile(model)(B)
13181364

1365+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1366+
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
1367+
@use_evt_config
1368+
def test_evt_flexible_layout(self):
1369+
class TestModel(torch.nn.Module):
1370+
def forward(self, B):
1371+
A = torch.zeros_like(B)
1372+
return (A @ B).relu()
1373+
1374+
M = 1024
1375+
B = torch.randn(M, M).cuda().half()
1376+
model = TestModel().cuda().half()
1377+
1378+
with config.patch(
1379+
{
1380+
"max_autotune": True,
1381+
"max_autotune_gemm_backends": "CUTLASS",
1382+
"cuda.cutlass_max_profiling_configs": 1,
1383+
"autotune_fallback_to_aten": False,
1384+
}
1385+
):
1386+
_ = torch.compile(model)(B)
1387+
1388+
self.assertEqual(
1389+
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
1390+
)
1391+
13191392
@unittest.skipIf(not SM90OrLater, "need sm_90")
13201393
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
13211394
def test_filtered_ops_cache(self):
@@ -1359,6 +1432,89 @@ def test_compilation_time(self):
13591432
_ = torch.compile(torch.mm)(A, B)
13601433
self.assertTrue(time.time() - start_time < 50)
13611434

1435+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1436+
@use_evt_config
1437+
@evt_all_ops
1438+
def test_evt_fusions_basic(self, op):
1439+
class TestModel(torch.nn.Module):
1440+
def forward(self, a, b, extra_args):
1441+
res = (a @ b).relu() # add extra activation to not hit addmm path
1442+
return op(res, *extra_args)
1443+
1444+
self.run_evt_test(TestModel(), op, (1024, 512))
1445+
1446+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1447+
@use_evt_config
1448+
@evt_bin_ops
1449+
def test_evt_broadcasting(self, op):
1450+
class TestModel(torch.nn.Module):
1451+
def forward(self, a, b, extra_args):
1452+
acc = a @ b
1453+
return acc, op(acc.relu(), *extra_args)
1454+
1455+
M = 1024
1456+
N = 512
1457+
a = torch.ones(M, N).cuda().half()
1458+
b = torch.ones(N, N).cuda().half()
1459+
extra_args = gen_args(op, (M, N))
1460+
model = TestModel().cuda()
1461+
1462+
result = torch.compile(model)(a, b, extra_args)
1463+
ref_result = model(a, b, extra_args)
1464+
1465+
self.assertEqual(
1466+
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
1467+
)
1468+
torch.testing.assert_close(result, ref_result)
1469+
1470+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1471+
@use_evt_config
1472+
@evt_all_ops
1473+
def test_evt_mixed_dtypes(self, op):
1474+
pass
1475+
1476+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1477+
@use_evt_config
1478+
@evt_all_ops
1479+
def test_evt_multi_op(self, op):
1480+
class TestModel(torch.nn.Module):
1481+
def forward(self, a, b, extra_args):
1482+
acc = a @ b
1483+
return torch.add(op(acc.relu(), *extra_args).relu(), *extra_args)
1484+
1485+
self.run_evt_test(TestModel(), op, (1024, 512))
1486+
1487+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1488+
@use_evt_config
1489+
@evt_all_ops
1490+
def test_evt_multi_output(self, op):
1491+
pass
1492+
1493+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1494+
@use_evt_config
1495+
def test_evt_return_accumulator(self):
1496+
op = torch.add
1497+
1498+
class TestModel(torch.nn.Module):
1499+
def forward(self, a, b, extra_args):
1500+
acc = a @ b
1501+
return acc, op(acc.relu(), *extra_args)
1502+
1503+
M = 1024
1504+
N = 512
1505+
a = torch.ones(M, N).cuda().half()
1506+
b = torch.ones(N, N).cuda().half()
1507+
extra_args = gen_args(op, (M, N))
1508+
model = TestModel().cuda()
1509+
1510+
result = torch.compile(model)(a, b, extra_args)
1511+
ref_result = model(a, b, extra_args)
1512+
1513+
self.assertEqual(
1514+
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
1515+
)
1516+
torch.testing.assert_close(result, ref_result)
1517+
13621518

13631519
if __name__ == "__main__":
13641520
from torch._inductor.utils import is_big_gpu

test/inductor/test_cutlass_evt.py

Lines changed: 82 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -147,22 +147,25 @@ def inner_fn_buf4(index):
147147
MockSchedulerNode(buf3),
148148
MockSchedulerNode(buf4, last_usage=OrderedSet(["buf3"])),
149149
],
150+
OrderedSet([]),
150151
)
151-
self.assertExpectedInline(reads, """['buf0', 'buf1', 'buf2']""")
152+
self.assertExpectedInline(reads, """['buf1', 'buf2']""")
152153
self.assertExpectedInline(writes, """['buf0', 'buf3', 'buf4']""")
153154
self.assertExpectedInline(
154-
renames, """{'buf0': 'accum', 'buf3': 'tmp_1', 'buf4': 'tmp_2'}"""
155+
renames,
156+
"""{'accum': 'buf0', 'tmp_0': 'buf0', 'buf1': 'buf1', 'buf2': 'buf2', 'D': 'buf3', 'tmp_3': 'buf4'}""",
155157
)
156158
self.assertExpectedInline(
157159
code,
158160
"""\
159161
def fn(accum, buf1, buf2):
160-
D = accum # cutlass evt requirement
161-
tmp_0 = accum * buf1
162-
tmp_1 = tmp_0 + buf2
163-
tmp_2 = accum + tmp_1
162+
tmp_0 = accum
163+
tmp_1 = tmp_0 * buf1
164+
tmp_2 = tmp_1 + buf2
165+
D = tmp_2 # cutlass evt requirement
166+
tmp_3 = tmp_0 + D
164167
165-
return D, tmp_1, tmp_2""",
168+
return tmp_0, D, tmp_3""",
166169
)
167170

168171
@unittest.skipIf(not SM90OrLater, "need sm_90")
@@ -201,7 +204,9 @@ def inner_fn_buf4(index):
201204
result = None
202205
try:
203206
CutlassEVTCodegen.ir_to_evt_python_code(
204-
"buf0", [MockSchedulerNode(buf3), MockSchedulerNode(buf4)]
207+
"buf0",
208+
[MockSchedulerNode(buf3), MockSchedulerNode(buf4)],
209+
OrderedSet([]),
205210
)
206211
except NotImplementedError as e:
207212
result = e
@@ -251,23 +256,26 @@ def inner_fn_buf4(index):
251256
MockSchedulerNode(buf3),
252257
MockSchedulerNode(buf4, last_usage=OrderedSet(["buf0"])),
253258
],
259+
OrderedSet([]),
254260
)
255-
self.assertExpectedInline(reads, """['buf0', 'buf1', 'buf2']""")
256-
self.assertExpectedInline(writes, """['buf3', 'buf4']""")
261+
self.assertExpectedInline(reads, """['buf1', 'buf2']""")
262+
self.assertExpectedInline(writes, """['buf0', 'buf3', 'buf4']""")
257263
self.assertExpectedInline(
258-
renames, """{'buf3': 'D', 'buf4': 'tmp_3', 'buf0': 'accum'}"""
264+
renames,
265+
"""{'accum': 'buf0', 'tmp_0': 'buf0', 'buf1': 'buf1', 'buf2': 'buf2', 'D': 'buf3', 'tmp_4': 'buf4'}""",
259266
)
260267
self.assertExpectedInline(
261268
code,
262269
"""\
263270
def fn(accum, buf1, buf2):
264-
tmp_0 = accum * buf1
265-
tmp_1 = tmp_0 + buf2
266-
D = tmp_1 # cutlass evt requirement
267-
tmp_2 = D * D
268-
tmp_3 = accum + tmp_2
269-
270-
return D, tmp_3""",
271+
tmp_0 = accum
272+
tmp_1 = tmp_0 * buf1
273+
tmp_2 = tmp_1 + buf2
274+
D = tmp_2 # cutlass evt requirement
275+
tmp_3 = D * D
276+
tmp_4 = tmp_0 + tmp_3
277+
278+
return tmp_0, D, tmp_4""",
271279
)
272280

273281
@unittest.skipIf(not SM90OrLater, "need sm_90")
@@ -305,13 +313,15 @@ def inner_fn_buf4(index):
305313
"buf0",
306314
[
307315
MockSchedulerNode(buf3),
308-
MockSchedulerNode(buf4, last_usage=OrderedSet(["buf0"])),
316+
MockSchedulerNode(buf4),
309317
],
318+
OrderedSet(["buf0"]),
310319
)
311-
self.assertExpectedInline(reads, """['buf0', 'buf1', 'buf2']""")
320+
self.assertExpectedInline(reads, """['buf1', 'buf2']""")
312321
self.assertExpectedInline(writes, """['buf3', 'buf4']""")
313322
self.assertExpectedInline(
314-
renames, """{'buf3': 'D', 'buf4': 'tmp_2', 'buf0': 'accum'}"""
323+
renames,
324+
"""{'accum': 'buf0', 'buf1': 'buf1', 'buf2': 'buf2', 'D': 'buf3', 'tmp_2': 'buf4'}""",
315325
)
316326
self.assertExpectedInline(
317327
code,
@@ -338,13 +348,9 @@ def test_example_tensor_creation(self):
338348
col_major_buf1 = MockComputedBuffer(
339349
"buf1", None, torch.float32, (3, 2, 1), (1, 3, 0)
340350
)
341-
read_names = ["buf0"]
342-
write_names = ["buf1"]
343-
buffer_renames = {"buf0": "acc"}
351+
buffer_renames = {"buf0": "buf0", "buf1": "buf1", "acc": "buf0"}
344352
name_to_buffer = {"buf0": row_major_buf0, "buf1": col_major_buf1}
345-
result = create_example_tensors(
346-
read_names, write_names, buffer_renames, name_to_buffer
347-
)
353+
result = create_example_tensors(buffer_renames, name_to_buffer)
348354
self.assertEqual(result["acc"].shape, (3, 4, 1))
349355
self.assertEqual(result["acc"].stride, (4, 1, 0))
350356
self.assertEqual(
@@ -360,7 +366,10 @@ def test_example_tensor_creation(self):
360366
@unittest.skipIf(not SM90OrLater, "need sm_90")
361367
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
362368
def test_evt_argument_codegen(self):
363-
epilogue_functor = _trace(BIAS_CODE, EXAMPLE_TENSORS)
369+
from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch
370+
371+
cuda_arch = int(get_cuda_arch()) # type: ignore[arg-type]
372+
epilogue_functor = _trace(BIAS_CODE, EXAMPLE_TENSORS, cuda_arch)
364373

365374
self.assertExpectedInline(
366375
_render_argument_type(
@@ -388,6 +397,51 @@ def test_evt_argument_codegen(self):
388397
""",
389398
)
390399

400+
@unittest.skipIf(not SM90OrLater, "need sm_90")
401+
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
402+
def test_evt_argument_codegen_return_accumulator(self):
403+
from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch
404+
405+
code = """
406+
def fn(accum, bias):
407+
E = accum
408+
D = E + bias
409+
return D, E
410+
"""
411+
example_tensors = {
412+
"accum": CutlassTensor(
413+
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
414+
),
415+
"bias": BIAS,
416+
# "beta": 0.5, TODO: mlazos support scalars
417+
# "alpha": 0.5, TODO: mlazos support scalars
418+
"D": CutlassTensor(
419+
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
420+
),
421+
"E": CutlassTensor(
422+
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
423+
),
424+
}
425+
426+
cuda_arch = int(get_cuda_arch()) # type: ignore[arg-type]
427+
epilogue_functor = _trace(code, example_tensors, cuda_arch)
428+
429+
self.assertExpectedInline(
430+
_render_argument_type(
431+
epilogue_functor, _create_mock_buffer_name_map(example_tensors)
432+
),
433+
"""\
434+
{ /* thread */
435+
{ /* E */
436+
{}, /* accum */
437+
{/* ptr_aux */ (float*) E, /* dAux */ {2048, _1{}, _0{}}}, /* E */
438+
},
439+
{/* ptr_col */ (float*) bias, /* null_default */ float(0), /* dCol */ {}}, /* bias */
440+
{}, /* compute_0 */
441+
}
442+
""",
443+
)
444+
391445
@unittest.skipIf(not SM90OrLater, "need sm_90")
392446
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
393447
def test_evt_codegen(self):

0 commit comments

Comments
 (0)
0