8000 [mm_logs][ez] dump tuned mm info at lowering stage (#148363) · pytorch/pytorch@1673bc7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1673bc7

Browse files
YUNQIUGUOpytorchmergebot
authored andcommitted
[mm_logs][ez] dump tuned mm info at lowering stage (#148363)
Summary: As title. it would be beneficial for judging e2e perf improvement Easy first step to dump mm info at lowering stage. e.g. ``` fbsource/fbcode/caffe2/torch/_inductor/kernel/mm.py:525] [0/0] Tuned aten.addmm: m=16, n=6, k=16, layout=FixedLayout('cuda:0', torch.float32, size=[16, 6], stride=[6, 1]) ``` Next step: Dump overview info at `post_grad_graph` stage such as overall count of `aten.mm` in the graph & visualize to a table structure. Test Plan: by looking very hard in aot inductor bmm and mm UTs. Differential Revision: D70507880 Pull Request resolved: #148363 Approved by: https://github.com/henrylhtsang
1 parent edc3ca5 commit 1673bc7

File tree

3 files changed

+61
-7
lines changed

3 files changed

+61
-7
lines changed

torch/_inductor/kernel/bmm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,16 @@ def may_require_contiguous(t, meta_t):
168168

169169
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
170170

171+
log.info(
172+
"Tuned aten.bmm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
173+
m,
174+
n,
175+
k,
176+
mat1.get_dtype(),
177+
mat2.get_dtype(),
178+
layout,
179+
)
180+
171181
# options to tune from
172182
choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
173183
if use_triton_template(layout):

torch/_inductor/kernel/mm.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@
6060
mm_template = TritonTemplate(
6161
name="mm",
6262
grid=mm_grid,
63-
source=r"""
63+
source=(
64+
r"""
6465
{{def_kernel("A", "B")}}
6566
M = {{size("A", 0)}}
6667
N = {{size("B", 1)}}
@@ -125,11 +126,11 @@
125126
# inductor generates a suffix
126127
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
127128
"""
128-
if torch.version.hip is None
129-
# FIXME: To get around rocm failures like https://github.com/pytorch/pytorch/actions/runs/13123783322/job/36617154943
130-
# The only difference between the two templates is M >= BLOCK_M and N >= BLOCK_N checking.
131-
# See more details in https://github.com/pytorch/pytorch/pull/146293
132-
else r"""
129+
if torch.version.hip is None
130+
# FIXME: To get around rocm failures like https://github.com/pytorch/pytorch/actions/runs/13123783322/job/36617154943
131+
# The only difference between the two templates is M >= BLOCK_M and N >= BLOCK_N checking.
132+
# See more details in https://github.com/pytorch/pytorch/pull/146293
133+
else r"""
133134
{{def_kernel("A", "B")}}
134135
M = {{size("A", 0)}}
135136
N = {{size("B", 1)}}
@@ -193,7 +194,8 @@
193194
194195
# inductor generates a suffix
195196
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
196-
""",
197+
"""
198+
),
197199
)
198200

199201
persistent_tma_mm_template = TritonTemplate(
@@ -357,6 +359,16 @@ def tuned_mm(mat1, mat2, *, layout=None):
357359
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
358360
name = "mm"
359361

362+
log.info(
363+
"Tuned aten.mm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
364+
m,
365+
n,
366+
k,
367+
mat1.get_dtype(),
368+
mat2.get_dtype(),
369+
layout,
370+
)
371+
360372
aten_layout = layout
361373
if not use_max_autotune():
362374
aten_layout = FlexibleLayout(
@@ -472,6 +484,17 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
472484
m, n, k, layout, mat1, mat2 = mm_args(
473485
mat1, mat2, layout=layout, out_dtype=torch.int32
474486
)
487+
488+
log.info(
489+
"Tuned aten._int_mm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
490+
m,
491+
n,
492+
k,
493+
mat1.get_dtype(),
494+
mat2.get_dtype(),
495+
layout,
496+
)
497+
475498
static_shape, is_nonzero = _is_static_problem(layout)
476499
use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k)
477500

@@ -516,6 +539,17 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
516539
ordered_kwargs_for_cpp_kernel = ("beta", "alpha")
517540
m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
518541
static_shape, is_nonzero = _is_static_problem(layout)
542+
543+
log.info(
544+
"Tuned aten.addmm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
545+
m,
546+
n,
547+
k,
548+
mat1.get_dtype(),
549+
mat2.get_dtype(),
550+
layout,
551+
)
552+
519553
if (not is_nonzero) or (not use_max_autotune()):
520554
# Use a FlexibleLayout if we are not autotuning.
521555
# This allows padding strides for the output.

torch/_inductor/kernel/mm_scaled.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,16 @@ def tuned_scaled_mm(
509509
mat_a, mat_b, layout=layout, out_dtype=out_dtype
510510
)
511511

512+
log.info(
513+
"Tuned aten._scaled_mm.default: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
514+
m,
515+
n,
516+
k,
517+
mat_a.get_dtype(),
518+
mat_b.get_dtype(),
519+
layout,
520+
)
521+
512522
check_supported_striding(mat_a, mat_b)
513523

514524
scale_a, scale_b = realize_inputs(scale_a, scale_b)

0 commit comments

Comments
 (0)
0