8000 [Cutlass] Changes to gemm template for EVT (#150907) · pytorch/pytorch@a3154ca · GitHub
[go: up one dir, main page]

Skip to content

Commit a3154ca

Browse files
mlazospytorchmergebot
authored andcommitted
[Cutlass] Changes to gemm template for EVT (#150907)
Pull Request resolved: #150907 Approved by: https://github.com/henrylhtsang, https://github.com/eellison ghstack dependencies: #153196
1 parent c54aa0d commit a3154ca

File tree

3 files changed

+130
-20
lines changed

3 files changed

+130
-20
lines changed

torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,11 @@ def _can_fuse_epilogue_impl(
224224
elif node_to_fuse.is_reduction():
225225
why(f"{node_name} is a reduction which is not yet supported by EVT")
226226
return False
227-
elif not config.epilogue_fusion:
228-
why("epilogue fusion is not enabled")
227+
elif (
228+
not config.cuda.cutlass_epilogue_fusion_enabled
229+
or not config.epilogue_fusion
230+
):
231+
why("cutlass epilogue fusion is not enabled")
229232
return False
230233

231234
try:

torch/_inductor/codegen/cuda/gemm_template.py

Lines changed: 122 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from abc import ABC, abstractmethod
99
from typing import Any, Optional, Union
1010

11+
import torch
12+
from torch._inductor.scheduler import BaseSchedulerNode
1113
from torch._inductor.select_algorithm import create_inputs_key
1214
from torch._inductor.utils import clear_on_fresh_inductor_cache
1315

@@ -22,21 +24,26 @@
2224
Layout,
2325
ReinterpretView,
2426
)
25-
from ...utils import is_dynamic
27+
from ...utils import is_dynamic, OrderedSet
2628
from ...virtualized import V
2729
from ..common import IndentedBuffer
2830
from . import cutlass_utils
2931
from .cuda_kernel import CUDATemplateKernel
3032
from .cuda_template import CUTLASSTemplate
3133
from .cutlass_presets import gen_cutlass_presets
34+
from .cutlass_python_evt import CutlassEVTCodegen
35+
from .cutlass_utils import torch_dtype_to_cutlass_type
3236

3337

38+
GemmOperation = Any
39+
3440
log = logging.getLogger(__name__)
3541

3642
# Jinja template for GEMM Kernel, used by the CUTLASSGemm3xTemplate class below.
3743
GEMM_TEMPLATE_CUTLASS_3X = r"""
3844
{{template.header().getvalue()}}
3945
{{template.globals().getvalue()}}
46+
{{epilogue_visitor_tree}}
4047
{{instance_definition}}
4148
// When workspace_size is not a nullptr, populates requested workspace_size and returns.
4249
// Otherwise, computes the Gemm kernel using the given workspace ptr.
@@ -495,7 +502,8 @@ def _set_bias_layout_and_alignment(
495502
@abstractmethod
496503
def _define_gemm_instance(
497504
self,
498-
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
505+
op: GemmOperation,
506+
evt_name: Optional[str] = None,
499507
) -> tuple[str, str]:
500508
raise NotImplementedError
501509

@@ -965,6 +973,7 @@ def render( # type: ignore[override]
965973
kernel: CUDATemplateKernel,
966974
op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821
967975
template_buffer_node: Optional[CUDATemplateBuffer] = None,
976+
epilogue_nodes: Optional[list[BaseSchedulerNode]] = None,
968977
**kwargs,
969978
) -> str:
970979
"""
@@ -995,6 +1004,11 @@ def render( # type: ignore[override]
9951004
"op argument is required and has to be an instance of GemmOperation"
9961005
)
9971006

1007+
if epilogue_nodes and not self._has_tma_epilogue(op):
1008+
raise NotImplementedError(
1009+
"Non-TMA epilogue visitor tree is not supported in Cutlass."
1010+
)
1011+
9981012
assert len(self.input_nodes) >= 2 and self.output_node is not None
9991013
X, W = self.input_nodes[0], self.input_nodes[1]
10001014
for input_node in self.input_nodes:
@@ -1017,15 +1031,7 @@ def render( # type: ignore[override]
10171031
input_reorder = self.input_reorder
10181032
else:
10191033
input_reorder = None
1020-
kernel_call_signature = kernel.def_kernel(
1021-
inputs=inputs, # type: ignore[arg-type]
1022-
outputs=[Y],
1023-
names_str=names_str,
1024-
input_reorder=input_reorder,
1025-
epilogue_inputs=[], # TODO mlazos: will be filled in in https://github.com/pytorch/pytorch/pull/150907
1026-
epilogue_outputs=[], # TODO mlazos: will be filled in in https://github.com/pytorch/pytorch/pull/150907
1027-
)
1028-
test_call_statement = self.test_call_statement(kernel, inputs, names_str)
1034+
10291035
# The layouts might have changed between autotuning and this call if they were FlexibleLayout
10301036
# we need to adapt, which might lead to suboptimal performance.
10311037
op = self.fix_op_layout(op, X, W, Bias, Y)
@@ -1040,7 +1046,6 @@ def render( # type: ignore[override]
10401046

10411047
argument_template, epilogue_template = self._get_template_args(op)
10421048
should_swap_xw: bool = False
1043-
epilogue_args = f"{{ElementComputeEpilogue({self.alpha}), ElementComputeEpilogue({self.beta})}}"
10441049
if Bias is not None and self._has_tma_epilogue(op):
10451050
if (
10461051
op.epilogue_schedule
@@ -1051,7 +1056,45 @@ def render( # type: ignore[override]
10511056
op = self.swap_XW(op)
10521057
should_swap_xw = True
10531058

1054-
instance_definition, instance_type = self._define_gemm_instance(op)
1059+
if epilogue_nodes:
1060+
evt_read_names, evt_write_names, buffer_renames, evt_py_code = (
1061+
CutlassEVTCodegen.ir_to_evt_python_code(Y.get_name(), epilogue_nodes)
1062+
)
1063+
read_names = OrderedSet(evt_read_names) - OrderedSet(evt_write_names)
1064+
write_names = OrderedSet(evt_write_names)
1065+
1066+
name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs
1067+
epilogue_inputs = [name_to_buffer[name] for name in read_names]
1068+
epilogue_outputs = [name_to_buffer[name] for name in write_names]
1069+
1070+
evt_name, evt_args, evt_code = self._render_evt(
1071+
op,
1072+
evt_py_code,
1073+
evt_read_names,
1074+
evt_write_names,
1075+
buffer_renames,
1076+
Y.get_layout().dtype,
1077+
W.get_layout().dtype,
1078+
)
1079+
else:
1080+
evt_name = None
1081+
epilogue_inputs = []
1082+
epilogue_outputs = []
1083+
evt_args = f"{{ElementComputeEpilogue({self.alpha}), ElementComputeEpilogue({self.beta})}}"
1084+
evt_code = ""
1085+
1086+
kernel_call_signature = kernel.def_kernel(
1087+
inputs=inputs, # type: ignore[arg-type]
1088+
outputs=[Y],
1089+
epilogue_inputs=epilogue_inputs,
1090+
epilogue_outputs=epilogue_outputs,
1091+
names_str=names_str,
1092+
input_reorder=input_reorder,
1093+
)
1094+
1095+
test_call_statement = self.test_call_statement(kernel, inputs, names_str)
1096+
1097+
instance_definition, instance_type = self._define_gemm_instance(op, evt_name)
10551098

10561099
options = dict(
10571100
alpha=self.alpha,
@@ -1069,9 +1112,10 @@ def render( # type: ignore[override]
10691112
instance_definition=instance_definition,
10701113
instance_type=instance_type,
10711114
input_reorder=self.input_reorder,
1072-
epilogue_args=epilogue_args,
1115+
epilogue_args=evt_args,
10731116
test_call_statement=test_call_statement,
10741117
op_conf_name=op.configuration_name(),
1118+
epilogue_visitor_tree=evt_code,
10751119
)
10761120
options.update(dict(zip(extra_names, extra_inputs)))
10771121
res = self._template_from_string(self._get_template()).render(**options)
@@ -1106,8 +1150,25 @@ def test_call_statement(
11061150
]
11071151
return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950
11081152

1153+
def _render_evt(
1154+
self,
1155+
op: GemmOperation,
1156+
evt_py_code: str,
1157+
read_names: list[str],
1158+
write_names: list[str],
1159+
buffer_renames: dict[str, str],
1160+
output_dtype: torch.dtype,
1161+
accumulator_dtype: torch.dtype,
1162+
) -> tuple[str, str, str]: # type: ignore[name-defined] # noqa: F821
1163+
raise NotImplementedError("_render_evt in CUTLASSGemmTemplate not implemented")
1164+
11091165

11101166
class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
1167+
"""
1168+
CUTLASS 3x GEMM Template, which is used to generate CUTLASS GEMM kernels
1169+
including those which allow flexible fusions with epilogues.
1170+
"""
1171+
11111172
def __init__(
11121173
self,
11131174
input_nodes: list[Buffer],
@@ -1239,6 +1300,43 @@ def _are_inputs_layout_compatible(self, layouts: list[Layout]) -> bool:
12391300
return False
12401301
return True
12411302

1303+
def _render_evt(
1304+
self,
1305+
op: GemmOperation,
1306+
evt_py_code: str,
1307+
read_names: list[str],
1308+
write_names: list[str],
1309+
buffer_renames: dict[str, str],
1310+
output_dtype: torch.dtype,
1311+
accumulator_dtype: torch.dtype,
1312+
) -> tuple[str, str, str]: # type: ignore[name-defined] # noqa: F821
1313+
from .cutlass_lib_extensions.evt_extensions import create_example_tensors, trace
1314+
1315+
name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs
1316+
1317+
acc_dtype = torch_dtype_to_cutlass_type(accumulator_dtype)
1318+
output_dtype = torch_dtype_to_cutlass_type(output_dtype)
1319+
evt_name, evt_args, evt_code = trace(
1320+
evt_py_code,
1321+
create_example_tensors(
1322+
read_names,
1323+
write_names,
1324+
buffer_renames,
1325+
name_to_buffer, # type: ignore[arg-type]
1326+
),
1327+
acc_dtype,
1328+
output_dtype,
1329+
op.tile_description, # type: ignore[attr-defined]
1330+
op.epilogue_schedule, # type: ignore[attr-defined]
1331+
name_to_buffer, # type: ignore[arg-type]
1332+
)
1333+
1334+
return (
1335+
evt_name,
1336+
evt_args,
1337+
evt_code,
1338+
)
1339+
12421340
def _shape_match(
12431341
self,
12441342
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
@@ -1282,7 +1380,8 @@ def _set_bias_layout_and_alignment(
12821380

12831381
def _define_gemm_instance(
12841382
self,
1285-
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
1383+
op: GemmOperation,
1384+
evt_name: Optional[str] = None,
12861385
) -> tuple[str, str]:
12871386
"""Defines and renders the Cutlass / CUDA C++ code for a given GEMM operation instance.
12881387
@@ -1298,15 +1397,18 @@ def _define_gemm_instance(
12981397
code (render) and the second part is the string that specifies the operation type.
12991398
"""
13001399
assert cutlass_utils.try_import_cutlass()
1301-
import cutlass_library.gemm_operation as cutlass_gemm_op
13021400
import cutlass_library.library as cutlass_lib
13031401

1304-
emitter = cutlass_gemm_op.EmitGemmUniversal3xInstance()
1402+
from .cutlass_lib_extensions import gemm_operation_extensions as gemm_extensions
1403+
1404+
emitter = gemm_extensions.EmitGemmUniversal3xInstanceWithEVT(evt_name=evt_name)
1405+
13051406
if not hasattr(op, "epilogue_functor") or not isinstance(
13061407
op.epilogue_functor, enum.Enum
13071408
):
13081409
op = copy.deepcopy(op)
13091410
op.epilogue_functor = cutlass_lib.EpilogueFunctor.LinearCombination
1411+
13101412
op_def = emitter.emit(op)
13111413
pattern = re.compile(r"\s*struct\s(.*?)\s:")
13121414
decl = [line for line in op_def.split("\n") if "struct " in line][-1]
@@ -1318,6 +1420,7 @@ def _define_gemm_instance(
13181420
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
13191421
op_def += f"\n using {op_type}_device_type = cutlass::gemm::device::GemmUniversalAdapter<{op_type}>;\n"
13201422
op_type = f"{op_type}_device_type"
1423+
13211424
return op_def, op_type
13221425

13231426
def _get_extra_inputs_and_names(
@@ -1564,7 +1667,8 @@ def _set_bias_layout_and_alignment(
15641667

15651668
def _define_gemm_instance(
15661669
self,
1567-
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
1670+
op: GemmOperation,
1671+
evt_name: Optional[str] = None,
15681672
) -> tuple[str, str]:
15691673
"""Defines and renders the Cutlass / CUDA C++ code for a given GEMM operation instance.
15701674

torch/_inductor/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,6 +1332,9 @@ class cuda:
13321332
# The L2 swizzle values to consider when profiling CUTLASS configs in max_autotune.
13331333
cutlass_max_profiling_swizzle_options: list[int] = [1, 2, 4]
13341334

1335+
# Whether to use CUTLASS EVT for epilogue fusion
1336+
cutlass_epilogue_fusion_enabled = False
1337+
13351338
# Path to CUDA NVCC.
13361339
# NVCC search order:
13371340
# 1) cuda_cxx set in this config

0 commit comments

Comments
 (0)
0