8
8
from abc import ABC , abstractmethod
9
9
from typing import Any , Optional , Union
10
10
11
+ import torch
12
+ from torch ._inductor .scheduler import BaseSchedulerNode
11
13
from torch ._inductor .select_algorithm import create_inputs_key
12
14
from torch ._inductor .utils import clear_on_fresh_inductor_cache
13
15
22
24
Layout ,
23
25
ReinterpretView ,
24
26
)
25
- from ...utils import is_dynamic
27
+ from ...utils import is_dynamic , OrderedSet
26
28
from ...virtualized import V
27
29
from ..common import IndentedBuffer
28
30
from . import cutlass_utils
29
31
from .cuda_kernel import CUDATemplateKernel
30
32
from .cuda_template import CUTLASSTemplate
31
33
from .cutlass_presets import gen_cutlass_presets
34
+ from .cutlass_python_evt import CutlassEVTCodegen
35
+ from .cutlass_utils import torch_dtype_to_cutlass_type
32
36
33
37
38
+ GemmOperation = Any
39
+
34
40
log = logging .getLogger (__name__ )
35
41
36
42
# Jinja template for GEMM Kernel, used by the CUTLASSGemm3xTemplate class below.
37
43
GEMM_TEMPLATE_CUTLASS_3X = r"""
38
44
{{template.header().getvalue()}}
39
45
{{template.globals().getvalue()}}
46
+ {{epilogue_visitor_tree}}
40
47
{{instance_definition}}
41
48
// When workspace_size is not a nullptr, populates requested workspace_size and returns.
42
49
// Otherwise, computes the Gemm kernel using the given workspace ptr.
@@ -495,7 +502,8 @@ def _set_bias_layout_and_alignment(
495
502
@abstractmethod
496
503
def _define_gemm_instance (
497
504
self ,
498
- op : "cutlass_library.gemm_op.GemmOperation" , # type: ignore[name-defined] # noqa: F821
505
+ op : GemmOperation ,
506
+ evt_name : Optional [str ] = None ,
499
507
) -> tuple [str , str ]:
500
508
raise NotImplementedError
501
509
@@ -965,6 +973,7 @@ def render( # type: ignore[override]
965
973
kernel : CUDATemplateKernel ,
966
974
op : "cutlass_gemm_op.GemmOperation" = None , # type: ignore[name-defined] # noqa: F821
967
975
template_buffer_node : Optional [CUDATemplateBuffer ] = None ,
976
+ epilogue_nodes : Optional [list [BaseSchedulerNode ]] = None ,
968
977
** kwargs ,
969
978
) -> str :
970
979
"""
@@ -995,6 +1004,11 @@ def render( # type: ignore[override]
995
1004
"op argument is required and has to be an instance of GemmOperation"
996
1005
)
997
1006
1007
+ if epilogue_nodes and not self ._has_tma_epilogue (op ):
1008
+ raise NotImplementedError (
1009
+ "Non-TMA epilogue visitor tree is not supported in Cutlass."
1010
+ )
1011
+
998
1012
assert len (self .input_nodes ) >= 2 and self .output_node is not None
999
1013
X , W = self .input_nodes [0 ], self .input_nodes [1 ]
1000
1014
for input_node in self .input_nodes :
@@ -1017,15 +1031,7 @@ def render( # type: ignore[override]
1017
1031
input_reorder = self .input_reorder
1018
1032
else :
1019
1033
input_reorder = None
1020
- kernel_call_signature = kernel .def_kernel (
1021
- inputs = inputs , # type: ignore[arg-type]
1022
- outputs = [Y ],
1023
- names_str = names_str ,
1024
- input_reorder = input_reorder ,
1025
- epilogue_inputs = [], # TODO mlazos: will be filled in in https://github.com/pytorch/pytorch/pull/150907
1026
- epilogue_outputs = [], # TODO mlazos: will be filled in in https://github.com/pytorch/pytorch/pull/150907
1027
- )
1028
- test_call_statement = self .test_call_statement (kernel , inputs , names_str )
1034
+
1029
1035
# The layouts might have changed between autotuning and this call if they were FlexibleLayout
1030
1036
# we need to adapt, which might lead to suboptimal performance.
1031
1037
op = self .fix_op_layout (op , X , W , Bias , Y )
@@ -1040,7 +1046,6 @@ def render( # type: ignore[override]
1040
1046
1041
1047
argument_template , epilogue_template = self ._get_template_args (op )
1042
1048
should_swap_xw : bool = False
1043
- epilogue_args = f"{{ElementComputeEpilogue({ self .alpha } ), ElementComputeEpilogue({ self .beta } )}}"
1044
1049
if Bias is not None and self ._has_tma_epilogue (op ):
1045
1050
if (
1046
1051
op .epilogue_schedule
@@ -1051,7 +1056,45 @@ def render( # type: ignore[override]
1051
1056
op = self .swap_XW (op )
1052
1057
should_swap_xw = True
1053
1058
1054
- instance_definition , instance_type = self ._define_gemm_instance (op )
1059
+ if epilogue_nodes :
1060
+ evt_read_names , evt_write_names , buffer_renames , evt_py_code = (
1061
+ CutlassEVTCodegen .ir_to_evt_python_code (Y .get_name (), epilogue_nodes )
1062
+ )
1063
+ read_names = OrderedSet (evt_read_names ) - OrderedSet (evt_write_names )
1064
+ write_names = OrderedSet (evt_write_names )
1065
+
1066
+ name_to_buffer = V .graph .name_to_buffer | V .graph .graph_inputs
1067
+ epilogue_inputs = [name_to_buffer [name ] for name in read_names ]
1068
+ epilogue_outputs = [name_to_buffer [name ] for name in write_names ]
1069
+
1070
+ evt_name , evt_args , evt_code = self ._render_evt (
1071
+ op ,
1072
+ evt_py_code ,
1073
+ evt_read_names ,
1074
+ evt_write_names ,
1075
+ buffer_renames ,
1076
+ Y .get_layout ().dtype ,
1077
+ W .get_layout ().dtype ,
1078
+ )
1079
+ else :
1080
+ evt_name = None
1081
+ epilogue_inputs = []
1082
+ epilogue_outputs = []
1083
+ evt_args = f"{{ElementComputeEpilogue({ self .alpha } ), ElementComputeEpilogue({ self .beta } )}}"
1084
+ evt_code = ""
1085
+
1086
+ kernel_call_signature = kernel .def_kernel (
1087
+ inputs = inputs , # type: ignore[arg-type]
1088
+ outputs = [Y ],
1089
+ epilogue_inputs = epilogue_inputs ,
1090
+ epilogue_outputs = epilogue_outputs ,
1091
+ names_str = names_str ,
1092
+ input_reorder = input_reorder ,
1093
+ )
1094
+
1095
+ test_call_statement = self .test_call_statement (kernel , inputs , names_str )
1096
+
1097
+ instance_definition , instance_type = self ._define_gemm_instance (op , evt_name )
1055
1098
1056
1099
options = dict (
1057
1100
alpha = self .alpha ,
@@ -1069,9 +1112,10 @@ def render( # type: ignore[override]
1069
1112
instance_definition = instance_definition ,
1070
1113
instance_type = instance_type ,
1071
1114
input_reorder = self .input_reorder ,
1072
- epilogue_args = epilogue_args ,
1115
+ epilogue_args = evt_args ,
1073
1116
test_call_statement = test_call_statement ,
1074
1117
op_conf_name = op .configuration_name (),
1118
+ epilogue_visitor_tree = evt_code ,
1075
1119
)
1076
1120
options .update (dict (zip (extra_names , extra_inputs )))
1077
1121
res = self ._template_from_string (self ._get_template ()).render (** options )
@@ -1106,8 +1150,25 @@ def test_call_statement(
1106
1150
]
1107
1151
return f"{ kernel .kernel_name } ({ ', ' .join (arguments )} , M, N, K, B, lda, ldb, ldc, ldd, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950
1108
1152
1153
+ def _render_evt (
1154
+ self ,
1155
+ op : GemmOperation ,
1156
+ evt_py_code : str ,
1157
+ read_names : list [str ],
1158
+ write_names : list [str ],
1159
+ buffer_renames : dict [str , str ],
1160
+ output_dtype : torch .dtype ,
1161
+ accumulator_dtype : torch .dtype ,
1162
+ ) -> tuple [str , str , str ]: # type: ignore[name-defined] # noqa: F821
1163
+ raise NotImplementedError ("_render_evt in CUTLASSGemmTemplate not implemented" )
1164
+
1109
1165
1110
1166
class CUTLASS3xGemmTemplate (CUTLASSGemmTemplate ):
1167
+ """
1168
+ CUTLASS 3x GEMM Template, which is used to generate CUTLASS GEMM kernels
1169
+ including those which allow flexible fusions with epilogues.
1170
+ """
1171
+
1111
1172
def __init__ (
1112
1173
self ,
1113
1174
input_nodes : list [Buffer ],
@@ -1239,6 +1300,43 @@ def _are_inputs_layout_compatible(self, layouts: list[Layout]) -> bool:
1239
1300
return False
1240
1301
return True
1241
1302
1303
+ def _render_evt (
1304
+ self ,
1305
+ op : GemmOperation ,
1306
+ evt_py_code : str ,
1307
+ read_names : list [str ],
1308
+ write_names : list [str ],
1309
+ buffer_renames : dict [str , str ],
1310
+ output_dtype : torch .dtype ,
1311
+ accumulator_dtype : torch .dtype ,
1312
+ ) -> tuple [str , str , str ]: # type: ignore[name-defined] # noqa: F821
1313
+ from .cutlass_lib_extensions .evt_extensions import create_example_tensors , trace
1314
+
1315
+ name_to_buffer = V .graph .name_to_buffer | V .graph .graph_inputs
1316
+
1317
+ acc_dtype = torch_dtype_to_cutlass_type (accumulator_dtype )
1318
+ output_dtype = torch_dtype_to_cutlass_type (output_dtype )
1319
+ evt_name , evt_args , evt_code = trace (
1320
+ evt_py_code ,
1321
+ create_example_tensors (
1322
+ read_names ,
1323
+ write_names ,
1324
+ buffer_renames ,
1325
+ name_to_buffer , # type: ignore[arg-type]
1326
+ ),
1327
+ acc_dtype ,
1328
+ output_dtype ,
1329
+ op .tile_description , # type: ignore[attr-defined]
1330
+ op .epilogue_schedule , # type: ignore[attr-defined]
1331
+ name_to_buffer , # type: ignore[arg-type]
1332
+ )
1333
+
1334
+ return (
1335
+ evt_name ,
1336
+ evt_args ,
1337
+ evt_code ,
1338
+ )
1339
+
1242
1340
def _shape_match (
1243
1341
self ,
1244
1342
op : "cutlass_library.gemm_op.GemmOperation" , # type: ignore[name-defined] # noqa: F821
@@ -1282,7 +1380,8 @@ def _set_bias_layout_and_alignment(
1282
1380
1283
1381
def _define_gemm_instance (
1284
1382
self ,
1285
- op : "cutlass_library.gemm_op.GemmOperation" , # type: ignore[name-defined] # noqa: F821
1383
+ op : GemmOperation ,
1384
+ evt_name : Optional [str ] = None ,
1286
1385
) -> tuple [str , str ]:
1287
1386
"""Defines and renders the Cutlass / CUDA C++ code for a given GEMM operation instance.
1288
1387
@@ -1298,15 +1397,18 @@ def _define_gemm_instance(
1298
1397
code (render) and the second part is the string that specifies the operation type.
1299
1398
"""
1300
1399
assert cutlass_utils .try_import_cutlass ()
1301
- import cutlass_library .gemm_operation as cutlass_gemm_op
1302
1400
import cutlass_library .library as cutlass_lib
1303
1401
1304
- emitter = cutlass_gemm_op .EmitGemmUniversal3xInstance ()
1402
+ from .cutlass_lib_extensions import gemm_operation_extensions as gemm_extensions
1403
+
1404
+ emitter = gemm_extensions .EmitGemmUniversal3xInstanceWithEVT (evt_name = evt_name )
1405
+
1305
1406
if not hasattr (op , "epilogue_functor" ) or not isinstance (
1306
1407
op .epilogue_functor , enum .Enum
1307
1408
):
1308
1409
op = copy .deepcopy (op )
1309
1410
op .epilogue_functor = cutlass_lib .EpilogueFunctor .LinearCombination
1411
+
1310
1412
op_def = emitter .emit (op )
1311
1413
pattern = re .compile (r"\s*struct\s(.*?)\s:" )
1312
1414
decl = [line for line in op_def .split ("\n " ) if "struct " in line ][- 1 ]
@@ -1318,6 +1420,7 @@ def _define_gemm_instance(
1318
1420
if op .gemm_kind == cutlass_lib .GemmKind .Universal3x :
1319
1421
op_def += f"\n using { op_type } _device_type = cutlass::gemm::device::GemmUniversalAdapter<{ op_type } >;\n "
1320
1422
op_type = f"{ op_type } _device_type"
1423
+
1321
1424
return op_def , op_type
1322
1425
1323
1426
def _get_extra_inputs_and_names (
@@ -1564,7 +1667,8 @@ def _set_bias_layout_and_alignment(
1564
1667
1565
1668
def _define_gemm_instance (
1566
1669
self ,
1567
- op : "cutlass_library.gemm_op.GemmOperation" , # type: ignore[name-defined] # noqa: F821
1670
+ op : GemmOperation ,
1671
+ evt_name : Optional [str ] = None ,
1568
1672
) -> tuple [str , str ]:
1569
1673
"""Defines and renders the Cutlass / CUDA C++ code for a given GEMM operation instance.
1570
1674
0 commit comments