8000 [inductor] Improve codegen for argmax+max · Issue #146643 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[inductor] Improve codegen for argmax+max #146643

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
jansel opened this issue Feb 6, 2025 · 9 comments
Open

[inductor] Improve codegen for argmax+max #146643

jansel opened this issue Feb 6, 2025 · 9 comments
Assignees
Labels
internal ramp-up task Tasks that are suitable for new folks w/ high-touch guidance from senior PyTorch folks module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jansel
Copy link
Contributor
jansel commented Feb 6, 2025
@torch.compile(dynamic=True)
def fn(x):
    return torch.max(x, -1)

generates the following code:

@triton.jit
def triton_red_fused_max_0(in_ptr0, out_ptr0, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    x0 = xindex
    _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
    _tmp4 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
    _tmp4_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64)
    for r0_offset in range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_1 = r0_index
        tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), xmask & r0_mask, eviction_policy='evict_first', other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
        tmp3 = triton_helpers.maximum(_tmp2, tmp1)
        _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2)
        _tmp4_next, _tmp4_index_next = triton_helpers.maximum_with_index(
            _tmp4, _tmp4_index, tmp1, rindex
        )
        _tmp4 = tl.where(r0_mask & xmask, _tmp4_next, _tmp4)
        _tmp4_index = tl.where(r0_mask & xmask, _tmp4_index_next, _tmp4_index)
    tmp2 = triton_helpers.max2(_tmp2, 1)[:, None]
    tmp4_val, tmp4_idx = triton_helpers.max_with_index(_tmp4, _tmp4_index, 1)
    tmp4 = tmp4_idx[:, None]
    tl.store(out_ptr0 + (x0), tmp2, xmask)
    tl.store(out_ptr1 + (x0), tmp4, xmask)

This could could be improved by doing:

diff --git a/out.py b/out.py
index 5d0acd594f7..5c3879867ed 100644
--- a/out.py
+++ b/out.py
@@ -8,7 +8,6 @@ def triton_red_fused_max_0(in_ptr0, out_ptr0, out_ptr1, ks0, xnumel, r0_numel, X
     r0_base = tl.arange(0, R0_BLOCK)[None, :]
     rbase = r0_base
     x0 = xindex
-    _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
     _tmp4 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
     _tmp4_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64)
     for r0_offset in range(0, r0_numel, R0_BLOCK):
@@ -19,15 +18,13 @@ def triton_red_fused_max_0(in_ptr0, out_ptr0, out_ptr1, ks0, xnumel, r0_numel, X
         r0_1 = r0_index
         tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), xmask & r0_mask, eviction_policy='evict_first', other=0.0)
         tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
-        tmp3 = triton_helpers.maximum(_tmp2, tmp1)
-        _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2)
         _tmp4_next, _tmp4_index_next = triton_helpers.maximum_with_index(
             _tmp4, _tmp4_index, tmp1, rindex
         )
         _tmp4 = tl.where(r0_mask & xmask, _tmp4_next, _tmp4)
         _tmp4_index = tl.where(r0_mask & xmask, _tmp4_index_next, _tmp4_index)
-    tmp2 = triton_helpers.max2(_tmp2, 1)[:, None]
     tmp4_val, tmp4_idx = triton_helpers.max_with_index(_tmp4, _tmp4_index, 1)
     tmp4 = tmp4_idx[:, None]
+    tmp2 = tmp4_val[:, None]
     tl.store(out_ptr0 + (x0), tmp2, xmask)
     tl.store(out_ptr1 + (x0), tmp4, xmask)

because the argmax already compute the amax, so we don't need a separate reduction.

We could either:

  1. Have a single two-output reduction op that does both amax+argmax
  2. Combining the two at codegen time using the reduction cache. (We could use a DeferredLine, to swap between triton_helpers.max_with_index and triton_helpers.max2 based on if the output is used.)

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @desertfire @aakhundov

@mikaylagawarecki mikaylagawarecki added oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 6, 2025
@williamwen42 williamwen42 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: inductor labels Feb 7, 2025
@eellison eellison added the internal ramp-up task Tasks that are suitable for new folks w/ high-touch guidance from senior PyTorch folks label Mar 27, 2025
@vulkomilev
Copy link

I will try to fix this one

@vulkomilev
Copy link
vulkomilev commented Apr 28, 2025
def triton_red_fused_max_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    xnumel = 2
    r0_numel = 5000
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    x0 = xindex
    _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
    for r0_offset in range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_1 = r0_index
        tmp0 = tl.load(in_ptr0 + (r0_1 + 5000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
        tmp3 = triton_helpers.maximum(_tmp2, tmp1)
        _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2)
    tmp2 = triton_helpers.max2(_tmp2, 1)[:, None]
    tl.store(out_ptr0 + (x0), tmp2, xmask)

this is the current code generation . How I can confirm that the problem is still present ? Like the wrapper code for the original function + commands for code generation

@jansel
Copy link
Contributor Author
jansel commented Apr 29, 2025

It looks like the code you showed it only computing max, not max+argmax. I'd expect two outputs (two calls to tl.store) not one.

@vulkomilev
Copy link
vulkomilev commented Apr 30, 2025

@jansel Thanks for the quick response
I am using the following code to generate the code listed below .Am I doing something wrong ?

import torch
@torch.compile(dynamic=True)
def fn(x):
    return torch.max(x, -1)
new_fn = torch.compile(fn)
input_tensor = torch.randn(10001).to(device="cuda:0")
a = new_fn(input_tensor)

TORCH_COMPILE_DEBUG=1 python main3.py

def triton_red_fused_max_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    xnumel = 2
    r0_numel = 5001
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    x0 = xindex
    _tmp5 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
    for r0_offset in range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_1 = r0_index
        tmp0 = r0_1 + 5001*x0
        tmp1 = tl.full([1, 1], 10001, tl.int32)
        tmp2 = tmp0 < tmp1
        tmp3 = tl.load(in_ptr0 + (r0_1 + 5001*x0), xmask & r0_mask & tmp2, eviction_policy='evict_first', other=float("-inf"))
        tmp4 = tl.broadcast_to(tmp3, [XBLOCK, R0_BLOCK])
        tmp6 = triton_helpers.maximum(_tmp5, tmp4)
        _tmp5 = tl.where(r0_mask & xmask, tmp6, _tmp5)
    tmp5 = triton_helpers.max2(_tmp5, 1)[:, None]
    tl.store(out_ptr0 + (x0), tmp5, xmask)
''', device_str='cuda')

@vulkomilev
Copy link

I have trouble with formatting the code

@jansel
Copy link
Contributor Author
jansel commented Apr 30, 2025

I think you need to pass a 2D tensor as input.

@vulkomilev
Copy link
def triton_per_fused_max_0(in_ptr0, out_ptr0, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 31
    r0_numel = 10
    R0_BLOCK: tl.constexpr = 16
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = r0_index < r0_numel
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r0_1 + 10*x0), xmask & r0_mask, other=0.0)
    tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
    tmp3 = tl.where(r0_mask & xmask, tmp1, float("-inf"))
    tmp4 = triton_helpers.max2(tmp3, 1)[:, None]
    tmp6 = tl.broadcast_to(rindex, tmp3.shape)
    tmp5_val, tmp5_idx = triton_helpers.max_with_index(tmp3, tmp6, 1)
    tmp5 = tmp5_idx[:, None]
    tmp5 = tmp5.to(tl.float64)
    tl.store(out_ptr0 + (x0), tmp4, xmask)
    tl.store(out_ptr1 + (x0), tmp5, xmask

it worked working on it

@vulkomilev
Copy link

Sorry for writing again but where I can find a documentation of how pytorch works in depth so I can fix this. I have searched but I have found only user guides

@jansel
Copy link
Contributor Author
jansel commented May 10, 2025

I think the easiest way to do this would be by using the reduction_cache:

cache_key = (src_dtype, reduction_type, value)
if cache_key in self.cse.reduction_cache:
return self.cse.reduction_cache[cache_key]

Right now, you will see two entries:

  • (float32, "argmax", "tmp3")
  • (float32, "max", "tmp3")

You could combine these into a single entry:

  • (float32, "max_argmax", "tmp3")

Then select the correct element out of the tuple based on what you need. argmax already compute the max, it is just unused.

You should be able to get something working with some smallish changes to that def reduction() function.

The next step after the above works would be to cleanup the codegen for amax so it doesn't compute an unnessary argmax (if the argmax isn't needed). One way to do that would be to create a new type of DefferedLineBase:

class DeferredLineBase:

This would be something like:

if argmax_was_used:
   print line to compute max+argmax
else:
   print line to compute max

This allows you to create the cache entry before you know if the argmax is needed, then swap out the right reduction at the end.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
internal ramp-up task Tasks that are suitable for new folks w/ high-touch guidance from senior PyTorch folks module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants
0