8000 Update base for Update on "[Cutlass] Integrate EVT into CUDACPPSchedu… · pytorch/pytorch@1f9858c · GitHub
[go: up one dir, main page]

Skip to content

Commit 1f9858c

Browse files
committed
Update base for Update on "[Cutlass] Integrate EVT into CUDACPPScheduling"
Previously merged: * #151713 * #151405 * #150905 * #152306 * #152305 Allow epilogue nodes in cuda combined scheduling cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
1 parent 884ff34 commit 1f9858c

File tree

2 files changed

+70
-3
lines changed

2 files changed

+70
-3
lines changed

test/inductor/test_cutlass_evt.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def __init__(self, name_to_buffer):
105105

106106
self.sizevars = torch._inductor.sizevars.SizeVarAllocator()
107107
self.name_to_buffer = name_to_buffer
108+
self.graph_inputs = dict()
108109
self.mutated_buffers = OrderedSet()
109110

110111

@@ -210,6 +211,64 @@ def inner_fn_buf4(index):
210211
"""Unsupported indexing for buf0 with index 200*i0 + 60000*i1 + i2 and strides [200, 60000, 1]""",
211212
)
212213

214+
@unittest.skipIf(not SM90OrLater, "need sm_90")
215+
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
216+
def test_py_codegen_broadcasting(self):
217+
from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
218+
from torch._inductor.virtualized import V
219+
220+
size = (100, 300, 200)
221+
buf0 = MockComputedBuffer("buf0", None, torch.float32, size)
222+
buf1 = MockComputedBuffer("buf1", None, torch.float32, size)
223+
buf2 = MockComputedBuffer("buf2", None, torch.float32, size)
224+
225+
# buf0 is acc
226+
# buf1 is external
227+
def inner_fn_buf3(index):
228+
tmp0 = buf0.make_loader()(index)
229+
tmp1 = buf1.make_loader()(index)
230+
tmp2 = buf2.make_loader()(index)
231+
return tmp0 * tmp1 + tmp2
232+
233+
def inner_fn_buf4(index):
234+
tmp0 = buf0.make_loader()(index)
235+
tmp3 = buf3.make_loader()(index)
236+
return tmp0 + tmp3 * tmp3
237+
238+
buf3 = MockComputedBuffer("buf3", inner_fn_buf3, torch.float32, size)
239+
buf4 = MockComputedBuffer(
240+
"buf4", inner_fn_buf4, torch.float32, (100, 300, 1)
241+
) # broadcast
242+
with V.set_graph_handler(
243+
MockGraphHandler(
244+
{"buf0": buf0, "buf1": buf1, "buf2": buf2, "buf3": buf3, "buf4": buf4}
245+
)
246+
):
247+
reads, writes, renames, code = CutlassEVTCodegen.ir_to_evt_python_code(
248+
"buf0",
249+
[
250+
MockSchedulerNode(buf3),
251+
MockSchedulerNode(buf4, last_usage=OrderedSet(["buf0"])),
252+
],
253+
)
254+
self.assertExpectedInline(reads, """['buf0', 'buf1', 'buf2']""")
255+
self.assertExpectedInline(writes, """['buf3', 'buf4']""")
256+
self.assertExpectedInline(
257+
renames, """{'buf3': 'D', 'buf4': 'tmp_3', 'buf0': 'accum'}"""
258+
)
259+
self.assertExpectedInline(
260+
code,
261+
"""\
262+
def fn(accum, buf1, buf2):
263+
tmp_0 = accum * buf1
264+
tmp_1 = tmp_0 + buf2
265+
D = tmp_1 # cutlass evt requirement
266+
tmp_2 = D * D
267+
tmp_3 = accum + tmp_2
268+
269+
return D, tmp_3""",
270+
)
271+
213272
@unittest.skipIf(not SM90OrLater, "need sm_90")
214273
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
215274
def test_py_codegen(self):

torch/_inductor/codegen/cuda/cutlass_python_evt.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(self, accumulator_node_name: str, last_usages: OrderedSet[str]):
6969
self.reads: OrderedSet[str] = OrderedSet()
7070
self.last_usages: OrderedSet[str] = OrderedSet()
7171
self.cur_node: Optional[ComputedBuffer] = None
72+
self.name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs
7273

7374
if accumulator_node_name not in last_usages:
7475
self.store(accumulator_node_name, value=OpsValue(_ACCUMULATOR_ALIAS))
@@ -207,15 +208,22 @@ def _check_indexing(self, name: str, index: sympy.Expr) -> None:
207208
# We only support indexing that matches the layout today because
208209
# CUTLASS doesn't support arbitrary indexing
209210
buffer_name = self.accumulator_node_name if name == _ACCUMULATOR_ALIAS else name
210-
buffer = V.graph.name_to_buffer[buffer_name]
211+
buffer = self.name_to_buffer[buffer_name]
211212
index_strides = V.graph.sizevars.stride_vars(
212213
index, self._get_current_index_vars()
213214
)
214-
if buffer.get_layout().stride != index_strides:
215+
stride = buffer.get_layout().stride
216+
if not self._stride_compatible(stride, index_strides):
215217
raise NotImplementedError(
216-
f"Unsupported indexing for {name} with index {index} and strides {index_strides}"
218+
f"Unsupported indexing for {name} with index {index}, index strides {index_strides}, and layout stride {stride}"
217219
)
218220

221+
def _stride_compatible(self, left, right):
222+
return all(
223+
sympy.Eq(l, r) or sympy.Eq(l, 0) or sympy.Eq(r, 0)
224+
for l, r in (zip(left, right))
225+
)
226+
219227
def _render_input_signature(self) -> str:
220228
arguments = ", ".join(
221229
[_ACCUMULATOR_ALIAS]

0 commit comments

Comments
 (0)
0