8000 can we make inductor create faster fusions for tiled reductions across dim0 and dim1? · Issue #148682 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

can we make inductor create faster fusions for tiled reductions across dim0 and dim1? #148682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
vkuzo opened this issue Mar 6, 2025 · 2 comments
Assignees
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vkuzo
Copy link
Contributor
vkuzo commented Mar 6, 2025

🐛 Describe the bug

Can we make fusions of reductions across Mx1 and 1xM tiles fast in inductor? The key use case for this is scaling tensors to MX across both dim0 and dim1 at the same time, which is important for microscaling (MX) training. Here is an example snippet to demonstrate a simplified version of the pattern:

def scale_dim0_dim1_reference(x_hp: torch.Tensor, block_size) -> Tuple[torch.Tensor, torch.Tensor]:

    # normalize across dim0
    x_hp_d0_block = x_hp.reshape(-1, block_size)
    x_hp_d0_block_abs = x_hp_d0_block.abs()
    amax_dim0 = torch.amax(x_hp_d0_block_abs, dim=1).unsqueeze(1)
    x_hp_d0_block_normalized = x_hp_d0_block / amax_dim0
    x_hp_d0_normalized = x_hp_d0_block_normalized.reshape(x_hp.shape)

    # normalize across dim1
    x_hp_d1 = x_hp.t().contiguous()
    x_hp_d1_block = x_hp_d1.reshape(-1, block_size)
    x_hp_d1_block_abs = x_hp_d1_block.abs()
    amax_dim1 = torch.amax(x_hp_d1_block_abs, dim=1).unsqueeze(1)
    x_hp_d1_block_normalized = x_hp_d1_block / amax_dim1
    x_hp_d1_normalized = x_hp_d1_block_normalized.reshape(x_hp_d1.shape)

    return x_hp_d0_normalized, x_hp_d1_normalized.t(), amax_dim0, amax_dim1

In the code above, we

  • start with a tensor and a block size (32 for MX)
  • for dim0, partition the tensor into chunks of block_size, normalize by the max absolute value in each block, and write out a normalized tensor and the scales used for normalization
  • for dim1, repeat ^

Note: in the "real" use case for MX (pytorch/ao#1788), the differences are:

  1. instead of "calculate max absolute value", we will do "calculate MX e8m0 scale"
  2. instead of "write out a normalized tensor", we will do "write out a normalized low precision tensor"
  3. we also need to swizzle the scales, but that can be done separately from this issue

When I run torch.compile on the above example kernel today, I see two kernels - one for each dim (example logs: https://gist.github.com/vkuzo/7bfd4e23411f22fc25f94323bcd93794)

Claude and I wrote a triton kernel to load the input data in tiles and do the normalization across dim0 and dim1 inline: https://gist.github.com/vkuzo/a7374b1f1f5eabff4a6d774972248c22 / https://github.com/vkuzo/pytorch_scripts/blob/6c26861f2a7d0d31930006b63e538d56026b8aba/mx_cast_poc/20250305_mx_dim0_dim1_cast.py). It seems to be up to 2x faster than the current torch.compile behavior with tile size 32, and up to 4x if we increase tile size to 128, so a faster kernel is definitely possible. Note that the triton kernel I linked here currently just normalizes across tile_size, it would need to be updated to normalize by inner_tile_size if inner_tile_size != outer_tile_size.

Output of comparison of torch.compile vs triton kernel across a couple of block sizes:

(pytorch) [vasiliy@devgpu023.atn1 ~/local/pytorch_scripts/mx_cast_poc (20250305_mx_max_dim0_dim1)]$ python 20250305_mx_dim0_dim1_cast.py --M 8192 -K 4096 --BLOCK_SIZE 32
M 8192 K 4096 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.7.0a0+gitd518490
triton version: 3.2.0
bf16 vs normalized reference sqnrs: dim0 57.75, dim1 57.75
normalized reference vs normalized triton are bitwise equivalent
time_reference_compile_us 182.9869150943399
time_triton_us 94.03462108262123
speedup 1.9459525969011233

(pytorch) [vasiliy@devgpu023.atn1 ~/local/pytorch_scripts/mx_cast_poc (20250305_mx_max_dim0_dim1)]$ python 20250305_mx_dim0_dim1_cast.py --M 8192 -K 4096 --BLOCK_SIZE 64
M 8192 K 4096 BLOCK_SIZE 64
GPU: NVIDIA B200
torch version: 2.7.0a0+gitd518490
triton version: 3.2.0
bf16 vs normalized reference sqnrs: dim0 56.5, dim1 56.5
normalized reference vs normalized triton are bitwise equivalent
time_reference_compile_us 183.88731220657243
time_triton_us 53.66455661375654
speedup 3.42660638249705

(pytorch) [vasiliy@devgpu023.atn1 ~/local/pytorch_scripts/mx_cast_poc (20250305_mx_max_dim0_dim1)]$ python 20250305_mx_dim0_dim1_cast.py --M 8192 -K 4096 --BLOCK_SIZE 128
M 8192 K 4096 BLOCK_SIZE 128
GPU: NVIDIA B200
torch version: 2.7.0a0+gitd518490
triton version: 3.2.0
bf16 vs normalized reference sqnrs: dim0 56.0, dim1 56.0
normalized reference vs normalized triton are bitwise equivalent
time_reference_compile_us 312.7817773722634
time_triton_us 67.17439390386868
speedup 4.656264972332706

(pytorch) [vasiliy@devgpu023.atn1 ~/local/pytorch_scripts/mx_cast_poc (20250305_mx_max_dim0_dim1)]$ python 20250305_mx_dim0_dim1_cast.py --M 8192 -K 4096 --BLOCK_SIZE 256
M 8192 K 4096 BLOCK_SIZE 256
GPU: NVIDIA B200
torch version: 2.7.0a0+gitd518490
triton version: 3.2.0
bf16 vs normalized reference sqnrs: dim0 56.25, dim1 56.25
normalized reference vs normalized triton are bitwise equivalent
time_reference_compile_us 362.2346390041493
time_triton_us 1091.8661034482752
speedup 0.33175738111125397

Can we improve this in inductor?

Versions

main branch

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @aakhundov

@vkuzo
Copy link
Contributor Author
vkuzo commented Mar 6, 2025

cc @eellison

@eellison eellison self-assigned this Mar 6, 2025
@yf225 yf225 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 8, 2025
@vkuzo
Copy link
Contributor Author
vkuzo commented Mar 25, 2025

splitting the request for a fast dim1 kernel to a dedicated issue: #149982

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants
0