8000 [Inductor] Restrict block analysis to only match integer dims and strides by kundaMwiza · Pull Request #149615 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Inductor] Restrict block analysis to only match integer dims and strides #149615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

kundaMwiza
Copy link
Contributor
@kundaMwiza kundaMwiza commented Mar 20, 2025

Restrict block analysis to only match dimension sizes and strides that are integers. E.g. sympy can match index expressions like ModularIndexing(xindex, 4, 4)) + 4*(ModularIndexing(xindex, 32, 2)) with the candidate below that is invalid.

match_expr = stride_mod0_*((xindex//(dim_mod1_*dim_mod2_*dim_mod3_*dim_mod4_))) + stride_mod1_*(ModularIndexing(xindex, dim_mod2_*dim_mod3_*dim_mod4_, dim_mod1_)) + stride_mod2_*(ModularIndexing(xindex, dim_mod3_*dim_mod4_, dim_mod2_)) + stride_mod3_*(ModularIndexing(xindex, dim_mod4_, dim_mod3_)) + stride_mod4_*(ModularIndexing(xindex, 1, dim_mod4_))
match={
    dim_mod4_: 32, dim_mod3_: 2, stride_mod3_: 4, dim_mod2_: 1/16,
     dim_mod1_: 4, stride_mod1_: 1, stride_mod4_: 0, stride_mod2_: 0, stride_mod0_: 0
   }

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Copy link
pytorch-bot bot commented Mar 20, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149615

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 312d18e with merge base 4cd6e96 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@kundaMwiza kundaMwiza force-pushed the mwizak/restrict-block-ptr-dims-strides-integer branch from 9f49c4b to fd6cbc3 Compare March 20, 2025 10:48
@kundaMwiza
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Mar 20, 2025
@bdhirsh bdhirsh requested review from eellison and shunting314 March 24, 2025 14:30
@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 24, 2025
@eellison
Copy link
Contributor

Similar here - @blaine-rister would you midn reviewing this one ?

@eellison eellison requested a review from blaine-rister April 15, 2025 19:19
@eellison eellison removed their request for review April 22, 2025 19:25
# dim_mod1_: 4, stride_mod1_: 1, stride_mod4_: 0, stride_mod2_: 0, stride_mod0_: 0
# }
# This is now fixed by ensuring that that wild symbols only match nonnegative integers
def test_ensure_integral_dims_and_strides(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the test case! This is a good one.

@@ -167,7 +171,11 @@ def match_affine_block_expr(
stride.
"""
index = cls._preprocess(index)
stride = sympy.Wild("stride", exclude=[index_var])
stride = sympy.Wild(
Copy link
Contributor
@blaine-rister blaine-rister Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it might make sense to use a helper function here, since these properties are non-trivial. Something like this might work:

class BlockPatternMatcher:
     _indexing_wild = functools.partial(sympy.Wild, properties=[x.is_integer])

wild = functools.partial(
sympy.Wild,
exclude=[index_var],
properties=[lambda x: x.is_integer and x.is_nonnegative],
Copy link
Contributor
@blaine-rister blaine-rister Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch--strides should be integers.

Do we also need them to be positive, or is that too restrictive? I'm not sure how Inductor will handle this under the hood, but at least in normal Python we can have negative strides. This would be a good test case:

import torch

def foo(x, y):
     return x[:-1:] + y # Slice in reverse order via a negative stride

x, y = (torch.randn(8) for _ in range(2))
foo(x, y)

# TODO test with torch.compile + block pointers...

For dims, I agree that they need to be positive integers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you're right, although support for generating negative strides seems limited to functions like torch.flip (see #59786). I've changed the code to include negative values for strides.

import torch
from torch._inductor import config

# Matches the following:
# index_relative_to_xyr_index = -256*((xindex//64)) - (ModularIndexing(xindex, 1, 8)) - 16*(ModularIndexing(xindex, 8, 8)) + 1911
# subexpr = -256*((xindex//64)) - (ModularIndexing(xindex, 1, 8)) - 16*(ModularIndexing(xindex, 8, 8))
# BlockParameters(shape=[8, 8, 8], block_shape=[
#           ((XBLOCK + 63)//64), 
#           Min(8, ((XBLOCK + 7)//8)),
#            Min(8, XBLOCK)
# ], strides=[-256, -16, -1],
#  offsets=[(xoffset//64), ModularIndexing(xoffset, 8, 8),
#  ModularIndexing(xoffset, 1, 8)])
# constant_offset = 1911
def fn(x, y):
     return torch.flip(x, [0, 1, 2]) + y # Slice in reverse order via a negative stride

def discontiguous_tensor(view_size, device="cpu"):
   full_size = tuple(2 * dim for dim in view_size)
   full = torch.randn(full_size).to(device)
   view = torch.as_strided(full, view_size, full.stride())
   return view

x, y = (discontiguous_tensor((8, 8, 8)) for _ in range(2))
result_eager = fn(x, y)

config.triton.use_block_ptr= True
foo_compile = torch.compile(fn)
result_compile = foo_compile(x, y)
torch.testing.assert_close(result_eager, result_compile)

Copy link
Contributor
@blaine-rister blaine-rister May 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice find! It wouldn't hurt to add this test case to the CI as well.

Copy link
Contributor
@blaine-rister blaine-rister left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good fix and test case. I'm requesting follow-up on the issue of negative strides. Happy to approve once that's cleared up.

@kundaMwiza kundaMwiza force-pushed the mwizak/restrict-block-ptr-dims-strides-integer branch from fd6cbc3 to 89aa333 Compare April 28, 2025 20:30
Copy link
Contributor
@blaine-rister blaine-rister left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I left a comment about possibly adding a test case for negative strides prior to merging.

@kundaMwiza kundaMwiza force-pushed the mwizak/restrict-block-ptr-dims-strides-integer branch from fb3bce7 to 5e09e0b Compare May 21, 2025 19:51
@kundaMwiza kundaMwiza requested a review from blaine-rister May 21, 2025 20:17
@kundaMwiza
Copy link
Contributor Author

@blaine-rister I think you need to approve the workflows / merge

@kundaMwiza
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 20, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch rebase origin/main returned non-zero exit code 1

Rebasing (1/1)
Auto-merging test/inductor/test_torchinductor_strided_blocks.py
CONFLICT (content): Merge conflict in test/inductor/test_torchinductor_strided_blocks.py
error: could not apply 486e358dd37... [Inductor] Restrict block analysis to only match integer dims and strides (#149615)
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Could not apply 486e358dd37... [Inductor] Restrict block analysis to only match integer dims and strides (#149615)
Details for Dev Infra team Raised by workflow job

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Jun 23, 2025
@kundaMwiza
Copy link
Contributor Author

@pytorchbot merge

Copy link
pytorch-bot bot commented Jun 23, 2025

Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra.

@kundaMwiza
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 24, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

skarjala pushed a commit to skarjala/pytorch that referenced this pull request Jun 25, 2025
…ides (pytorch#149615)

Restrict block analysis to only match dimension sizes and strides that are integers. E.g. `sympy` can match index expressions like  `ModularIndexing(xindex, 4, 4)) + 4*(ModularIndexing(xindex, 32, 2))` with the candidate below that is invalid.
  ```python
match_expr = stride_mod0_*((xindex//(dim_mod1_*dim_mod2_*dim_mod3_*dim_mod4_))) + stride_mod1_*(ModularIndexing(xindex, dim_mod2_*dim_mod3_*dim_mod4_, dim_mod1_)) + stride_mod2_*(ModularIndexing(xindex, dim_mod3_*dim_mod4_, dim_mod2_)) + stride_mod3_*(ModularIndexing(xindex, dim_mod4_, dim_mod3_)) + stride_mod4_*(ModularIndexing(xindex, 1, dim_mod4_))
match={
      dim_mod4_: 32, dim_mod3_: 2, stride_mod3_: 4, dim_mod2_: 1/16,
       dim_mod1_: 4, stride_mod1_: 1, stride_mod4_: 0, stride_mod2_: 0, stride_mod0_: 0
     }
```

Pull Request resolved: pytorch#149615
Approved by: https://github.com/blaine-rister
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants
0