8000 [Cutlass] E2E Tests for EVT · pytorch/pytorch@699d6c3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 699d6c3

Browse files
committed
[Cutlass] E2E Tests for EVT
ghstack-source-id: db7217e Pull Request resolved: #152815
1 parent f5e0806 commit 699d6c3

File tree

11 files changed

+459
-114
lines changed

11 files changed

+459
-114
lines changed

test/inductor/test_cutlass_backend.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Owner(s): ["module: inductor"]
2+
import itertools
23
import logging
34
import math
45
import os
@@ -62,6 +63,38 @@ def _get_path_without_sccache() -> str:
6263
return ":".join(path_envs)
6364

6465

66+
un_ops_under_test = [torch.relu]
67+
bin_ops_under_test = [torch.add, torch.mul, torch.sub, torch.div]
68+
69+
evt_all_ops = parametrize(
70+
"op", un_ops_under_test + bin_ops_under_test, name_fn=lambda f: f.__name__
71+
)
72+
73+
evt_bin_ops = parametrize("op", bin_ops_under_test, name_fn=lambda f: f.__name__)
74+
75+
evt_all_shapes = parametrize("shape", itertools.product([512, 1024], repeat=2))
76+
77+
78+
def gen_args(op, shape, dtype=torch.float16):
79+
if op in bin_ops_under_test:
80+
return (torch.rand(*shape, device="cuda:0", dtype=dtype),)
81+
else:
82+
return ()
83+
84+
85+
use_evt_config = config.patch(
86+
{
87+
"max_autotune": True,
88+
"max_autotune_gemm_backends": "CUTLASS",
89+
"cuda.cutlass_max_profiling_configs": 1,
90+
"autotune_fallback_to_aten": False,
91+
"benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet
92+
"cuda.cutlass_tma_only": True,
93+
"cuda.cutlass_epilogue_fusion_enabled": True,
94+
}
95+
)
96+
97+
6598
@instantiate_parametrized_tests
6699
class TestCutlassBackend(TestCase):
67100
def setUp(self):
@@ -91,6 +124,22 @@ def tearDown(self):
91124
super().tearDown()
92125
clear_inductor_caches()
93126

127+
def run_evt_test(self, model, op, shape, num_fusions=1):
128+
M, N = shape
129+
a = torch.ones(M, N).cuda().half()
130+
b = torch.ones(N, N).cuda().half()
131+
extra_args = gen_args(op, (M, N))
132+
model = model.cuda()
133+
134+
result = torch.compile(model)(a, b, extra_args)
135+
ref_result = model(a, b, extra_args)
136+
137+
self.assertEqual(
138+
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"],
139+
num_fusions,
140+
)
141+
torch.testing.assert_close(result, ref_result)
142+
94143
@unittest.skipIf(not SM90OrLater, "need sm_90")
95144
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
96145
def test_max_autotune_cutlass_threshold(self):
@@ -1316,6 +1365,33 @@ def forward(self, B):
13161365
):
13171366
_ = torch.compile(model)(B)
13181367

1368+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1369+
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
1370+
@use_evt_config
1371+
def test_evt_flexible_layout(self):
1372+
class TestModel(torch.nn.Module):
1373+
def forward(self, B):
1374+
A = torch.zeros_like(B)
1375+
return (A @ B).relu()
1376+
1377+
M = 1024
1378+
B = torch.randn(M, M).cuda().half()
1379+
model = TestModel().cuda().half()
1380+
1381+
with config.patch(
1382+
{
1383+
"max_autotune": True,
1384+
"max_autotune_gemm_backends": "CUTLASS",
1385+
"cuda.cutlass_max_profiling_configs": 1,
1386+
"autotune_fallback_to_aten": False,
1387+
}
1388+
):
1389+
_ = torch.compile(model)(B)
1390+
1391+
self.assertEqual(
1392+
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
1393+
)
1394+
13191395
@unittest.skipIf(not SM90OrLater, "need sm_90")
13201396
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
13211397
def test_filtered_ops_cache(self):
@@ -1359,6 +1435,156 @@ def test_compilation_time(self):
13591435
_ = torch.compile(torch.mm)(A, B)
13601436
self.assertTrue(time.time() - start_time < 50)
13611437

1438+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1439+
@use_evt_config
1440+
@evt_all_ops
1441+
@evt_all_shapes
1442+
def test_evt_fusions_basic(self, op, shape):
1443+
class TestModel(torch.nn.Module):
1444+
def forward(self, a, b, extra_args):
1445+
res = (a @ b).relu() # add extra activation to not hit addmm path
1446+
return op(res, *extra_args)
1447+
1448+
self.run_evt_test(TestModel(), op, shape)
1449+
1450+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1451+
@use_evt_config
1452+
@evt_bin_ops
1453+
def test_evt_broadcasting(self, op):
1454+
class TestModel(torch.nn.Module):
1455+
def forward(self, a, b, extra_args):
1456+
acc = a @ b
1457+
return acc, op(acc.relu(), *extra_args)
1458+
1459+
M = 1024
1460+
N = 512
1461+
a = torch.ones(M, N).cuda().half()
1462+
b = torch.ones(N, N).cuda().half()
1463+
extra_args = gen_args(op, (M, N))
1464+
model = TestModel().cuda()
1465+
1466+
result = torch.compile(model)(a, b, extra_args)
1467+
ref_result = model(a, b, extra_args)
1468+
1469+
self.assertEqual(
1470+
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
1471+
)
1472+
torch.testing.assert_close(result, ref_result)
1473+
1474+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1475+
@use_evt_config
1476+
@evt_all_ops
1477+
def test_evt_mixed_dtypes(self, op):
1478+
M = 1024
1479+
N = 256
1480+
1481+
fp32_tensor = torch.ones(M, N).cuda().float()
1482+
1483+
class TestModel(torch.nn.Module):
1484+
def forward(self, a, b, extra_args):
1485+
acc = a @ b
1486+
out0 = op(acc.relu(), *extra_args)
1487+
out1 = torch.add(out0, fp32_tensor)
1488+
return out1
1489+
1490+
model = TestModel().cuda()
1491+
a = torch.ones(M, N).cuda().half()
1492+
b = torch.ones(N, N).cuda().half()
1493+
extra_args = gen_args(op, (M, N), dtype=torch.float16)
1494+
1495+
# baseline is cutlass kernel + triton
1496+
# matches expected casting behavior
1497+
with config.patch({"cuda.cutlass_epilogue_fusion_enabled": False}):
1498+
ref_result = torch.compile(model)(a, b, extra_args)
1499+
1500+
self.assertEqual(
1501+
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 0
1502+
)
1503+
1504+
torch._dynamo.reset()
1505+
result = torch.compile(model)(a, b, extra_args)
1506+
1507+
self.assertEqual(
1508+
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"],
1509+
1,
1510+
)
1511+
1512+
torch.testing.assert_close(result, ref_result)
1513+
1514+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1515+
@use_evt_config
1516+
@evt_all_ops
1517+
def test_evt_multi_op(self, op):
1518+
class TestModel(torch.nn.Module):
1519+
def forward(self, a, b, extra_args):
1520+
acc = a @ b
1521+
return torch.add(op(acc.relu(), *extra_args).relu(), acc)
1522+
1523+
self.run_evt_test(TestModel(), op, (1024, 512))
1524+
1525+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1526+
@use_evt_config
1527+
@evt_all_ops
1528+
def test_evt_reuse_matmul_input(self, op):
1529+
class TestModel(torch.nn.Module):
1530+
def forward(self, a, b, extra_args):
1531+
acc = a @ b
1532+
return torch.add(op(acc.relu(), *extra_args).relu(), a)
1533+
1534+
self.run_evt_test(TestModel(), op, (1024, 1024)) # shape needs to be square
1535+
1536+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1537+
@use_evt_config
1538+
@evt_all_ops
1539+
@unittest.skip("Needs fused scheduler node fusion support, (upcoming PR)")
1540+
def test_evt_multi_output(self, op):
1541+
class TestModel(torch.nn.Module):
1542+
def forward(self, a, b, extra_args):
1543+
acc = a @ b
1544+
z = op(acc.relu(), *extra_args)
1545+
y = z + 1
1546+
return acc, y
1547+
1548+
M = 1024
1549+
N = 512
1550+
a = torch.ones(M, N).cuda().half()
1551+
b = torch.ones(N, N).cuda().half()
1552+
extra_args = gen_args(op, (M, N))
1553+
model = TestModel().cuda()
1554+
1555+
result = torch.compile(model)(a, b, extra_args)
1556+
ref_result = model(a, b, extra_args)
1557+
1558+
self.assertEqual(
1559+
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
1560+
)
1561+
torch.testing.assert_close(result, ref_result)
1562+
1563+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1564+
@use_evt_config
1565+
def test_evt_return_accumulator(self):
1566+
op = torch.add
1567+
1568+
class TestModel(torch.nn.Module):
1569+
def forward(self, a, b, extra_args):
1570+
acc = a @ b
1571+
return acc, op(acc.relu(), *extra_args)
1572+
1573+
M = 1024
1574+
N = 512
1575+
a = torch.ones(M, N).cuda().half()
1576+
b = torch.ones(N, N).cuda().half()
1577+
extra_args = gen_args(op, (M, N))
1578+
model = TestModel().cuda()
1579+
1580+
result = torch.compile(model)(a, b, extra_args)
1581+
ref_result = model(a, b, extra_args)
1582+
1583+
self.assertEqual(
1584+
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
1585+
)
1586+
torch.testing.assert_close(result, ref_result)
1587+
13621588

13631589
if __name__ == "__main__":
13641590
from torch._inductor.utils import is_big_gpu

0 commit comments

Comments
 (0)
0