Updated Scaled_mm to support more scaling formats via CuBlas #153555
Labels
Blackwell
Specific failures or issues related to sm100 + Cuda arches
enhancement
Not as big of a feature, but technically not a bug. Should be easy to fix
module: cuda
Related to torch.cuda, and CUDA support in general
module: float8
For torch.float8_e5m2 and torch.float8_e4m3
topic: performance
topic category
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Uh oh!
There was an error while loading. Please reload this page.
Summary
In Cuda 12.9 cublas released support for an expanded set of scaling strategies besides just per-tensor: https://developer.nvidia.com/blog/boosting-matrix-multiplication-speed-and-flexibility-with-nvidia-cublas-12-9/
Currently on Cuda:
SM89
_scaled_mm
dispatches to one of 2 backends on H100:H100
_scaled_mm
dispatches to one of 2 backends on H100:B200
_scaled_mm
dispatches to one of 2 backends on H100:We should add new cublas bindings to enable this more performant code path.
Blockers
We ideally would remove the cutlass templates since Cublas claims appear to be universally more performant. The main blocker is that we would lose support for SM89 hardware
We don't currently ship a prebuilt version of PyTorch for 12.9
cc @ptrblck @msaroufim @eqy @jerryzh168 @yanbing-j @vkuzo @albanD @kadeng @penguinwu @ngimel, @lw
The text was updated successfully, but these errors were encountered: