-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Support tuning of _scaled_grouped_mm #150421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150421
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit 7765869 with merge base fe96167 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) ghstack-source-id: 5f7b9d2 Pull Request resolved: #150421
) | ||
|
||
device_type = get_device_type(mat_a) | ||
check_supported_striding(mat_a, mat_b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should you also check that mat_a is 2d and mat_b is 3d here? aten op supports different 2d-3d combinations, but triton kernel only one
) | ||
c = accumulator.to(tl.float32) * a_scale * b_scale | ||
if USE_TMA_STORE: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like this USE_TMA_STORE option (i know we are not using it) because it can't do masking, and we absolutely need m masking (that regular store is doing below) because we might have computed completely bogus results with m_size that's not a multiple of required tma size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. This is actually a little surprising to me, because I thought the way the TMA store descriptor was being set up accounts for the m_size... but I'm also not sure how that interacts with the required alignment/size.
Since inductor currently can't use this anyways though, I'm thinking I'll just delete the USE_TMA_STORE
branch, since there's no reason to leave around dead code
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) ghstack-source-id: b9d849d Pull Request resolved: #150421
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) ghstack-source-id: df92ba4 Pull Request resolved: #150421
test/test_matmul_cuda.py
Outdated
print(a) | ||
print("alist:", alist) | ||
asplit = [x.squeeze(1) for x in a.view(m, n_groups, k).split(1, dim=1)] | ||
print("asplit:", asplit) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stray print?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whoops, need to be more careful with my debug code 😬
This isn't quite ready yet, I'm still seeing some buggy output with smaller problem sizes (e.g. G,M,N,K=4,16,64,64). I'll ping the thread when it's fixed |
…ed_mm" This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) ghstack-source-id: 1dcb89b Pull Request resolved: #150421
…rouped_mm" This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) ghstack-source-id: a4006fd Pull Request resolved: #150421
@bertmaher: Unaware of your PR, I was going to work on the same thing - do you need/want some help? 🙂 I started with attached proof-of-concept script that would validate/benchmark/profile a Triton kernel (based on Triton tutorial instead of FBGEMM) vs. _scaled_grouped_mm and _scaled_mm in a loop. My code is for 3D/3D case, so I can't immediately plug in your implementation to compare, but my experience so far is that Triton version of grouped GEMM is quite fast (and single kernel should easily handle all 2D/3D combinations), but that the issue regarding the performance may be in the code preparing arrays of pointers as inputs to the kernel; so the attached script contains a miniscule Triton kernel to do that (for 3D/3D case), and it seemed to me it helps. grouped-gemm-triton-vs-pytorch.pyfrom itertools import product
from typing import Optional
import torch
import triton
import triton.language as tl
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"
def supports_tma():
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
def num_sms():
if is_cuda():
return torch.cuda.get_device_properties("cuda").multi_processor_count
return 148
tma_configs = [
triton.Config(
{"BLOCK_SIZE_M": BM, "BLOCK_SIZE_N": BN, "BLOCK_SIZE_K": BK},
num_stages=s,
num_warps=w,
)
for BM in [128]
for BN in [128, 256]
for BK in [64, 128]
for s in ([3, 4])
for w in [4, 8]
]
@triton.autotune(
tma_configs,
key=["hash_value"],
)
@triton.jit
def grouped_matmul_tma_kernel(
hash_value,
# device tensor of matrices pointers
group_a_ptrs,
group_b_ptrs,
group_a_scale_ptrs,
group_b_scale_ptrs,
group_c_ptrs,
# device tensor of gemm sizes. its shape is [group_size, 3]
# dim 0 is group_size, dim 1 is the values of <M, N, K> of each gemm
group_gemm_sizes,
# device tensor of leading dimension sizes. its shape is [group_size, 3]
# dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemm
g_lds,
# number of gemms
group_size,
# number of virtual SM
NUM_SM: tl.constexpr,
# tile sizes
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
# is the output FP8 or FP16
FP8: tl.constexpr,
):
dtype = tl.float8e4nv if FP8 else tl.bfloat16
dtype_scale = tl.float32
dtype_out = tl.bfloat16
tile_idx = tl.program_id(0)
last_problem_end = 0
for g in range(group_size):
# get the gemm size of the current problem
gm = tl.load(group_gemm_sizes + g * 3)
gn = tl.load(group_gemm_sizes + g * 3 + 1)
gk = tl.load(group_gemm_sizes + g * 3 + 2)
num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)
num_tiles = num_m_tiles * num_n_tiles
if tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles:
# pick up a tile from the current gemm problem
lda = tl.load(g_lds + g * 3)
ldb = tl.load(g_lds + g * 3 + 1)
ldc = tl.load(g_lds + g * 3 + 2)
a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(dtype))
b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(dtype))
a_scale_ptr = tl.load(group_a_scale_ptrs + g).to(
tl.pointer_type(dtype_scale)
)
b_scale_ptr = tl.load(group_b_scale_ptrs + g).to(
tl.pointer_type(dtype_scale)
)
c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(dtype_out))
a_desc = tl._experimental_make_tensor_descriptor(
a_ptr,
shape=[gm, gk],
strides=[lda, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
)
b_desc = tl._experimental_make_tensor_descriptor(
b_ptr,
shape=[gn, gk],
strides=[ldb, 1],
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
)
c_desc = tl._experimental_make_tensor_descriptor(
c_ptr,
shape=[gm, gn],
strides=[ldc, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
)
# iterate through the tiles in the current gemm problem
while (
tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles
):
k = gk
# figure out tile coordinates
tile_idx_in_gemm = tile_idx - last_problem_end
tile_m_idx = tile_idx_in_gemm // num_n_tiles
tile_n_idx = tile_idx_in_gemm % num_n_tiles
# do regular gemm here
offs_am = tile_m_idx * BLOCK_SIZE_M
offs_bn = tile_n_idx * BLOCK_SIZE_N
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
a = a_desc.load([offs_am, kk * BLOCK_SIZE_K])
b = b_desc.load([offs_bn, kk * BLOCK_SIZE_K])
accumulator += tl.dot(a, b.T)
# apply scales
offsets = offs_am + tl.arange(0, BLOCK_SIZE_M)
a_scale_mask = offsets < gm
a_scale = tl.load(a_scale_ptr + offsets, a_scale_mask).expand_dims(-1)
accumulator *= a_scale
offsets = offs_bn + tl.arange(0, BLOCK_SIZE_N)
b_scale_mask = offsets < gn
b_scale = tl.load(b_scale_ptr + offsets, b_scale_mask).expand_dims(0)
accumulator *= b_scale
offs_cm = tile_m_idx * BLOCK_SIZE_M
offs_cn = tile_n_idx * BLOCK_SIZE_N
c = accumulator.to(dtype_out)
c_desc.store([offs_cm, offs_cn], c)
# go to the next tile by advancing NUM_SM
tile_idx += NUM_SM
# get ready to go to the next gemm problem
last_problem_end = last_problem_end + num_tiles
@triton.jit
def grouped_matmul_tma_prepare_data_kernel(
group_size,
a_data_ptr,
b_data_ptr,
a_scale_data_ptr,
b_scale_data_ptr,
c_data_ptr,
M,
N,
K,
a_stride_group,
b_stride_group,
a_scale_stride_group,
b_scale_stride_group,
c_stride_group,
lda,
ldb,
ldc,
d_a_ptrs,
d_b_ptrs,
d_a_scale_ptrs,
d_b_scale_ptrs,
d_c_ptrs,
d_g_sizes,
d_g_lds,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < group_size
a_ptrs = a_data_ptr + offsets * a_stride_group
tl.store(d_a_ptrs + offsets, a_ptrs, mask)
b_ptrs = b_data_ptr + offsets * b_stride_group
tl.store(d_b_ptrs + offsets, b_ptrs, mask)
a_scale_ptrs = a_scale_data_ptr + offsets * a_scale_stride_group
tl.store(d_a_scale_ptrs + offsets, a_scale_ptrs, mask)
b_scale_ptrs = b_scale_data_ptr + offsets * b_scale_stride_group
tl.store(d_b_scale_ptrs + offsets, b_scale_ptrs, mask)
c_ptrs = c_data_ptr + offsets * c_stride_group
tl.store(d_c_ptrs + offsets, c_ptrs, mask)
offsets = block_start + tl.arange(0, BLOCK_SIZE)
for i in range(3):
rem = offsets % 3
sel0 = rem == 0
sel1 = rem == 1
sel2 = rem == 2
g_sizes = M * sel0 + N * sel1 + K * sel2
g_lds = lda * sel0 + ldb * sel1 + ldc * sel2
mask = offsets < 3 * group_size
tl.store(d_g_sizes + offsets, g_sizes.to(tl.int32), mask)
tl.store(d_g_lds + offsets, g_lds.to(tl.int32), mask)
offsets += BLOCK_SIZE
def group_gemm_tma_fn(A, B, A_scale, B_scale, dtype_C):
# B comes in shape (group_size, K, N), but last two dims are in
# column major order.
group_size, M, K = A.shape
_, _, N = B.shape
C = torch.empty((group_size, M, N), device=DEVICE, dtype=dtype_C)
# Apparently, an array of pointers has to be aligned to 128 bits
# (due to CUDA < 12.4 bug).
group_size_aligned = (group_size + 1) // 2 * 2
# torch.int64 -> 8 bytes, torch.int32 -> 4 bytes
nbytes = 5 * group_size_aligned * 8 + group_size * 3 * 4 + group_size * 3 * 4
buffer = torch.empty(nbytes, device=DEVICE, dtype=torch.int8)
off = 0
incr = group_size_aligned * 8
d_a_ptrs = buffer[off : off + incr].view(torch.int64)
off += incr
d_b_ptrs = buffer[off : off + incr].view(torch.int64)
off += incr
d_a_scale_ptrs = buffer[off : off + incr].view(torch.int64)
off += incr
d_b_scale_ptrs = buffer[off : off + incr].view(torch.int64)
off += incr
d_c_ptrs = buffer[off : off + incr].view(torch.int64)
off += incr
incr = group_size * 3 * 4
d_g_sizes = buffer[off : off + incr].view(torch.int32)
off += incr
d_g_lds = buffer[off : off + incr].view(torch.int32)
grid = lambda meta: (triton.cdiv(group_size, meta["BLOCK_SIZE"]),)
grouped_matmul_tma_prepare_data_kernel[grid](
group_size,
A,
B,
A_scale,
B_scale,
C,
M,
N,
K,
A.stride(0),
B.stride(0),
A_scale.stride(0),
B_scale.stride(0),
C.stride(0),
A.stride(-2),
B.stride(-1),
C.stride(-2),
d_a_ptrs,
d_b_ptrs,
d_a_scale_ptrs,
d_b_scale_ptrs,
d_c_ptrs,
d_g_sizes,
d_g_lds,
BLOCK_SIZE=1024,
)
# we use a fixed number of CTA, and it's auto-tunable
# TMA descriptors require a global memory allocation
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
return torch.empty(size, device=A.device, dtype=torch.int8)
triton.set_allocator(alloc_fn)
# Simple hashing for auto-tuning.
assert group_size < 2**10 and M < 2**18 and N < 2**18 and K < 2**18
hash_value = group_size << 54 | M << 36 | N << 18 | K
grid = lambda META: (META["NUM_SM"],)
grouped_matmul_tma_kernel[grid](
hash_value,
d_a_ptrs,
d_b_ptrs,
d_a_scale_ptrs,
d_b_scale_ptrs,
d_c_ptrs,
d_g_sizes,
d_g_lds,
group_size,
FP8=torch.float8_e4m3fn == A.dtype,
NUM_SM=num_sms(),
)
return C
def generate_data(group_size, M, N, K, device, dtype_AB, dtype_scale, strided):
A = torch.randn(group_size * (1 + strided), M, K * (1 + strided), device=device).to(
dtype_AB
)[:: (1 + strided), :, :K]
B = (
torch.randn(group_size * (1 + strided), N, K * (1 + strided), device=device)
.to(dtype_AB)[:: (1 + strided), :, :K]
.transpose(-2, -1)
)
A_scale = torch.rand((group_size * (1 + strided), M), device=device).to(
dtype_scale
)[:: (1 + strided), :]
B_scale = torch.rand((group_size * (1 + strided), N), device=device).to(
dtype_scale
)[:: (1 + strided), :]
return A, B, A_scale, B_scale
def validate_group_gemm_tma_fn():
def validate_helper(
group_size, M, N, K, device, dtype_AB, dtype_scale, dtype_C, strided, atol, rtol
):
A, B, A_scale, B_scale = generate_data(
group_size, M, N, K, device, dtype_AB, dtype_scale, strided
)
C_ref_1 = torch.stack(
[
torch._scaled_mm(
A[i, :, :],
B[i, :, :],
A_scale[i, :, None],
B_scale[i, None, :],
out_dtype=dtype_C,
use_fast_accum=True,
)
for i in range(group_size)
]
)
C_ref_2 = torch._scaled_grouped_mm(
A, B, A_scale, B_scale, out_dtype=dtype_C, use_fast_accum=True
)
assert torch.allclose(C_ref_2, C_ref_1, atol=atol, rtol=rtol)
if supports_tma():
C = group_gemm_tma_fn(A, B, A_scale, B_scale, dtype_C)
assert torch.allclose(C, C_ref_1, atol=atol, rtol=rtol)
group_size = 4
device = DEVICE
dtype_AB = torch.float8_e4m3fn
dtype_scale = torch.float32
atol = 1e-2
rtol = 1e-2
MNK_range = [2**i for i in range(4, 8)]
strided_range = [False, True]
for M, N, K, strided in product(MNK_range, MNK_range, MNK_range, strided_range):
dtype_C = torch.bfloat16
validate_helper(
group_size,
M,
N,
K,
device,
dtype_AB,
dtype_scale,
dtype_C,
strided,
atol,
rtol,
)
def pytorch_loop_fn(A, B, A_scale, B_scale, dtype_C):
return torch.stack(
[
torch._scaled_mm(
A[i, :, :],
B[i, :, :],
A_scale[i, :, None],
B_scale[i, None, :],
out_dtype=dtype_C,
use_fast_accum=True,
)
for i in range(A.shape[0])
]
)
def pytorch_grouped_fn(A, B, A_scale, B_scale, dtype_C):
torch._scaled_grouped_mm(
A, B, A_scale, B_scale, out_dtype=dtype_C, use_fast_accum=True
)
def triton_grouped_fn(A, B, A_scale, B_scale, type_C):
group_gemm_tma_fn(A, B, A_scale, B_scale, type_C)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["M", "NK"],
x_vals=list(
product([2**i for i in range(6, 13)], [2**i for i in range(6, 13)])
),
line_arg="provider",
line_vals=["pytorch-loop", "pytorch-grouped"]
+ (["triton-grouped"] if supports_tma() else []),
line_names=["PyTorch loop", "PyTorch grouped"]
+ (["Triton grouped"] if supports_tma() else []),
styles=[("green", "-"), ("blue", "-")]
+ ([("red", "-")] if supports_tma() else []),
ylabel="runtime(ms)",
plot_name="scaled grouped GEMM performance",
args={},
)
)
def benchmark_batches(M, NK, provider):
group_size = 4
device = DEVICE
dtype_AB = torch.float8_e4m3fn
dtype_scale = torch.float32
strided = False
A, B, A_scale, B_scale = generate_data
67ED
(
group_size, M, NK, NK, device, dtype_AB, dtype_scale, strided
)
dtype_C = torch.bfloat16
quantiles = [0.5, 0.2, 0.8]
if provider == "pytorch-loop":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: pytorch_loop_fn(A, B, A_scale, B_scale, dtype_C),
quantiles=quantiles,
)
if provider == "pytorch-grouped":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: pytorch_grouped_fn(A, B, A_scale, B_scale, dtype_C),
quantiles=quantiles,
)
if provider == "triton-grouped":
# Call once in order to compile the kernels, ...
triton_grouped_fn(A, B, A_scale, B_scale, dtype_C)
# ...and then do the benchmark.
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: triton_grouped_fn(A, B, A_scale, B_scale, dtype_C),
quantiles=quantiles,
)
return ms, max_ms, min_ms
def profile_group_gemm_tma_fn():
dtype_AB = torch.float8_e4m3fn
dtype_scale = torch.float32
A, B, A_scale, B_scale = generate_data(
4, 1024, 1024, 1024, DEVICE, dtype_AB, dtype_scale, True
)
dtype_C = torch.bfloat16
torch.profiler._utils._init_for_cuda_graphs()
C = triton_grouped_fn(A, B, A_scale, B_scale, dtype_C)
prof = torch.profiler.profile()
with prof:
# C = torch._scaled_grouped_mm(
# A, B, A_scale, B_scale, out_dtype=dtype_C, use_fast_accum=True
# )
C = triton_grouped_fn(A, B, A_scale, B_scale, dtype_C)
prof.export_chrome_trace("foo.json")
###validate_group_gemm_tma_fn()
benchmark_batches.run(show_plots=True, print_data=True)
###profile_group_gemm_tma_fn() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
v cool!
] | ||
|
||
_AMD_CONFIGS = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason to include way more configs in AMD than NV?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was my mistake as I cut down the set of configs while debugging. Fixed now!
) | ||
|
||
# 1. make sure we have enough smem | ||
max_shared_memory = driver.active.utils.get_device_properties(device)[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we factor this out and cache it ? that has been the existing pattern although I dont know how long these calls actually takeb
] | ||
if torch.version.hip: | ||
required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: would you mind commenting why there is a difference ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to check with some more AMD-knowledgeable folks, and the answer was quite interesting. Apparently:
- AMD often doesn't even software pipeline gemms, and instead uses additional wavefronts to hide latency
- This is partly because the AMD design has a lot less shared memory (64 KB versus 228 on H100) and it looks like more registers (although the register size doesn't seem super well documented)
- So often even when pipelining they try to put A in registers and use shared mem for B
offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) | ||
offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) | ||
a_scale = tl.load( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you have any sense of how important tma store is ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really, but it's apparently not used by FBGEMM on Nvidia (USE_TMA_STORE
is set to False
on all the paths I can find) so that seems like decent evidence that it's not a deal breaker.
Still would be worth re-enabling some day, just because, but maybe not this moment
m_size = M_end_offset - M_start_offset | ||
if m_size > 0: | ||
N_start_offset = g.to(tl.int64) * N |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we want to gate tl.int64 on input sizes ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a good idea but maybe want to save it for an optimization down the line; I don't think it significantly slows down this kernel as-is
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) ghstack-source-id: 6711b11 Pull Request resolved: #150421
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) ghstack-source-id: 158b7bf Pull Request resolved: #150421
@alexsamardzic Awesome, I'd love some help on this work! It's encouraging to see that you're getting great results with the tutorial grouped gemm too -- I went with fbgemm because it's been tested in prod here and we know it has good perf characteristics, but it seems like either could be a good starting point. I think this PR is pretty close to landing 🤞 , but if you're interested, a diff to expand coverage to 3d/3d would be a great contribution! |
def early_config_prune(configs, named_args): | ||
from triton.runtime import driver | ||
|
||
device = torch.cuda.current_device() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My I suggest we use the general device interface here? eg.
from torch._inductor.utils import get_gpu_type
from torch._dynamo.device_interface import get_interface_for_device
device_interface = get_interface_for_device(get_gpu_type())
device = device_interface.current_device()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ended up reusing an existing function to get shared memory size, so I didn't need this code after all 😁
@pytorchbot revert -m "Looks like it broke lint, see https://hud.pytorch.org/hud/pytorch/pytorch/a0ab243c3a5dfe12b392e4074d69360fd013f842/1?per_page=50&name_filter=lint&mergeEphemeralLF=true" -c nosignal |
@pytorchbot successfully started a revert job. Check the current status here. |
This reverts commit 8efcf21. Reverted #150421 on behalf of https://github.com/malfet due to Looks like it broke lint, see https://hud.pytorch.org/hud/pytorch/pytorch/a0ab243c3a5dfe12b392e4074d69360fd013f842/1?per_page=50&name_filter=lint&mergeEphemeralLF=true ([comment](#150421 (comment)))
@bertmaher your PR has been successfully reverted. |
Awh man, this is annoying, it raced with the addition of a docstring linter 😠 😠 😠 |
This reverts commit 6a65f2c. [ghstack-poisoned]
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) ghstack-source-id: e5b0868 Pull Request resolved: #150421
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) ghstack-source-id: 06a27e2 Pull Request resolved: #150421
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) Pull Request resolved: pytorch#150421 Approved by: https://github.com/ngimel
This reverts commit 8efcf21. Reverted pytorch#150421 on behalf of https://github.com/malfet due to Looks like it broke lint, see https://hud.pytorch.org/hud/pytorch/pytorch/a0ab243c3a5dfe12b392e4074d69360fd013f842/1?per_page=50&name_filter=lint&mergeEphemeralLF=true ([comment](pytorch#150421 (comment)))
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) Pull Request resolved: pytorch#150421 Approved by: https://github.com/ngimel
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) Pull Request resolved: pytorch#150421 Approved by: https://github.com/ngimel
This reverts commit 8efcf21. Reverted pytorch#150421 on behalf of https://github.com/malfet due to Looks like it broke lint, see https://hud.pytorch.org/hud/pytorch/pytorch/a0ab243c3a5dfe12b392e4074d69360fd013f842/1?per_page=50&name_filter=lint&mergeEphemeralLF=true ([comment](pytorch#150421 (comment)))
This includes the default aten implementation, as well as a Triton implementation imported from FBGEMM (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py) Pull Request resolved: pytorch#150421 Approved by: https://github.com/ngimel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry - I should have caught this on first run, but would you mind adding a test for the form in test_max_autotune where we run with backends set to TRITON, max-autotune set, and check output code? see
pytorch/test/inductor/test_max_autotune.py
Lines 669 to 678 in 916f6ba
with config.patch( | |
{"max_autotune": True, "max_autotune_gemm_backends": "TRITON"} | |
): | |
@torch.compile() | |
def foo(mod, x): | |
return mod(x) | |
with torch.no_grad(): | |
out, code = run_and_get_code(foo, conv1x1, input_tensor) |
Since the compile here is not using max-autotune I'm not sure this is being run.
Stack from ghstack (oldest at bottom):
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov