8000 [ROCm] MI300X FP8 scaled_mm is extremely slow on nightly · Issue #143465 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
[ROCm] MI300X FP8 scaled_mm is extremely slow on nightly #143465
Closed
@functionstackx

Description

@functionstackx

🐛 Describe the bug

Hi AMD Team,

torch._scaled_mm is extremely slow on MI300X at ~100TFLOP/s verus ~1200TFLOP/s on H100

Can you look into this?

cc: @hliuca

ROCm

m=16384 n=8192 k=1280: 108.07154472843483
m=16384 n=1024 k=8192: 110.56206220309926
m=16384 n=8192 k=7168: 109.66662842248034
m=16384 n=3584 k=8192: 110.59228182207659
m=8192 n=8192 k=8192: 109.86138366796457

H100

m=16384 n=8192 k=1280: 1239.4133451945781
m=16384 n=1024 k=8192: 1347.0844475792383
m=16384 n=8192 k=7168: 1332.2623882545472
m=16384 n=3584 k=8192: 1309.4453003269748
m=8192 n=8192 k=8192: 1304.5406858844613

Reprod

import time
import torch
from triton.testing import do_bench

torch.manual_seed(0)
repeats = 200
warmup = 30
timeout = 0.5

device = 'cuda'

# GEMM Shapes
shapes = [
    (16384, 8192, 1280),
    (16384, 1024, 8192),
    (16384, 8192, 7168),
    (16384, 3584, 8192),
    (8192, 8192, 8192)
]

results = []

for (m, n, k) in shapes:
    # FLOPS
    nFLOPS = 2 * m * n * k
    a_fp8_e5m2 = torch.randn(m, k, device=device).to(torch.float8_e5m2fnuz)
    b_fp8_e5m2 = torch.randn(n, k, device=device).to(torch.float8_e4m3fnuz).transpose(-1, -2)
    scale_a = torch.tensor(1.0, device=device, dtype=torch.float32)
    scale_b = torch.tensor(1.0, device=device, dtype=torch.float32)
    ms_fp8_scaled_mm_e4m3 = do_bench(lambda: torch._scaled_mm(a_fp8_e5m2, b_fp8_e5m2, scale_a, scale_b), warmup=warmup, rep=repeats)
    tflops_fp8_scaled_mm_e4m3 = nFLOPS / ms_fp8_scaled_mm_e4m3 * 1e-9
    time.sleep(timeout)

    print(f"{m=} {n=} {k=}: {tflops_fp8_scaled_mm_e4m3}")

cc: @hliuca

Versions

pip list | grep torch
pytorch-triton-rocm      3.2.0+git35c6c7c6
torch                    2.6.0.dev20241216+rocm6.2.4

cc @msaroufim @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: performanceIssues related to performance, either of kernel code or framework gluemodule: rocmAMD GPU support for PytorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0