8000 Update auto-tuning support for _scaled_grouped_mm by alexsamardzic · Pull Request #150944 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Update auto-tuning support for _scaled_grouped_mm #150944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
708190a
Update
alexsamardzic Apr 9, 2025
e09a6c1
Update
alexsamardzic Apr 9, 2025
294d988
Update
alexsamardzic Apr 10, 2025
35cdb06
Update
alexsamardzic Apr 18, 2025
6e314db
Update
alexsamardzic Apr 19, 2025
0a362c7
Update
alexsamardzic Apr 19, 2025
197fd49
Update
alexsamardzic Apr 20, 2025
5b958dd
Update
alexsamardzic Apr 21, 2025
0304006
Update
alexsamardzic Apr 21, 2025
6a649e1
Update
alexsamardzic Apr 22, 2025
9158e2e
Update
alexsamardzic Apr 23, 2025
cc83362
Update
alexsamardzic Apr 26, 2025
bc25f5a
Update
alexsamardzic Apr 30, 2025
7f2e4bc
Update
alexsamardzic May 1, 2025
ed56d19
Update
alexsamardzic May 11, 2025
8073549
Update
alexsamardzic May 11, 2025
19762fa
Update
alexsamardzic May 13, 2025
834e757
Update
alexsamardzic May 13, 2025
41494d8
Update
alexsamardzic May 15, 2025
74b8536
Update
alexsamardzic May 15, 2025
aeda7d9
Update
alexsamardzic May 17, 2025
e403515
Update
alexsamardzic May 18, 2025
75b1014
Update
alexsamardzic May 18, 2025
0b14000
Update
alexsamardzic May 18, 2025
756e6a1
Update
alexsamardzic May 19, 2025
c8aa69f
Update
alexsamardzic May 20, 2025
5bfd0fc
Update
alexsamardzic May 23, 2025
55e8e12
Update
alexsamardzic May 25, 2025
3dd163b
Update
alexsamardzic May 28, 2025
458e4ef
Update
alexsamardzic Jun 2, 2025
ee5f69f
Update
alexsamardzic Jun 4, 2025
f7559ad
Update
alexsamardzic Jun 6, 2025
bb93a62
Update
alexsamardzic Jun 10, 2025
d36e059
Update
alexsamardzic Jun 11, 2025
c9a909b
Update
alexsamardzic Jun 11, 2025
bc448fe
Update
alexsamardzic Jun 11, 2025
6bf753a
Update
alexsamardzic Jun 11, 2025
7b167c8
Update
alexsamardzic Jun 11, 2025
016b438
Update
alexsamardzic Jun 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,7 @@ namespace {
"D, arg ",
arg_idx);
TORCH_CHECK(
scale.is_contiguous(), "scale_a must be contiguous for arg ", arg_idx);
scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx);
TORCH_CHECK(
scale.size(0) == mat.size(dim) * scale_multiplier,
"scale must have the same length as mat for arg ",
Expand All @@ -1545,8 +1545,8 @@ namespace {
"D for arg ",
arg_idx);
TORCH_CHECK(
scale.stride(1),
"scale_a must be contiguous in the last dimension for arg ",
scale.stride(1) == 1,
"scale must be contiguous in the last dimension for arg ",
arg_idx);
TORCH_CHECK(
scale.size(0) == mat.size(0),
Expand Down Expand Up @@ -1610,6 +1610,7 @@ bool use_fast_accum) {


TORCH_CHECK(!bias.has_value(), "Bias not supported yet");
TORCH_CHECK(!scale_result.has_value(), "Scale result not supported yet");
TORCH_CHECK(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");

if (offs.has_value()) {
Expand Down
58 changes: 39 additions & 19 deletions test/test_matmul_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,7 +1616,7 @@ def scaled_grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist
for a, b, ascale, 8000 bscale, out in zip(alist, blist, ascalelist, bscalelist, outlist):
out_ref = torch._scaled_mm(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1),
out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum)
self.assertEqual(out, out_ref)
self.assertEqual(out, out_ref, atol=1e-1, rtol=1e-2)

@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater
Expand All @@ -1626,14 +1626,19 @@ def scaled_grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist
@parametrize("use_torch_compile", [False, True])
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, use_torch_compile):
device = "cuda"
m, n, k, n_groups = 16, 16, 16, 4 # all sizes have to be divisible by 16
m, n, k, n_groups = 16, 32, 64, 4 # all sizes have to be divisible by 16
a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
scale_a = torch.arange(m * n_groups, device=device, dtype=torch.float32) / 4
scale_b = torch.arange(n * n_groups, device=device, dtype=torch.float32) / 4
scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32)
scale_b = torch.rand(n * n_groups, device=device, dtype=torch.float32)
offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32)
f = torch._scaled_grouped_mm
f = torch.compile(f) if use_torch_compile else f
f = torch.compile(
f,
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
}) if use_torch_compile else f
out = f(a, b.t(), scale_a, scale_b, offs=offs,
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
offs_cpu = offs.cpu()
Expand All @@ -1657,7 +1662,7 @@ def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, use_torch_compile)
def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile):
device = "cuda"
s_int = int(strided)
m, n, k, n_groups = 16, 32, 16, 4
m, n, k, n_groups = 16, 32, 64, 4
a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
self.assertTrue(a.is_contiguous() is not strided)
Expand All @@ -1666,11 +1671,16 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile)
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
if check_zero_size:
offs[0] = offs[1]
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32)
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32)
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)

f = torch._scaled_grouped_mm
f = torch.compile(f, dynamic=False) if use_torch_compile else f
f = torch.compile(
f,
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
}) if use_torch_compile else f
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)

Expand All @@ -1682,7 +1692,7 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile)
ascalelist.append(scale_a[start:offs_cpu[i]])
outlist.append(out[start:offs_cpu[i]])
start = offs_cpu[i]
self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum)
self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this indent accidental?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, reverted - as mentioned above, apparently messed up some indentation when updating this file.



@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
Expand All @@ -1694,16 +1704,21 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile)
def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided, use_torch_compile):
device = "cuda"
s_int = int(strided)
m, n, k, n_groups = 16, 32, 16, 4
m, n, k, n_groups = 16, 32, 64, 4
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
self.assertTrue(a.is_contiguous() is not strided)
self.assertTrue(b.is_contiguous() is not strided)
scale_a = torch.ones(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)

f = torch._scaled_grouped_mm
f = torch.compile(f) if use_torch_compile else f
f = torch.compile(
f,
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
}) if use_torch_compile else f
out = f(a, b.transpose(-2, -1), scale_a, scale_b,
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)

Expand All @@ -1719,20 +1734,25 @@ def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided, use_torch_compile)
def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided, use_torch_compile):
device = "cuda"
s_int = int(strided)
m, n, k, n_groups = 16, 32, 16, 4
m, n, k, n_groups = 16, 32, 64, 4
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
self.assertTrue(a.is_contiguous() is not strided)
self.assertTrue(b.is_contiguous() is not strided)
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
scale_b = torch.arange(n_groups * n, device="cuda", dtype=torch.float32)
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32)
for check_zero_size in (True, False):
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
if check_zero_size:
offs[0] = offs[1]

f = torch._scaled_grouped_mm
f = torch.compile(f) if use_torch_compile else f
f = torch.compile(
f,
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
}) if use_torch_compile else f
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
offs_cpu = offs.cpu()
Expand All @@ -1743,7 +1763,7 @@ def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided, use_torch_compile)
bscalelist.append(scale_b[start:offs_cpu[i]])
outlist.append(out[:, start:offs_cpu[i]])
start = offs_cpu[i]
self.scaled_grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum)
self.scaled_grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum)


@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
Expand Down
1 change: 0 additions & 1 deletion torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def mark_nodes_dislike_padding(
aten.convolution,
aten.convolution_backward,
aten._scaled_mm,
aten._scaled_grouped_mm,
]
)
# what's a better way to collect the reduction ops?
Expand Down
3 changes: 2 additions & 1 deletion torch/_inductor/kernel/mm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def persistent_mm_grid(M: int, N: int, meta: dict[str, Any], *, cdiv, min):


@SymbolicGridFn
def persistent_grouped_mm_grid(m, n, meta):
def persistent_grouped_mm_grid(*args):
meta = args[-1]
return (meta["NUM_SMS"], 1, 1)


Expand Down
Loading
Loading
0