8000 [Cutlass] E2E Tests for EVT · pytorch/pytorch@88ed8eb · GitHub
[go: up one dir, main page]

Skip to content

Commit 88ed8eb

Browse files
committed
[Cutlass] E2E Tests for EVT
ghstack-source-id: 58d2daf Pull Request resolved: #152815
1 parent 0104ac0 commit 88ed8eb

File tree

11 files changed

+196
-32
lines changed

11 files changed

+196
-32
lines changed

test/inductor/test_cutlass_backend.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,34 @@ def _get_path_without_sccache() -> str:
6262
return ":".join(path_envs)
6363

6464

65+
un_ops_under_test = [torch.relu, torch.sigmoid, torch.tanh]
66+
bin_ops_under_test = [torch.add, torch.mul, torch.sub, torch.div]
67+
68+
evt_all_ops = evt_bin_ops = parametrize(
69+
"op", un_ops_under_test + bin_ops_under_test, name_fn=lambda f: f.__name__
70+
)
71+
72+
73+
def gen_args(op, shape):
74+
if op in bin_ops_under_test:
75+
return (torch.rand(*shape, device="cuda:0").half(),)
76+
else:
77+
return ()
78+
79+
80+
use_evt_config = config.patch(
81+
{
82+
"max_autotune": True,
83+
"max_autotune_gemm_backends": "CUTLASS",
84+
"cuda.cutlass_max_profiling_configs": 1,
85+
"autotune_fallback_to_aten": False,
86+
"benchmark_epilogue_fusion": False,
87+
"cuda.cutlass_tma_only": True, # EVT doesn't support benchmark fusion yet
88+
"cuda.cutlass_epilogue_fusion_enabled": True,
89+
}
90+
)
91+
92+
6593
@instantiate_parametrized_tests
6694
class TestCutlassBackend(TestCase):
6795
def setUp(self):
@@ -1316,6 +1344,35 @@ def forward(self, B):
13161344
):
13171345
_ = torch.compile(model)(B)
13181346

1347+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1348+
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
1349+
@config.patch(
1350+
{"benchmark_epilogue_fusion": False, "cuda.cutlass_tma_only": True}
1351+
) # EVT doesn't support benchmark fusion yet
1352+
def test_evt_flexible_layout(self):
1353+
class TestModel(torch.nn.Module):
1354+
def forward(self, B):
1355+
A = torch.zeros_like(B)
1356+
return (A @ B).relu()
1357+
1358+
M = 1024
1359+
B = torch.randn(M, M).cuda().half()
1360+
model = TestModel().cuda()
1361+
1362+
with config.patch(
1363+
{
1364+
"max_autotune": True,
1365+
"max_autotune_gemm_backends": "CUTLASS",
1366+
"cuda.cutlass_max_profiling_configs": 1,
1367+
"autotune_fallback_to_aten": False,
1368+
}
1369+
):
1370+
_ = torch.compile(model)(B)
1371+
1372+
self.assertEqual(
1373+
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
1374+
)
1375+
13191376
@unittest.skipIf(not SM90OrLater, "need sm_90")
13201377
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
13211378
def test_filtered_ops_cache(self):
@@ -1359,6 +1416,60 @@ def test_compilation_time(self):
13591416
_ = torch.compile(torch.mm)(A, B)
13601417
self.assertTrue(time.time() - start_time < 50)
13611418

1419+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1420+
@use_evt_config
1421+
@evt_all_ops
1422+
def test_evt_fusions_basic(self, op):
1423+
class TestModel(torch.nn.Module):
1424+
def forward(self, a, b, extra_args):
1425+
res = (a @ b).relu() # add extra activation to not hit addmm path
1426+
return op(res, *extra_args)
1427+
1428+
M = 16
1429+
N = 16
1430+
a = torch.ones(M, N).cuda().half()
1431+
b = torch.ones(N, N).cuda().half()
1432+
extra_args = gen_args(op, (M, N))
1433+
model = TestModel().cuda()
1434+
1435+
result = torch.compile(model)(a, b, extra_args)
1436+
ref_result = model(a, b, extra_args)
1437+
1438+
self.assertEqual(
1439+
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
1440+
)
1441+
torch.testing.assert_close(result, ref_result)
1442+
1443+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1444+
@use_evt_config
1445+
@evt_all_ops
1446+
def test_evt_broadcasting(self):
1447+
pass
1448+
1449+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1450+
@use_evt_config
1451+
@evt_all_ops
1452+
def test_evt_mixed_dtypes(self):
1453+
pass
1454+
1455+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1456+
@use_evt_config
1457+
@evt_all_ops
1458+
def test_evt_multi_op(self):
1459+
pass
1460+
1461+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1462+
@use_evt_config
1463+
@evt_all_ops
1464+
def test_evt_multi_output(self):
1465+
pass
1466+< 10000 div class="diff-text-inner">
1467+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1468+
@use_evt_config
1469+
@evt_all_ops
1470+
def test_evt_return_accumulator(self):
1471+
pass
1472+
13621473

13631474
if __name__ == "__main__":
13641475
from torch._inductor.utils import is_big_gpu

test/inductor/test_cutlass_evt.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,10 @@ def test_example_tensor_creation(self):
360360
@unittest.skipIf(not SM90OrLater, "need sm_90")
361361
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
362362
def test_evt_argument_codegen(self):
363-
epilogue_functor = _trace(BIAS_CODE, EXAMPLE_TENSORS)
363+
from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch
364+
365+
cuda_arch = int(get_cuda_arch()) # type: ignore[arg-type]
366+
epilogue_functor = _trace(BIAS_CODE, EXAMPLE_TENSORS, cuda_arch)
364367

365368
self.assertExpectedInline(
366369
_render_argument_type(

torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@ def _can_fuse_epilogue_impl(
201201
- bool: True if the given node can be fused with the epilogue, False otherwise.
202202
203203
"""
204-
205204
why = WhyNoFuseNames(cuda_template_buffer.get_name(), node_to_fuse.get_name())
206205

207206
ir_node_to_fuse = node_to_fuse.node
@@ -244,6 +243,9 @@ def _can_fuse_epilogue_impl(
244243
):
245244
why("cutlass epilogue fusion is not enabled")
246245
return False
246+
elif not cuda_template_buffer.supports_epilogue_fusion:
247+
why("epilogue fusion is only supported for TMA-enabled gemm ops")
248+
return False
247249

248250
try:
249251
from torch._inductor.codegen.cuda.cutlass_python_evt import (

torch/_inductor/codegen/cuda/cuda_kernel.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
1313
from torch._inductor.scheduler import BaseSchedulerNode
1414
from torch._inductor.utils import Placeholder
15+
from torch.utils._sympy.value_ranges import ValueRanges
1516

1617

1718
if TYPE_CHECKING:
@@ -30,6 +31,7 @@
3031
from ...utils import sympy_product
3132
from ...virtualized import V
3233
from ..common import (
34+
CSEVariable,
3335
IndentedBuffer,
3436
Kernel,
3537
OpOverrides,
@@ -238,7 +240,6 @@ def def_kernel(
238240
inputs: list[IRNode],
239241
outputs: list[IRNode],
240242
epilogue_inputs: list[IRNode],
241-
epilogue_outputs: list[IRNode],
242243
names_str: str = "",
243244
input_reorder: Optional[list[int]] = None,
244245
) -> str:
@@ -285,13 +286,6 @@ def def_kernel(
285286
self.named_nodes[name] = node
286287
self.args.output_buffers[node.get_name()] = name
287288

288-
for epilogue_output in epilogue_outputs:
289-
if epilogue_output is not None:
290-
self.named_nodes[epilogue_output.get_name()] = epilogue_output
291-
self.args.output_buffers[epilogue_output.get_name()] = (
292-
epilogue_output.get_name()
293-
)
294-
295289
arg_defs, *_ = self.args.cpp_argdefs()
296290

297291
self.init_layout_args()
@@ -540,6 +534,12 @@ def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str:
540534
f"At least 1 stride should be 1. Strides: {node.get_stride()=}"
541535
)
542536

537+
def load(self, name: str, index: Expr, mode: Any = None) -> CSEVariable:
538+
"""
539+
Mock load function for memory planning to optimize allocations properly.
540+
"""
541+
return self.create_cse_var(name, bounds=ValueRanges.unknown())
542+
543543
def store(self, name: str, index: Expr, value: Any, mode: Any = None) -> None:
544544
"""
545545
Mock store function for memory planning to optimize allocations properly.
@@ -570,6 +570,7 @@ def __init__(
570570
tuple[CUDATemplateKernel, functools.partial[str]],
571571
],
572572
bmreq: CUDABenchmarkRequest,
573+
supports_epilogue_fusion: bool,
573574
template: "CUDATemplate", # type: ignore[name-defined]
574575
info_kwargs: Optional[
575576
dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]
@@ -580,6 +581,7 @@ def __init__(
580581
self.category = category
581582
self.make_kernel_render = make_kernel_render
582583
self.bmreq = bmreq
584+
self.supports_epilogue_fusion = supports_epilogue_fusion
583585
self.template = template
584586
self.info_kwargs = info_kwargs
585587

@@ -636,6 +638,7 @@ def output_node(self) -> TensorBox:
636638
inputs=self.input_nodes,
637639
make_kernel_render=self.make_kernel_render,
638640
workspace_size=self.bmreq.workspace_size,
641+
supports_epilogue_fusion=self.supports_epilogue_fusion,
639642
template=self.template,
640643
)
641644
)

torch/_inductor/codegen/cuda/cuda_template.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
else:
2727
BaseSchedulerNode = Any
2828

29+
GemmOperation = Any
2930

3031
autotuning_log = getArtifactLogger(__name__, "autotuning")
3132

@@ -63,6 +64,10 @@ def __init__(
6364
self.input_reorder = input_reorder
6465
self.layout = layout
6566

67+
@staticmethod
68+
def supports_epilogue_fusion(op: GemmOperation) -> bool:
69+
return False
70+
6671
def generate( # type: ignore[override]
6772
self,
6873
description,
@@ -126,10 +131,21 @@ def generate( # type: ignore[override]
126131
source_code=code,
127132
)
128133

134+
# kwargs has "op" argument in case of CUTLASSGemmTemplate
135+
op = kwargs["op"]
136+
if not op:
137+
supports_epilogue_fusion = False
138+
else:
139+
# epilogue fusion is only supported for TMA kernels
140+
supports_epilogue_fusion = self.supports_epilogue_fusion(op)
141+
129142
def make_kernel_render(
130143
template_node: CUDATemplateBuffer,
131144
epilogue_nodes: Optional[list[BaseSchedulerNode]] = None,
132145
) -> tuple[CUDATemplateKernel, functools.partial[str]]:
146+
assert supports_epilogue_fusion or not epilogue_nodes, (
147+
"epilogue fusion is not supported for this kernel"
148+
)
133149
kernel = CUDATemplateKernel(
134150
kernel_name=str(Placeholder.KERNEL_NAME),
135151
runtime_arg_info=self.get_runtime_arg_info(),
@@ -151,6 +167,7 @@ def make_kernel_render(
151167
self.output_node.get_layout(),
152168
make_kernel_render,
153169
bmreq,
170+
supports_epilogue_fusion,
154171
self,
155172
kwargs,
156173
description,

torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def trace(
119119
) -> tuple[str, str, str]:
120120
cuda_arch = int(cuda_env.get_cuda_arch()) # type: ignore[arg-type]
121121
assert cuda_arch >= 90, "Only SM90+ is supported for EVT"
122-
epilogue_functor = _trace(fn_src, example_tensors, **kwargs)
122+
epilogue_functor = _trace(fn_src, example_tensors, cuda_arch, **kwargs)
123123
visitor = EpilogueFunctorVisitor(cuda_arch, epilogue_functor)
124124
fusion_callbacks = FusionCallbacks(visitor.graph, cuda_arch, emit_CD=False)
125125
collective_epilogue = CollectiveEpilogue(
@@ -138,7 +138,7 @@ def trace(
138138
# This is modified to enable directly passing the source code of the epilogue vs getting it from a bona-fide python function
139139
# The reason for this is that inspect.getsource does not work with functions defined at runtime via exec/eval
140140
def _trace(
141-
fn_src: str, example_tensors: dict[str, CutlassTensor], **kwargs: Any
141+
fn_src: str, example_tensors: dict[str, CutlassTensor], cc: int, **kwargs: Any
142142
) -> EpilogueFunctor:
143143
class EpilogueFunctor(PythonASTFrontend):
144144
def __init__(self, cc: int, **kwargs: Any):

torch/_inductor/codegen/cuda/cutlass_python_evt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,15 @@ def relu(x0: str) -> str:
6868

6969
@staticmethod
7070
def sigmoid(x0: str) -> str:
71-
return CutlassEVTOpsMixIn._prefix_un_op("sigmoid", x0)
71+
raise NotImplementedError("sigmoid is not supported in CUTLASS python evt")
7272

7373
@staticmethod
7474
def sub(x0: str, x1: str) -> str:
7575
return CutlassEVTOpsMixIn._infix_bin_op("-", x0, x1)
7676

7777
@staticmethod
7878
def tanh(x0: str) -> str:
79-
return CutlassEVTOpsMixIn._prefix_un_op("tanh", x0)
79+
raise NotImplementedError("tanh is not supported in CUTLASS python evt")
8080

8181

8282
class MockCutlassHandler(CutlassEVTOpsMixIn, WrapperHandler):

0 commit comments

Comments
 (0)
0