8000 [Inductor] Investigate computing global amaxes via atomics (instead of a reduction based approach) in triton codgen · Issue #153103 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Inductor] Investigate computing global amaxes via atomics (instead of a reduction based approach) in triton codgen #153103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
danielvegamyhre opened this issue May 7, 2025 · 1 comment
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@danielvegamyhre
Copy link
Contributor
danielvegamyhre commented May 7, 2025

Summary

Tensorwise or rowwise amax values are used to compute scaling factors in float8 quantization. Computing these values in a performant way is critical for float8 training with dynamic quantization, where we are dynamically scaling the tensors at runtime in forward/backward.

Currently inductor codegen uses a reduction based approach to compute global amaxes. Benchmarking has shown atomics have outperformed the reduction based approach. We should investigate computing global amax via atomics (instead of a reduction based approach) in triton codgen.

  • tlparse link w/ example kernels showing reduction based approach

Additional context

In float8 training we compute tensorwise amax or rowwise amaxes as part of the computation of the float8 scaling factor(s): code

Currently the inductor codegen produces triton kernels which compute these amaxes in 2 separate kernels:

  1. Compute block local amaxes
    • The first kernel reads input tensor blocks from HBM and compute block local amaxes, write them back out to a temporary buffer in HBM.
  2. Compute global amaxes by reducing block local amaxes
    • The second kernel loads the temporary buffer back from HBM into SRAM and reduces it to compute the global amaxes. These are either written back out to HBM, or used to compute scales which are then written to HBM, depending on what fusion decision inductor makes.

This process can be optimized by using atomics to compute the global amaxes in a single kernel, reducing the amount of data movement between HBM and SRAM.

I actually have implemented and benchmarked both approaches using triton kernels I handwrote for the float8nocompile project (eager mode float8 training with improved perf via handwritten triton kernels), so I have some additional context here.

Microbenchmarking of dynamic float8 quantization implementations using these different approaches showed atomics substantially outperformed both torch.compile codegen as well as the handwritten reduction kernels (see benchmarks below). We should investigate using atomics in inductor codegen, to improve dynamic float8 quantization perf.

Image

As you can see the handwritten reduction kernels were not as optimal as the inductor codgen ones, but the atomics based kernels still outperformed them both.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

@masnesral masnesral added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 8, 2025
@danielvegamyhre
Copy link
Contributor Author

cc @eellison

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants
0