8000 [Inductor] Restrict block analysis to only match integer dims and str… · pytorch/pytorch@ce97a5d · GitHub
[go: up one dir, main page]

Skip to content

Commit ce97a5d

Browse files
kundaMwizapytorchmergebot
authored andcommitted
[Inductor] Restrict block analysis to only match integer dims and strides (#149615)
Restrict block analysis to only match dimension sizes and strides that are integers. E.g. `sympy` can match index expressions like `ModularIndexing(xindex, 4, 4)) + 4*(ModularIndexing(xindex, 32, 2))` with the candidate below that is invalid. ```python match_expr = stride_mod0_*((xindex//(dim_mod1_*dim_mod2_*dim_mod3_*dim_mod4_))) + stride_mod1_*(ModularIndexing(xindex, dim_mod2_*dim_mod3_*dim_mod4_, dim_mod1_)) + stride_mod2_*(ModularIndexing(xindex, dim_mod3_*dim_mod4_, dim_mod2_)) + stride_mod3_*(ModularIndexing(xindex, dim_mod4_, dim_mod3_)) + stride_mod4_*(ModularIndexing(xindex, 1, dim_mod4_)) match={ dim_mod4_: 32, dim_mod3_: 2, stride_mod3_: 4, dim_mod2_: 1/16, dim_mod1_: 4, stride_mod1_: 1, stride_mod4_: 0, stride_mod2_: 0, stride_mod0_: 0 } ``` Pull Request resolved: #149615 Approved by: https://github.com/blaine-rister
1 parent c48d0f4 commit ce97a5d

File tree

2 files changed

+80
-5
lines changed

2 files changed

+80
-5
lines changed

test/inductor/test_torchinductor_strided_blocks.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,64 @@ def foo(x, y, z):
11461146
# Singleton splits should be discarded.
11471147
self._assert_pointwise_ndims(triton_code, 2)
11481148

1149+
# Integration test to ensure that matched dims & strides from match_mod_div_expr
1150+
# are unsigned and signed integers respectively. This test case has the following
1151+
# index:=(ModularIndexing(xindex, 4, 4)) + 4*(ModularIndexing(xindex, 32, 2))
1152+
# and the match below is a candidate that is invalid:
1153+
# match={
1154+
# dim_mod4_: 32, dim_mod3_: 2, stride_mod3_: 4, dim_mod2_: 1/16,
1155+
# dim_mod1_: 4, stride_mod1_: 1, stride_mod4_: 0, stride_mod2_: 0, stride_mod0_: 0
1156+
# }
1157+
# This is now fixed by ensuring that that wild symbols only match integers
1158+
def test_ensure_integral_dims_and_strides(self):
1159+
def model(data, *args):
1160+
return torch.nn.functional.unfold(data, *args)
1161+
1162+
data = torch.zeros(
1163+
[2, 3, 5, 5], dtype=torch.float16, requires_grad=True, device=self.device
1164+
)
1165+
args = [2, 1, 0, 1]
1166+
run_and_compare(
1167+
self,
1168+
model,
1169+
data,
1170+
*args,
1171+
expected_num_triton_kernels=2,
1172+
expected_num_block_pointers=4,
1173+
compile_kwargs={"fullgraph": True},
1174+
)
1175+
1176+
# Integration test to test block analysis with index expressions using
1177+
# negative strides.
1178+
# This test case has the following index:
1179+
# index_relative_to_xyr_index = -256*((xindex//64)) - (ModularIndexing(xindex, 1, 8))
1180+
# - 16*(ModularIndexing(xindex, 8, 8)) + 1911
1181+
# subexpr = -256*((xindex//64)) - (ModularIndexing(xindex, 1, 8)) - 16*(ModularIndexing(xindex, 8, 8))
1182+
# Block analysis should produce the following:
1183+
# BlockParameters(
1184+
# shape=[8, 8, 8],
1185+
# block_shape=[((XBLOCK + 63)//64), Min(8, ((XBLOCK + 7)//8)), Min(8, XBLOCK) ],
1186+
# strides=[-256, -16, -1],
1187+
# offsets=[(xoffset//64), ModularIndexing(xoffset, 8, 8), ModularIndexing(xoffset, 1, 8)]
1188+
# )
1189+
# constant_offset = 1911
1190+
def test_negative_strides(self):
1191+
def model(x, y):
1192+
# Slice in reverse order via a negative stride
1193+
return torch.flip(x, [0, 1, 2]) + y
1194+
1195+
x, y = (
1196+
self._discontiguous_tensor((8, 8, 8), device=self.device) for _ in range(2)
1197+
)
1198+
run_and_compare(
1199+
self,
1200+
model,
1201+
x,
1202+
y,
1203+
expected_num_triton_kernels=1,
1204+
expected_num_block_pointers=3,
1205+
)
1206+
11491207
@config.patch("triton.prefer_nd_tiling", True)
11501208
@config.patch("triton.max_tiles", 3)
11511209
@parametrize(

torch/_inductor/codegen/block_analysis.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ class BlockPatternMatcher:
1717
Matches block indexing expressions.
1818
"""
1919

20+
_indexing_wild_signed_int = functools.partial(
21+
sympy.Wild, properties=[lambda x: x.is_integer]
22+
)
23+
_indexing_wild_unsigned_int = functools.partial(
24+
sympy.Wild, properties=[lambda x: x.is_integer and x.is_nonnegative]
25+
)
26+
2027
@classmethod
2128
def get_subexpr_involving_symbol(cls, expr: Expr, symbol: Symbol) -> Expr:
2229
"""
@@ -63,9 +70,18 @@ def match_mod_div_block_expr(
6370
index = cls._preprocess(index)
6471

6572
# Pattern match to find the strides and offset.
66-
wild = functools.partial(sympy.Wild, exclude=[index_var])
67-
dims: list[Expr] = [wild(f"dim_mod{idx}") for idx in range(num_dims)]
68-
strides: list[Expr] = [wild(f"stride_mod{idx}") for idx in range(num_dims)]
73+
wild_unsigned_int = functools.partial(
74+
cls._indexing_wild_unsigned_int, exclude=[index_var]
75+
)
76+
wild_signed_int = functools.partial(
77+
cls._indexing_wild_signed_int, exclude=[index_var]
78+
)
79+
dims: list[Expr] = [
80+
wild_unsigned_int(f"dim_mod{idx}") for idx in range(num_dims)
81+
]
82+
strides: list[Expr] = [
83+
wild_signed_int(f"stride_mod{idx}") for idx in range(num_dims)
84+
]
6985

7086
# The first dimension's index is computed by division.
7187
# The remaining are computed by modulo.
@@ -83,7 +99,8 @@ def match_mod_div_block_expr(
8399
# for more details. In short, here we check that each subexpression in sympy.Add contains
84100
# only FloorDiv or ModularIndexing expressions.
85101
if num_dims >= 5:
86-
stride, denom, other = sympy.symbols("stride denominator other", cls=wild)
102+
stride = sympy.symbols("stride", cls=wild_signed_int)
103+
denom, other = sympy.symbols("denominator other", cls=wild_unsigned_int)
87104
mod_div_pattern = stride * ModularIndexing(index_var, denom, other)
88105
floor_div_pattern = stride * FloorDiv(index_var, denom)
89106
first_dim_floor_div_matched = False
@@ -167,7 +184,7 @@ def match_affine_block_expr(
167184
stride.
168185
"""
169186
index = cls._preprocess(index)
170-
stride = sympy.Wild("stride", exclude=[index_var])
187+
stride = cls._indexing_wild_signed_int(name="stride", exclude=[index_var])
171188
m = index.match(index_var * stride)
172189
if m is None:
173190
return None

0 commit comments

Comments
 (0)
0