-
Notifications
You must be signed in to change notification settings - Fork 24.6k
[Inductor] Restrict block analysis to only match integer dims and strides #149615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor] Restrict block analysis to only match integer dims and strides #149615
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149615
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 312d18e with merge base 4cd6e96 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
9f49c4b
to
fd6cbc3
Compare
@pytorchbot label "topic: not user facing" |
Similar here - @blaine-rister would you midn reviewing this one ? |
# dim_mod1_: 4, stride_mod1_: 1, stride_mod4_: 0, stride_mod2_: 0, stride_mod0_: 0 | ||
# } | ||
# This is now fixed by ensuring that that wild symbols only match nonnegative integers | ||
def test_ensure_integral_dims_and_strides(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the test case! This is a good one.
@@ -167,7 +171,11 @@ def match_affine_block_expr( | |||
stride. | |||
""" | |||
index = cls._preprocess(index) | |||
stride = sympy.Wild("stride", exclude=[index_var]) | |||
stride = sympy.Wild( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it might make sense to use a helper function here, since these properties are non-trivial. Something like this might work:
class BlockPatternMatcher:
_indexing_wild = functools.partial(sympy.Wild, properties=[x.is_integer])
wild = functools.partial( | ||
sympy.Wild, | ||
exclude=[index_var], | ||
properties=[lambda x: x.is_integer and x.is_nonnegative], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch--strides should be integers.
Do we also need them to be positive, or is that too restrictive? I'm not sure how Inductor will handle this under the hood, but at least in normal Python we can have negative strides. This would be a good test case:
import torch
def foo(x, y):
return x[:-1:] + y # Slice in reverse order via a negative stride
x, y = (torch.randn(8) for _ in range(2))
foo(x, y)
# TODO test with torch.compile + block pointers...
For dims, I agree that they need to be positive integers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes you're right, although support for generating negative strides seems limited to functions like torch.flip
(see #59786). I've changed the code to include negative values for strides.
import torch
from torch._inductor import config
# Matches the following:
# index_relative_to_xyr_index = -256*((xindex//64)) - (ModularIndexing(xindex, 1, 8)) - 16*(ModularIndexing(xindex, 8, 8)) + 1911
# subexpr = -256*((xindex//64)) - (ModularIndexing(xindex, 1, 8)) - 16*(ModularIndexing(xindex, 8, 8))
# BlockParameters(shape=[8, 8, 8], block_shape=[
# ((XBLOCK + 63)//64),
# Min(8, ((XBLOCK + 7)//8)),
# Min(8, XBLOCK)
# ], strides=[-256, -16, -1],
# offsets=[(xoffset//64), ModularIndexing(xoffset, 8, 8),
# ModularIndexing(xoffset, 1, 8)])
# constant_offset = 1911
def fn(x, y):
return torch.flip(x, [0, 1, 2]) + y # Slice in reverse order via a negative stride
def discontiguous_tensor(view_size, device="cpu"):
full_size = tuple(2 * dim for dim in view_size)
full = torch.randn(full_size).to(device)
view = torch.as_strided(full, view_size, full.stride())
return view
x, y = (discontiguous_tensor((8, 8, 8)) for _ in range(2))
result_eager = fn(x, y)
config.triton.use_block_ptr= True
foo_compile = torch.compile(fn)
result_compile = foo_compile(x, y)
torch.testing.assert_close(result_eager, result_compile)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice find! It wouldn't hurt to add this test case to the CI as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good fix and test case. I'm requesting follow-up on the issue of negative strides. Happy to approve once that's cleared up.
fd6cbc3
to
89aa333
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I left a comment about possibly adding a test case for negative strides prior to merging.
fb3bce7
to
5e09e0b
Compare
@blaine-rister I think you need to approve the workflows / merge |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: Command
Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ides (pytorch#149615) Restrict block analysis to only match dimension sizes and strides that are integers. E.g. `sympy` can match index expressions like `ModularIndexing(xindex, 4, 4)) + 4*(ModularIndexing(xindex, 32, 2))` with the candidate below that is invalid. ```python match_expr = stride_mod0_*((xindex//(dim_mod1_*dim_mod2_*dim_mod3_*dim_mod4_))) + stride_mod1_*(ModularIndexing(xindex, dim_mod2_*dim_mod3_*dim_mod4_, dim_mod1_)) + stride_mod2_*(ModularIndexing(xindex, dim_mod3_*dim_mod4_, dim_mod2_)) + stride_mod3_*(ModularIndexing(xindex, dim_mod4_, dim_mod3_)) + stride_mod4_*(ModularIndexing(xindex, 1, dim_mod4_)) match={ dim_mod4_: 32, dim_mod3_: 2, stride_mod3_: 4, dim_mod2_: 1/16, dim_mod1_: 4, stride_mod1_: 1, stride_mod4_: 0, stride_mod2_: 0, stride_mod0_: 0 } ``` Pull Request resolved: pytorch#149615 Approved by: https://github.com/blaine-rister
Restrict block analysis to only match dimension sizes and strides that are integers. E.g.
sympy
can match index expressions likeModularIndexing(xindex, 4, 4)) + 4*(ModularIndexing(xindex, 32, 2))
with the candidate below that is invalid.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov