8000 [mm_logs] enhance the printing for overview info (#148716) · pytorch/pytorch@3f069e7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3f069e7

Browse files
YUNQIUGUOpytorchmergebot
authored andcommitted
[mm_logs] enhance the printing for overview info (#148716)
Summary: previously the dynamo counters does not print the counts information automatically. explicitly added a log msg to print after lowering for overview info for inductor aten mms it will look like: the name is in `{a 8000 ten_op_name}_{m}_{n}_{k}` ``` torch/_inductor/compile_fx.py:832] [0/0] Overview info of inductor aten mms: (aten.addmm_16_6_16: 1), (name: count), xxx ``` {F1975874802} Test Plan: ``` TORCH_LOGS="+inductor" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_addmm_cuda ``` Differential Revision: D70739912 Pull Request resolved: #148716 Approved by: https://github.com/henrylhtsang
1 parent 5f392ae commit 3f069e7

File tree

4 files changed

+14
-6
lines changed

4 files changed

+14
-6
lines changed

torch/_inductor/compile_fx.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,14 @@ def _compile_fx_inner(
841841

842842
log.debug("FX codegen and compilation took %.3fs", time.time() - start)
843843

844+
# This message is for printing overview information of inductor mm counts, shapes,etc after lowering
845+
log.info(
846+
"Overview info of inductor aten mms: %s",
847+
", ".join(
848+
f"({key}: {value})" for key, value in counters["aten_mm_info"].items()
849+
),
850+
)
851+
844852
# Clear Compiled Triton Kernels per inductor compile, as the future objects
845853
# may not be valid for use after they are run/autotuned
846854
torch._inductor.async_compile.CompiledTritonKernels.cache_clear()

torch/_inductor/kernel/bmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def may_require_contiguous(t, meta_t):
170170
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
171171

172172
# below is for getting an overview logging info of inductor mms
173-
counters["inductor"][f"aten.bmm_{m}_{n}_{k}"] += 1
173+
counters["aten_mm_info"][f"aten.bmm_{m}_{n}_{k}"] += 1
174174
log.info(
175175
"Tuned aten.bmm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
176176
m,
@@ -220,7 +220,7 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
220220
m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
221221

222222
# below is for getting an overview logging info of inductor mms
223-
counters["inductor"][f"aten.baddbmm_{m}_{n}_{k}"] += 1
223+
counters["aten_mm_info"][f"aten.baddbmm_{m}_{n}_{k}"] += 1
224224
log.info(
225225
"Tuned aten.baddbmm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, inp=%s, output_layout=%s",
226226
m,

torch/_inductor/kernel/mm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
360360
name = "mm"
361361

362362
# below is for getting an overview logging info of inductor mms
363-
counters["inductor"][f"aten.mm_{m}_{n}_{k}"] += 1
363+
counters["aten_mm_info"][f"aten.mm_{m}_{n}_{k}"] += 1
364364
log.info(
365365
"Tuned aten.mm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
366366
m,
@@ -482,7 +482,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
482482
)
483483

484484
# below is for getting an overview logging info of inductor mms
485-
counters["inductor"][f"aten._int_mm_{m}_{n}_{k}"] += 1
485+
counters["aten_mm_info"][f"aten._int_mm_{m}_{n}_{k}"] += 1
486486
log.info(
487487
"Tuned aten._int_mm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
488488
m,
@@ -528,7 +528,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
528528
static_shape, is_nonzero = _is_static_problem(layout)
529529

530530
# below is for getting an overview logging info of inductor mms
531-
counters["inductor"][f"aten.addmm_{m}_{n}_{k}"] += 1
531+
counters["aten_mm_info"][f"aten.addmm_{m}_{n}_{k}"] += 1
532532
log.info(
533533
"Tuned aten.addmm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
534534
m,

torch/_inductor/kernel/mm_scaled.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ def tuned_scaled_mm(
509509
mat_a, mat_b, layout=layout, out_dtype=out_dtype
510510
)
511511
# below is for getting an overview logging info of inductor mms
512-
counters["inductor"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1
512+
counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1
513513
log.info(
514514
"Tuned aten._scaled_mm.default: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
515515
m,

0 commit comments

Comments
 (0)
0