-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[inductor] Improve codegen for argmax+max #146643
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I will try to fix this one |
def triton_red_fused_max_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 2
r0_numel = 5000
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
_tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_1 + 5000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
tmp3 = triton_helpers.maximum(_tmp2, tmp1)
_tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2)
tmp2 = triton_helpers.max2(_tmp2, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp2, xmask) this is the current code generation . How I can confirm that the problem is still present ? Like the wrapper code for the original function + commands for code generation |
It looks like the code you showed it only computing max, not max+argmax. I'd expect two outputs (two calls to tl.store) not one. |
@jansel Thanks for the quick response import torch
@torch.compile(dynamic=True)
def fn(x):
return torch.max(x, -1)
new_fn = torch.compile(fn)
input_tensor = torch.randn(10001).to(device="cuda:0")
a = new_fn(input_tensor) TORCH_COMPILE_DEBUG=1 python main3.py def triton_red_fused_max_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 2
r0_numel = 5001
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
_tmp5 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = r0_1 + 5001*x0
tmp1 = tl.full([1, 1], 10001, tl.int32)
tmp2 = tmp0 < tmp1
tmp3 = tl.load(in_ptr0 + (r0_1 + 5001*x0), xmask & r0_mask & tmp2, eviction_policy='evict_first', other=float("-inf"))
tmp4 = tl.broadcast_to(tmp3, [XBLOCK, R0_BLOCK])
tmp6 = triton_helpers.maximum(_tmp5, tmp4)
_tmp5 = tl.where(r0_mask & xmask, tmp6, _tmp5)
tmp5 = triton_helpers.max2(_tmp5, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp5, xmask)
''', device_str='cuda') |
I have trouble with formatting the code |
I think you need to pass a 2D tensor as input. |
def triton_per_fused_max_0(in_ptr0, out_ptr0, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
xnumel = 31
r0_numel = 10
R0_BLOCK: tl.constexpr = 16
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_index = tl.arange(0, R0_BLOCK)[None, :]
r0_offset = 0
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r0_1 + 10*x0), xmask & r0_mask, other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
tmp3 = tl.where(r0_mask & xmask, tmp1, float("-inf"))
tmp4 = triton_helpers.max2(tmp3, 1)[:, None]
tmp6 = tl.broadcast_to(rindex, tmp3.shape)
tmp5_val, tmp5_idx = triton_helpers.max_with_index(tmp3, tmp6, 1)
tmp5 = tmp5_idx[:, None]
tmp5 = tmp5.to(tl.float64)
tl.store(out_ptr0 + (x0), tmp4, xmask)
tl.store(out_ptr1 + (x0), tmp5, xmask it worked working on it |
Sorry for writing again but where I can find a documentation of how pytorch works in depth so I can fix this. I have searched but I have found only user guides |
I think the easiest way to do this would be by using the reduction_cache: pytorch/torch/_inductor/codegen/triton.py Lines 2528 to 2530 in e4f2282
Right now, you will see two entries:
You could combine these into a single entry:
Then select the correct element out of the tuple based on what you need. You should be able to get something working with some smallish changes to that The next step after the above works would be to cleanup the codegen for pytorch/torch/_inductor/utils.py Line 1340 in e4f2282
This would be something like: if argmax_was_used:
print line to compute max+argmax
else:
print line to compute max This allows you to create the cache entry before you know if the argmax is needed, then swap out the right reduction at the end. |
generates the following code:
This could could be improved by doing:
because the
argmax
already compute theamax
, so we don't need a separate reduction.We could either:
amax+argmax
triton_helpers.max_with_index
andtriton_helpers.max2
based on if the output is used.)cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @desertfire @aakhundov
The text was updated successfully, but these errors were encountered: