diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index 19f83a35e96d..940bc24dbd12 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -1146,6 +1146,64 @@ def foo(x, y, z): # Singleton splits should be discarded. self._assert_pointwise_ndims(triton_code, 2) + # Integration test to ensure that matched dims & strides from match_mod_div_expr + # are unsigned and signed integers respectively. This test case has the following + # index:=(ModularIndexing(xindex, 4, 4)) + 4*(ModularIndexing(xindex, 32, 2)) + # and the match below is a candidate that is invalid: + # match={ + # dim_mod4_: 32, dim_mod3_: 2, stride_mod3_: 4, dim_mod2_: 1/16, + # dim_mod1_: 4, stride_mod1_: 1, stride_mod4_: 0, stride_mod2_: 0, stride_mod0_: 0 + # } + # This is now fixed by ensuring that that wild symbols only match integers + def test_ensure_integral_dims_and_strides(self): + def model(data, *args): + return torch.nn.functional.unfold(data, *args) + + data = torch.zeros( + [2, 3, 5, 5], dtype=torch.float16, requires_grad=True, device=self.device + ) + args = [2, 1, 0, 1] + run_and_compare( + self, + model, + data, + *args, + expected_num_triton_kernels=2, + expected_num_block_pointers=4, + compile_kwargs={"fullgraph": True}, + ) + + # Integration test to test block analysis with index expressions using + # negative strides. + # This test case has the following index: + # index_relative_to_xyr_index = -256*((xindex//64)) - (ModularIndexing(xindex, 1, 8)) + # - 16*(ModularIndexing(xindex, 8, 8)) + 1911 + # subexpr = -256*((xindex//64)) - (ModularIndexing(xindex, 1, 8)) - 16*(ModularIndexing(xindex, 8, 8)) + # Block analysis should produce the following: + # BlockParameters( + # shape=[8, 8, 8], + # block_shape=[((XBLOCK + 63)//64), Min(8, ((XBLOCK + 7)//8)), Min(8, XBLOCK) ], + # strides=[-256, -16, -1], + # offsets=[(xoffset//64), ModularIndexing(xoffset, 8, 8), ModularIndexing(xoffset, 1, 8)] + # ) + # constant_offset = 1911 + def test_negative_strides(self): + def model(x, y): + # Slice in reverse order via a negative stride + return torch.flip(x, [0, 1, 2]) + y + + x, y = ( + self._discontiguous_tensor((8, 8, 8), device=self.device) for _ in range(2) + ) + run_and_compare( + self, + model, + x, + y, + expected_num_triton_kernels=1, + expected_num_block_pointers=3, + ) + @config.patch("triton.prefer_nd_tiling", True) @config.patch("triton.max_tiles", 3) @parametrize( diff --git a/torch/_inductor/codegen/block_analysis.py b/torch/_inductor/codegen/block_analysis.py index b99f7f786cff..b47c8325e215 100644 --- a/torch/_inductor/codegen/block_analysis.py +++ b/torch/_inductor/codegen/block_analysis.py @@ -17,6 +17,13 @@ class BlockPatternMatcher: Matches block indexing expressions. """ + _indexing_wild_signed_int = functools.partial( + sympy.Wild, properties=[lambda x: x.is_integer] + ) + _indexing_wild_unsigned_int = functools.partial( + sympy.Wild, properties=[lambda x: x.is_integer and x.is_nonnegative] + ) + @classmethod def get_subexpr_involving_symbol(cls, expr: Expr, symbol: Symbol) -> Expr: """ @@ -63,9 +70,18 @@ def match_mod_div_block_expr( index = cls._preprocess(index) # Pattern match to find the strides and offset. - wild = functools.partial(sympy.Wild, exclude=[index_var]) - dims: list[Expr] = [wild(f"dim_mod{idx}") for idx in range(num_dims)] - strides: list[Expr] = [wild(f"stride_mod{idx}") for idx in range(num_dims)] + wild_unsigned_int = functools.partial( + cls._indexing_wild_unsigned_int, exclude=[index_var] + ) + wild_signed_int = functools.partial( + cls._indexing_wild_signed_int, exclude=[index_var] + ) + dims: list[Expr] = [ + wild_unsigned_int(f"dim_mod{idx}") for idx in range(num_dims) + ] + strides: list[Expr] = [ + wild_signed_int(f"stride_mod{idx}") for idx in range(num_dims) + ] # The first dimension's index is computed by division. # The remaining are computed by modulo. @@ -83,7 +99,8 @@ def match_mod_div_block_expr( # for more details. In short, here we check that each subexpression in sympy.Add contains # only FloorDiv or ModularIndexing expressions. if num_dims >= 5: - stride, denom, other = sympy.symbols("stride denominator other", cls=wild) + stride = sympy.symbols("stride", cls=wild_signed_int) + denom, other = sympy.symbols("denominator other", cls=wild_unsigned_int) mod_div_pattern = stride * ModularIndexing(index_var, denom, other) floor_div_pattern = stride * FloorDiv(index_var, denom) first_dim_floor_div_matched = False @@ -167,7 +184,7 @@ def match_affine_block_expr( stride. """ index = cls._preprocess(index) - stride = sympy.Wild("stride", exclude=[index_var]) + stride = cls._indexing_wild_signed_int(name="stride", exclude=[index_var]) m = index.match(index_var * stride) if m is None: return None