8000 Support tuning of _scaled_grouped_mm · pytorch/pytorch@70be5ca · GitHub
[go: up one dir, main page]

Skip to content

Commit 70be5ca

Browse files
committed
Support tuning of _scaled_grouped_mm
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) ghstack-source-id: 06a27e2 Pull Request resolved: #150421
1 parent fe96167 commit 70be5ca

File tree

6 files changed

+521
-13
lines changed

6 files changed

+521
-13
lines changed

test/test_matmul_cuda.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,16 +1440,19 @@ def scaled_grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist
14401440
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
14411441
@parametrize("fast_accum", [False, True])
14421442
@parametrize("strided", [False, True])
1443-
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided):
1443+
@parametrize("use_torch_compile", [False, True])
1444+
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, use_torch_compile):
14441445
device = "cuda"
14451446
m, n, k, n_groups = 16, 16, 16, 4 # all sizes have to be divisible by 16
14461447
a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
14471448
b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
14481449
scale_a = torch.arange(m * n_groups, device=device, dtype=torch.float32) / 4
14491450
scale_b = torch.arange(n * n_groups, device=device, dtype=torch.float32) / 4
14501451
offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32)
1451-
out = torch._scaled_grouped_mm(a, b.t(), scale_a, scale_b, offs=offs,
1452-
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
1452+
f = torch._scaled_grouped_mm
1453+
f = torch.compile(f) if use_torch_compile else f
1454+
out = f(a, b.t(), scale_a, scale_b, offs=offs,
1455+
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
14531456
offs_cpu = offs.cpu()
14541457
alist, blist, ascalelist, bscalelist = [], [], [], []
14551458
start = 0
@@ -1466,7 +1469,8 @@ def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided):
14661469
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
14671470
@parametrize("fast_accum", [False, True])
14681471
@parametrize("strided", [False, True])
1469-
def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided):
1472+
@parametrize("use_torch_compile", [False, True])
1473+
def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile):
14701474
device = "cuda"
14711475
s_int = int(strided)
14721476
m, n, k, n_groups = 16, 32, 16, 4
@@ -1478,8 +1482,10 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided):
14781482
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32)
14791483
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
14801484

1481-
out = torch._scaled_grouped_mm(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
1482-
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
1485+
f = torch._scaled_grouped_mm
1486+
f = torch.compile(f) if use_torch_compile else f
1487+
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
1488+
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
14831489

14841490
offs_cpu = offs.cpu()
14851491
alist, ascalelist, outlist = [], [], []
@@ -1496,7 +1502,8 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided):
14961502
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
14971503
@parametrize("fast_accum", [False, True])
14981504
@parametrize("strided", [False, True])
1499-
def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided):
1505+
@parametrize("use_torch_compile", [False, True])
1506+
def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided, use_torch_compile):
15001507
device = "cuda"
15011508
s_int = int(strided)
15021509
m, n, k, n_groups = 16, 32, 16, 4
@@ -1507,8 +1514,10 @@ def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided):
15071514
scale_a = torch.ones(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
15081515
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
15091516

1510-
out = torch._scaled_grouped_mm(a, b.transpose(-2, -1), scale_a, scale_b,
1511-
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
1517+
f = torch._scaled_grouped_mm
1518+
f = torch.compile(f) if use_torch_compile else f
1519+
out = f(a, b.transpose(-2, -1), scale_a, scale_b,
1520+
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
15121521

15131522
self.scaled_grouped_mm_helper(a, b, scale_a, scale_b, out, fast_accum)
15141523

@@ -1517,7 +1526,8 @@ def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided):
15171526
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
15181527
@parametrize("fast_accum", [False, True])
15191528
@parametrize("strided", [False, True])
1520-
def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided):
1529+
@parametrize("use_torch_compile", [False, True])
1530+
def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided, use_torch_compile):
15211531
device = "cuda"
15221532
s_int = int(strided)
15231533
m, n, k, n_groups = 16, 32, 16, 4
@@ -1529,8 +1539,10 @@ def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided):
15291539
scale_b = torch.arange(n_groups * n, device="cuda", dtype=torch.float32)
15301540
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
15311541

1532-
out = torch._scaled_grouped_mm(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
1533-
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
1542+
f = torch._scaled_grouped_mm
1543+
f = torch.compile(f) if use_torch_compile else f
1544+
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
1545+
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
15341546
offs_cpu = offs.cpu()
15351547
blist, bscalelist, outlist = [], [], []
15361548
start = 0

torch/_dynamo/trace_rules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,6 +1577,7 @@
15771577
"torch._scaled_dot_product_flash_attention_for_cpu",
15781578
"torch._scaled_dot_product_cudnn_attention",
15791579
"torch._scaled_mm",
1580+
"torch._scaled_grouped_mm",
15801581
"torch._shape_as_tensor",
15811582
"torch._sobol_engine_draw",
15821583
"torch._sobol_engine_ff_",

torch/_inductor/graph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def mark_nodes_dislike_padding(
204204
aten.convolution,
205205
aten.convolution_backward,
206206
aten._scaled_mm,
207+
aten._scaled_grouped_mm,
207208
]
208209
)
209210
# what's a better way to collect the reduction ops?

torch/_inductor/kernel/mm_common.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mypy: allow-untyped-defs
22
import logging
3+
from collections.abc import Sequence
34
from typing import Any
45

56
import sympy
@@ -10,7 +11,7 @@
1011

1112
from .. import config as inductor_config
1213
from ..codegen.wrapper import PythonWrapperCodegen
13-
from ..ir import ChoiceCaller, Layout
14+
from ..ir import _IntLike, ChoiceCaller, Layout, TensorBox
1415
from ..utils import get_num_sms, TMA_DESCRIPTOR_SIZE, use_aten_gemm_kernels
1516

1617

@@ -54,6 +55,11 @@ def persistent_mm_grid(M: int, N: int, meta: dict[str, Any], *, cdiv, min):
5455
)
5556

5657

58+
@SymbolicGridFn
59+
def persistent_grouped_mm_grid(m, n, meta):
60+
return (meta["NUM_SMS"], 1, 1)
61+
62+
5763
def acc_type(dtype):
5864
if dtype in (torch.float16, torch.bfloat16):
5965
return "tl.float32"
@@ -259,3 +265,26 @@ def _is_static_problem(layout: Layout) -> tuple[bool, bool]:
259265
numel *= dim
260266
nonzero = numel > 0
261267
return static_shape, nonzero
268+
269+
270+
def check_supported_striding(mat_a: TensorBox, mat_b: TensorBox) -> None:
271+
def is_row_major(stride: Sequence[_IntLike]) -> bool:
272+
return stride[-1] == 1
273+
274+
def is_col_major(stride: Sequence[_IntLike]) -> bool:
275+
return stride[-2] == 1
276+
277+
def has_zero_dim(size: Sequence[_IntLike]) -> bool:
278+
return bool(size[0] == 0 or size[1] == 0)
279+
280+
# Check mat_a (self) stride requirements
281+
torch._check(
282+
is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()),
283+
lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}",
284+
)
285+
286+
# Check mat_b stride requirements
287+
torch._check(
288+
is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()),
289+
lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}",
290+
)

0 commit comments

Comments
 (0)
0