8000 [Cutlass] E2E Tests for EVT · pytorch/pytorch@434ca07 · GitHub
[go: up one dir, main page]

Skip to content

Commit 434ca07

Browse files
committed
[Cutlass] E2E Tests for EVT
ghstack-source-id: 8a37df6 Pull Request resolved: #152815
1 parent 97fc2ac commit 434ca07

File tree

7 files changed

+62
-8
lines changed

7 files changed

+62
-8
lines changed

test/inductor/test_cutlass_backend.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,6 +1314,31 @@ def forward(self, B):
13141314
):
13151315
_ = torch.compile(model)(B)
13161316

1317+
@unittest.skipIf(not SM90OrLater, "need sm_90")
1318+
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
1319+
def test_evt_flexible_layout(self):
1320+
class TestModel(torch.nn.Module):
1321+
def forward(self, B):
1322+
A = torch.zeros_like(B)
1323+
return (A @ B).relu()
1324+
1325+
M = 1024
1326+
B = torch.randn(M, M).cuda().half()
1327+
model = TestModel().cuda()
1328+
1329+
with config.patch(
1330+
{
1331+
"max_autotune": True,
1332+
"benchmark_epilogue_fusion": False, # does not support benchmark fusion yet
1333+
"max_autotune_gemm_backends": "CUTLASS",
1334+
"cuda.cutlass_max_profiling_configs": 20,
1335+
"autotune_fallback_to_aten": False,
1336+
}
1337+
):
1338+
_ = torch.compile(model)(B)
1339+
1340+
self.assertEqual(torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1)
1341+
13171342

13181343
if __name__ == "__main__":
13191344
from torch._inductor.utils import is_big_gpu

torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ def _can_fuse_epilogue_impl(
187187
- bool: True if the given node can be fused with the epilogue, False otherwise.
188188
189189
"""
190-
191190
why = WhyNoFuseNames(cuda_template_buffer.get_name(), node_to_fuse.get_name())
192191

193192
ir_node_to_fuse = node_to_fuse.node
@@ -227,6 +226,9 @@ def _can_fuse_epilogue_impl(
227226
elif not config.epilogue_fusion:
228227
why("epilogue fusion is not enabled")
229228
return False
229+
elif not cuda_template_buffer.supports_epilogue_fusion:
230+
why("epilogue fusion is only supported for TMA-enabled gemm ops")
231+
return False
230232

231233
try:
232234
from torch._inductor.codegen.cuda.cutlass_python_evt import (

torch/_inductor/codegen/cuda/cuda_kernel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,7 @@ def __init__(
563563
tuple[CUDATemplateKernel, functools.partial[str]],
564564
],
565565
bmreq: CUDABenchmarkRequest,
566+
supports_epilogue_fusion: bool,
566567
template: "CUDATemplate", # type: ignore[name-defined]
567568
info_kwargs: Optional[
568569
dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]
@@ -573,6 +574,7 @@ def __init__(
573574
self.category = category
574575
self.make_kernel_render = make_kernel_render
575576
self.bmreq = bmreq
577+
self.supports_epilogue_fusion = supports_epilogue_fusion
576578
self.template = template
577579
self.info_kwargs = info_kwargs
578580

@@ -629,6 +631,7 @@ def output_node(self) -> TensorBox:
629631
inputs=self.input_nodes,
630632
make_kernel_render=self.make_kernel_render,
631633
workspace_size=self.bmreq.workspace_size,
634+
supports_epilogue_fusion=self.supports_epilogue_fusion,
632635
template=self.template,
633636
)
634637
)

torch/_inductor/codegen/cuda/cuda_template.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
else:
2525
BaseSchedulerNode = Any
2626

27+
GemmOperation = Any
2728

2829
autotuning_log = getArtifactLogger(__name__, "autotuning")
2930

@@ -61,6 +62,10 @@ def __init__(
6162
self.input_reorder = input_reorder
6263
self.layout = layout
6364

65+
@staticmethod
66+
def supports_epilogue_fusion(op: GemmOperation) -> bool:
67+
return False
68+
6469
def generate( # type: ignore[override]
6570
self,
6671
description,
@@ -122,10 +127,21 @@ def generate( # type: ignore[override]
122127
source_code=code,
123128
)
124129

130+
# kwargs has "op" argument in case of CUTLASSGemmTemplate
131+
op = kwargs["op"]
132+
if not op:
133+
supports_epilogue_fusion = False
134+
else:
135+
# epilogue fusion is only supported for TMA kernels
136+
supports_epilogue_fusion = self.supports_epilogue_fusion(op)
137+
125138
def make_kernel_render(
126139
template_node: CUDATemplateBuffer,
127140
epilogue_nodes: Optional[list[BaseSchedulerNode]] = None,
128141
) -> tuple[CUDATemplateKernel, functools.partial[str]]:
142+
assert supports_epilogue_fusion or not epilogue_nodes, (
143+
"epilogue fusion is not supported for this kernel"
144+
)
129145
kernel = CUDATemplateKernel(
130146
kernel_name="KERNEL_NAME",
131147
runtime_arg_info=self.get_runtime_arg_info(),
@@ -147,6 +163,7 @@ def make_kernel_render(
147163
self.output_node.get_layout(),
148164
make_kernel_render,
149165
bmreq,
166+
supports_epilogue_fusion,
150167
self,
151168
kwargs,
152169
description,

torch/_inductor/codegen/cuda/gemm_template.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -826,9 +826,6 @@ def filter_op(
826826
10000 ):
827827
return None
828828

829-
if not self._has_tma_epilogue(op):
830-
return None
831-
832829
# Filter ops by alignment.
833830
if not self._alignment_match(op):
834831
log.debug(
@@ -989,7 +986,6 @@ def render( # type: ignore[override]
989986
All inputs and their corresponding buffer addresses and names take precedence over previously
990987
passed inputs to the template at construction time. However, they should be layout compatible.
991988
"""
992-
993989
assert cutlass_utils.try_import_cutlass()
994990
import cutlass_library.gemm_operation as cutlass_gemm_op
995991
import cutlass_library.library as cutlass_lib
@@ -1038,6 +1034,10 @@ def render( # type: ignore[override]
10381034
# operand
10391035
op.C.element = op.A.element
10401036

1037+
assert op.C.element == op.D.element, (
1038+
f"Expect C and D to have the same dtype, found {op.C.element} and {op.D.element}"
1039+
)
1040+
10411041
argument_template, epilogue_template = self._get_template_args(op)
10421042
should_swap_xw: bool = False
10431043
if Bias is not None and self._has_tma_epilogue(op):
@@ -1219,6 +1219,10 @@ def _has_tma_epilogue( # noqa: F821 # type: ignore[arg-type,name-defined]
12191219
result = epilogue_schedule_str.lower().startswith("tma")
12201220
return result
12211221

1222+
@staticmethod
1223+
def supports_epilogue_fusion(op: GemmOperation) -> bool:
1224+
return CUTLASS3xGemmTemplate._has_tma_epilogue(op)
1225+
12221226
def _are_inputs_layout_compatible(self, layouts: list[Layout]) -> bool:
12231227
"""
12241228
Evaluates whether input layouts are compatible for General Matrix Multiply (GEMM).
@@ -1355,9 +1359,6 @@ def _set_bias_layout_and_alignment(
13551359
op.C.element = cutlass_utils.torch_dtype_to_cutlass_type(
13561360
Bias.get_layout().dtype
13571361
)
1358-
assert op.C.element == op.D.element, (
1359-
f"Expect C and D to have the same dtype, found {op.C.element} and {op.D.element}"
1360-
)
13611362

13621363
# Bias layout
13631364
bias_layout = CUTLASSGemmTemplate.cutlass_layout(Bias.get_layout())

torch/_inductor/codegen/cuda_combined_scheduling.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ def can_fuse_vertical(
6060
) -> bool:
6161
if self._cuda_cpp_scheduling.can_fuse_vertical(node1, node2):
6262
return True
63+
elif self._cuda_cpp_scheduling.is_cuda_cpp_template(
64+
node1
65+
) or self._cuda_cpp_scheduling.is_cuda_cpp_template(node2):
66+
return False
6367
return self._triton_scheduling.can_fuse_vertical(node1, node2)
6468

6569
def can_fuse_horizontal(

torch/_inductor/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4725,11 +4725,13 @@ def __init__( # type: ignore[no-untyped-def]
47254725
make_kernel_render,
47264726
workspace_size: int,
47274727
template: CUDATemplate,
4728+
supports_epilogue_fusion: bool,
47284729
) -> None:
47294730
super().__init__(layout, inputs, make_kernel_render)
47304731
# Global memory (in bytes) needed for this template.
47314732
self.workspace_size = workspace_size
47324733
self.template = template
4734+
self.supports_epilogue_fusion = supports_epilogue_fusion
47334735

47344736
def get_workspace_size(self): # type: ignore[no-untyped-def]
47354737
return self.workspace_size if self.workspace_size is not None else 0

0 commit comments

Comments
 (0)
0