8000 [Cutlass] E2E Tests for EVT by mlazos · Pull Request #152815 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Cutlass] E2E Tests for EVT #152815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
Closed
226 changes: 226 additions & 0 deletions test/inductor/test_cutlass_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Owner(s): ["module: inductor"]
import itertools
import logging
import math
import os
Expand Down Expand Up @@ -62,6 +63,38 @@ def _get_path_without_sccache() -> str:
return ":".join(path_envs)


un_ops_under_test = [torch.relu]
bin_ops_under_test = [torch.add, torch.mul, torch.sub, torch.div]

evt_all_ops = parametrize(
"op", un_ops_under_test + bin_ops_under_test, name_fn=lambda f: f.__name__
)

evt_bin_ops = parametrize("op", bin_ops_under_test, name_fn=lambda f: f.__name__)

evt_all_shapes = parametrize("shape", itertools.product([512, 1024], repeat=2))


def gen_args(op, shape, dtype=torch.float16):
if op in bin_ops_under_test:
return (torch.rand(*shape, device="cuda:0", dtype=dtype),)
else:
8000 return ()


use_evt_config = config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 1,
"autotune_fallback_to_aten": False,
"benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet
"cuda.cutlass_tma_only": True,
"cuda.cutlass_epilogue_fusion_enabled": True,
}
)


@instantiate_parametrized_tests
class TestCutlassBackend(TestCase):
def setUp(self):
Expand Down Expand Up @@ -91,6 +124,22 @@ def tearDown(self):
super().tearDown()
clear_inductor_caches()

def run_evt_test(self, model, op, shape, num_fusions=1):
M, N = shape
a = torch.ones(M, N).cuda().half()
b = torch.ones(N, N).cuda().half()
extra_args = gen_args(op, (M, N))
model = model.cuda()

result = torch.compile(model)(a, b, extra_args)
ref_result = model(a, b, extra_args)

self.assertEqual(
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"],
num_fusions,
)
torch.testing.assert_close(result, ref_result)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_max_autotune_cutlass_threshold(self):
Expand Down Expand Up @@ -1316,6 +1365,33 @@ def forward(self, B):
):
_ = torch.compile(model)(B)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
@use_evt_config
def test_evt_flexible_layout(self):
class TestModel(torch.nn.Module):
def forward(self, B):
A = torch.zeros_like(B)
return (A @ B).relu()

M = 1024
B = torch.randn(M, M).cuda().half()
model = TestModel().cuda().half()

with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 1,
"autotune_fallback_to_aten": False,
}
):
_ = torch.compile(model)(B)

self.assertEqual(
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_filtered_ops_cache(self):
Expand Down Expand Up @@ -1359,6 +1435,156 @@ def test_compilation_time(self):
_ = torch.compile(torch.mm)(A, B)
self.assertTrue(time.time() - start_time < 50)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@use_evt_config
@evt_all_ops
@evt_all_shapes
def test_evt_fusions_basic(self, op, shape):
class TestModel(torch.nn.Module):
def forward(self, a, b, extra_args):
res = (a @ b).relu() # add extra activation to not hit addmm path
return op(res, *extra_args)

self.run_evt_test(TestModel(), op, shape)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@use_evt_config
@evt_bin_ops
def test_evt_broadcasting(self, op):
class TestModel(torch.nn.Module):
def forward(self, a, b, extra_args):
acc = a @ b
return acc, op(acc.relu(), *extra_args)

M = 1024
N = 512
a = torch.ones(M, N).cuda().half()
b = torch.ones(N, N).cuda().half()
extra_args = gen_args(op, (M, N))
model = TestModel().cuda()

result = torch.compile(model)(a, b, extra_args)
ref_result = model(a, b, extra_args)

self.assertEqual(
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
)
torch.testing.assert_close(result, ref_result)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@use_evt_config
@evt_all_ops
def test_evt_mixed_dtypes(self, op):
M = 1024
N = 256

fp32_tensor = torch.ones(M, N).cuda().float()

class TestModel(torch.nn.Module):
def forward(self, a, b, extra_args):
acc = a @ b
out0 = op(acc.relu(), *extra_args)
out1 = torch.add(out0, fp32_tensor)
return out1

model = TestModel().cuda()
a = torch.ones(M, N).cuda().half()
b = torch.ones(N, N).cuda().half()
extra_args = gen_args(op, (M, N), dtype=torch.float16)

# baseline is cutlass kernel + triton
# matches expected casting behavior
with config.patch({"cuda.cutlass_epilogue_fusion_enabled": False}):
ref_result = torch.compile(model)(a, b, extra_args)

self.assertEqual(
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 0
)

torch._dynamo.reset()
result = torch.compile(model)(a, b, extra_args)

self.assertEqual(
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"],
1,
)

torch.testing.assert_close(result, ref_result)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@use_evt_config
@evt_all_ops
def test_evt_multi_op(self, op):
class TestModel(torch.nn.Module):
def forward(self, a, b, extra_args):
acc = a @ b
return torch.add(op(acc.relu(), *extra_args).relu(), acc)

self.run_evt_test(TestModel(), op, (1024, 512))

@unittest.skipIf(not SM90OrLater, "need sm_90")
@use_evt_config
@evt_all_ops
def test_evt_reuse_matmul_input(self, op):
class TestModel(torch.nn.Module):
def forward(self, a, b, extra_args):
acc = a @ b
return torch.add(op(acc.relu(), *extra_args).relu(), a)

self.run_evt_test(TestModel(), op, (1024, 1024)) # shape needs to be square

@unittest.skipIf(not SM90OrLater, "need sm_90")
@use_evt_config
@evt_all_ops
@unittest.skip("Needs fused scheduler node fusion support, (upcoming PR)")
def test_evt_multi_output(self, op):
class TestModel(torch.nn.Module):
def forward(self, a, b, extra_args):
acc = a @ b
z = op(acc.relu(), *extra_args)
y = z + 1
return acc, y

M = 1024
N = 512
a = torch.ones(M, N).cuda().half()
b = torch.ones(N, N).cuda().half()
extra_args = gen_args(op, (M, N))
model = TestModel().cuda()

result = torch.compile(model)(a, b, extra_a 8000 rgs)
ref_result = model(a, b, extra_args)

self.assertEqual(
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
)
torch.testing.assert_close(result, ref_result)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@use_evt_config
def test_evt_return_accumulator(self):
op = torch.add

class TestModel(torch.nn.Module):
def forward(self, a, b, extra_args):
acc = a @ b
return acc, op(acc.relu(), *extra_args)

M = 1024
N = 512
a = torch.ones(M, N).cuda().half()
b = torch.ones(N, N).cuda().half()
extra_args = gen_args(op, (M, N))
model = TestModel().cuda()

result = torch.compile(model)(a, b, extra_args)
ref_result = model(a, b, extra_args)

self.assertEqual(
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
)
torch.testing.assert_close(result, ref_result)


if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu
Expand Down
Loading
Loading
0