8000 MX single node performance tracker · Issue #1768 · pytorch/ao · GitHub
[go: up one dir, main page]

Skip to content
MX single node performance tracker #1768
@vkuzo

Description

@vkuzo

This issue tracks single node performance of MX training and inference: fast gemm, fast fused kernels.

training performance summary

As of 2025-03-27

  • e2e pretraining speedup vs bf16 + compile on LLaMa 3 8B, 8 B200 GPUs, torchtitan with default settings
    • 🟢 float8 tensorwise: 1.19x
    • 🟡 mxfp8: 1.13x (should be similar to tensorwise's 1.19x once we fix all the issues). Right now scaling/casting to mx is slow
  • 🟢 gemm speedup: cuBLAS mxfp8 gemm is 2x to 3x faster than bf16 - done for now
  • 🟢 mx casting to dim0 with torch.compile achieves up to 5.4 TB/s (67% of B200 peak mem bw) - done for now
  • 🟡 mx casting to dim1 is our main performance gap
  • 🔲 mx casting to dim0 + dim1 at the same time is postponed for now until we make the individual dim0 and dim1 kernels better

invididual components

system overview (for training)

# There are three gemms in a forward + backward of a Linear layer:
#
# 1.       input @ weight_t    = output     (forward pass)
# 2. grad_output @ weight      = grad_input (backward pass)
# 3.     input_t @ grad_output = grad_weight (backward pass)
# 
# in Python pseudocode, we want the following (for mxfp8):

# forward pass

# inputs are in high precision
x_hp, w_hp = ...

# input @ weight_t = output
x_mx_dim0, x_scale_dim0 = to_mx(x_hp, dim=0)
w_mx_dim0, w_scale_dim0 = to_mx(w_hp, dim=0)
y = mx_gemm(x_mx_dim0, w_mx_dim0.t(), x_scale_dim0, w_scale_dim1)

# backward pass

# inputs are in high precision
x_hp, w_hp, go_hp = ...

# grad_output @ weight = grad_input
go_mx_dim0, go_scale_dim0 = to_mx(go_hp, dim=0)
w_mx_dim1, w_scale_dim1 = to_mx(w_hp.t().contiguous(), dim=0)
gi = mx_gemm(go_mx_dim0, w_mx_dim1.t(), go_scale_dim0, w_scale_dim1)

# input_t @ grad_output = grad_weight
go_mx_dim1, go_scale_dim1 = to_mx(go_hp.t().contiguous().t(), dim=0)
x_mx_dim1, x_scale_dim1 = to_mx(x_hp.t().contiguous(), dim=0)
gw = mx_gemm(go_mx_dim1, x_mx_dim1.t(), go_scale_dim1, x_scale_dim1)

We want:

  1. the mx gemm to be fast
  2. the cast from high precision to mx (to_mx in pseudocode above) to be fast
  3. the cast from high precision to mx to be fused to preceding/subsequent ops where possible

gemm kernel

Expected peak TFLOPs on NVIDIA B200, without sparsity: 2.25 petaFLOPs for b16, 4.25 petaFLOPs for fp8/fp6 (2x from bf16), 9.0 petaFLOPs for fp4 (4x from bf16) (source: https://resources.nvidia.com/en-us-blackwell-architecture, pages 19-20)

kernel wrapper current TFLOPs peak TFLOPs notes
mxfp8 cuBLAS torch._scaled_mm TBD 4.25 petaFLOPs landed, pytorch/pytorch#147548
mxfp8 CUTLASS torchao.ops.mx_fp8_bf16 TBD 4.25 petaFLOPs landed, #1637
mxfp4 CUTLASS torchao.ops.mx_fp4_bf16 TBD 9.0 petaFLOPs landed, #1661
nvfp4 cuBLAS torch._scaled_mm TBD 9.0 petaFLOPs in progress, pytorch/pytorch#148792

Once we have machines where benchmarking is possible, we should add easily reproducible gemm benchmarks and fill out the TFLOP column in the table above.

scaling/casting kernels

Our current plan is to use torch.compile, same as we are doing with float8.

e2e training performance

From https://resources.nvidia.com/en-us-blackwell-architecture pages 19-20, on B200 the single GPU memory bandwidth we expect is 8 TB/s, the fp8/fp6 tensor core peak FLOPS is 4.5 petaFLOPS (without sparsity), and the fp4 tensor core peak FLOPS is 9.0 petaFLOPS (without sparsity).

  • we need a roofline of mx scaling/casting to get the shapes which are expected to see speedups, and we should have a benchmark to compared observed to theoretical
  • [blocked] eventually we should get to SOTA performance in torchtitan. Currently, this work is blocked by general issues with Blackwell support in PyTorch, such as NCCL not working. Tracking is here: [CUDA][Blackwell] Blackwell Tracking Issue pytorch#145949

e2e inference performance

  • need an inference roofline
  • need to decide where to benchmark

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0