@@ -1003,6 +1003,31 @@ def test_pointwise_index_order(self):
1003
1003
xindex = xoffset + tl.arange(0, XBLOCK)[None, None, :]""" , # noqa: B950
1004
1004
)
1005
1005
1006
+ # Integration test to ensure that matched dims & strides from match_mod_div_expr
1007
+ # are nonnegative integers. This test case has the following index
1008
+ # index=(ModularIndexing(xindex, 4, 4)) + 4*(ModularIndexing(xindex, 32, 2))
1009
+ # and the match below is a candidate that is invalid:
1010
+ # match={
1011
+ # dim_mod4_: 32, dim_mod3_: 2, stride_mod3_: 4, dim_mod2_: 1/16,
1012
+ # dim_mod1_: 4, stride_mod1_: 1, stride_mod4_: 0, stride_mod2_: 0, stride_mod0_: 0
1013
+ # }
1014
+ # This is now fixed by ensuring that that wild symbols only match nonnegative integers
1015
+ def test_ensure_integral_dims_and_strides (self ):
1016
+ def model (data , * args ):
1017
+ return torch .nn .functional .unfold (data , * args )
1018
+
1019
+ data = torch .zeros ([2 , 3 , 5 , 5 ], dtype = torch .float16 , requires_grad = True )
1020
+ args = [2 , 1 , 0 , 1 ]
1021
+ run_and_compare (
1022
+ self ,
1023
+ model ,
1024
+ data ,
1025
+ * args ,
1026
+ expected_num_triton_kernels = 2 ,
1027
+ expected_num_block_pointers = 4 ,
1028
+ compile_kwargs = {"fullgraph" : True },
1029
+ )
1030
+
1006
1031
1007
1032
@unittest .skipIf (not TRITON_HAS_CPU , "requires triton CPU backend" )
1008
1033
@config .patch (cpu_backend = "triton" )
0 commit comments