8000 [inductor] Generate synthetic offsets appropriately for autotuning _s… · pytorch/pytorch@e820b05 · GitHub
[go: up one dir, main page]

Skip to content

Commit e820b05

Browse files
bertmaherpytorchmergebot
authored andcommitted
[inductor] Generate synthetic offsets appropriately for autotuning _scaled_grouped_mm (#152968)
Summary: The autotuner is using zero-filled tensors to autotune _scaled_grouped_mm and that's not appropriate for the offsets tensor, since it essentially corresponds to "no input" and thus yields invalid perf results. We can't really use the actual input tensors, since we might be compiling this op in the context of an entire graph. So instead, I decided to create a synthetic offsets tensor assuming that each group is (roughly) the same size. I don't have data but I'd guess this approach is OK for MoE since we're generally hoping to load-balance the experts; I'm not sure how well it applies to other scenarios that might be more heavy-tailed. Test Plan: ``` pytest test_matmul_cuda.py -k test_scaled_grouped_gemm_ ``` Pull Request resolved: #152968 Approved by: https://github.com/ngimel
1 parent 590965f commit e820b05

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

torch/_inductor/kernel/mm_scaled_grouped.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,16 @@ def can_use_triton_kernel(
380380
)
381381

382382

383+
def create_offsets(x, m1_size, m2_size, offs_size):
384+
assert len(m1_size) == 2 and len(m2_size) == 3, (
385+
"Autotuning _scaled_grouped_mm is only implemented for 2d-3d tensors"
386+
)
387+
m = V.graph.sizevars.size_hint(m1_size[0])
388+
noffs = V.graph.sizevars.size_hint(offs_size[0])
389+
step = m / noffs
390+
return torch.linspace(step, m, noffs, dtype=x.get_dtype(), device=x.get_device())
391+
392+
383393
@register_lowering(aten._scaled_grouped_mm.default, type_promotion_kind=None)
384394
def tuned_scaled_grouped_mm(
385395
mat_a: TensorBox,
@@ -461,4 +471,9 @@ def tuned_scaled_grouped_mm(
461471
**config.kwargs,
462472
)
463473

464-
return autotune_select_algorithm("scaled_grouped_mm", choices, input_nodes, layout)
474+
input_gen_fns = {
475+
4: lambda x: create_offsets(x, m1_size, m2_size, offs.get_size()),
476+
}
477+
return autotune_select_algorithm(
478+
"scaled_grouped_mm", choices, input_nodes, layout, input_gen_fns=input_gen_fns
479+
)

0 commit comments

Comments
 (0)
0