8000 Updated Scaled_mm to support more scaling formats via CuBlas · Issue #153555 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Updated Scaled_mm to support more scaling formats via CuBlas #153555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
drisspg opened this issue May 14, 2025 · 0 comments
Open

Updated Scaled_mm to support more scaling formats via CuBlas #153555

drisspg opened this issue May 14, 2025 · 0 comments
Labels
Blackwell Specific failures or issues related to sm100 + Cuda arches enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: cuda Related to torch.cuda, and CUDA support in general module: float8 For torch.float8_e5m2 and torch.float8_e4m3 topic: performance topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@drisspg
Copy link
Contributor
drisspg commented May 14, 2025

Summary

In Cuda 12.9 cublas released support for an expanded set of scaling strategies besides just per-tensor: https://developer.nvidia.com/blog/boosting-matrix-multiplication-speed-and-flexibility-with-nvidia-cublas-12-9/

Currently on Cuda:

SM89

_scaled_mm dispatches to one of 2 backends on H100:

  • Per-Tensor scaling -> CublasLT
  • Per-Row scaling -> RowWise Cutlass kernel
  • GroupWise Scaling -> Not supported | some support in AO
  • BlockWise Scaling -> Not supported | some support in AO

H100

_scaled_mm dispatches to one of 2 backends on H100:

  • Per-Tensor scaling -> CublasLT
  • Per-Row scaling -> RowWise Cutlass kernel
  • GroupWise Scaling -> Not supported | some support in AO
  • BlockWise Scaling -> Not supported | some support in AO

B200

_scaled_mm dispatches to one of 2 backends on H100:

  • Per-Tensor scaling -> CublasLT
  • Per-Row scaling -> RowWise Cutlass kernel (template is not optimal)
  • GroupWise Scaling -> MXFP8 BlockWise scaling is support via CublasLT
  • BlockWise Scaling -> Not supported

We should add new cublas bindings to enable this more performant code path.

Blockers

We ideally would remove the cutlass templates since Cublas claims appear to be universally more performant. The main blocker is that we would lose support for SM89 hardware

We don't currently ship a prebuilt version of PyTorch for 12.9

cc @ptrblck @msaroufim @eqy @jerryzh168 @yanbing-j @vkuzo @albanD @kadeng @penguinwu @ngimel, @lw

@drisspg drisspg added module: cuda Related to torch.cuda, and CUDA support in general topic: performance topic category Blackwell Specific failures or issues related to sm100 + Cuda arches labels May 14, 2025
@malfet malfet added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module enhancement Not as big of a fe 5C84 ature, but technically not a bug. Should be easy to fix module: float8 For torch.float8_e5m2 and torch.float8_e4m3 labels May 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Blackwell Specific failures or issues related to sm100 + Cuda arches enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: cuda Related to torch.cuda, and CUDA support in general module: float8 For torch.float8_e5m2 and torch.float8_e4m3 topic: performance topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants
0