|
60 | 60 | mm_template = TritonTemplate(
|
61 | 61 | name="mm",
|
62 | 62 | grid=mm_grid,
|
63 |
| - source=r""" |
| 63 | + source=( |
| 64 | + r""" |
64 | 65 | {{def_kernel("A", "B")}}
|
65 | 66 | M = {{size("A", 0)}}
|
66 | 67 | N = {{size("B", 1)}}
|
|
125 | 126 | # inductor generates a suffix
|
126 | 127 | {{store_output(("idx_m", "idx_n"), "acc", "mask")}}
|
127 | 128 | """
|
128 |
| - if torch.version.hip is None |
129 |
| - # FIXME: To get around rocm failures like https://github.com/pytorch/pytorch/actions/runs/13123783322/job/36617154943 |
130 |
| - # The only difference between the two templates is M >= BLOCK_M and N >= BLOCK_N checking. |
131 |
| - # See more details in https://github.com/pytorch/pytorch/pull/146293 |
132 |
| - else r""" |
| 129 | + if torch.version.hip is None |
| 130 | + # FIXME: To get around rocm failures like https://github.com/pytorch/pytorch/actions/runs/13123783322/job/36617154943 |
| 131 | + # The only difference between the two templates is M >= BLOCK_M and N >= BLOCK_N checking. |
| 132 | + # See more details in https://github.com/pytorch/pytorch/pull/146293 |
| 133 | + else r""" |
133 | 134 | {{def_kernel("A", "B")}}
|
134 | 135 | M = {{size("A", 0)}}
|
135 | 136 | N = {{size("B", 1)}}
|
|
193 | 194 |
|
194 | 195 | # inductor generates a suffix
|
195 | 196 | {{store_output(("idx_m", "idx_n"), "acc", "mask")}}
|
196 |
| -""", |
| 197 | +""" |
| 198 | + ), |
197 | 199 | )
|
198 | 200 |
|
199 | 201 | persistent_tma_mm_template = TritonTemplate(
|
@@ -357,6 +359,16 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
357 | 359 | m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
358 | 360 | name = "mm"
|
359 | 361 |
|
| 362 | + log.info( |
| 363 | + "Tuned aten.mm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", |
| 364 | + m, |
| 365 | + n, |
| 366 | + k, |
| 367 | + mat1.get_dtype(), |
| 368 | + mat2.get_dtype(), |
| 369 | + layout, |
| 370 | + ) |
| 371 | + |
360 | 372 | aten_layout = layout
|
361 | 373 | if not use_max_autotune():
|
362 | 374 | aten_layout = FlexibleLayout(
|
@@ -472,6 +484,17 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
|
472 | 484 | m, n, k, layout, mat1, mat2 = mm_args(
|
473 | 485 | mat1, mat2, layout=layout, out_dtype=torch.int32
|
474 | 486 | )
|
| 487 | + |
| 488 | + log.info( |
| 489 | + "Tuned aten._int_mm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", |
| 490 | + m, |
| 491 | + n, |
| 492 | + k, |
| 493 | + mat1.get_dtype(), |
| 494 | + mat2.get_dtype(), |
| 495 | + layout, |
| 496 | + ) |
| 497 | + |
475 | 498 | static_shape, is_nonzero = _is_static_problem(layout)
|
476 | 499 | use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k)
|
477 | 500 |
|
@@ -516,6 +539,17 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
516 | 539 | ordered_kwargs_for_cpp_kernel = ("beta", "alpha")
|
517 | 540 | m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
|
518 | 541 | static_shape, is_nonzero = _is_static_problem(layout)
|
| 542 | + |
| 543 | + log.info( |
| 544 | + "Tuned aten.addmm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", |
| 545 | + m, |
| 546 | + n, |
| 547 | + k, |
| 548 | + mat1.get_dtype(), |
| 549 | + mat2.get_dtype(), |
| 550 | + layout, |
| 551 | + ) |
| 552 | + |
519 | 553 | if (not is_nonzero) or (not use_max_autotune()):
|
520 | 554 | # Use a FlexibleLayout if we are not autotuning.
|
521 | 555 | # This allows padding strides for the output.
|
|
0 commit comments