8000 [Cutlass] Implement Epilogue Argument emitter (#150903) · pytorch/pytorch@4f62dcc · GitHub
[go: up one dir, main page]

Skip to content

Commit 4f62dcc

Browse files
mlazospytorchmergebot
authored andcommitted
[Cutlass] Implement Epilogue Argument emitter (#150903)
This implements epilogue visitor tree argument generation (example type [here](https://github.com/NVIDIA/cutlass/blob/3fe62887d8dd75700fdaf57f9c181878701b0802/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp#L332)). Details: The codegen task here is to implement a function which can generate a tree of C++ structs and properly extract the correct properties from Inductor buffers and write them to the correct locations in the generated struct. To implement this with the minimum amount of code, I generate the cutlass DAGIR (the EVT internal represenation) which specifically has a pass, [pass_argument_type.py ](https://github.com/NVIDIA/cutlass/blob/5e497243f7ad13a2aa842143f9b10bbb23d98292/python/cutlass/backend/evt/passes/pass_argument_type.py#L4) which generates a nested tree of custom argument types for each node in the DAGIR. This nested tree of constructors is then passed kwargs to fill in the proper values, where the node's name is used to differentiate between different values in the kwarg dictionary. This however is non-customizable; the nested tree of EVT args is a nested tree of ctypes which looks for *actual values* so that this object can be passed directly to the cutlass-python C++ runner. Inductor on the other hand needs to fill this struct with string C++ expressions representing the values (or extracting the values from kernel launcher args). So `_render_argument_type` implements this: it iterates over the tree of types created by pass_argument_type.py and generates a string representing the nested structs, filling in C++ expressions representing the different fields. Long term plan: Long term, I will ask the nvidia to provide an overridable [visitor_factory](https://github.com/NVIDIA/cutlass/blob/5e497243f7ad13a2aa842143f9b10bbb23d98292/python/cutlass/backend/evt/passes/pass_argument_type.py#L82) which could allow us to override the behavior of pass_argument_type.py to generate the string we would like during DAGIR generation. Previously merged: * #150346 * #150345 * #150344 Pull Request resolved: #150903 Approved by: https://github.com/henrylhtsang, https://github.com/eellison
1 parent 8e0f9fb commit 4f62dcc

File tree

2 files changed

+226
-102
lines changed

2 files changed

+226
-102
lines changed

test/inductor/test_cutlass_evt.py

Lines changed: 114 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Owner(s): ["module: inductor"]
22
import unittest
33

4+
import torch
45
from torch._dynamo.test_case import TestCase
56
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
67
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
@@ -13,54 +14,111 @@
1314
LayoutType = cutlass_lib.LayoutType
1415
DataType = cutlass_lib.DataType
1516
from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
17+
_render_argument_type,
18+
_trace,
1619
CutlassTensor,
1720
trace,
1821
)
1922

23+
BIAS_CODE = """def example_epilogue(accum, C, aux, bias):
24+
F = accum + C + aux
25+
E = relu(F) + bias
26+
D = E + F
27+
return D, F"""
28+
29+
TYPE_C = DataType.f32
30+
M = 4224
31+
N = 2048
32+
BIAS = CutlassTensor(shape=(M, 1), element=TYPE_C, layout_tag=LayoutType.RowMajor)
33+
34+
EXAMPLE_TENSORS = {
35+
"accum": CutlassTensor(
36+
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
37+
),
38+
"bias": BIAS,
39+
# "beta": 0.5, TODO: mlazos support scalars
40+
# "alpha": 0.5, TODO: mlazos support scalars
41+
"D": CutlassTensor(
42+
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
43+
),
44+
"C": CutlassTensor(
45+
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
46+
),
47+
"F": CutlassTensor(
48+
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
49+
),
50+
"aux": CutlassTensor(
51+
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
52+
),
53+
}
54+
2055
class MockTileDescription:
2156
threadblock_shape = (128, 128, 8)
2257

58+
def _create_mock_buffer_name_map(example_tensors):
59+
class MockNode:
60+
def __init__(self, name, stride, dtype):
61+
self.name = name
62+
self.dtype = dtype
63+
self.stride = stride
64+
65+
def get_layout(self):
66+
class MockLayout:
67+
def __init__(self, stride, dtype):
68+
self.dtype = dtype
69+
self.stride = stride
70+
71+
return MockLayout(self.stride, self.dtype)
72+
73+
def get_name(self):
74+
return self.name
75+
76+
name_to_buffer = {}
77+
for name, tensor in example_tensors.items():
78+
if isinstance(tensor, CutlassTensor):
79+
name_to_buffer[name] = MockNode(name, tensor.stride, torch.float32)
80+
81+
return name_to_buffer
82+
2383

2484
class TestCutlassEVT(TestCase):
2585
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
26-
def test_evt_codegen(self):
27-
bias_code = """def example_epilogue(accum, alpha, C, beta, aux, bias):
28-
F = alpha * accum + (beta * C + aux)
29-
E = relu(F + 1) + bias
30-
D = E + F
31-
return D, F"""
32-
33-
type_C = DataType.f32
34-
m = 4224
35-
n = 2048
36-
bias = CutlassTensor(
37-
shape=(m, 1), element=type_C, layout_tag=LayoutType.RowMajor
38-
)
86+
def test_evt_argument_codegen(self):
87+
epilogue_functor = _trace(BIAS_CODE, EXAMPLE_TENSORS)
3988

40-
examples_tensors = {
41-
"accum": CutlassTensor(
42-
element=DataType.f32, shape=(m, n), layout_tag=LayoutType.RowMajor
43-
),
44-
"bias": bias,
45-
"beta": 0.5,
46-
"alpha": 0.5,
47-
"D": CutlassTensor(
48-
element=DataType.f32, shape=(m, n), layout_tag=LayoutType.RowMajor
49-
),
50-
"C": CutlassTensor(
51-
element=DataType.f32, shape=(m, n), layout_tag=LayoutType.RowMajor
52-
),
53-
"F": CutlassTensor(
54-
element=DataType.f32, shape=(m, n), layout_tag=LayoutType.RowMajor
55-
),
56-
"aux": CutlassTensor(
57-
element=DataType.f32, shape=(m, n), layout_tag=LayoutType.RowMajor
89+
self.assertExpectedInline(
90+
_render_argument_type(
91+
epilogue_functor, _create_mock_buffer_name_map(EXAMPLE_TENSORS)
5892
),
59-
}
93+
"""\
94+
{{
95+
{ /* thread */
96+
{ /* F */
97+
{ /* compute_1 */
98+
{ /* compute_0 */
99+
{}, /* accum */
100+
{}, /* C */
101+
{}, /* compute_0 */
102+
},
103+
{/* ptr_aux */ aux.get(), /* null_default */ float, /* dAux */ {2048, _1{}, _0{}}}, /* aux */
104+
{}, /* compute_1 */
105+
},
106+
{/* ptr_aux */ F.get(), /* dAux */ {2048, _1{}, _0{}}}, /* F */
107+
},
108+
{/* ptr_col */ bias.get(), /* null_default */ float, /* dCol */ {}}, /* bias */
109+
{}, /* compute_2 */
110+
{}, /* compute_3 */
111+
{}, /* compute_4 */
112+
},
113+
}};
114+
""",
115+
)
60116

117+
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
118+
def test_evt_codegen(self):
61119
_, code = trace(
62-
bias_code,
63-
examples_tensors,
120+
BIAS_CODE,
121+
EXAMPLE_TENSORS,
64122
DataType.f32,
65123
DataType.f32,
66124
MockTileDescription(),
@@ -82,20 +140,13 @@ def test_evt_codegen(self):
82140
83141
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
84142
85-
using Alpha = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
86-
float, cute::Stride<cute::Int<0>, cute::Int<0>, cute::Int<0>>, 1, cutlass::multiplies
87-
>;
88-
89-
using AuxDescriptor = cutlass::epilogue::collective::detail::AuxLoadDescriptor\
90-
<EpilogueDescriptor, cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, float>;
143+
using AuxDescriptor = cutlass::epilogue::collective::detail::AuxLoadDescriptor<EpilogueDescriptor, \
144+
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, float>;
91145
92146
using Aux = cutlass::epilogue::fusion::Sm90AuxLoad<
93147
AuxDescriptor::Stages, typename AuxDescriptor::EpilogueTile, float,
94-
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, typename AuxDescriptor::SmemLayoutAtom, typename AuxDescriptor::CopyOpS2R
95-
>;
96-
97-
using Beta = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
98-
float, cute::Stride<cute::Int<0>, cute::Int<0>, cute::Int<0>>, 1, cutlass::multiplies
148+
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, typename AuxDescriptor::SmemLayoutAtom, \
149+
typename AuxDescriptor::CopyOpS2R
99150
>;
100151
101152
using Bias = cutlass::epilogue::fusion::Sm90ColBroadcast<
@@ -104,102 +155,69 @@ def test_evt_codegen(self):
104155
>;
105156
106157
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
107-
cutlass::multiplies, float, float,
158+
cutlass::plus, float, float,
108159
cutlass::FloatRoundStyle::round_to_nearest
109160
>;
110161
111162
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<
112163
Compute0,
113-
Alpha,
114-
Accum>;
164+
Accum,
165+
TensorC>;
115166
116167
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
117-
cutlass::multiplies, float, float,
168+
cutlass::plus, float, float,
118169
cutlass::FloatRoundStyle::round_to_nearest
119170
>;
120171
121172
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<
122173
Compute1,
123-
Beta,
124-
TensorC>;
125-
126-
using Compute2 = cutlass::epilogue::fusion::Sm90Compute<
127-
cutlass::plus, float, float,
128-
cutlass::FloatRoundStyle::round_to_nearest
129-
>;
130-
131-
using EVTCompute2 = cutlass::epilogue::fusion::Sm90EVT<
132-
Compute2,
133-
EVTCompute1,
134-
Aux>;
135-
136-
using Compute3 = cutlass::epilogue::fusion::Sm90Compute<
137-
cutlass::plus, float, float,
138-
cutlass::FloatRoundStyle::round_to_nearest
139-
>;
140-
141-
using EVTCompute3 = cutlass::epilogue::fusion::Sm90EVT<
142-
Compute3,
143174
EVTCompute0,
144-
EVTCompute2>;
175+
Aux>;
145176
146177
using FDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
147178
EpilogueDescriptor, cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, float
148179
>;
149180
150181
using F = cutlass::epilogue::fusion::Sm90AuxStore<
151182
FDescriptor::Stages, typename FDescriptor::EpilogueTile, float,
152-
cutlass::FloatRoundStyle::round_to_nearest, \
153-
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, typename FDescriptor::SmemLayoutAtom,
183+
cutlass::FloatRoundStyle::round_to_nearest, cute::Stride<int64_t, cute::Int<1>, \
184+
cute::Int<0>>, typename FDescriptor::SmemLayoutAtom,
154185
typename FDescriptor::CopyOpR2S
155186
>;
156187
157188
using EVTF = cutlass::epilogue::fusion::Sm90EVT<
158189
F,
159-
EVTCompute3>;
160-
161-
using Imm10 = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
162-
float, cute::Stride<cute::Int<0>, cute::Int<0>, cute::Int<0>>, 1, cutlass::multiplies
163-
>;
164-
165-
using Compute4 = cutlass::epilogue::fusion::Sm90Compute<
166-
cutlass::plus, float, float,
167-
cutlass::FloatRoundStyle::round_to_nearest
168-
>;
190+
EVTCompute1>;
169191
170-
using Compute5 = cutlass::epilogue::fusion::Sm90Compute<
192+
using Compute2 = cutlass::epilogue::fusion::Sm90Compute<
171193
cutlass::epilogue::thread::ReLu, float, float,
172194
cutlass::FloatRoundStyle::round_to_nearest
173195
>;
174196
175-
using Compute6 = cutlass::epilogue::fusion::Sm90Compute<
197+
using Compute3 = cutlass::epilogue::fusion::Sm90Compute<
176198
cutlass::plus, float, float,
177199
cutlass::FloatRoundStyle::round_to_nearest
178200
>;
179201
180-
using Compute7 = cutlass::epilogue::fusion::Sm90Compute<
202+
using Compute4 = cutlass::epilogue::fusion::Sm90Compute<
181203
cutlass::plus, float, float,
182204
cutlass::FloatRoundStyle::round_to_nearest
183205
>;
184206
185-
using DagCompute7 = cutlass::epilogue::fusion::Sm90TopologicalVisitor<
207+
using DagCompute4 = cutlass::epilogue::fusion::Sm90TopologicalVisitor<
186208
float,
187209
cute::tuple<
188210
cute::seq<>,
189211
cute::seq<>,
190-
cute::seq<>,
191-
cute::seq<0, 2>,
192-
cute::seq<3>,
193-
cute::seq<4, 1>,
194-
cute::seq<5, 0>,
212+
cute::seq<0>,
213+
cute::seq<2, 1>,
214+
cute::seq<3, 0>,
195215
>,
196216
EVTF,
197217
Bias,
198-
Imm10,
199-
Compute4,
200-
Compute5,
201-
Compute6,
202-
Compute7
218+
Compute2,
219+
Compute3,
220+
Compute4
203221
>;
204222
205223
using ElementD = float;

0 commit comments

Comments
 (0)
0