8000 Update auto-tuning support for _scaled_grouped_mm by alexsamardzic · Pull Request #150944 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Update auto-tuning support for _scaled_grouped_mm #150944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: gh/alexsamardzic/1/base
Choose a base branch
from

Conversation

alexsamardzic
Copy link
Collaborator
@alexsamardzic alexsamardzic commented Apr 9, 2025

Stack from ghstack (oldest at bottom):

  1. Enable strided inputs
  2. Implement "2d/2d", "3d/2d" and "3d/3d" combinatio 8000 ns of inputs
  3. Fix non-TMA load variant
  4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
  5. Fix cases when group size along K dimension is not multiple of block size along K
  6. Updated meta registration
  7. Update synthetic offsets creation

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Copy link
pytorch-bot bot commented Apr 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150944

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 74b8536 with merge base b03e4f5 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

alexsamardzic added a commit that referenced this pull request Apr 9, 2025
1. Enable strided inputs
2. Implement "3d/3d" and "3d/2d" combinations of inputs
3. Fix non-TMA load variant

ghstack-source-id: 31add9f
Pull Request resolved: #150944
@alexsamardzic alexsamardzic added the topic: not user facing topic category label Apr 9, 2025
@alexsamardzic
Copy link
Collaborator Author
alexsamardzic commented Apr 9, 2025
Validation script
from enum import Enum
from itertools import product

import torch


f_ref = torch._scaled_grouped_mm
f = torch.compile(
    f_ref,
    options={
        "max_autotune": True,
        "max_autotune_gemm_backends": "TRITON",
    },
)


class MMType(Enum):
    MM_2D_2D = 1
    MM_2D_3D = 2
    MM_3D_2D = 3
    MM_3D_3D = 4


def generate_data(
    mm_type, group_size, M, N, K, device, dtype_AB, dtype_scale, dtype_offset, strided
):
    if mm_type == MMType.MM_2D_2D:
        A = torch.randn(M, K * (group_size + strided), device=device).to(dtype_AB)[
            :, : K * group_size
        ]
        B = torch.randn(N, K * (group_size + strided), device=device).to(dtype_AB)[
            :, : K * group_size
        ]
        A_scale = torch.rand(group_size * M, device=device, dtype=dtype_scale)
        B_scale = torch.rand(group_size * N, device=device, dtype=dtype_scale)
        offs = torch.arange(K, group_size * K + 1, K, device=device, dtype=dtype_offset)

    if mm_type == MMType.MM_2D_3D:
        A = torch.randn(M * group_size, K * (1 + strided), device=device).to(dtype_AB)[
            :, :K
        ]
        B = torch.randn(
            group_size * (1 + strided), N, K * (1 + strided), device=device
        ).to(dtype_AB)[:: (1 + strided), :, :K]
        A_scale = torch.rand(group_size * M, device=device, dtype=dtype_scale)
        B_scale = torch.rand(
            group_size, N * (1 + strided), device=device, dtype=dtype_scale
        )[:, :N]
        offs = torch.arange(M, group_size * M + 1, M, device=device, dtype=dtype_offset)

    if mm_type == MMType.MM_3D_2D:
        A = torch.randn(
            group_size * (1 + strided), M, K * (1 + strided), device=device
        ).to(dtype_AB)[:: (1 + strided), :, :K]
        B = torch.randn(N * group_size, K * (1 + strided), device=device).to(dtype_AB)[
            :, :K
        ]
        A_scale = torch.rand(
            group_size, M * (1 + strided), device=device, dtype=dtype_scale
        )[:, :M]
        B_scale = torch.rand(group_size * N, device=device, dtype=dtype_scale)
        offs = torch.arange(N, group_size * N + 1, N, device=device, dtype=dtype_offset)

    if mm_type == MMType.MM_3D_3D:
        A = torch.randn(
            group_size * (1 + strided), M, K * (1 + strided), device=device
        ).to(dtype_AB)[:: (1 + strided), :, :K]
        B = torch.randn(
            group_size * (1 + strided), N, K * (1 + strided), device=device
        ).to(dtype_AB)[:: (1 + strided), :, :K]
        A_scale = torch.rand(group_size, M * (1 + strided), device=device).to(
            dtype_scale
        )[:, :M]
        B_scale = torch.rand(group_size, N * (1 + strided), device=device).to(
            dtype_scale
        )[:, :N]
        offs = None

    if offs is not None:
        if offs[0] >= 32:
            offs[0] -= 16
            offs[2] += 16
        elif offs[0] >= 64:
            offs[0] -= 16
            offs[1] += 16
            offs[2] -= 32

    return A, B, A_scale, B_scale, offs


def validate():
    def validate_helper(
        mm_type,
        group_size,
        M,
        N,
        K,
        device,
        dtype_AB,
        dtype_scale,
        dtype_offset,
        dtype_C,
        use_fast_accum,
        strided,
        atol,
        rtol,
    ):
        torch._dynamo.reset()

        A, B, A_scale, B_scale, offs = generate_data(
            mm_type,
            group_size,
            M,
            N,
            K,
            device,
            dtype_AB,
            dtype_scale,
            dtype_offset,
            strided,
        )

        C_ref = f_ref(
            A,
            B.transpose(-2, -1),
            A_scale,
            B_scale,
            offs,
            out_dtype=dtype_C,
            use_fast_accum=use_fast_accum,
        )
        C = f(
            A,
            B.transpose(-2, -1),
            A_scale,
            B_scale,
            offs,
            out_dtype=dtype_C,
            use_fast_accum=use_fast_accum,
        )
        assert torch.allclose(C, C_ref, atol=atol, rtol=rtol)

    device = "cuda"
    group_size = 4
    M_range = [2**i for i in range(4, 6)]
    N_range = [2**i for i in range(5, 8)]
    K_range = [2**i for i in range(6, 9)]
    dtype_AB = torch.float8_e4m3fn
    dtype_scale = torch.float32
    dtype_offset = torch.int32
    dtype_C = torch.bfloat16
    use_fast_accum_range = [False, True]
    strided_range = [False, True]
    atol = 1e-2
    rtol = 1e-2
    for mm_type, M, N, K, use_fast_accum, strided in product(
        MMType, M_range, N_range, K_range, use_fast_accum_range, strided_range
    ):
        validate_helper(
            mm_type,
            group_size,
            M,
            N,
            K,
            device,
            dtype_AB,
            dtype_scale,
            dtype_offset,
            dtype_C,
            use_fast_accum,
            strided,
            atol,
            rtol,
        )


validate()

(Note: to validate non-TMA load variant, change "USE_TMA_LOAD": True in mm_scaled_grouped.py to False.)


Todo: handle use_fast_accum case like CUTLASS does it...

[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Apr 9, 2025
1. Enable strided inputs
2. Implement "3d/3d" and "3d/2d" combinations of inputs
3. Fix non-TMA load variant

ghstack-source-id: 31add9f
Pull Request resolved: #150944
alexsamardzic added a commit that referenced this pull request Apr 9, 2025
1. Enable strided inputs
2. Implement "3d/3d" and "3d/2d" combinations of inputs
3. Fix non-TMA load variant

ghstack-source-id: e775b83
Pull Request resolved: #150944
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Apr 10, 2025
1. Enable strided inputs
2. Implement "3d/3d" and "3d/2d" combinations of inputs
3. Fix non-TMA load variant

ghstack-source-id: 4be59f3
Pull Request resolved: #150944
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Apr 18, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant

ghstack-source-id: f01ac93
Pull Request resolved: #150944
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Apr 19, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant

ghstack-source-id: 5847dae
Pull Request resolved: #150944
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Apr 19, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant

ghstack-source-id: 2b0248e
Pull Request resolved: #150944
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Apr 20, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant

ghstack-source-id: a1ef2b7
Pull Request resolved: #150944
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Apr 21, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor

ghstack-source-id: 7a64328
Pull Request resolved: #150944
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Apr 21, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K

ghstack-source-id: e6016b7
Pull Request resolved: #150944
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Apr 22, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K

ghstack-source-id: 63a6271
Pull Request resolved: #150944
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Apr 26, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K

ghstack-source-id: bbff45c
Pull Request resolved: #150944
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Apr 30, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K
6. Implemented meta registration

ghstack-source-id: 5bbfaae
Pull Request resolved: #150944
alexsamardzic added a commit that referenced this pull request Apr 30, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K
6. Implement meta registration

ghstack-source-id: 5bbfaae
Pull Request resolved: #150944
[ghstack-poisoned]
@alexsamardzic
Copy link
Collaborator Author

@bertmaher @ngimel This PR is ready for review. I'll update the test along the way, and then proceed onto grouped (non-scaled) MM auto-tuning support.

alexsamardzic added a commit that referenced this pull request May 1, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K
6. Implement meta registration

ghstack-source-id: 111429c
Pull Request resolved: #150944
[ghstack-poisoned]
@bertmaher
Copy link
Contributor

Hey Alex just wanted to apologize for not getting to this sooner -- been a very crazy week or two for me but this is near the top of my queue finally.

Btw you will probably want something like #152968 to fix a silly bug in autotuning that I introduced

@alexsamardzic
Copy link
Collaborator Author

Btw you will probably want something like #152968 to fix a silly bug in autotuning that I introduced

NP, as soon as this PR or yours gets merged, and I assume it's quickly, I'll rebase my PR, and make appropriate additions for the inputs layout combinations other than 2d/3d.

alexsamardzic added a commit that referenced this pull request May 11, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K
6. Update meta registration
7. Updated synthetic offsets creation

ghstack-source-id: ab5036b
Pull Request resolved: #150944
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request May 11, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K
6. Update meta registration
7. Updated synthetic offsets creation

ghstack-source-id: 8b476d2
Pull Request resolved: #150944
[ghstack-poisoned]
@ngimel
Copy link
Collaborator
ngimel commented May 12, 2025

@alexsamardzic I added meta kernel in my previous PR #153226 can you please rebase to use it?

@alexsamardzic
Copy link
Collaborator Author
alexsamardzic commented May 12, 2025 via email

Copy link
Collaborator
@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Do you have some benchmarking results to make sure that the changes didn't regress the kernel perf compared to just dynamic M kernel? I'm especially curious what perf you are getting compared to eager for 2d-2d case where you require masking all the inputs.

F438

@@ -1627,11 +1632,16 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile)
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
if check_zero_size:
offs[0] = offs[1]
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32)
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scale_a, scale_b should be constructed regardless of check_zero_size

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, it seems I messed up indentation while updating this file.

@@ -53,8 +47,8 @@ class Config:
num_warps=num_warps,
)
for block_size_m in [64, 128]
for block_size_n in [64, 128, 256]
for block_size_k in [64, 128, 256]
for block_size_n in [64, 128]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was block size 256 never picked?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted. It seemed rarely picked, but primary reason for deletion was to speedup testing, then the change slipped into an update.

m, k1 = m1_size
g, k2, n = m2_size
k = V.graph.sizevars.guard_equals(k1, k2)
if is_nonzero and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where are you checking that a and b are row- and column- major respectively? Do you rely on meta function to check that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but - do you think guards should be put here too to check on strides for row-major/column-major ordering?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine to not double check on those, but can you add a comment saying that you are relying on meta function checks so that it's easier to keep in sync?

Copy link
Contributor
@bertmaher bertmaher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Thank you for doing this. I have a few inline questions/comments

@@ -1643,7 +1653,7 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile)
ascalelist.append(scale_a[start:offs_cpu[i]])
outlist.append(out[start:offs_cpu[i]])
start = offs_cpu[i]
self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum)
self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this indent accidental?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, reverted - as mentioned above, apparently messed up some indentation when updating this file.

@@ -6379,6 +6379,165 @@ def ceil_div(a, b):
return torch.empty(self.size(0), mat2.size(1), dtype=_out_dtype, device=self.device)


@register_meta([aten._scaled_grouped_mm.default])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you avoid moving meta_scaled_grouped_mm around in this PR? It makes it hard to see what all changed. (If it's being moved for a good reason, can you explain what that is? It's hard for me to tell from the diff)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved back. (I've added it in my PR before it get added into the main, and I though it may make sense to put it right after meta registration for _scaled_mm.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put it next to grouped_mm which also makes sense, scaled_grouped_mm has 2 attributes and it's unclear which one is more important ;-)

@@ -167,46 +123,127 @@ def early_config_prune(configs, named_args):

# Copied from fbgemm grouped_gemm.py
triton_scaled_grouped_mm_source = r"""
{{def_kernel("a_ptr", "b_ptr", "a_scale_ptr", "b_scale_ptr", "m_sizes")}}
{% if A_IS_2D or B_IS_2D %}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was kind of contemplating whether it would be better to have a separate kernel for each 2d/3d case, since nested control flow in jinja templates is kind of hard to read. But it would come with the downsides of some amount of code duplication. Just curious to hear your opinion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me, it's just a matter of a trade-off... It seems to me the differences between the variations are not that big to warrant separate kernels, it would be too much code duplication. Also, non-scaled version could be easily added into the same code, the same way (which I'm going to do next). Furthermore, the Jinja if-else statements make the generated code readable (IMO), basically this generated code is like a separated version for given 2d/3d case. There are only two cases of nested if-else Jinja statements; I could use some kind of indentation to make it more clear. Overall, I'd prefer to keep it as is throughout the development, as for me it's easier to work with it this way; when we're happy with it, we could arrange it differently.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be cool if torchtitan (that has eager versions of these kernels) also adopted a single kernel approach, cc @lessw2020

alexsamardzic added a commit that referenced this pull request May 13, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K
6. Update meta registration
7. Updated synthetic offsets creation

ghstack-source-id: 894be5f
Pull Request resolved: #150944
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request May 13, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K
6. Update meta registration
7. Updated synthetic offsets creation

ghstack-source-id: 4856f96
Pull Request resolved: #150944
[ghstack-poisoned]
@alexsamardzic
Copy link
Collaborator Author

Looks good! Do you have some benchmarking results to make sure that the changes didn't regress the kernel perf compared to just dynamic M kernel? I'm especially curious what perf you are getting compared to eager for 2d-2d case where you require masking all the inputs.

I'm doing benchmarking over the last two days, indeed it seems going "fully dynamic" is not good for performance. So I'm probably going to revert some things back.

alexsamardzic added a commit that referenced this pull request May 15, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K
6. Update meta registration
7. Updated synthetic offsets creation

ghstack-source-id: 067deb6
Pull Request resolved: #150944
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request May 15, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K
6. Update meta registration
7. Updated synthetic offsets creation

ghstack-source-id: 55e56c4
Pull Request resolved: #150944
[ghstack-poisoned]
@ngimel
Copy link
Collaborator
ngimel commented May 15, 2025

fwiw this blog post https://pytorch.org/blog/metashuffling-accelerating-llama-4-moe-inference/ contains good set of benchmarking configs (and it's based on the same triton kernel, save for warp specialization which typically has only very small effect)

@alexsamardzic
Copy link
Collaborator Author
alexsamardzic commented May 16, 2025

Looks good! Do you have some benchmarking results to make sure that the changes didn't regress the kernel perf compared to just dynamic M kernel? I'm especially curious what perf you are getting compared to eager for 2d-2d case where you require masking all the inputs.

Here is a benchmarking script, and some benchmarking results (number of groups is 4 everywhere, "CUTLASS" here means eager CUTLASS-based _scaled_grouped_mm() kernel, "Triton" means compiled auto-tuned Trition kernel, speedup is vs. CUTLASS results, and all latencies are in ms):

Benchmarking script
from enum import Enum
import pandas as pd

from tqdm import tqdm

import torch
from triton.testing import do_bench


class MMType(Enum):
    MM_2D_2D = 1
    MM_2D_3D = 2
    MM_3D_2D = 3
    MM_3D_3D = 4

    def __str__(self):
        if self == MMType.MM_2D_2D:
            return "2d_2d"
        elif self == MMType.MM_2D_3D:
            return "2d_3d"
        elif self == MMType.MM_3D_2D:
            return "3d_2d"
        elif self == MMType.MM_3D_3D:
            return "3d_3d"
        else:
            return ""


device = "cuda"
dtype_AB = torch.float8_e4m3fn
dtype_scale = torch.float32
dtype_offset = torch.int32
dtype_C = torch.bfloat16
group_size = 4
use_fast_accum = True
strided = True

f_ref = torch._scaled_grouped_mm
f = torch.compile(
    f_ref,
    options={
        "max_autotune": True,
        "max_autotune_gemm_backends": "TRITON",
    },
)


def benchmark_microseconds(f, *args, **kwargs):
    return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3


def get_problem(mm_type, M, N, K):
    if mm_type == MMType.MM_2D_2D:
        A = torch.randn(M, K * (group_size + strided), device=device).to(dtype_AB)[
            :, : K * group_size
        ]
        B = torch.randn(N, K * (group_size + strided), device=device).to(dtype_AB)[
            :, : K * group_size
        ]
        A_scale = torch.rand(group_size * M, device=device, dtype=dtype_scale)
        B_scale = torch.rand(group_size * N, device=device, dtype=dtype_scale)
        offs = torch.arange(K, group_size * K + 1, K, device=device, dtype=dtype_offset)

    if mm_type == MMType.MM_2D_3D:
        A = torch.randn(M * group_size, K * (1 + strided), device=device).to(dtype_AB)[
            :, :K
        ]
        B = torch.randn(
            group_size * (1 + strided), N, K * (1 + strided), device=device
        ).to(dtype_AB)[:: (1 + strided), :, :K]
        A_scale = torch.rand(group_size * M, device=device, dtype=dtype_scale)
        B_scale = torch.rand(
            group_size, N * (1 + strided), device=device, dtype=dtype_scale
        )[:, :N]
        offs = torch.arange(M, group_size * M + 1, M, device=device, dtype=dtype_offset)

    if mm_type == MMType.MM_3D_2D:
        A = torch.randn(
            group_size * (1 + strided), M, K * (1 + strided), device=device
        ).to(dtype_AB)[:: (1 + strided), :, :K]
        B = torch.randn(N * group_size, K * (1 + strided), device=device).to(dtype_AB)[
            :, :K
        ]
        A_scale = torch.rand(
            group_size, M * (1 + strided), device=device, dtype=dtype_scale
        )[:, :M]
        B_scale = torch.rand(group_size * N, device=device, dtype=dtype_scale)
        offs = torch.arange(N, group_size * N + 1, N, device=device, dtype=dtype_offset)

    if mm_type == MMType.MM_3D_3D:
        A = torch.randn(
            group_size * (1 + strided), M, K * (1 + strided), device=device
        ).to(dtype_AB)[:: (1 + strided), :, :K]
        B = torch.randn(
            group_size * (1 + strided), N, K * (1 + strided), device=device
        ).to(dtype_AB)[:: (1 + strided), :, :K]
        A_scale = torch.rand(group_size, M * (1 + strided), device=device).to(
            dtype_scale
        )[:, :M]
        B_scale = torch.rand(group_size, N * (1 + strided), device=device).to(
            dtype_scale
        )[:, :N]
        offs = None

    if mm_type not in [MMType.MM_2D_3D, MMType.MM_3D_3D]:
        if group_size >= 2:
            offs[0] = offs[0] + (offs[1] - offs[0]) // 4
        if group_size >= 4:
            offs[2] = offs[2] + (offs[3] - offs[2]) // 2

    return A, B, A_scale, B_scale, offs


def benchmark(mm_type: MMType, m: int, k: int, n: int):
    torch._dynamo.reset()  # FIXME: remove this!

    A, B, A_scale, B_scale, offs = get_problem(mm_type, m, n, k)
    eager_time = benchmark_microseconds(
        f_ref,
        A,
        B.transpose(-2, -1),
        A_scale,
        B_scale,
        offs,
        out_dtype=dtype_C,
        use_fast_accum=use_fast_accum,
    )
    compiled_time = benchmark_microseconds(
        f,
        A,
        B.transpose(-2, -1),
        A_scale,
        B_scale,
        offs,
        out_dtype=dtype_C,
        use_fast_accum=use_fast_accum,
    )

    return {
        "m": m,
        "n": n,
        "k": k,
        "Eager (CUTLASS) latency (ms)": eager_time,
        "Compiled (Triton) latency (ms)": compiled_time,
        "Compiled speedup (d/s)": eager_time / compiled_time,
    }


if __name__ == "__main__":
    k_vals = (1024, 2048, 4096)
    n_vals = (1024, 2048, 4096)

    for mm_type in MMType:
        results = []
        i_range = range(8) if mm_type != MMType.MM_2D_3D else range(4, 10)
        for m in tqdm([1 << i for i in i_range]):
            for n, k in zip(n_vals, k_vals):
                results.append(benchmark(mm_type, m, k, n))
                df = pd.DataFrame(results)
                df.to_csv(
                    f"scaled_grouped_mm_{str(mm_type)}_time_results.csv", index=False
                )
                print(df.to_markdown(index=False))
Benchmarking results for 2D/2D case
m n k CUTLASS latency Triton latency (D) Triton latency (ND) Triton speedup (D) Triton speedup (ND)
1 1024 1024 16.86 16.86 16.7 1 1.01
1 2048 2048 24.4 27.26 25.47 0.89 0.96
1 4096 4096 51.33 65.09 61.92 0.79 0.83
2 1024 1024 16.82 32.58 16.58 0.52 1.01
2 2048 2048 23.84 32.8 25.25 0.73 0.94
2 4096 4096 50.77 74.91 61.66 0.68 0.82
4 1024 1024 16.78 30.59 17.02 0.55 0.99
4 2048 2048 23.95 32.13 25.41 0.75 0.94
4 4096 4096 50.94 75.06 61.6 0.68 0.83
8 1024 1024 16.75 30.98 16.83 0.54 1
8 2048 2048 23.97 32.32 25.15 0.74 0.95
8 4096 4096 50.88 74.98 61.44 0.68 0.83
16 1024 1024 16.78 31.55 16.8 0.53 1
16 2048 2048 24.14 33.25 25.44 0.73 0.95
16 4096 4096 51.15 75.04 61.39 0.68 0.83
32 1024 1024 16.91 30.98 16.9 0.55 1
32 2048 2048 24.29 40.1 28.16 0.61 0.86
32 4096 4096 51.2 100.29 68.37 0.51 0.75
64 1024 1024 16.96 32.35 18.34 0.52 0.92
64 2048 2048 24.35 55.84 30.21 0.44 0.81
64 4096 4096 51.06 183.23 71.71 0.28 0.71
128 1024 1024 16.96 36.38 19.46 0.47 0.87
128 2048 2048 24.8 99.9 41.54 0.25 0.6
128 4096 4096 51.44 349.41 100.45 0.15 0.51
Benchmarking results for 2D/3D case
m n k CUTLASS latency Triton latency (D) Triton latency (ND) Triton speedup (D) Triton speedup (ND)
16 1024 1024 16.48 39.97 11.97 0.41 1.38
16 2048 2048 25.02 40.06 17.18 0.62 1.46
16 4096 4096 51.01 44.74 39.1 1.14 1.3
32 1024 1024 16.63 38.11 11.97 0.44 1.39
32 2048 2048 25.14 39.78 17.34 0.63 1.45
32 4096 4096 50.51 44.32 39.62 1.14 1.28
64 1024 1024 16.67 38.82 12.26 0.43 1.36
64 2048 2048 25.33 39.74 17.6 0.64 1.44
64 4096 4096 50.82 44.77 39.71 1.14 1.28
128 1024 1024 16.77 38.75 12.1 0.43 1.39
128 2048 2048 25.81 39.33 18.08 0.66 1.43
128 4096 4096 51.34 60.86 41.02 0.84 1.25
256 1024 1024 20.4 39.74 12.58 0.51 1.62
256 2048 2048 38.4 43.68 19.81 0.88 1.94
256 4096 4096 83.41 99.49 45.63 0.84 1.83
512 1024 1024 20.66 39.39 14.27 0.52 1.45
512 2048 2048 39.31 59.38 24.67 0.66 1.59
512 4096 4096 86.14 176.9 61.89 0.49 1.39
Benchmarking results for 3D/2D case
m n k CUTLASS latency Triton latency (D) Triton latency (ND) Triton speedup (D) Triton speedup (ND)
1 1024 1024 16.69 39.26 12.16 0.43 1.37
1 2048 2048 24.9 38.69 17.06 0.64 1.46
1 4096 4096 51.42 43.68 41.79 1.18 1.23
2 1024 1024 16.62 46.85 12.06 0.35 1.38
2 2048 2048 24.93 45.38 16.86 0.55 1.48
2 4096 4096 50.54 51.94 42.14 0.97 1.2
4 1024 1024 16.59 45.76 12.26 0.36 1.35
4 2048 2048 24.98 45.63 17.15 0.55 1.46
4 4096 4096 50.91 49.38 45.41 1.03 1.12
8 1024 1024 16.61 44.96 12.16 0.37 1.37
8 2048 2048 25.09 45.79 16.99 0.55 1.48
8 4096 4096 50.74 50.5 43.36 1 1.17
16 1024 1024 16.61 44.32 11.97 0.37 1.39
16 2048 2048 25.02 46.98 17.12 0.53 1.46
16 4096 4096 50.46 51.17 40.9 0.99 1.23
32 1024 1024 16.66 45.41 12.16 0.37 1.37
32 2048 2048 25.26 46.08 17.09 0.55 1.48
32 4096 4096 51.86 52.22 40.96 0.99 1.27
64 1024 1024 16.7 45.92 12.29 0.36 1.36
64 2048 2048 25.47 45.66 17.44 0.56 1.46
64 4096 4096 51.18 50.82 41.38 1.01 1.24
128 1024 1024 16.67 45.44 12.1 0.37 1.38
128 2048 2048 25.84 46.37 21.41 0.56 1.21
128 4096 4096 51.02 60.48 51.26 0.84 1
Benchmarking results for 3D/3D case
m n k CUTLASS latency Triton latency (D) Triton latency (ND) Triton speedup (D) Triton speedup (ND)
1 1024 1024 16.14 51.68 12.03 0.31 1.34
1 2048 2048 24.35 52.03 17.06 0.47 1.43
1 4096 4096 49.76 55.26 38.59 0.9 1.29
2 1024 1024 16.16 57.38 12 0.28 1.35
2 2048 2048 24.4 56.8 17.06 0.43 1.43
2 4096 4096 49.66 61.15 38.56 0.81 1.29
4 1024 1024 16.1 55.94 11.81 0.29 1.36
4 2048 2048 24.54 58.14 16.86 0.42 1.46
4 4096 4096 50.59 62.43 38.5 0.81 1.31
8 1024 1024 16.14 56.03 11.97 0.29 1.35
8 2048 2048 24.62 56.1 17.34 0.44 1.42
8 4096 4096 50.42 61.2 38.98 0.82 1.29
16 1024 1024 16.14 55.94 12.1 0.29 1.33
16 2048 2048 24.56 56.7 17.15 0.43 1.43
16 4096 4096 49.9 61.12 38.78 0.82 1.29
32 1024 1024 16.18 54.88 11.65 0.29 1.39
32 2048 2048 24.64 57.18 17.15 0.43 1.44
32 4096 4096 50.51 59.79 38.88 0.84 1.3
64 1024 1024 16.22 55.97 12 0.29 1.35
64 2048 2048 24.9 55.39 17.54 0.45 1.42
64 4096 4096 50.32 61.15 39.07 0.82 1.29
128 1024 1024 16.19 55.17 11.71 0.29 1.38
128 2048 2048 25.3 56.93 17.34 0.44 1.46
128 4096 4096 50.91 61.18 40.83 0.83 1.25

Some comments:

  • For the benchmarking, I had to change block sizes along M-axis from [64, 128], to [16, 32, 64, 128] - otherwise, as small M values are used for benchmarking, the Triton kernel is very slow. This is not specific to this kernel, in general auto-tuning in Inductor need a better search method than going brute force over a limited number of configs.
  • The "D" means dynamic, i.e. by Inductor defaults, for each 2D/3D input combination auto-tuning would be run twice, once with static shapes, then with dynamic shapes, and then the later kernel would be used for the rest of input shapes. The "ND" means that auto-tuning would be forced (I did it by adding torch._dynamo.reset() at the beginning of the benchmark() function in the benchmarking script) for each row in the tables above. Obviously, "D" doesn't work well, and "ND" is basically forcing dynamic=False and thus not acceptable, so the right approach is somewhere in-between. However, so far I was not able to find how to make Inductor do for example dynamic M, and non-dynamic K and N - any help here appreciated.

Auxiliary script to produce the tables above from the .csv files with dynamic/non-dynamic results
import sys

import pandas as pd

assert len(sys.argv) == 3

df1 = pd.read_csv(sys.argv[1])
df1 = df1.drop(["Compiled speedup (d/s)"], axis=1)
df1 = df1.rename(
    columns={
        "Eager (CUTLASS) latency (ms)": "CUTLASS latency (D)",
        "Compiled (Triton) latency (ms)": "Triton latency (D)",
    }
)
df2 = pd.read_csv(sys.argv[2])
df2 = df2.drop(["Compiled speedup (d/s)"], axis=1)
df2 = df2.rename(
    columns={
        "Eager (CUTLASS) latency (ms)": "CUTLASS latency (ND)",
        "Compiled (Triton) latency (ms)": "Triton latency (ND)",
    }
)

df = pd.merge(df1, df2, on=["m", "n", "k"])
df.insert(
    loc=3,
    column="CUTLASS latency",
    value=(df["CUTLASS latency (D)"] + df["CUTLASS latency (ND)"]) / 2,
)
df = df.drop(["CUTLASS latency (D)", "CUTLASS latency (ND)"], axis=1)
df["Triton speedup (D)"] = df["CUTLASS latency"] / df["Triton latency (D)"]
df["Triton speedup (ND)"] = df["CUTLASS latency"] / df["Triton latency (ND)"]


float_columns = df.select_dtypes(include=['float64', 'float32']).columns
for col in float_columns:
    df[col] = df[col].apply(lambda x: f"{x:.2f}")
print(df.to_markdown(index=False))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0