8000 [Inductor] Restrict block analysis to only match integer dims and strides by kundaMwiza · Pull Request #149615 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Inductor] Restrict block analysis to only match integer dims and strides #149615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

kundaMwiza
Copy link
Contributor
@kundaMwiza kundaMwiza commented Mar 20, 2025

Restrict block analysis to only match dimension sizes and strides that are integers. E.g. sympy can match index expressions like ModularIndexing(xindex, 4, 4)) + 4*(ModularIndexing(xindex, 32, 2)) with the candidate below that is invalid.

match_expr = stride_mod0_*((xindex//(dim_mod1_*dim_mod2_*dim_mod3_*dim_mod4_))) + stride_mod1_*(ModularIndexing(xindex, dim_mod2_*dim_mod3_*dim_mod4_, dim_mod1_)) + stride_mod2_*(ModularIndexing(xindex, dim_mod3_*dim_mod4_, dim_mod2_)) + stride_mod3_*(ModularIndexing(xindex, dim_mod4_, dim_mod3_)) + stride_mod4_*(ModularIndexing(xindex, 1, dim_mod4_))
match={
    dim_mod4_: 32, dim_mod3_: 2, stride_mod3_: 4, dim_mod2_: 1/16,
     dim_mod1_: 4, stride_mod1_: 1, stride_mod4_: 0, stride_mod2_: 0, stride_mod0_: 0
   }

Fixes #ISSUE_NUMBER

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Copy link
pytorch-bot bot commented Mar 20, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149615

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 7 Unrelated Failures

As of commit fb3bce7 with merge base 98bd2bd (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@kundaMwiza kundaMwiza force-pushed the mwizak/restrict-block-ptr-dims-strides-integer branch from 9f49c4b to fd6cbc3 Compare March 20, 2025 10:48
@kundaMwiza
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Mar 20, 2025
@bdhirsh bdhirsh requested review from eellison and shunting314 March 24, 2025 14:30
@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 24, 2025
@eellison
Copy link
Contributor

Similar here - @blaine-rister would you midn reviewing this one ?

@eellison eellison requested a review from blaine-rister April 15, 2025 19:19
@eellison eellison removed their request for review April 22, 2025 19:25
# dim_mod1_: 4, stride_mod1_: 1, stride_mod4_: 0, stride_mod2_: 0, stride_mod0_: 0
# }
# This is now fixed by ensuring that that wild symbols only match nonnegative integers
def test_ensure_integral_dims_and_strides(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the test case! This is a good one.

@@ -167,7 +171,11 @@ def match_affine_block_expr(
stride.
"""
index = cls._preprocess(index)
stride = sympy.Wild("stride", exclude=[index_var])
stride = sympy.Wild(
Copy link
Contributor
@blaine-rister blaine-rister Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it might make sense to use a helper function here, since these properties are non-trivial. Something like this might work:

class BlockPatternMatcher:
     _indexing_wild = functools.partial(sympy.Wild, properties=[x.is_integer])

wild = functools.partial(
sympy.Wild,
exclude=[index_var],
properties=[lambda x: x.is_integer and x.is_nonnegative],
Copy link
Contributor
@blaine-rister blaine-rister Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch--strides should be integers.

Do we also need them to be positive, or is that too restrictive? I'm not sure how Inductor will handle this under the hood, but at least in normal Python we can have negative strides. This would be a good test case:

import torch

def foo(x, y):
     return x[:-1:] + y # Slice in reverse order via a negative stride

x, y = (torch.randn(8) for _ in range(2))
foo(x, y)

# TODO test with torch.compile + block pointers...

For dims, I agree that they need to be positive integers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you're right, although support for generating negative strides seems limited to functions like torch.flip (see #59786). I've changed the code to include negative values for strides.

import torch
from torch._inductor import config

# Matches the following:
# index_relative_to_xyr_index = -256*((xindex//64)) - (ModularIndexing(xindex, 1, 8)) - 16*(ModularIndexing(xindex, 8, 8)) + 1911
# subexpr = -256*((xindex//64)) - (ModularIndexing(xindex, 1, 8)) - 16*(ModularIndexing(xindex, 8, 8))
# BlockParameters(shape=[8, 8, 8], block_shape=[
#           ((XBLOCK + 63)//64), 
#           Min(8, ((XBLOCK + 7)//8)),
#            Min(8, XBLOCK)
# ], strides=[-256, -16, -1],
#  offsets=[(xoffset//64), ModularIndexing(xoffset, 8, 8),
#  ModularIndexing(xoffset, 1, 8)])
# constant_offset = 1911
def fn(x, y):
     return torch.flip(x, [0, 1, 2]) + y # Slice in reverse order via a negative stride

def discontiguous_tensor(view_size, device="cpu"):
   full_size = tuple(2 * dim for dim in view_size)
   full = torch.randn(full_size).to(device)
   view = torch.as_strided(full, view_size, full.stride())
   return view

x, y = (discontiguous_tensor((8, 8, 8)) for _ in range(2))
result_eager = fn(x, y)

config.triton.use_block_ptr= True
foo_compile = torch.compile(fn)
result_compile = foo_compile(x, y)
torch.testing.assert_close(result_eager, result_compile)

Copy link
Contributor
@blaine-rister blaine-rister May 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice find! It wouldn't hurt to add this test case to the CI as well.

Copy link
Contributor
@blaine-rister blaine-rister left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good fix and test case. I'm requesting follow-up on the issue of negative strides. Happy to approve once that's cleared up.

@kundaMwiza kundaMwiza force-pushed the mwizak/restrict-block-ptr-dims-strides-integer branch from fd6cbc3 to 89aa333 Compare April 28, 2025 20:30
Copy link
Contributor
@blaine-rister blaine-rister left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I left a comment about possibly adding a test case for negative strides prior to merging.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
0