10000 Support tuning of _scaled_grouped_mm by bertmaher · Pull Request #150421 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Support tuning of _scaled_grouped_mm #150421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from

Conversation

bertmaher
Copy link
Contributor
@bertmaher bertmaher commented Apr 1, 2025

This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

[ghstack-poisoned]
Copy link
pytorch-bot bot commented Apr 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150421

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit 7765869 with merge base fe96167 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

bertmaher added a commit that referenced this pull request Apr 1, 2025
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

ghstack-source-id: 5f7b9d2
Pull Request resolved: #150421
)

device_type = get_device_type(mat_a)
check_supported_striding(mat_a, mat_b)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should you also check that mat_a is 2d and mat_b is 3d here? aten op supports different 2d-3d combinations, but triton kernel only one

)
c = accumulator.to(tl.float32) * a_scale * b_scale
if USE_TMA_STORE:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this USE_TMA_STORE option (i know we are not using it) because it can't do masking, and we absolutely need m masking (that regular store is doing below) because we might have computed completely bogus results with m_size that's not a multiple of required tma size.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. This is actually a little surprising to me, because I thought the way the TMA store descriptor was being set up accounts for the m_size... but I'm also not sure how that interacts with the required alignment/size.

Since inductor currently can't use this anyways though, I'm thinking I'll just delete the USE_TMA_STORE branch, since there's no reason to leave around dead code

This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
bertmaher added a commit that referenced this pull request Apr 2, 2025
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

ghstack-source-id: b9d849d
Pull Request resolved: #150421
@bertmaher bertmaher added the topic: not user facing topic category label Apr 3, 2025
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
bertmaher added a commit that referenced this pull request Apr 3, 2025
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

ghstack-source-id: df92ba4
Pull Request resolved: #150421
print(a)
print("alist:", alist)
asplit = [x.squeeze(1) for x in a.view(m, n_groups, k).split(1, dim=1)]
print("asplit:", asplit)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stray print?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops, need to be more careful with my debug code 😬

@bertmaher
Copy link
Contributor Author

This isn't quite ready yet, I'm still seeing some buggy output with smaller problem sizes (e.g. G,M,N,K=4,16,64,64). I'll ping the thread when it's fixed

…ed_mm"

This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
bertmaher added a commit that referenced this pull request Apr 4, 2025
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

ghstack-source-id: 1dcb89b
Pull Request resolved: #150421
…rouped_mm"

This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
bertmaher added a commit that referenced this pull request Apr 4, 2025
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

ghstack-source-id: a4006fd
Pull Request resolved: #150421
@alexsamardzic
Copy link
Collaborator

@bertmaher: Unaware of your PR, I was going to work on the same thing - do you need/want some help? 🙂

I started with attached proof-of-concept script that would validate/benchmark/profile a Triton kernel (based on Triton tutorial instead of FBGEMM) vs. _scaled_grouped_mm and _scaled_mm in a loop. My code is for 3D/3D case, so I can't immediately plug in your implementation to compare, but my experience so far is that Triton version of grouped GEMM is quite fast (and single kernel should easily handle all 2D/3D combinations), but that the issue regarding the performance may be in the code preparing arrays of pointers as inputs to the kernel; so the attached script contains a miniscule Triton kernel to do that (for 3D/3D case), and it seemed to me it helps.

grouped-gemm-triton-vs-pytorch.py
from itertools import product
from typing import Optional

import torch

import triton
import triton.language as tl

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_cuda():
    return triton.runtime.driver.active.get_current_target().backend == "cuda"


def supports_tma():
    return is_cuda() and torch.cuda.get_device_capability()[0] >= 9


def num_sms():
    if is_cuda():
        return torch.cuda.get_device_properties("cuda").multi_processor_count
    return 148


tma_configs = [
    triton.Config(
        {"BLOCK_SIZE_M": BM, "BLOCK_SIZE_N": BN, "BLOCK_SIZE_K": BK},
        num_stages=s,
        num_warps=w,
    )
    for BM in [128]
    for BN in [128, 256]
    for BK in [64, 128]
    for s in ([3, 4])
    for w in [4, 8]
]


@triton.autotune(
    tma_configs,
    key=["hash_value"],
)
@triton.jit
def grouped_matmul_tma_kernel(
    hash_value,
    # device tensor of matrices pointers
    group_a_ptrs,
    group_b_ptrs,
    group_a_scale_ptrs,
    group_b_scale_ptrs,
    group_c_ptrs,
    # device tensor of gemm sizes. its shape is [group_size, 3]
    # dim 0 is group_size, dim 1 is the values of <M, N, K> of each gemm
    group_gemm_sizes,
    # device tensor of leading dimension sizes. its shape is [group_size, 3]
    # dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemm
    g_lds,
    # number of gemms
    group_size,
    # number of virtual SM
    NUM_SM: tl.constexpr,
    # tile sizes
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    # is the output FP8 or FP16
    FP8: tl.constexpr,
):
    dtype = tl.float8e4nv if FP8 else tl.bfloat16
    dtype_scale = tl.float32
    dtype_out = tl.bfloat16

    tile_idx = tl.program_id(0)
    last_problem_end = 0
    for g in range(group_size):
        # get the gemm size of the current problem
        gm = tl.load(group_gemm_sizes + g * 3)
        gn = tl.load(group_gemm_sizes + g * 3 + 1)
        gk = tl.load(group_gemm_sizes + g * 3 + 2)
        num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)
        num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)
        num_tiles = num_m_tiles * num_n_tiles
        if tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles:
            # pick up a tile from the current gemm problem
            lda = tl.load(g_lds + g * 3)
            ldb = tl.load(g_lds + g * 3 + 1)
            ldc = tl.load(g_lds + g * 3 + 2)

            a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(dtype))
            b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(dtype))
            a_scale_ptr = tl.load(group_a_scale_ptrs + g).to(
                tl.pointer_type(dtype_scale)
            )
            b_scale_ptr = tl.load(group_b_scale_ptrs + g).to(
                tl.pointer_type(dtype_scale)
            )
            c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(dtype_out))

            a_desc = tl._experimental_make_tensor_descriptor(
                a_ptr,
                shape=[gm, gk],
                strides=[lda, 1],
                block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
            )
            b_desc = tl._experimental_make_tensor_descriptor(
                b_ptr,
                shape=[gn, gk],
                strides=[ldb, 1],
                block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
            )
            c_desc = tl._experimental_make_tensor_descriptor(
                c_ptr,
                shape=[gm, gn],
                strides=[ldc, 1],
                block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
            )

            # iterate through the tiles in the current gemm problem
            while (
                tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles
            ):
                k = gk
                # figure out tile coordinates
                tile_idx_in_gemm = tile_idx - last_problem_end
                tile_m_idx = tile_idx_in_gemm // num_n_tiles
                tile_n_idx = tile_idx_in_gemm % num_n_tiles

                # do regular gemm here
                offs_am = tile_m_idx * BLOCK_SIZE_M
                offs_bn = tile_n_idx * BLOCK_SIZE_N

                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
                for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
                    a = a_desc.load([offs_am, kk * BLOCK_SIZE_K])
                    b = b_desc.load([offs_bn, kk * BLOCK_SIZE_K])
                    accumulator += tl.dot(a, b.T)

                # apply scales
                offsets = offs_am + tl.arange(0, BLOCK_SIZE_M)
                a_scale_mask = offsets < gm
                a_scale = tl.load(a_scale_ptr + offsets, a_scale_mask).expand_dims(-1)
                accumulator *= a_scale
                offsets = offs_bn + tl.arange(0, BLOCK_SIZE_N)
                b_scale_mask = offsets < gn
                b_scale = tl.load(b_scale_ptr + offsets, b_scale_mask).expand_dims(0)
                accumulator *= b_scale

                offs_cm = tile_m_idx * BLOCK_SIZE_M
                offs_cn = tile_n_idx * BLOCK_SIZE_N

                c = accumulator.to(dtype_out)
                c_desc.store([offs_cm, offs_cn], c)

                # go to the next tile by advancing NUM_SM
                tile_idx += NUM_SM

        # get ready to go to the next gemm problem
        last_problem_end = last_problem_end + num_tiles


@triton.jit
def grouped_matmul_tma_prepare_data_kernel(
    group_size,
    a_data_ptr,
    b_data_ptr,
    a_scale_data_ptr,
    b_scale_data_ptr,
    c_data_ptr,
    M,
    N,
    K,
    a_stride_group,
    b_stride_group,
    a_scale_stride_group,
    b_scale_stride_group,
    c_stride_group,
    lda,
    ldb,
    ldc,
    d_a_ptrs,
    d_b_ptrs,
    d_a_scale_ptrs,
    d_b_scale_ptrs,
    d_c_ptrs,
    d_g_sizes,
    d_g_lds,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    block_start = pid * BLOCK_SIZE

    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < group_size

    a_ptrs = a_data_ptr + offsets * a_stride_group
    tl.store(d_a_ptrs + offsets, a_ptrs, mask)

    b_ptrs = b_data_ptr + offsets * b_stride_group
    tl.store(d_b_ptrs + offsets, b_ptrs, mask)

    a_scale_ptrs = a_scale_data_ptr + offsets * a_scale_stride_group
    tl.store(d_a_scale_ptrs + offsets, a_scale_ptrs, mask)

    b_scale_ptrs = b_scale_data_ptr + offsets * b_scale_stride_group
    tl.store(d_b_scale_ptrs + offsets, b_scale_ptrs, mask)

    c_ptrs = c_data_ptr + offsets * c_stride_group
    tl.store(d_c_ptrs + offsets, c_ptrs, mask)

    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    for i in range(3):
        rem = offsets % 3
        sel0 = rem == 0
        sel1 = rem == 1
        sel2 = rem == 2
        g_sizes = M * sel0 + N * sel1 + K * sel2
        g_lds = lda * sel0 + ldb * sel1 + ldc * sel2

        mask = offsets < 3 * group_size
        tl.store(d_g_sizes + offsets, g_sizes.to(tl.int32), mask)
        tl.store(d_g_lds + offsets, g_lds.to(tl.int32), mask)

        offsets += BLOCK_SIZE


def group_gemm_tma_fn(A, B, A_scale, B_scale, dtype_C):
    # B comes in shape (group_size, K, N), but last two dims are in
    # column major order.

    group_size, M, K = A.shape
    _, _, N = B.shape

    C = torch.empty((group_size, M, N), device=DEVICE, dtype=dtype_C)

    # Apparently, an array of pointers has to be aligned to 128 bits
    # (due to CUDA < 12.4 bug).
    group_size_aligned = (group_size + 1) // 2 * 2
    # torch.int64 -> 8 bytes, torch.int32 -> 4 bytes
    nbytes = 5 * group_size_aligned * 8 + group_size * 3 * 4 + group_size * 3 * 4
    buffer = torch.empty(nbytes, device=DEVICE, dtype=torch.int8)
    off = 0
    incr = group_size_aligned * 8
    d_a_ptrs = buffer[off : off + incr].view(torch.int64)
    off += incr
    d_b_ptrs = buffer[off : off + incr].view(torch.int64)
    off += incr
    d_a_scale_ptrs = buffer[off : off + incr].view(torch.int64)
    off += incr
    d_b_scale_ptrs = buffer[off : off + incr].view(torch.int64)
    off += incr
    d_c_ptrs = buffer[off : off + incr].view(torch.int64)
    off += incr
    incr = group_size * 3 * 4
    d_g_sizes = buffer[off : off + incr].view(torch.int32)
    off += incr
    d_g_lds = buffer[off : off + incr].view(torch.int32)

    grid = lambda meta: (triton.cdiv(group_size, meta["BLOCK_SIZE"]),)
    grouped_matmul_tma_prepare_data_kernel[grid](
        group_size,
        A,
        B,
        A_scale,
        B_scale,
        C,
        M,
        N,
        K,
        A.stride(0),
        B.stride(0),
        A_scale.stride(0),
        B_scale.stride(0),
        C.stride(0),
        A.stride(-2),
        B.stride(-1),
        C.stride(-2),
        d_a_ptrs,
        d_b_ptrs,
        d_a_scale_ptrs,
        d_b_scale_ptrs,
        d_c_ptrs,
        d_g_sizes,
        d_g_lds,
        BLOCK_SIZE=1024,
    )

    # we use a fixed number of CTA, and it's auto-tunable

    # TMA descriptors require a global memory allocation
    def alloc_fn(size: int, alignment: int, stream: Optional[int]):
        return torch.empty(size, device=A.device, dtype=torch.int8)

    triton.set_allocator(alloc_fn)

    # Simple hashing for auto-tuning.
    assert group_size < 2**10 and M < 2**18 and N < 2**18 and K < 2**18
    hash_value = group_size << 54 | M << 36 | N << 18 | K

    grid = lambda META: (META["NUM_SM"],)
    grouped_matmul_tma_kernel[grid](
        hash_value,
        d_a_ptrs,
        d_b_ptrs,
        d_a_scale_ptrs,
        d_b_scale_ptrs,
        d_c_ptrs,
        d_g_sizes,
        d_g_lds,
        group_size,
        FP8=torch.float8_e4m3fn == A.dtype,
        NUM_SM=num_sms(),
    )

    return C


def generate_data(group_size, M, N, K, device, dtype_AB, dtype_scale, strided):
    A = torch.randn(group_size * (1 + strided), M, K * (1 + strided), device=device).to(
        dtype_AB
    )[:: (1 + strided), :, :K]
    B = (
        torch.randn(group_size * (1 + strided), N, K * (1 + strided), device=device)
        .to(dtype_AB)[:: (1 + strided), :, :K]
        .transpose(-2, -1)
    )

    A_scale = torch.rand((group_size * (1 + strided), M), device=device).to(
        dtype_scale
    )[:: (1 + strided), :]
    B_scale = torch.rand((group_size * (1 + strided), N), device=device).to(
        dtype_scale
    )[:: (1 + strided), :]

    return A, B, A_scale, B_scale


def validate_group_gemm_tma_fn():
    def validate_helper(
        group_size, M, N, K, device, dtype_AB, dtype_scale, dtype_C, strided, atol, rtol
    ):
        A, B, A_scale, B_scale = generate_data(
            group_size, M, N, K, device, dtype_AB, dtype_scale, strided
        )

        C_ref_1 = torch.stack(
            [
                torch._scaled_mm(
                    A[i, :, :],
                    B[i, :, :],
                    A_scale[i, :, None],
                    B_scale[i, None, :],
                    out_dtype=dtype_C,
                    use_fast_accum=True,
                )
                for i in range(group_size)
            ]
        )

        C_ref_2 = torch._scaled_grouped_mm(
            A, B, A_scale, B_scale, out_dtype=dtype_C, use_fast_accum=True
        )
        assert torch.allclose(C_ref_2, C_ref_1, atol=atol, rtol=rtol)

        if supports_tma():
            C = group_gemm_tma_fn(A, B, A_scale, B_scale, dtype_C)
            assert torch.allclose(C, C_ref_1, atol=atol, rtol=rtol)

    group_size = 4
    device = DEVICE
    dtype_AB = torch.float8_e4m3fn
    dtype_scale = torch.float32
    atol = 1e-2
    rtol = 1e-2
    MNK_range = [2**i for i in range(4, 8)]
    strided_range = [False, True]
    for M, N, K, strided in product(MNK_range, MNK_range, MNK_range, strided_range):
        dtype_C = torch.bfloat16
        validate_helper(
            group_size,
            M,
            N,
            K,
            device,
            dtype_AB,
            dtype_scale,
            dtype_C,
            strided,
            atol,
            rtol,
        )


def pytorch_loop_fn(A, B, A_scale, B_scale, dtype_C):
    return torch.stack(
        [
            torch._scaled_mm(
                A[i, :, :],
                B[i, :, :],
                A_scale[i, :, None],
                B_scale[i, None, :],
                out_dtype=dtype_C,
                use_fast_accum=True,
            )
            for i in range(A.shape[0])
        ]
    )


def pytorch_grouped_fn(A, B, A_scale, B_scale, dtype_C):
    torch._scaled_grouped_mm(
        A, B, A_scale, B_scale, out_dtype=dtype_C, use_fast_accum=True
    )


def triton_grouped_fn(A, B, A_scale, B_scale, type_C):
    group_gemm_tma_fn(A, B, A_scale, B_scale, type_C)


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["M", "NK"],
        x_vals=list(
            product([2**i for i in range(6, 13)], [2**i for i in range(6, 13)])
        ),
        line_arg="provider",
        line_vals=["pytorch-loop", "pytorch-grouped"]
        + (["triton-grouped"] if supports_tma() else []),
        line_names=["PyTorch loop", "PyTorch grouped"]
        + (["Triton grouped"] if supports_tma() else []),
        styles=[("green", "-"), ("blue", "-")]
        + ([("red", "-")] if supports_tma() else []),
        ylabel="runtime(ms)",
        plot_name="scaled grouped GEMM performance",
        args={},
    )
)
def benchmark_batches(M, NK, provider):
    group_size = 4
    device = DEVICE
    dtype_AB = torch.float8_e4m3fn
    dtype_scale = torch.float32
    strided = False
    A, B, A_scale, B_scale = generate_data
67ED
(
        group_size, M, NK, NK, device, dtype_AB, dtype_scale, strided
    )

    dtype_C = torch.bfloat16

    quantiles = [0.5, 0.2, 0.8]
    if provider == "pytorch-loop":
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: pytorch_loop_fn(A, B, A_scale, B_scale, dtype_C),
            quantiles=quantiles,
        )
    if provider == "pytorch-grouped":
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: pytorch_grouped_fn(A, B, A_scale, B_scale, dtype_C),
            quantiles=quantiles,
        )
    if provider == "triton-grouped":
        # Call once in order to compile the kernels, ...
        triton_grouped_fn(A, B, A_scale, B_scale, dtype_C)

        # ...and then do the benchmark.
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: triton_grouped_fn(A, B, A_scale, B_scale, dtype_C),
            quantiles=quantiles,
        )
    return ms, max_ms, min_ms


def profile_group_gemm_tma_fn():
    dtype_AB = torch.float8_e4m3fn
    dtype_scale = torch.float32
    A, B, A_scale, B_scale = generate_data(
        4, 1024, 1024, 1024, DEVICE, dtype_AB, dtype_scale, True
    )
    dtype_C = torch.bfloat16

    torch.profiler._utils._init_for_cuda_graphs()
    C = triton_grouped_fn(A, B, A_scale, B_scale, dtype_C)
    prof = torch.profiler.profile()
    with prof:
        # C = torch._scaled_grouped_mm(
        #    A, B, A_scale, B_scale, out_dtype=dtype_C, use_fast_accum=True
        # )
        C = triton_grouped_fn(A, B, A_scale, B_scale, dtype_C)
    prof.export_chrome_trace("foo.json")


###validate_group_gemm_tma_fn()
benchmark_batches.run(show_plots=True, print_data=True)
###profile_group_gemm_tma_fn()

Copy link
Contributor
@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v cool!

Comment on lines +44 to +46
]

_AMD_CONFIGS = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to include way more configs in AMD than NV?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was my mistake as I cut down the set of configs while debugging. Fixed now!

)

# 1. make sure we have enough smem
max_shared_memory = driver.active.utils.get_device_properties(device)[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we factor this out and cache it ? that has been the existing pattern although I dont know how long these calls actually takeb

]
if torch.version.hip:
required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would you mind commenting why there is a difference ?

Copy link
Contributor Author
@bertmaher bertmaher Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to check with some more AMD-knowledgeable folks, and the answer was quite interesting. Apparently:

  • AMD often doesn't even software pipeline gemms, and instead uses additional wavefronts to hide latency
  • This is partly because the AMD design has a lot less shared memory (64 KB versus 228 on H100) and it looks like more registers (although the register size doesn't seem super well documented)
  • So often even when pipelining they try to put A in registers and use shared mem for B

offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
a_scale = tl.load(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have any sense of how important tma store is ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, but it's apparently not used by FBGEMM on Nvidia (USE_TMA_STORE is set to False on all the paths I can find) so that seems like decent evidence that it's not a deal breaker.

Still would be worth re-enabling some day, just because, but maybe not this moment

m_size = M_end_offset - M_start_offset
if m_size > 0:
N_start_offset = g.to(tl.int64) * N
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to gate tl.int64 on input sizes ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good idea but maybe want to save it for an optimization down the line; I don't think it significantly slows down this kernel as-is

This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
bertmaher added a commit that referenced this pull request Apr 7, 2025
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

ghstack-source-id: 6711b11
Pull Request resolved: #150421
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
bertmaher added a commit that referenced this pull request Apr 7, 2025
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

ghstack-source-id: 158b7bf
Pull Request resolved: #150421
@bertmaher
Copy link
Contributor Author

@alexsamardzic Awesome, I'd love some help on this work! It's encouraging to see that you're getting great results with the tutorial grouped gemm too -- I went with fbgemm because it's been tested in prod here and we know it has good perf characteristics, but it seems like either could be a good starting point.

I think this PR is pretty close to landing 🤞 , but if you're interested, a diff to expand coverage to 3d/3d would be a great contribution!

def early_config_prune(configs, named_args):
from triton.runtime import driver

device = torch.cuda.current_device()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My I suggest we use the general device interface here? eg.

from torch._inductor.utils import get_gpu_type
from torch._dynamo.device_interface import get_interface_for_device
device_interface = get_interface_for_device(get_gpu_type())
device = device_interface.current_device()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up reusing an existing function to get shared memory size, so I didn't need this code after all 😁

@malfet
5
Copy link
Contributor
malfet commented Apr 10, 2025

@pytorchbot revert -m "Looks like it broke lint, see https://hud.pytorch.org/hud/pytorch/pytorch/a0ab243c3a5dfe12b392e4074d69360fd013f842/1?per_page=50&name_filter=lint&mergeEphemeralLF=true" -c nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@bertmaher your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Apr 10, 2025
@bertmaher
Copy link
Contributor Author

Awh man, this is annoying, it raced with the addition of a docstring linter 😠 😠 😠

bertmaher added a commit that referenced this pull request Apr 11, 2025
bertmaher added a commit that referenced this pull request Apr 11, 2025
This reverts commit 6a65f2c.

ghstack-source-id: e5b0868
Pull Request resolved: #151128
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
bertmaher added a commit that referenced this pull request Apr 11, 2025
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

ghstack-source-id: e5b0868
Pull Request resolved: #150421
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
bertmaher added a commit that referenced this pull request Apr 11, 2025
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

ghstack-source-id: 06a27e2
Pull Request resolved: #150421
@bertmaher
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

timocafe pushed a commit to timocafe/pytorch that referenced this pull request Apr 16, 2025
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

Pull Request resolved: pytorch#150421
Approved by: https://github.com/ngimel
timocafe pushed a commit to timocafe/pytorch that referenced this pull request Apr 16, 2025
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

Pull Request resolved: pytorch#150421
Approved by: https://github.com/ngimel
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

Pull Request resolved: pytorch#150421
Approved by: https://github.com/ngimel
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

Pull Request resolved: pytorch#150421
Approved by: https://github.com/ngimel
Copy link
Contributor
@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry - I should have caught this on first run, but would you mind adding a test for the form in test_max_autotune where we run with backends set to TRITON, max-autotune set, and check output code? see

with config.patch(
{"max_autotune": True, "max_autotune_gemm_backends": "TRITON"}
):
@torch.compile()
def foo(mod, x):
return mod(x)
with torch.no_grad():
out, code = run_and_get_code(foo, conv1x1, input_tensor)

Since the compile here is not using max-autotune I'm not sure this is being run.

@alexsamardzic
Copy link
Collaborator

Sorry - I should have caught this on first run, but would you mind adding a test for the form in test_max_autotune where we run with backends set to TRITON, max-autotune set, and check output code?

Updated in the meantime in #150944, like this - would that do?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants
0