8000 Ensure block analysis only matches integer dims and strides · pytorch/pytorch@fd6cbc3 · GitHub
[go: up one dir, main page]

Skip to content

Commit fd6cbc3

Browse files
committed
Ensure block analysis only matches integer dims and strides
1 parent 00a2c68 commit fd6cbc3

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

test/inductor/test_torchinductor_strided_blocks.py

+25
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,31 @@ def test_pointwise_index_order(self):
10031003
xindex = xoffset + tl.arange(0, XBLOCK)[None, None, :]""", # noqa: B950
10041004
)
10051005

1006+
# Integration test to ensure that matched dims & strides from match_mod_div_expr
1007+
# are nonnegative integers. This test case has the following index
1008+
# index=(ModularIndexing(xindex, 4, 4)) + 4*(ModularIndexing(xindex, 32, 2))
1009+
# and the match below is a candidate that is invalid:
1010+
# match={
1011+
# dim_mod4_: 32, dim_mod3_: 2, stride_mod3_: 4, dim_mod2_: 1/16,
1012+
# dim_mod1_: 4, stride_mod1_: 1, stride_mod4_: 0, stride_mod2_: 0, stride_mod0_: 0
1013+
# }
1014+
# This is now fixed by ensuring that that wild symbols only match nonnegative integers
1015+
def test_ensure_integral_dims_and_strides(self):
1016+
def model(data, *args):
1017+
return torch.nn.functional.unfold(data, *args)
1018+
1019+
data = torch.zeros([2, 3, 5, 5], dtype=torch.float16, requires_grad=True)
1020+
args = [2, 1, 0, 1]
1021+
run_and_compare(
1022+
self,
1023+
model,
1024+
data,
1025+
*args,
1026+
expected_num_triton_kernels=2,
1027+
expected_num_block_pointers=4,
1028+
compile_kwargs={"fullgraph": True},
1029+
)
1030+
10061031

10071032
@unittest.skipIf(not TRITON_HAS_CPU, "requires triton CPU backend")
10081033
@config.patch(cpu_backend="triton")

torch/_inductor/codegen/block_analysis.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@ def match_mod_div_block_expr(
6363
index = cls._preprocess(index)
6464

6565
# Pattern match to find the strides and offset.
66-
wild = functools.partial(sympy.Wild, exclude=[index_var])
66+
wild = functools.partial(
67+
sympy.Wild,
68+
exclude=[index_var],
69+
properties=[lambda x: x.is_integer and x.is_nonnegative],
70+
)
6771
dims: list[Expr] = [wild(f"dim_mod{idx}") for idx in range(num_dims)]
6872
strides: list[Expr] = [wild(f"stride_mod{idx}") for idx in range(num_dims)]
6973

@@ -167,7 +171,11 @@ def match_affine_block_expr(
167171
stride.
168172
"""
169173
index = cls._preprocess(index)
170-
stride = sympy.Wild("stride", exclude=[index_var])
174+
stride = sympy.Wild(
175+
"stride",
176+
exclude=[index_var],
177+
properties=[lambda x: x.is_integer and x.is_nonnegative],
178+
)
171179
m = index.match(index_var * stride)
172180
if m is None:
173181
return None

0 commit comments

Comments
 (0)
0