@@ -1146,6 +1146,64 @@ def foo(x, y, z):
1146
1146
# Singleton splits should be discarded.
1147
1147
self ._assert_pointwise_ndims (triton_code , 2 )
1148
1148
1149
+ # Integration test to ensure that matched dims & strides from match_mod_div_expr
1150
+ # are unsigned and signed integers respectively. This test case has the following
1151
+ # index:=(ModularIndexing(xindex, 4, 4)) + 4*(ModularIndexing(xindex, 32, 2))
1152
+ # and the match below is a candidate that is invalid:
1153
+ # match={
1154
+ # dim_mod4_: 32, dim_mod3_: 2, stride_mod3_: 4, dim_mod2_: 1/16,
1155
+ # dim_mod1_: 4, stride_mod1_: 1, stride_mod4_: 0, stride_mod2_: 0, stride_mod0_: 0
1156
+ # }
1157
+ # This is now fixed by ensuring that that wild symbols only match integers
1158
+ def test_ensure_integral_dims_and_strides (self ):
1159
+ def model (data , * args ):
1160
+ return torch .nn .functional .unfold (data , * args )
1161
+
1162
+ data = torch .zeros (
1163
+ [2 , 3 , 5 , 5 ], dtype = torch .float16 , requires_grad = True , device = self .device
1164
+ )
1165
+ args = [2 , 1 , 0 , 1 ]
1166
+ run_and_compare (
1167
+ self ,
1168
+ model ,
1169
+ data ,
1170
+ * args ,
1171
+ expected_num_triton_kernels = 2 ,
1172
+ expected_num_block_pointers = 4 ,
1173
+ compile_kwargs = {"fullgraph" : True },
1174
+ )
1175
+
1176
+ # Integration test to test block analysis with index expressions using
1177
+ # negative strides.
1178
+ # This test case has the following index:
1179
+ # index_relative_to_xyr_index = -256*((xindex//64)) - (ModularIndexing(xindex, 1, 8))
1180
+ # - 16*(ModularIndexing(xindex, 8, 8)) + 1911
1181
+ # subexpr = -256*((xindex//64)) - (ModularIndexing(xindex, 1, 8)) - 16*(ModularIndexing(xindex, 8, 8))
1182
+ # Block analysis should produce the following:
1183
+ # BlockParameters(
1184
+ # shape=[8, 8, 8],
1185
+ # block_shape=[((XBLOCK + 63)//64), Min(8, ((XBLOCK + 7)//8)), Min(8, XBLOCK) ],
1186
+ # strides=[-256, -16, -1],
1187
+ # offsets=[(xoffset//64), ModularIndexing(xoffset, 8, 8), ModularIndexing(xoffset, 1, 8)]
1188
+ # )
1189
+ # constant_offset = 1911
1190
+ def test_negative_strides (self ):
1191
+ def model (x , y ):
1192
+ # Slice in reverse order via a negative stride
1193
+ return torch .flip (x , [0 , 1 , 2 ]) + y
1194
+
1195
+ x , y = (
1196
+ self ._discontiguous_tensor ((8 , 8 , 8 ), device = self .device ) for _ in range (2 )
1197
+ )
1198
+ run_and_compare (
1199
+ self ,
1200
+ model ,
1201
+ x ,
1202
+ y ,
1203
+ expected_num_triton_kernels = 1 ,
1204
+ expected_num_block_pointers = 3 ,
1205
+ )
1206
+
1149
1207
@config .patch ("triton.prefer_nd_tiling" , True )
1150
1208
@config .patch ("triton.max_tiles" , 3 )
1151
1209
@parametrize (
0 commit comments