8000 [ROCm] MI300X FP8 scaled_mm is extremely slow on nightly · Issue #143465 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ROCm] MI300X FP8 scaled_mm is extremely slow on nightly #143465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
functionstackx opened this issue Dec 18, 2024 · 23 comments
Closed

[ROCm] MI300X FP8 scaled_mm is extremely slow on nightly #143465

functionstackx opened this issue Dec 18, 2024 · 23 comments
Labels
module: performance Issues related to performance, either of kernel code or framework glue module: rocm AMD GPU support for Pytorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@functionstackx
Copy link
Contributor
functionstackx commented Dec 18, 2024

🐛 Describe the bug

Hi AMD Team,

torch._scaled_mm is extremely slow on MI300X at ~100TFLOP/s verus ~1200TFLOP/s on H100

Can you look into this?

cc: @hliuca

ROCm

m=16384 n=8192 k=1280: 108.07154472843483
m=16384 n=1024 k=8192: 110.56206220309926
m=16384 n=8192 k=7168: 109.66662842248034
m=16384 n=3584 k=8192: 110.59228182207659
m=8192 n=8192 k=8192: 109.86138366796457

H100

m=16384 n=8192 k=1280: 1239.4133451945781
m=16384 n=1024 k=8192: 1347.0844475792383
m=16384 n=8192 k=7168: 1332.2623882545472
m=16384 n=3584 k=8192: 1309.4453003269748
m=8192 n=8192 k=8192: 1304.5406858844613

Reprod

import time
import torch
from triton.testing import do_bench

torch.manual_seed(0)
repeats = 200
warmup = 30
timeout = 0.5

device = 'cuda'

# GEMM Shapes
shapes = [
    (16384, 8192, 1280),
    (16384, 1024, 8192),
    (16384, 8192, 7168),
    (16384, 3584, 8192),
    (8192, 8192, 8192)
]

results = []

for (m, n, k) in shapes:
    # FLOPS
    nFLOPS = 2 * m * n * k
    a_fp8_e5m2 = torch.randn(m, k, device=device).to(torch.float8_e5m2fnuz)
    b_fp8_e5m2 = torch.randn(n, k, device=device).to(torch.float8_e4m3fnuz).transpose(-1, -2)
    scale_a = torch.tensor(1.0, device=device, dtype=torch.float32)
    scale_b = torch.tensor(1.0, device=device, dtype=torch.float32)
    ms_fp8_scaled_mm_e4m3 = do_bench(lambda: torch._scaled_mm(a_fp8_e5m2, b_fp8_e5m2, scale_a, scale_b), warmup=warmup, rep=repeats)
    tflops_fp8_scaled_mm_e4m3 = nFLOPS / ms_fp8_scaled_mm_e4m3 * 1e-9
    time.sleep(timeout)

    print(f"{m=} {n=} {k=}: {tflops_fp8_scaled_mm_e4m3}")

cc: @hliuca

Versions

pip list | grep torch
pytorch-triton-rocm      3.2.0+git35c6c7c6
torch                    2.6.0.dev20241216+rocm6.2.4

cc @msaroufim @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd

@pytorch-bot pytorch-bot bot added the module: rocm AMD GPU support for Pytorch label Dec 18, 2024
@jeffdaily
Copy link
Collaborator

Are you able to try building the latest tip of hipblaslt develop branch and rerunning your numbers?

@jeffdaily
Copy link
Collaborator

Since you're using the nightly wheel it will perhaps take some manual hacking of your torch install to use the latest hipblaslt. You might be able to build and install latest hipblaslt to some other location on your system and then use LD_LIBRARY_PATH to point to its lib instead of the hipblaslt lib that is bundled in the nightly wheel.

In case you want to try the hack of copying the newer hipblaslt into your torch install location, note that hipblaslt has both the libhipblaslt.so component but also its auxiliary files containing the GPU code objects [kernels] stored relative to libhipblaslt.so under hipblaslt/library/*.co and *.dat files. You'll need to copy all of it.

@hliuca
Copy link
hliuca commented Dec 18, 2024

Hi @OrenLeung we are on this. Thank you @jeffdaily for your response.

@functionstackx
Copy link
Contributor Author
functionstackx commented Dec 18, 2024

I have also tried the latest official rocm 6.3 torch 2.5 image and this

rocm6.3_ubuntu22.04_py3.10_pytorch_release_2.5.0_preview
image

i also tried e4m3xe4m3 and it doesnt work either on nightly or 2.5

Tho on rocm/pytorch:rocm6.3_ubuntu24.04_py3.12_pytorch_release_2.4.0, e4m3xe4m3 works at ~1000TFLOP/s

e4m3xe4m3

import time
import torch
from triton.testing import do_bench
torch.manual_seed(0)
repeats = 200
warmup = 30
timeout = 0.0
device = 'cuda'
# GEMM Shapes
shapes = [
    (16384, 8192, 1280),
    (16384, 1024, 8192),
    (16384, 8192, 7168),
    (16384, 3584, 8192),
    (8192, 8192, 8192)
]
results = []
for (m, n, k) in shapes:
    # FLOPS
    nFLOPS = 2 * m * n * k
    a_fp8_e4m3 = torch.randn(m, k, device=device).to(torch.float8_e4m3fnuz)
    b_fp8_e4m3 = torch.randn(n, k, device=device).to(torch.float8_e4m3fnuz).transpose(-1, -2)
    scale_a = torch.tensor(1.0, device=device, dtype=torch.float32)
    scale_b = torch.tensor(1.0, device=device, dtype=torch.float32)
    with torch.inference_mode():
        ms_fp8_scaled_mm_e4m3 = do_bench(lambda: torch._scaled_mm(a_fp8_e4m3, b_fp8_e4m3), warmup=warmup, rep=repeats)
    tflops_fp8_scaled_mm_e4m3 = nFLOPS / ms_fp8_scaled_mm_e4m3 * 1e-9
    time.sleep(timeout)
    print(f"{m=} {n=} {k=}: {tflops_fp8_scaled_mm_e4m3}")

@hliuca
Copy link
hliuca commented Dec 18, 2024

@OrenLeung it looks e4m3 is more optimized while e5m2 isn't.

by the way, your code mixed e5m2 and e4m3,
a_fp8_e5m2 = torch.randn(m, k, device=device).to(torch.float8_e5m2fnuz)
b_fp8_e5m2 = torch.randn(n, k, device=device).to(torch.float8_e4m3fnuz).transpose(-1, -2)

I have reported to library owner. Thank you Oren.

@functionstackx
Copy link
Contributor Author

Hi @hliuca ,

Thank you. yes it is mxied e5m2 for A tensor & e4m3 B tensor as nvidia cublaslt doesn't support both input tensors being e5m2.

My understanding is that e5m2/e5m2 isnt used in practice & that is why nvidia doesnt support it. Please correct if wrong

#132005

image

@hliuca
Copy link
hliuca commented Dec 18, 2024

Probably. I checked hipblaslt commits, and I do see more f8 (e4m3) commits than bf8 (e5m2). :-)

@hongxiayang hongxiayang moved this to In Progress in PyTorch on ROCm Dec 18, 2024
@malfet malfet added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: performance Issues related to performance, either of kernel code or framework glue labels Dec 18, 2024
@naromero77amd
Copy link
Collaborator

e5m2 support is being handled by the hipblaslt team. So this is being closed.

@github-project-automation github-project-automation bot moved this from In Progress to Done in PyTorch on ROCm Dec 19, 2024
@hliuca
Copy link
hliuca commented Dec 20, 2024

@OrenLeung I will follow up tomorrow. Thank you. Even e4m3 * e4m3, I got,

m=16384 n=8192 k=1280: 101.02149096903179
m=16384 n=1024 k=8192: 94.50917622036803
m=16384 n=8192 k=7168: 101.4627641159379
m=16384 n=3584 k=8192: 98.45815730495234
m=8192 n=8192 k=8192: 97.94690870292024

The perf is really bad.

@hliuca
Copy link
hliuca commented Dec 20, 2024

@OrenLeung we have 3 teams on this issue internally. I will keep you posted through other channels. Thank you.

@hongxiayang hongxiayang moved this from Done to In Progress in PyTorch on ROCm Dec 20, 2024
@hongxiayang hongxiayang reopened this Dec 20, 2024
@jeffdaily
Copy link
Collaborator

Setting the env var PYTORCH_TUNABLEOP_ENABLED=1 is able to provide significantly improved performance for the reproducers provided in this PR.

@functionstackx
Copy link
Contributor Author

@jeffdaily thanks for tip.

any chance that out of the box will be tuned for torch._scaled_mm?

PYTORCH_TUNABLEOP_ENABLED leads to long dev cycle times. On nvidia, out of the box without magic flags & runtime tuning, runs really well due to their gemm heurstic model

@deke997
Copy link
deke997 commented Jan 5, 2025

Hi,

We are also seeing this issue. Is there a fix coming without Tunable Ops?

@hliuca
Copy link
hliuca commented Jan 10, 2025

Hi @deke997 many F8B8 Gemm has been optimized and checked in.

https://github.com/ROCm/hipBLASLt/commits/develop/

We are actively working on this. If you wish to test, please compile the latest hipblaslt. Thank you.

@naromero77amd
Copy link
Collaborator

@OrenLeung This is with a relatively recent version of hipblasLt develop (maybe a month old):

m=16384 n=8192 k=1280: 989.8769855505444
m=16384 n=1024 k=8192: 1047.7430343012413
m=16384 n=8192 k=7168: 1149.413494621263
m=16384 n=3584 k=8192: 1185.4703718721548
m=8192 n=8192 k=8192: 1134.934285778508

Let us know if you want us to continue to keep this issue open or if it can be closed.

@functionstackx
Copy link
Contributor Author
functionstackx commented Apr 23, 2025

hi @naromero77amd thanks for working on this!

if let's close this issue if the current pypi torch nightly is able to get above ~1000 TFLOP/s too!

@naromero77amd
Copy link
Collaborator

The performance improvements come from the version of hipblasLt that is bunding with ROCm stack. The particular version of PyTorch is not as relevant.

In the latest ROCm 6.4 docker images, e.g.

docker pull rocm/pytorch:rocm6.4_ubuntu22.04_py3.10_pytorch_release_2.6.0

The performance issue is resolved:

m=16384 n=8192 k=1280: 970.3676871709138
m=16384 n=1024 k=8192: 1046.7187985127473
m=16384 n=8192 k=7168: 1150.8146096596909
m=16384 n=3584 k=8192: 1193.3655142900766
m=8192 n=8192 k=8192: 1131.2118523426138

@functionstackx
Copy link
Contributor Author
functionstackx commented Apr 26, 2025

@naromero77amd on cuda, pypi torch bundles cublas via pypi nvidia dependencies automatically. would rocm pypi torch be doing the same?

sorry if this is a dumb question.

pip list
Package                  Version
------------------------ ------------
nvidia-cublas-cu12       12.8.3.14
nvidia-cuda-cupti-cu12   12.8.57
nvidia-cuda-nvrtc-cu12   12.8.61
nvidia-cuda-runtime-cu12 12.8.57
nvidia-cudnn-cu12        9.7.1.26
nvidia-cufft-cu12        11.3.3.41
nvidia-cufile-cu12       1.13.0.11
nvidia-curand-cu12       10.3.9.55
nvidia-cusolver-cu12     11.7.2.55
nvidia-cusparse-cu12     12.5.7.53
nvidia-cusparselt-cu12   0.6.3
nvidia-nccl-cu12         2.26.2
nvidia-nvjitlink-cu12    12.8.61
nvidia-nvtx-cu12         12.8.55
torch                    2.7.0+cu128
torchaudio               2.7.0+cu128
torchvision              0.22.0+cu128
triton                   3.3.0
...

Image

@sunway513
Copy link

@OrenLeung At this point, we have bounded the ROCm acceleration libraries (hipBLASLT, MIOpen, etc.) inside the PyTorch WHL package. The torch domain libraries are bounded in the same way as the NV package at the Python level.
In the near future, after we publish the ROCm base Python packages, the PyTorch ROCm WHL will also bundle the hipBLASLT library WHLs instead.

@functionstackx
Copy link
Contributor Author
functionstackx commented Apr 26, 2025

thanks for the explanation @sunway513

it was able to get the performance update inside rocm docker container but for the pypi torch whl installation, the performance improvements still isn't there for stable whl or nightly whl. I am following https://pytorch.org/get-started/locally/

Can it be updated such that u bundled the updated ROCm libraries inside the torch pypi whl package too?

Docker

docker pull rocm/pytorch:rocm6.4_ubuntu22.04_py3.10_pytorch_release_2.6.0

Pypi Stable

following the installations from https://pytorch.org/get-started/locally/

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3
user@tw031:~$ python ./reprod.py 
m=16384 n=8192 k=1280: 102.20135224999864
m=16384 n=1024 k=8192: 102.85467618252625
m=16384 n=8192 k=7168: 104.92880160773375
m=16384 n=3584 k=8192: 105.19174292269251
m=8192 n=8192 k=8192: 104.96762894820529
user@tw031:~$ pip list | grep torch
pytorch-triton-rocm 3.3.0
torch               2.7.0+rocm6.3

Pypi Nightly

following the installations from https://pytorch.org/get-started/locally/

pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3
user@tw031:~$ python ./reprod.py 
m=16384 n=8192 k=1280: 101.688998175193
m=16384 n=1024 k=8192: 102.83025828797749
m=16384 n=8192 k=7168: 104.56994939801766
m=16384 n=3584 k=8192: 104.70265188641152
m=8192 n=8192 k=8192: 104.52714592170736
user@tw031:~$ pip list | grep torch
pytorch-triton-rocm 3.3.0+git96316ce5
torch               2.8.0.dev20250426+rocm6.3
torchaudio          2.6.0.dev20250426+rocm6.3
torchvision         0.22.0.dev20250426+rocm6.3

Image

@sunway513
Copy link

Thanks for trying out the WHL packages, @OrenLeung.
The current nightly and stable WHLs hosted under pytorch.org are all using ROCm6.3, hence the performance is still lagging comparing to the ROCm 6.4 docker container you tried out.
The team is in progress upgrading the base ROCm version to ROCm6.4, it should be available in the coming weeks.
We'll let you know once the nightly WHL with ROCm6.4 becomes available under pytorch.org. cc @jithunnair-amd @jeffdaily

Besides, if you would like to prefetch the PyTorch WHLs with ROCm6.4 base, you can try out the ones that ROCm has been hosting for fast availability:
http://repo.radeon.com/rocm/manylinux/rocm-rel-6.4/

@jeffdaily
Copy link
Collaborator
jeffdaily commented Apr 30, 2025

Nightly wheels are at ROCm 6.4 now, but the get-started chooser hasn't been updated yet.

pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.4

@functionstackx
Copy link
Contributor Author
functionstackx commented May 12, 2025

Thanks for fixing this issue! I see that get started website nightly is now ROCm 6.4 with the updated hipBlasLt

@github-project-automation github-project-automation bot moved this from In Progress to Done in PyTorch on ROCm May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: performance Issues related to performance, either of kernel code or framework glue module: rocm AMD GPU support for Pytorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

No branches or pull requests

8 participants
0