|
1 | 1 | # Owner(s): ["module: inductor"]
|
| 2 | +import itertools |
2 | 3 | import logging
|
3 | 4 | import math
|
4 | 5 | import os
|
@@ -62,6 +63,38 @@ def _get_path_without_sccache() -> str:
|
62 | 63 | return ":".join(path_envs)
|
63 | 64 |
|
64 | 65 |
|
| 66 | +un_ops_under_test = [torch.relu] |
| 67 | +bin_ops_under_test = [torch.add, torch.mul, torch.sub, torch.div] |
| 68 | + |
| 69 | +evt_all_ops = parametrize( |
| 70 | + "op", un_ops_under_test + bin_ops_under_test, name_fn=lambda f: f.__name__ |
| 71 | +) |
| 72 | + |
| 73 | +evt_bin_ops = parametrize("op", bin_ops_under_test, name_fn=lambda f: f.__name__) |
| 74 | + |
| 75 | +evt_all_shapes = parametrize("shape", itertools.product([512, 1024], repeat=2)) |
| 76 | + |
| 77 | + |
| 78 | +def gen_args(op, shape, dtype=torch.float16): |
| 79 | + if op in bin_ops_under_test: |
| 80 | + return (torch.rand(*shape, device="cuda:0", dtype=dtype),) |
| 81 | + else: |
| 82 | + return () |
| 83 | + |
| 84 | + |
| 85 | +use_evt_config = config.patch( |
| 86 | + { |
| 87 | + "max_autotune": True, |
| 88 | + "max_autotune_gemm_backends": "CUTLASS", |
| 89 | + "cuda.cutlass_max_profiling_configs": 1, |
| 90 | + "autotune_fallback_to_aten": False, |
| 91 | + "benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet |
| 92 | + "cuda.cutlass_tma_only": True, |
| 93 | + "cuda.cutlass_epilogue_fusion_enabled": True, |
| 94 | + } |
| 95 | +) |
| 96 | + |
| 97 | + |
65 | 98 | @instantiate_parametrized_tests
|
66 | 99 | class TestCutlassBackend(TestCase):
|
67 | 100 | def setUp(self):
|
@@ -91,6 +124,22 @@ def tearDown(self):
|
91 | 124 | super().tearDown()
|
92 | 125 | clear_inductor_caches()
|
93 | 126 |
|
| 127 | + def run_evt_test(self, model, op, shape, num_fusions=1): |
| 128 | + M, N = shape |
| 129 | + a = torch.ones(M, N).cuda().half() |
| 130 | + b = torch.ones(N, N).cuda().half() |
| 131 | + extra_args = gen_args(op, (M, N)) |
| 132 | + model = model.cuda() |
| 133 | + |
| 134 | + result = torch.compile(model)(a, b, extra_args) |
| 135 | + ref_result = model(a, b, extra_args) |
| 136 | + |
| 137 | + self.assertEqual( |
| 138 | + torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], |
| 139 | + num_fusions, |
| 140 | + ) |
| 141 | + torch.testing.assert_close(result, ref_result) |
| 142 | + |
94 | 143 | @unittest.skipIf(not SM90OrLater, "need sm_90")
|
95 | 144 | @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
96 | 145 | def test_max_autotune_cutlass_threshold(self):
|
@@ -1316,6 +1365,33 @@ def forward(self, B):
|
1316 | 1365 | ):
|
1317 | 1366 | _ = torch.compile(model)(B)
|
1318 | 1367 |
|
| 1368 | + @unittest.skipIf(not SM90OrLater, "need sm_90") |
| 1369 | + @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) |
| 1370 | + @use_evt_config |
| 1371 | + def test_evt_flexible_layout(self): |
| 1372 | + class TestModel(torch.nn.Module): |
| 1373 | + def forward(self, B): |
| 1374 | + A = torch.zeros_like(B) |
| 1375 | + return (A @ B).relu() |
| 1376 | + |
| 1377 | + M = 1024 |
| 1378 | + B = torch.randn(M, M).cuda().half() |
| 1379 | + model = TestModel().cuda().half() |
| 1380 | + |
| 1381 | + with config.patch( |
| 1382 | + { |
| 1383 | + "max_autotune": True, |
| 1384 | + "max_autotune_gemm_backends": "CUTLASS", |
| 1385 | + "cuda.cutlass_max_profiling_configs": 1, |
| 1386 | + "autotune_fallback_to_aten": False, |
| 1387 | + } |
| 1388 | + ): |
| 1389 | + _ = torch.compile(model)(B) |
| 1390 | + |
| 1391 | + self.assertEqual( |
| 1392 | + torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1 |
| 1393 | + ) |
| 1394 | + |
1319 | 1395 | @unittest.skipIf(not SM90OrLater, "need sm_90")
|
1320 | 1396 | @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
1321 | 1397 | def test_filtered_ops_cache(self):
|
@@ -1359,6 +1435,156 @@ def test_compilation_time(self):
|
1359 | 1435 | _ = torch.compile(torch.mm)(A, B)
|
1360 | 1436 | self.assertTrue(time.time() - start_time < 50)
|
1361 | 1437 |
|
| 1438 | + @unittest.skipIf(not SM90OrLater, "need sm_90") |
| 1439 | + @use_evt_config |
| 1440 | + @evt_all_ops |
| 1441 | + @evt_all_shapes |
| 1442 | + def test_evt_fusions_basic(self, op, shape): |
| 1443 | + class TestModel(torch.nn.Module): |
| 1444 | + def forward(self, a, b, extra_args): |
| 1445 | + res = (a @ b).relu() # add extra activation to not hit addmm path |
| 1446 | + return op(res, *extra_args) |
| 1447 | + |
| 1448 | + self.run_evt_test(TestModel(), op, shape) |
| 1449 | + |
| 1450 | + @unittest.skipIf(not SM90OrLater, "need sm_90") |
| 1451 | + @use_evt_config |
| 1452 | + @evt_bin_ops |
| 1453 | + def test_evt_broadcasting(self, op): |
| 1454 | + class TestModel(torch.nn.Module): |
| 1455 | + def forward(self, a, b, extra_args): |
| 1456 | + acc = a @ b |
| 1457 | + return acc, op(acc.relu(), *extra_args) |
| 1458 | + |
| 1459 | + M = 1024 |
| 1460 | + N = 512 |
| 1461 | + a = torch.ones(M, N).cuda().half() |
| 1462 | + b = torch.ones(N, N).cuda().half() |
| 1463 | + extra_args = gen_args(op, (M, N)) |
| 1464 | + model = TestModel().cuda() |
| 1465 | + |
| 1466 | + result = torch.compile(model)(a, b, extra_args) |
| 1467 | + ref_result = model(a, b, extra_args) |
| 1468 | + |
| 1469 | + self.assertEqual( |
| 1470 | + torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1 |
| 1471 | + ) |
| 1472 | + torch.testing.assert_close(result, ref_result) |
| 1473 | + |
| 1474 | + @unittest.skipIf(not SM90OrLater, "need sm_90") |
| 1475 | + @use_evt_config |
| 1476 | + @evt_all_ops |
| 1477 | + def test_evt_mixed_dtypes(self, op): |
| 1478 | + M = 1024 |
| 1479 | + N = 256 |
| 1480 | + |
| 1481 | + fp32_tensor = torch.ones(M, N).cuda().float() |
| 1482 | + |
| 1483 | + class TestModel(torch.nn.Module): |
| 1484 | + def forward(self, a, b, extra_args): |
| 1485 | + acc = a @ b |
| 1486 | + out0 = op(acc.relu(), *extra_args) |
| 1487 | + out1 = torch.add(out0, fp32_tensor) |
| 1488 | + return out1 |
| 1489 | + |
| 1490 | + model = TestModel().cuda() |
| 1491 | + a = torch.ones(M, N).cuda().half() |
| 1492 | + b = torch.ones(N, N).cuda().half() |
| 1493 | + extra_args = gen_args(op, (M, N), dtype=torch.float16) |
| 1494 | + |
| 1495 | + # baseline is cutlass kernel + triton |
| 1496 | + # matches expected casting behavior |
| 1497 | + with config.patch({"cuda.cutlass_epilogue_fusion_enabled": False}): |
| 1498 | + ref_result = torch.compile(model)(a, b, extra_args) |
| 1499 | + |
| 1500 | + self.assertEqual( |
| 1501 | + torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 0 |
| 1502 | + ) |
| 1503 | + |
| 1504 | + torch._dynamo.reset() |
| 1505 | + result = torch.compile(model)(a, b, extra_args) |
| 1506 | + |
| 1507 | + self.assertEqual( |
| 1508 | + torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], |
| 1509 | + 1, |
| 1510 | + ) |
| 1511 | + |
| 1512 | + torch.testing.assert_close(result, ref_result) |
| 1513 | + |
| 1514 | + @unittest.skipIf(not SM90OrLater, "need sm_90") |
| 1515 | + @use_evt_config |
| 1516 | + @evt_all_ops |
| 1517 | + def test_evt_multi_op(self, op): |
| 1518 | + class TestModel(torch.nn.Module): |
| 1519 | + def forward(self, a, b, extra_args): |
| 1520 | + acc = a @ b |
| 1521 | + return torch.add(op(acc.relu(), *extra_args).relu(), acc) |
| 1522 | + |
| 1523 | + self.run_evt_test(TestModel(), op, (1024, 512)) |
| 1524 | + |
| 1525 | + @unittest.skipIf(not SM90OrLater, "need sm_90") |
| 1526 | + @use_evt_config |
| 1527 | + @evt_all_ops |
| 1528 | + def test_evt_reuse_matmul_input(self, op): |
| 1529 | + class TestModel(torch.nn.Module): |
| 1530 | + def forward(self, a, b, extra_args): |
| 1531 | + acc = a @ b |
| 1532 | + return torch.add(op(acc.relu(), *extra_args).relu(), a) |
| 1533 | + |
| 1534 | + self.run_evt_test(TestModel(), op, (1024, 1024)) # shape needs to be square |
| 1535 | + |
| 1536 | + @unittest.skipIf(not SM90OrLater, "need sm_90") |
| 1537 | + @use_evt_config |
| 1538 | + @evt_all_ops |
| 1539 | + @unittest.skip("Needs fused scheduler node fusion support, (upcoming PR)") |
| 1540 | + def test_evt_multi_output(self, op): |
| 1541 | + class TestModel(torch.nn.Module): |
| 1542 | + def forward(self, a, b, extra_args): |
| 1543 | + acc = a @ b |
| 1544 | + z = op(acc.relu(), *extra_args) |
| 1545 | + y = z + 1 |
| 1546 | + return acc, y |
| 1547 | + |
| 1548 | + M = 1024 |
| 1549 | + N = 512 |
| 1550 | + a = torch.ones(M, N).cuda().half() |
| 1551 | + b = torch.ones(N, N).cuda().half() |
| 1552 | + extra_args = gen_args(op, (M, N)) |
| 1553 | + model = TestModel().cuda() |
| 1554 | + |
| 1555 | + result = torch.compile(model)(a, b, extra_args) |
| 1556 | + ref_result = model(a, b, extra_args) |
| 1557 | + |
| 1558 | + self.assertEqual( |
| 1559 | + torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1 |
| 1560 | + ) |
| 1561 | + torch.testing.assert_close(result, ref_result) |
| 1562 | + |
| 1563 | + @unittest.skipIf(not SM90OrLater, "need sm_90") |
| 1564 | + @use_evt_config |
| 1565 | + def test_evt_return_accumulator(self): |
| 1566 | + op = torch.add |
| 1567 | + |
| 1568 | + class TestModel(torch.nn.Module): |
| 1569 | + def forward(self, a, b, extra_args): |
| 1570 | + acc = a @ b |
| 1571 | + return acc, op(acc.relu(), *extra_args) |
| 1572 | + |
| 1573 | + M = 1024 |
| 1574 | + N = 512 |
| 1575 | + a = torch.ones(M, N).cuda().half() |
| 1576 | + b = torch.ones(N, N).cuda().half() |
| 1577 | + extra_args = gen_args(op, (M, N)) |
| 1578 | + model = TestModel().cuda() |
| 1579 | + |
| 1580 | + result = torch.compile(model)(a, b, extra_args) |
| 1581 | + ref_result = model(a, b, extra_args) |
| 1582 | + |
| 1583 | + self.assertEqual( |
| 1584 | + torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1 |
| 1585 | + ) |
| 1586 | + torch.testing.assert_close(result, ref_result) |
| 1587 | + |
1362 | 1588 |
|
1363 | 1589 | if __name__ == "__main__":
|
1364 | 1590 | from torch._inductor.utils import is_big_gpu
|
|
0 commit comments