-
Notifications
You must be signed in to change notification settings - Fork 308
Description
This issue tracks single node performance of MX training and inference: fast gemm, fast fused kernels.
training performance summary
As of 2025-03-27
- e2e pretraining speedup vs bf16 + compile on LLaMa 3 8B, 8 B200 GPUs, torchtitan with default settings
- 🟢 float8 tensorwise: 1.19x
- 🟡 mxfp8: 1.13x (should be similar to tensorwise's 1.19x once we fix all the issues). Right now scaling/casting to mx is slow
- 🟢 gemm speedup: cuBLAS mxfp8 gemm is 2x to 3x faster than bf16 - done for now
- 🟢 mx casting to dim0 with torch.compile achieves up to 5.4 TB/s (67% of B200 peak mem bw) - done for now
- 🟡 mx casting to dim1 is our main performance gap
- 🟡 torch.compile achieves only up to 3.9 TB/s (49% peak mem bw) with
TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1
, tracking here: request for faster inductor kernels for blockwise reduction across dim1 -> write pytorch#149982 - 🟡 handwritten triton kernel from mx: triton kernel to cast to mx and write in col-major #1932 achieves up to 3.5 TB/s (44% peak mem bw), can likely be improved further
- 🟡 torch.compile achieves only up to 3.9 TB/s (49% peak mem bw) with
- 🔲 mx casting to dim0 + dim1 at the same time is postponed for now until we make the individual dim0 and dim1 kernels better
invididual components
system overview (for training)
# There are three gemms in a forward + backward of a Linear layer:
#
# 1. input @ weight_t = output (forward pass)
# 2. grad_output @ weight = grad_input (backward pass)
# 3. input_t @ grad_output = grad_weight (backward pass)
#
# in Python pseudocode, we want the following (for mxfp8):
# forward pass
# inputs are in high precision
x_hp, w_hp = ...
# input @ weight_t = output
x_mx_dim0, x_scale_dim0 = to_mx(x_hp, dim=0)
w_mx_dim0, w_scale_dim0 = to_mx(w_hp, dim=0)
y = mx_gemm(x_mx_dim0, w_mx_dim0.t(), x_scale_dim0, w_scale_dim1)
# backward pass
# inputs are in high precision
x_hp, w_hp, go_hp = ...
# grad_output @ weight = grad_input
go_mx_dim0, go_scale_dim0 = to_mx(go_hp, dim=0)
w_mx_dim1, w_scale_dim1 = to_mx(w_hp.t().contiguous(), dim=0)
gi = mx_gemm(go_mx_dim0, w_mx_dim1.t(), go_scale_dim0, w_scale_dim1)
# input_t @ grad_output = grad_weight
go_mx_dim1, go_scale_dim1 = to_mx(go_hp.t().contiguous().t(), dim=0)
x_mx_dim1, x_scale_dim1 = to_mx(x_hp.t().contiguous(), dim=0)
gw = mx_gemm(go_mx_dim1, x_mx_dim1.t(), go_scale_dim1, x_scale_dim1)
We want:
- the mx gemm to be fast
- the cast from high precision to mx (
to_mx
in pseudocode above) to be fast - the cast from high precision to mx to be fused to preceding/subsequent ops where possible
gemm kernel
Expected peak TFLOPs on NVIDIA B200, without sparsity: 2.25 petaFLOPs for b16, 4.25 petaFLOPs for fp8/fp6 (2x from bf16), 9.0 petaFLOPs for fp4 (4x from bf16) (source: https://resources.nvidia.com/en-us-blackwell-architecture, pages 19-20)
kernel | wrapper | current TFLOPs | peak TFLOPs | notes |
---|---|---|---|---|
mxfp8 cuBLAS | torch._scaled_mm |
TBD | 4.25 petaFLOPs | landed, pytorch/pytorch#147548 |
mxfp8 CUTLASS | torchao.ops.mx_fp8_bf16 |
TBD | 4.25 petaFLOPs | landed, #1637 |
mxfp4 CUTLASS | torchao.ops.mx_fp4_bf16 |
TBD | 9.0 petaFLOPs | landed, #1661 |
nvfp4 cuBLAS | torch._scaled_mm |
TBD | 9.0 petaFLOPs | in progress, pytorch/pytorch#148792 |
Once we have machines where benchmarking is possible, we should add easily reproducible gemm benchmarks and fill out the TFLOP column in the table above.
scaling/casting kernels
Our current plan is to use torch.compile
, same as we are doing with float8.
- we should ensure we can generate a single fused kernel for scaling and casting a tensor to mxfp8. Today, torch.compile generates two kernels: torch.compile cast to mxfp8 should only require one kernel #1769
- once we have a single fused kernel, we should make sure it's bandwidth bound. As of 2025-02-24, the casting to MX code is numerically correct but researchy and has not been optimized for performance. TODO issue.
- the
float8_e8m0fnu
dtype was added to PyTorch in add thetorch.float8_e8m0fnu
dtype to PyTorch pytorch#147466, we need to updatetorchao
to use this dtype for scales, and then ensure that PT2 works e2e. TODO issue - we need to ensure torch.compile is good at generating good fused kernels for the custom scale packing layout required by B200s. torch.compile cast to mxfp8 with blocked scales should be performant #1773
- we should ensure the cast across dim0 and dim1 is performant: mx cast to mxfp8 across dim0 and dim1 should be performant #1788
- given an MXLinear (fwd + bwd), we should expect at most six scale+cast kernels: two for each of
input
,weight
,grad_output
. The kernels forinput
andgrad_output
should be fused with preceding/subsequent ops as appropriate. TODO issue.
e2e training performance
From https://resources.nvidia.com/en-us-blackwell-architecture pages 19-20, on B200 the single GPU memory bandwidth we expect is 8 TB/s, the fp8/fp6 tensor core peak FLOPS is 4.5 petaFLOPS (without sparsity), and the fp4 tensor core peak FLOPS is 9.0 petaFLOPS (without sparsity).
- we need a roofline of mx scaling/casting to get the shapes which are expected to see speedups, and we should have a benchmark to compared observed to theoretical
- [blocked] eventually we should get to SOTA performance in torchtitan. Currently, this work is blocked by general issues with Blackwell support in PyTorch, such as NCCL not working. Tracking is here: [CUDA][Blackwell] Blackwell Tracking Issue pytorch#145949
e2e inference performance
- need an inference roofline
- need to decide where to benchmark