-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[Inductor] Restrict block analysis to only match integer dims and strides #149615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Inductor] Restrict block analysis to only match integer dims and strides #149615
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149615
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New Failures, 7 Unrelated FailuresAs of commit fb3bce7 with merge base 98bd2bd ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
9f49c4b
to
fd6cbc3
Compare
@pytorchbot label "topic: not user facing" |
Similar here - @blaine-rister would you midn reviewing this one ? |
# dim_mod1_: 4, stride_mod1_: 1, stride_mod4_: 0, stride_mod2_: 0, stride_mod0_: 0 | ||
# } | ||
# This is now fixed by ensuring that that wild symbols only match nonnegative integers | ||
def test_ensure_integral_dims_and_strides(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the test case! This is a good one.
@@ -167,7 +171,11 @@ def match_affine_block_expr( | |||
stride. | |||
""" | |||
index = cls._preprocess(index) | |||
stride = sympy.Wild("stride", exclude=[index_var]) | |||
stride = sympy.Wild( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it might make sense to use a helper function here, since these properties are non-trivial. Something like this might work:
class BlockPatternMatcher:
_indexing_wild = functools.partial(sympy.Wild, properties=[x.is_integer])
wild = functools.partial( | ||
sympy.Wild, | ||
exclude=[index_var], | ||
properties=[lambda x: x.is_integer and x.is_nonnegative], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch--strides should be integers.
Do we also need them to be positive, or is that too restrictive? I'm not sure how Inductor will handle this under the hood, but at least in normal Python we can have negative strides. This would be a good test case:
import torch
def foo(x, y):
return x[:-1:] + y # Slice in reverse order via a negative stride
x, y = (torch.randn(8) for _ in range(2))
foo(x, y)
# TODO test with torch.compile + block pointers...
For dims, I agree that they need to be positive integers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes you're right, although support for generating negative strides seems limited to functions like torch.flip
(see #59786). I've changed the code to include negative values for strides.
import torch
from torch._inductor import config
# Matches the following:
# index_relative_to_xyr_index = -256*((xindex//64)) - (ModularIndexing(xindex, 1, 8)) - 16*(ModularIndexing(xindex, 8, 8)) + 1911
# subexpr = -256*((xindex//64)) - (ModularIndexing(xindex, 1, 8)) - 16*(ModularIndexing(xindex, 8, 8))
# BlockParameters(shape=[8, 8, 8], block_shape=[
# ((XBLOCK + 63)//64),
# Min(8, ((XBLOCK + 7)//8)),
# Min(8, XBLOCK)
# ], strides=[-256, -16, -1],
# offsets=[(xoffset//64), ModularIndexing(xoffset, 8, 8),
# ModularIndexing(xoffset, 1, 8)])
# constant_offset = 1911
def fn(x, y):
return torch.flip(x, [0, 1, 2]) + y # Slice in reverse order via a negative stride
def discontiguous_tensor(view_size, device="cpu"):
full_size = tuple(2 * dim for dim in view_size)
full = torch.randn(full_size).to(device)
view = torch.as_strided(full, view_size, full.stride())
return view
x, y = (discontiguous_tensor((8, 8, 8)) for _ in range(2))
result_eager = fn(x, y)
config.triton.use_block_ptr= True
foo_compile = torch.compile(fn)
result_compile = foo_compile(x, y)
torch.testing.assert_close(result_eager, result_compile)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice find! It wouldn't hurt to add this test case to the CI as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good fix and test case. I'm requesting follow-up on the issue of negative strides. Happy to approve once that's cleared up.
fd6cbc3
to
89aa333
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I left a comment about possibly adding a test case for negative strides prior to merging.
Restrict block analysis to only match dimension sizes and strides that are integers. E.g.
sympy
can match index expressions likeModularIndexing(xindex, 4, 4)) + 4*(ModularIndexing(xindex, 32, 2))
with the candidate below that is invalid.Fixes #ISSUE_NUMBER
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov