8000 [Cutlass] Fixes for e2e compilation in arg rendering (#151405) · pytorch/pytorch@a1f6d85 · GitHub
[go: up one dir, main page]

Skip to content

Commit a1f6d85

Browse files
mlazospytorchmergebot
authored andcommitted
[Cutlass] Fixes for e2e compilation in arg rendering (#151405)
Pull Request resolved: #151405 Approved by: https://github.com/eellison ghstack dependencies: #152305, #152306, #150905
1 parent a0ce5ce commit a1f6d85

File tree

3 files changed

+65
-46
lines changed

3 files changed

+65
-46
lines changed

test/inductor/test_cutlass_evt.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -301,38 +301,37 @@ def test_evt_argument_codegen(self):
301301
epilogue_functor, _create_mock_buffer_name_map(EXAMPLE_TENSORS)
302302
),
303303
"""\
304-
{{
305-
{ /* thread */
304+
{ /* thread */
306305
{ /* F */
307-
{ /* compute_1 */
308-
{ /* compute_0 */
309-
{}, /* accum */
310-
{}, /* C */
311-
{}, /* compute_0 */
312-
},
313-
{/* ptr_aux */ aux.get(), /* null_default */ float, /* dAux */ {2048, _1{}, _0{}}}, /* aux */
314-
{}, /* compute_1 */
306+
{ /* compute_1 */
307+
{ /* compute_0 */
308+
{}, /* accum */
309+
{}, /* C */
310+
{}, /* compute_0 */
315311
},
316-
{/* ptr_aux */ F.get(), /* dAux */ {2048, _1{}, _0{}}}, /* F */
312+
{/* ptr_aux */ (float*) aux, /* null_default */ float(0), /* dAux */ {2048, _1{}, _0{}}}, /* aux */
313+
{}, /* compute_1 */
314+
},
315+
{/* ptr_aux */ (float*) F, /* dAux */ {2048, _1{}, _0{}}}, /* F */
317316
},
318-
{/* ptr_col */ bias.get(), /* null_default */ float, /* dCol */ {}}, /* bias */
317+
{/* ptr_col */ (float*) bias, /* null_default */ float(0), /* dCol */ {}}, /* bias */
319318
{}, /* compute_2 */
320319
{}, /* compute_3 */
321320
{}, /* compute_4 */
322-
},
323-
}};
321+
}
324322
""",
325323
)
326324

327325
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
328326
def test_evt_codegen(self):
329-
_, code = trace(
327+
_, _, code = trace(
330328
BIAS_CODE,
331329
EXAMPLE_TENSORS,
332330
DataType.f32,
333331
DataType.f32,
334332
MockTileDescription(),
335333
EpilogueScheduleType.ScheduleAuto,
334+
_create_mock_buffer_name_map(EXAMPLE_TENSORS),
336335
)
337336
self.assertExpectedInline(
338337
code,

torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,9 @@ def trace(
115115
output_type: DataType,
116116
tile_description: TileDescription,
117117
epilogue_schedule: EpilogueScheduleType,
118+
name_to_buffer: dict[str, Buffer],
118119
**kwargs: dict[str, Any],
119-
) -> tuple[str, str]:
120+
) -> tuple[str, str, str]:
120121
cuda_arch = int(cuda_env.get_cuda_arch()) # type: ignore[arg-type]
121122
assert cuda_arch >= 90, "Only SM90+ is supported for EVT"
122123
epilogue_functor = _trace(fn_src, example_tensors, **kwargs)
@@ -129,8 +130,9 @@ def trace(
129130
output_type,
130131
fusion_callbacks,
131132
)
132-
133-
return collective_epilogue.emit()
133+
evt_name, evt_code = collective_epilogue.emit()
134+
evt_args = _render_argument_type(epilogue_functor, name_to_buffer)
135+
return evt_name, evt_args, evt_code
134136

135137
# Based off of
136138
# https://github.com/NVIDIA/cutlass/blob/df18f5e4f5de76bed8be1de8e4c245f2f5ec3020/python/cutlass/epilogue/epilogue.py#L117
@@ -167,33 +169,42 @@ def is_nested_visitor_type(t: type) -> bool:
167169
)
168170

169171
buffer = IndentedBuffer()
172+
with buffer.set_tabwidth(2):
170173

171-
def render_argument_type(name: str, t: CutlassArgType) -> None:
172-
if issubclass(t, ctypes.c_byte):
173-
buffer.writeline(f"{{}}, /* {name} */")
174-
else:
175-
fields = [
176-
(fname, _get_arg_from_node(ty, name_to_buffer[name]))
177-
for fname, ty in t._fields_
178-
]
179-
field_strs = [f"/* {fname} */ {str(field)}" for fname, field in fields]
180-
buffer.writeline(f"{{{', '.join(field_strs)}}}, /* {name} */")
181-
182-
def render_thread_type(name: str, t: CutlassArgType) -> None:
183-
if is_nested_visitor_type(t):
184-
buffer.writeline(f"{{ /* {name} */")
185-
with buffer.indent():
186-
for name, inner_t in t._fields_:
187-
render_thread_type(name, inner_t)
188-
buffer.writeline("},")
189-
else:
190-
render_argument_type(name, t)
191-
192-
buffer.writeline("{{")
193-
with buffer.indent():
194-
render_thread_type("thread", epilogue_thread_type)
195-
196-
buffer.writeline("}};")
174+
def render_argument_type(name: str, t: CutlassArgType) -> None:
175+
if issubclass(t, ctypes.c_byte):
176+
buffer.writeline(f"{{}}, /* {name} */")
177+
else:
178+
fields = [
179+
(fname, _get_arg_from_node(ty, name_to_buffer[name]))
180+
for fname, ty in t._fields_
181+
]
182+
field_strs = [
183+
f"/* {fname} */ {str(field)}" for fname, field in fields
184+
]
185+
buffer.writeline(f"{{{', '.join(field_strs)}}}, /* {name} */")
186+
187+
def render_thread_type(name: str, t: CutlassArgType) -> None:
188+
if is_nested_visitor_type(t):
189+
buffer.writeline(f"{{ /* {name} */")
190+
with buffer.indent():
191+
for name, inner_t in t._fields_:
192+
render_thread_type(name, inner_t)
193+
buffer.writeline("},")
194+
else:
195+
render_argument_type(name, t)
196+
197+
# unroll the recursion once to address special case formatting
198+
# namely, no ending comma and no indentation for the outermost thread type
199+
buffer.writeline("{ /* thread */")
200+
with buffer.indent(3):
201+
if is_nested_visitor_type(epilogue_thread_type):
202+
with buffer.indent():
203+
for name, inner_t in epilogue_thread_type._fields_:
204+
render_thread_type(name, inner_t)
205+
else:
206+
render_argument_type("thread", epilogue_thread_type)
207+
buffer.writeline("}")
197208

198209
return buffer.getvalue()
199210

@@ -225,11 +236,11 @@ def render_stride(x: int) -> str:
225236
return f"{{{', '.join([render_stride(x) for x in stride])}}}"
226237

227238
elif issubclass(arg_ty, ctypes.c_void_p):
228-
return f"{node.get_name()}.get()"
239+
return f"({CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}*) {node.get_name()}"
229240
elif (
230241
arg_ty in _CUTLASS_C_DTYPES
231242
): # Assumption: this is the element dtype, this holds for all cutlass ir nodes currently
232-
return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]
243+
return f"{CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}(0)"
233244
elif issubclass(arg_ty, EmptyByte):
234245
return "{}"
235246

torch/_inductor/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,6 +1165,15 @@ def __init__(self, initial_indent: int = 0) -> None:
11651165
self._lines: list[Union[DeferredLineBase, LineContext, str]] = []
11661166
self._indent = initial_indent
11671167

1168+
@contextlib.contextmanager
1169+
def set_tabwidth(self, tabwidth: int) -> Iterator[None]:
1170+
prev = self.tabwidth
1171+
try:
1172+
self.tabwidth = tabwidth
1173+
yield
1174+
finally:
1175+
self.tabwidth = prev
1176+
11681177
def getvaluewithlinemap(self) -> ValueWithLineMap:
11691178
buf = StringIO()
11701179
p = 1

0 commit comments

Comments
 (0)
0