[Inductor] Investigate computing global amaxes via atomics (instead of a reduction based approach) in triton codgen #153103
Labels
module: inductor
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Summary
Tensorwise or rowwise amax values are used to compute scaling factors in float8 quantization. Computing these values in a performant way is critical for float8 training with dynamic quantization, where we are dynamically scaling the tensors at runtime in forward/backward.
Currently inductor codegen uses a reduction based approach to compute global amaxes. Benchmarking has shown atomics have outperformed the reduction based approach. We should investigate computing global amax via atomics (instead of a reduction based approach) in triton codgen.
Additional context
In float8 training we compute tensorwise amax or rowwise amaxes as part of the computation of the float8 scaling factor(s): code
Currently the inductor codegen produces triton kernels which compute these amaxes in 2 separate kernels:
This process can be optimized by using atomics to compute the global amaxes in a single kernel, reducing the amount of data movement between HBM and SRAM.
I actually have implemented and benchmarked both approaches using triton kernels I handwrote for the float8nocompile project (eager mode float8 training with improved perf via handwritten triton kernels), so I have some additional context here.
Microbenchmarking of dynamic float8 quantization implementations using these different approaches showed atomics substantially outperformed both
torch.compile
codegen as well as the handwritten reduction kernels (see benchmarks below). We should investigate using atomics in inductor codegen, to improve dynamic float8 quantization perf.As you can see the handwritten reduction kernels were not as optimal as the inductor codgen ones, but the atomics based kernels still outperformed them both.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov
The text was updated successfully, but these errors were encountered: