10000 Update auto-tuning support for _scaled_grouped_mm · pytorch/pytorch@9efaca7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9efaca7

Browse files
committed
Update auto-tuning support for _scaled_grouped_mm
1. Enable strided inputs 2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs 3. Fix non-TMA load variant 4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor 5. Fix cases when group size along K dimension is not multiple of block size along K 6. Update meta registration 7. Updated synthetic offsets creation ghstack-source-id: 4856f96 Pull Request resolved: #150944
1 parent 7fdd754 commit 9efaca7

File tree

8 files changed

+462
-277
lines changed

8 files changed

+462
-277
lines changed

aten/src/ATen/native/cuda/Blas.cpp

+4-3
Original file line number 8000 Diff line numberDiff line change
@@ -1541,7 +1541,7 @@ namespace {
15411541
"D, arg ",
15421542
arg_idx);
15431543
TORCH_CHECK(
1544-
scale.is_contiguous(), "scale_a must be contiguous for arg ", arg_idx);
1544+
scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx);
15451545
TORCH_CHECK(
15461546
scale.size(0) == mat.size(dim) * scale_multiplier,
15471547
"scale must have the same length as mat for arg ",
@@ -1554,8 +1554,8 @@ namespace {
15541554
"D for arg ",
15551555
arg_idx);
15561556
TORCH_CHECK(
1557-
scale.stride(1),
1558-
"scale_a must be contiguous in the last dimension for arg ",
1557+
scale.stride(1) == 1,
1558+
"scale must be contiguous in the last dimension for arg ",
15591559
arg_idx);
15601560
TORCH_CHECK(
15611561
scale.size(0) == mat.size(0),
@@ -1619,6 +1619,7 @@ bool use_fast_accum) {
16191619

16201620

16211621
TORCH_CHECK(!bias.has_value(), "Bias not supported yet");
1622+
TORCH_CHECK(!scale_result.has_value(), "Scale result not supported yet");
16221623
TORCH_CHECK(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
16231624

16241625
if (offs.has_value()) {

test/test_matmul_cuda.py

+30-10
Original file line numberDiff line numberDiff line change
@@ -1587,14 +1587,19 @@ def scaled_grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist
15871587
@parametrize("use_torch_compile", [False, True])
15881588
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, use_torch_compile):
15891589
device = "cuda"
1590-
m, n, k, n_groups = 16, 16, 16, 4 # all sizes have to be divisible by 16
1590+
m, n, k, n_groups = 16, 32, 64, 4 # all sizes have to be divisible by 16
15911591
a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
15921592
b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
15931593
scale_a = torch.arange(m * n_groups, device=device, dtype=torch.float32) / 4
15941594
scale_b = torch.arange(n * n_groups, device=device, dtype=torch.float32) / 4
15951595
offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32)
15961596
f = torch._scaled_grouped_mm
1597-
f = torch.compile(f) if use_torch_compile else f
1597+
f = torch.compile(
1598+
f,
1599+
options={
1600+
"max_autotune": True,
1601+
"max_autotune_gemm_backends": "TRITON",
1602+
}) if use_torch_compile else f
15981603
out = f(a, b.t(), scale_a, scale_b, offs=offs,
15991604
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
16001605
offs_cpu = offs.cpu()
@@ -1618,7 +1623,7 @@ def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, use_torch_compile)
16181623
def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile):
16191624
device = "cuda"
16201625
s_int = int(strided)
1621-
m, n, k, n_groups = 16, 32, 16, 4
1626+
m, n, k, n_groups = 16, 32, 64, 4
16221627
a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
16231628
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
16241629
self.assertTrue(a.is_contiguous() is not strided)
@@ -1631,7 +1636,12 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile)
16311636
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
16321637

16331638
f = torch._scaled_grouped_mm
1634-
f = torch.compile(f, dynamic=False) if use_torch_compile else f
1639+
f = torch.compile(
1640+
f,
1641+
options={
1642+
"max_autotune": True,
1643+
"max_autotune_gemm_backends": "TRITON",
1644+
}) if use_torch_compile else f
16351645
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
16361646
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
16371647

@@ -1643,7 +1653,7 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile)
16431653
ascalelist.append(scale_a[start:offs_cpu[i]])
16441654
outlist.append(out[start:offs_cpu[i]])
16451655
start = offs_cpu[i]
1646-
self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum)
1656+
self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum)
16471657

16481658

16491659
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@@ -1655,7 +1665,7 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile)
16551665
def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided, use_torch_compile):
16561666
device = "cuda"
16571667
s_int = int(strided)
1658-
m, n, k, n_groups = 16, 32, 16, 4
1668+
m, n, k, n_groups = 16, 32, 64, 4
16591669
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
16601670
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
16611671
self.assertTrue(a.is_contiguous() is not strided)
@@ -1664,7 +1674,12 @@ def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided, use_torch_compile)
16641674
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
16651675

16661676
f = torch._scaled_grouped_mm
1667-
f = torch.compile(f) if use_torch_compile else f
1677+
f = torch.compile(
1678+
f,
1679+
options={
1680+
"max_autotune": True,
1681+
"max_autotune_gemm_backends": "TRITON",
1682+
}) if use_torch_compile else f
16681683
out = f(a, b.transpose(-2, -1), scale_a, scale_b,
16691684
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
16701685

@@ -1680,7 +1695,7 @@ def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided, use_torch_compile)
16801695
def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided, use_torch_compile):
16811696
device = "cuda"
16821697
s_int = int(strided)
1683-
m, n, k, n_groups = 16, 32, 16, 4
1698+
m, n, k, n_groups = 16, 128, 64, 4
16841699
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
16851700
b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
16861701
self.assertTrue(a.is_contiguous() is not strided)
@@ -1693,7 +1708,12 @@ def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided, use_torch_compile)
16931708
offs[0] = offs[1]
16941709

16951710
f = torch._scaled_grouped_mm
1696-
f = torch.compile(f) if use_torch_compile else f
1711+
f = torch.compile(
1712+
f,
1713+
options={
1714+
"max_autotune": True,
1715+
"max_autotune_gemm_backends": "TRITON",
1716+
}) if use_torch_compile else f
16971717
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
16981718
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
16991719
offs_cpu = offs.cpu()
@@ -1704,7 +1724,7 @@ def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided, use_torch_compile)
17041724
bscalelist.append(scale_b[start:offs_cpu[i]])
17051725
outlist.append(out[:, start:offs_cpu[i]])
17061726
start = offs_cpu[i]
1707-
self.scaled_grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum)
1727+
self.scaled_grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum)
17081728

17091729

17101730
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)

torch/_inductor/graph.py

-1
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,6 @@ def mark_nodes_dislike_padding(
211211
aten.convolution,
212212
aten.convolution_backward,
213213
aten._scaled_mm,
214-
aten._scaled_grouped_mm,
215214
]
216215
)
217216
# what's a better way to collect the reduction ops?

torch/_inductor/kernel/mm_common.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def persistent_mm_grid(M: int, N: int, meta: dict[str, Any], *, cdiv, min):
5757

5858

5959
@SymbolicGridFn
60-
def persistent_grouped_mm_grid(m, n, meta):
60+
def persistent_grouped_mm_grid(*args):
61+
meta = args[-1]
6162
return (meta["NUM_SMS"], 1, 1)
6263

6364

0 commit comments

Comments
 (0)
0