Closed
Description
🐛 Describe the bug
Hi AMD Team,
torch._scaled_mm
is extremely slow on MI300X at ~100TFLOP/s verus ~1200TFLOP/s on H100
Can you look into this?
cc: @hliuca
ROCm
m=16384 n=8192 k=1280: 108.07154472843483
m=16384 n=1024 k=8192: 110.56206220309926
m=16384 n=8192 k=7168: 109.66662842248034
m=16384 n=3584 k=8192: 110.59228182207659
m=8192 n=8192 k=8192: 109.86138366796457
H100
m=16384 n=8192 k=1280: 1239.4133451945781
m=16384 n=1024 k=8192: 1347.0844475792383
m=16384 n=8192 k=7168: 1332.2623882545472
m=16384 n=3584 k=8192: 1309.4453003269748
m=8192 n=8192 k=8192: 1304.5406858844613
Reprod
import time
import torch
from triton.testing import do_bench
torch.manual_seed(0)
repeats = 200
warmup = 30
timeout = 0.5
device = 'cuda'
# GEMM Shapes
shapes = [
(16384, 8192, 1280),
(16384, 1024, 8192),
(16384, 8192, 7168),
(16384, 3584, 8192),
(8192, 8192, 8192)
]
results = []
for (m, n, k) in shapes:
# FLOPS
nFLOPS = 2 * m * n * k
a_fp8_e5m2 = torch.randn(m, k, device=device).to(torch.float8_e5m2fnuz)
b_fp8_e5m2 = torch.randn(n, k, device=device).to(torch.float8_e4m3fnuz).transpose(-1, -2)
scale_a = torch.tensor(1.0, device=device, dtype=torch.float32)
scale_b = torch.tensor(1.0, device=device, dtype=torch.float32)
ms_fp8_scaled_mm_e4m3 = do_bench(lambda: torch._scaled_mm(a_fp8_e5m2, b_fp8_e5m2, scale_a, scale_b), warmup=warmup, rep=repeats)
tflops_fp8_scaled_mm_e4m3 = nFLOPS / ms_fp8_scaled_mm_e4m3 * 1e-9
time.sleep(timeout)
print(f"{m=} {n=} {k=}: {tflops_fp8_scaled_mm_e4m3}")
cc: @hliuca
Versions
pip list | grep torch
pytorch-triton-rocm 3.2.0+git35c6c7c6
torch 2.6.0.dev20241216+rocm6.2.4
cc @msaroufim @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd
Metadata
Metadata
Assignees
Labels
Type
Projects
Status