can we make inductor create faster fusions for tiled reductions across dim0 and dim1? #148682
Labels
module: inductor
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Uh oh!
There was an error while loading. Please reload this page.
🐛 Describe the bug
Can we make fusions of reductions across Mx1 and 1xM tiles fast in inductor? The key use case for this is scaling tensors to MX across both dim0 and dim1 at the same time, which is important for microscaling (MX) training. Here is an example snippet to demonstrate a simplified version of the pattern:
In the code above, we
Note: in the "real" use case for MX (pytorch/ao#1788), the differences are:
When I run
torch.compile
on the above example kernel today, I see two kernels - one for each dim (example logs: https://gist.github.com/vkuzo/7bfd4e23411f22fc25f94323bcd93794)Claude and I wrote a triton kernel to load the input data in tiles and do the normalization across dim0 and dim1 inline: https://gist.github.com/vkuzo/a7374b1f1f5eabff4a6d774972248c22 / https://github.com/vkuzo/pytorch_scripts/blob/6c26861f2a7d0d31930006b63e538d56026b8aba/mx_cast_poc/20250305_mx_dim0_dim1_cast.py). It seems to be up to 2x faster than the current torch.compile behavior with tile size 32, and up to 4x if we increase tile size to 128, so a faster kernel is definitely possible. Note that the triton kernel I linked here currently just normalizes across
tile_size
, it would need to be updated to normalize byinner_tile_size
ifinner_tile_size != outer_tile_size
.Output of comparison of torch.compile vs triton kernel across a couple of block sizes:
Can we improve this in inductor?
Versions
main branch
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @aakhundov
The text was updated successfully, but these errors were encountered: