8000 Support tuning of _scaled_grouped_mm · pytorch/pytorch@7d498f3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7d498f3

Browse files
committed
Support tuning of _scaled_grouped_mm
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) ghstack-source-id: 5f7b9d2 Pull Request resolved: #150421
1 parent 414b9ae commit 7d498f3

File tree

7 files changed

+406
-1
lines changed

7 files changed

+406
-1
lines changed

torch/_dynamo/trace_rules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,6 +1575,7 @@
15751575
"torch._scaled_dot_product_flash_attention_for_cpu",
15761576
"torch._scaled_dot_product_cudnn_attention",
15771577
"torch._scaled_mm",
1578+
"torch._scaled_grouped_mm",
15781579
"torch._shape_as_tensor",
15791580
"torch._sobol_engine_draw",
15801581
"torch._sobol_engine_ff_",

torch/_inductor/choices.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,12 @@ def get_scaled_persistent_mm_configs(
114114
mm_heuristics = self.get_config_heuristics(device_type)
115115
return mm_heuristics.get_scaled_persistent_mm_configs()
116116

117+
def get_scaled_grouped_mm_configs(
118+
self, device_type: Optional[str] = "cuda"
119+
) -> partial[Generator[TritonConfig, None, None]]:
120+
mm_heuristics = self.get_config_heuristics(device_type)
121+
return mm_heuristics.get_scaled_grouped_mm_configs()
122+
117123
def get_mm_plus_mm_configs(
118124
self, device_type: Optional[str] = "cuda"
119125
) -> partial[Generator[TritonConfig, None, None]]:

torch/_inductor/graph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def mark_nodes_dislike_padding(
204204
aten.convolution,
205205
aten.convolution_backward,
206206
aten._scaled_mm,
207+
aten._scaled_grouped_mm,
207208
]
208209
)
209210
# what's a better way to collect the reduction ops?

torch/_inductor/kernel/mm_common.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mypy: allow-untyped-defs
22
import logging
3+
from collections.abc import Sequence
34
from typing import Any
45

56
import sympy
@@ -10,7 +11,7 @@
1011

1112
from .. import config as inductor_config
1213
from ..codegen.wrapper import PythonWrapperCodegen
13-
from ..ir import ChoiceCaller, Layout
14+
from ..ir import _IntLike, ChoiceCaller, Layout, TensorBox
1415
from ..utils import get_num_sms, TMA_DESCRIPTOR_SIZE, use_aten_gemm_kernels
1516

1617

@@ -253,3 +254,26 @@ def _is_static_problem(layout: Layout) -> tuple[bool, bool]:
253254
numel *= dim
254255
nonzero = numel > 0
255256
return static_shape, nonzero
257+
258+
259+
def check_supported_striding(mat_a: TensorBox, mat_b: TensorBox) -> None:
260+
def is_row_major(stride: Sequence[_IntLike]) -> bool:
261+
return stride[-1] == 1
262+
263+
def is_col_major(stride: Sequence[_IntLike]) -> bool:
264+
return stride[-2] == 1
265+
266+
def has_zero_dim(size: Sequence[_IntLike]) -> bool:
267+
return bool(size[0] == 0 or size[1] == 0)
268+
269+
# Check mat_a (self) stride requirements
270+
torch._check(
271+
is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()),
272+
lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}",
273+
)
274+
275+
# Check mat_b stride requirements
276+
torch._check(
277+
is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()),
278+
lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}",
279+
)
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
import logging
2+
from typing import Any, Optional
3+
4+
import torch
5+
from torch._dynamo.utils import counters
6+
from torch._inductor.virtualized import V
7+
from torch.utils._triton import has_triton_tma_device
8+
9+
from ..ir import ChoiceCaller, get_device_type, Layout, TensorBox
10+
from ..lowering import register_lowering
11+
from ..runtime.runtime_utils import next_power_of_2
12+
from ..select_algorithm import (
13+
autotune_select_algorithm,
14+
ExternKernelChoice,
15+
realize_inputs,
16+
TritonTemplate,
17+
)
18+
from ..utils import get_num_sms, get_tma_workspace_arg, use_aten_gemm_kernels
19+
from .mm_common import _is_static_problem, check_supported_striding, persistent_mm_grid
20+
21+
22+
log = logging.getLogger(__name__)
23+
aten = torch.ops.aten
24+
25+
26+
# Copied from fbgemm grouped_gemm.py
27+
triton_scaled_grouped_mm_source = r"""
28+
{{def_kernel("a_ptr", "b_ptr", "a_scale_ptr", "b_scale_ptr", "m_sizes")}}
29+
tidx = tl.program_id(0)
30+
31+
dtype = tl.float8e4nv
32+
TMA_SIZE: tl.constexpr = tl.constexpr(128)
33+
if USE_TMA_STORE:
34+
workspace_base = ws_ptr + tidx * 3 * TMA_SIZE
35+
c_desc_ptr = worspace_base + 2 * TMA_SIZE
36+
else:
37+
workspace_base = ws_ptr + tidx * 2 * TMA_SIZE
38+
c_desc_ptr = None
39+
40+
a_desc_ptr = workspace_base
41+
b_desc_ptr = workspace_base + TMA_SIZE
42+
43+
triton.language.extra.cuda.experimental_device_tensormap_create2d(
44+
desc_ptr=a_desc_ptr,
45+
global_address=a_ptr,
46+
load_size=[BLOCK_M, BLOCK_K],
47+
global_size=[M, K],
48+
element_ty=a_ptr.dtype.element_ty,
49+
)
50+
triton.language.extra.cuda.experimental_device_tensormap_create2d(
51+
desc_ptr=b_desc_ptr,
52+
global_address=b_ptr,
53+
load_size=[BLOCK_N, BLOCK_K],
54+
global_size=[N * G, K],
55+
element_ty=b_ptr.dtype.element_ty,
56+
)
57+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
58+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
59+
60+
M_end_offset = 0
61+
iterated_tiles = 0
62+
for g in tl.range(G):
63+
# Move across groups
64+
M_start_offset = M_end_offset
65+
M_end_offset = tl.load(m_sizes + g)
66+
m_size = M_end_offset - M_start_offset
67+
68+
if m_size > 0:
69+
N_start_offset = g.to(tl.int64) * N
70+
n_size = N
71+
num_m_tiles = tl.cdiv(m_size, BLOCK_M)
72+
num_n_tiles = tl.cdiv(n_size, BLOCK_N)
73+
num_tiles = num_m_tiles * num_n_tiles
74+
75+
if USE_TMA_STORE:
76+
# pyre-ignore
77+
tl.extra.cuda.experimental_device_tensormap_create2d(
78+
desc_ptr=c_desc_ptr,
79+
global_address=c_ptr + M_start_offset * N,
80+
load_size=[BLOCK_M, BLOCK_N],
81+
global_size=[m_size, n_size],
82+
element_ty=c_ptr.dtype.element_ty,
83+
)
84+
# pyre-ignore
85+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
86+
87+
# Move across tiles
88+
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
89+
gidx = tidx - iterated_tiles
90+
# Split M first and N second.
91+
tile_m_idx = gidx % num_m_tiles
92+
tile_n_idx = gidx // num_m_tiles
93+
94+
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
95+
tl.static_assert(K % BLOCK_K == 0)
96+
if USE_TMA_LOAD:
97+
m_offset = (M_start_offset + tile_m_idx * BLOCK_M).to(tl.int32)
98+
n_offset = (N_start_offset + tile_n_idx * BLOCK_N).to(tl.int32)
99+
for k_offset in range(0, K, BLOCK_K):
100+
a = tl._experimental_descriptor_load(
101+
a_desc_ptr,
102+
[m_offset, k_offset],
103+
[BLOCK_M, BLOCK_K],
104+
dtype,
105+
)
106+
b = tl._experimental_descriptor_load(
107+
b_desc_ptr,
108+
[n_offset, k_offset],
109+
[BLOCK_N, BLOCK_K],
110+
dtype,
111+
)
112+
if USE_FAST_ACCUM:
113+
accumulator = tl.dot(a, b.T, accumulator)
114+
else:
115+
accumulator += tl.dot(a, b.T)
116+
else:
117+
offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
118+
offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
119+
offs_k = tl.arange(0, BLOCK_K)
120+
a_ptrs = (
121+
a_desc_ptr
122+
+ (M_start_offset + offs_am[:, None]) * K
123+
+ offs_k[None, :]
124+
)
125+
b_ptrs = (
126+
b_desc_ptr
127+
+ (N_start_offset + offs_bn[:, None]) * K
128+
+ offs_k[None, :]
129+
)
130+
for k_offset in range(0, K, BLOCK_K):
131+
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
132+
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
133+
accumulator += tl.dot(a, b.T)
134+
a_ptrs += BLOCK_K
135+
b_ptrs += BLOCK_K
136+
137+
offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
138+
offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
139+
a_scale = tl.load(
140+
a_scale_ptr + M_start_offset + offs_am[:, None],
141+
mask=offs_am[:, None] < m_size,
142+
)
143+
b_scale = tl.load(
144+
b_scale_ptr + N_start_offset + offs_bn[None, :],
145+
mask=offs_bn[None, :] < n_size,
146+
)
147+
c = accumulator.to(tl.float32) * a_scale * b_scale
148+
149+
if USE_TMA_STORE:
150+
m_offset = (tile_m_idx * BLOCK_M).to(tl.int32)
151+
n_offset = (tile_n_idx * BLOCK_N).to(tl.int32)
152+
tl._experimental_descriptor_store(
153+
c_desc_ptr,
154+
c.to(c_ptr.dtype.element_ty),
155+
[m_offset, n_offset],
156+
)
157+
else:
158+
idx_m = (M_start_offset + offs_am[:, None])
159+
idx_n = offs_bn[None, :]
160+
mask = offs_am[:, None] < m_size and offs_bn[None, :] < n_size
161+
{{store_output(("idx_m", "idx_n"), "c", "mask", indent_width=20)}}
162+
tidx += NUM_SMS
163+
164+
iterated_tiles += num_tiles
165+
"""
166+
167+
triton_scaled_grouped_mm_template = TritonTemplate(
168+
name="scaled_grouped_mm",
169+
grid=persistent_mm_grid,
170+
source=triton_scaled_grouped_mm_source,
171+
)
172+
173+
174+
def grouped_mm_args(
175+
mat1,
176+
mat2,
177+
layout=None,
178+
out_dtype=None,
179+
):
180+
mat1, mat2 = realize_inputs(mat1, mat2)
181+
m, k1 = mat1.get_size()
182+
g, k2, n = mat2.get_size()
183+
k = V.graph.sizevars.guard_equals(k1, k2)
184+
if layout is None:
185+
from torch._inductor.ir import FixedLayout
186+
187+
if out_dtype is None:
188+
out_dtype = mat1.get_dtype()
189+
190+
layout = FixedLayout(
191+
mat1.get_device(),
192+
out_dtype,
193+
[m, n],
194+
)
195+
else:
196+
assert out_dtype is None, "out_dtype is ignored if layout is specified."
197+
198+
return (g, m, n, k, layout, mat1, mat2)
199+
200+
201+
aten__scaled_grouped_mm = ExternKernelChoice(
202+
torch._scaled_grouped_mm,
203+
"at::_scaled_grouped_mm",
204+
op_overload=aten._scaled_grouped_mm,
205+
has_out_variant=False,
206+
)
207+
208+
209+
@register_lowering(aten._scaled_grouped_mm.default, type_promotion_kind=None)
210+
def tuned_scaled_grouped_mm(
211+
mat_a: TensorBox,
212+
mat_b: TensorBox,
213+
scale_a: TensorBox,
214+
scale_b: TensorBox,
215+
offs: Optional[TensorBox] = None,
216+
bias: Optional[TensorBox] = None,
217+
scale_result: Optional[TensorBox] = None,
218+
out_dtype: Optional[torch.dtype] = None,
219+
use_fast_accum: bool = False,
220+
layout: Optional[Layout] = None,
221+
) -> TensorBox:
222+
g, m, n, k, layout, mat_a, mat_b = grouped_mm_args(
223+
mat_a, mat_b, layout=layout, out_dtype=out_dtype
224+
)
225+
226+
counters["aten_mm_info"][f"aten._scaled_grouped_mm.default_{g}_{m}_{n}_{k}"] += 1
227+
log.info(
228+
"Tuned aten._scaled_grouped_mm.default: g=%s m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
229+
g,
230+
m,
231+
n,
232+
k,
233+
mat_a.get_dtype(),
234+
mat_b.get_dtype(),
235+
layout,
236+
)
237+
238+
device_type = get_device_type(mat_a)
239+
check_supported_striding(mat_a, mat_b)
240+
241+
scale_a, scale_b = realize_inputs(scale_a, scale_b)
242+
243+
# workaround for Inductor not supporting optional tensor input arguments
244+
input_nodes: list[Any, ...] = [mat_a, mat_b, scale_a, scale_b]
245+
if offs is not None:
246+
input_nodes.append(realize_inputs(offs))
247+
if bias is not None:
248+
input_nodes.append(realize_inputs(bias))
249+
250+
aten_choice = aten__scaled_grouped_mm.bind(
251+
input_nodes,
252+
layout,
253+
out_dtype=out_dtype,
254+
use_fast_accum=use_fast_accum,
255+
)
256+
257+
choices: list[ChoiceCaller] = []
258+
if use_aten_gemm_kernels():
259+
choices.append(aten_choice)
260+
261+
_, is_nonzero = _is_static_problem(layout)
262+
263+
scaled_grouped_mm_configs = V.choices.get_scaled_grouped_mm_configs(device_type)
264+
265+
if is_nonzero and offs is not None and bias is None and has_triton_tma_device():
266+
for config in scaled_grouped_mm_configs(m, n, k):
267+
kwargs = {
268+
"G": g,
269+
"M": m,
270+
"M_BUCKET": next_power_of_2(m),
271+
"N": n,
272+
"K": k,
273+
"NUM_SMS": get_num_sms(),
274+
"USE_TMA_LOAD": True,
275+
"USE_TMA_STORE": False,
276+
"USE_FAST_ACCUM": use_fast_accum,
277+
"num_stages": config.num_stages,
278+
"num_warps": config.num_warps,
279+
**config.kwargs,
280+
}
281+
triton_scaled_grouped_mm_template.maybe_append_choice(
282+
choices,
283+
input_nodes=input_nodes,
284+
layout=layout,
285+
workspace_arg=get_tma_workspace_arg(
286+
num_tma_descriptors=2,
287+
device=mat_a.get_device(),
288+
),
289+
**kwargs,
290+
)
291+
292+
return autotune_select_algorithm("scaled_grouped_mm", choices, input_nodes, layout)

torch/_inductor/scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,7 @@ def should_prune(dep: Dep) -> bool:
948948
"extern_kernels.bmm": torch.ops.aten.bmm,
949949
"extern_kernels.addmm": torch.ops.aten.addmm,
950950
"extern_kernels._scaled_mm": torch.ops.aten._scaled_mm,
951+
"extern_kernels._scaled_grouped_mm": torch.ops.aten._scaled_grouped_mm,
951952
}
952953

953954

0 commit comments

Comments
 (0)
0