@@ -1299,13 +1299,7 @@ def forward(self, values, repeats, mask, embeddings, x, z, scalar):
1299
1299
1300
1300
unbacked_add_expr = backed + unbacked
1301
1301
repeated = x .repeat (unbacked_add_expr , 1 )
1302
- return torch .cat (
1303
- [
1304
- repeated ,
1305
- index_select ,
1306
- ],
1307
- dim = 1 ,
1308
- )
1302
+ return torch .cat ([repeated , index_select ], dim = 1 )
1309
1303
1310
1304
example_inputs = (
1311
1305
torch .ones (64 , dtype = torch .int64 , device = self .device ),
@@ -1327,6 +1321,115 @@ def forward(self, values, repeats, mask, embeddings, x, z, scalar):
1327
1321
}
1328
1322
self .check_model (Repro (), example_inputs , dynamic_shapes = spec )
1329
1323
1324
+ def test_size_with_unbacked_add_expr_transitive (self ):
1325
+ # Edge case with torch._check(expr1, expr2) + torch._check(expr2, unbacked).
1326
+ # When generating example input sizes for autotuning, it should coalesce
1327
+ # expr1, expr2, unbacked into a single size.
1328
+ if self .device != GPU_TYPE :
1329
+ raise unittest .SkipTest ("requires GPU" )
1330
+
1331
+ class Repro (torch .nn .Module ):
1332
+ def forward (self , values , repeats , mask , embeddings , x , y , z , lst ):
1333
+ u0 , u1 , random_unbacked = lst .tolist ()
1334
+ torch ._check_is_size (u0 )
1335
+ torch ._check_is_size (u1 )
1336
+ backed = z .size (0 )
1337
+ backed1 = z .size (1 )
1338
+
1339
+ repeated = x .repeat (backed + u0 , 1 )
1340
+ repeated1 = y .repeat (backed1 + u1 , 1 )
1341
+ out = torch .empty_like (repeated )
1342
+ add_kernel [(out .numel (),)](
1343
+ repeated , repeated , out , out .numel (), BLOCK_SIZE = 2
1344
+ )
1345
+
1346
+ torch ._check (repeated1 .size (0 ) == out .size (0 ))
1347
+ torch ._check (out .size (0 ) == random_unbacked )
1348
+
1349
+ index = torch .repeat_interleave (values , repeats )
1350
+ index_select = torch .index_select (embeddings , 0 , index )
1351
+
1352
+ cat = torch .cat ([out , index_select ], dim = 1 )
1353
+ add = repeated + repeated1
1354
+ return cat , add
1355
+
1356
+ example_inputs = (
1357
+ torch .ones (64 , dtype = torch .int64 , device = self .device ),
1358
+ torch .ones (64 , dtype = torch .int64 , device = self .device ) * 24 ,
1359
+ torch .ones ((768 ,), dtype = torch .int64 , device = self .device ).bool (),
1360
+ torch .randn ((401 , 8 ), dtype = torch .bfloat16 , device = self .device ),
1361
+ torch .randn ((2 , 256 ), dtype = torch .bfloat16 , device = self .device ),
1362
+ torch .randn ((2 , 256 ), dtype = torch .bfloat16 , device = self .device ),
1363
+ torch .ones (758 , 758 , dtype = torch .int64 , device = self .device ),
1364
+ torch .tensor (
1365
+ [10 , 10 , 2 * (758 + 10 )], dtype = torch .int32 , device = self .device
1366
+ ),
1367
+ )
1368
+ spec = {
1369
+ "values" : (Dim .DYNAMIC ,),
1370
+ "repeats" : (Dim .DYNAMIC ,),
1371
+ "mask" : (Dim .DYNAMIC ,),
1372
+ "embeddings" : (Dim .DYNAMIC , Dim .STATIC ),
1373
+ "x" : (Dim .DYNAMIC , Dim .STATIC ),
1374
+ "y" : (Dim .DYNAMIC , Dim .STATIC ),
1375
+ "z" : (Dim .DYNAMIC , Dim .DYNAMIC ),
1376
+ "lst" : (Dim .STATIC ,),
1377
+ }
1378
+ self .check_model (Repro (), example_inputs , dynamic_shapes = spec )
1379
+
1380
+ @config .patch ({"unbacked_symint_fallback" : 1024 })
1381
+ def test_size_with_unbacked_add_and_mul_expr (self ):
1382
+ # Edge case with torch._check(add_expr, mul_expr). When generating example
1383
+ # input sizes for autotuning, make sure they coalesce into a single size.
1384
+ if self .device != GPU_TYPE :
1385
+ raise unittest .SkipTest ("requires GPU" )
1386
+
1387
+ class Repro (torch .nn .Module ):
1388
+ def forward (self , values , repeats , mask , embeddings , x , y , z , lst ):
1389
+ u0 , u1 , u2 = lst .tolist ()
1390
+ torch ._check_is_size (u0 )
1391
+ torch ._check_is_size (u1 )
1392
+ torch ._check_is_size (u2 )
1393
+ backed = z .size (0 )
1394
+ backed1 = z .size (1 )
1395
+
1396
+ unbacked_add_expr = backed + u0
1397
+ unbacked_mul_expr = backed1 + (u1 * u2 )
1398
+ repeated0 = x .repeat (unbacked_add_expr , 1 )
1399
+ repeated1 = y .repeat (unbacked_mul_expr , 1 )
1400
+ out0 = torch .empty_like (repeated0 )
1401
+ out1 = torch .empty_like (repeated1 )
1402
+ add_kernel [(out0 .numel (),)](
1403
+ repeated0 , repeated0 , out0 , out0 .numel (), BLOCK_SIZE = 2
1404
+ )
1405
+ add_kernel [(out1 .numel (),)](
1406
+ repeated1 , repeated1 , out1 , out1 .numel (), BLOCK_SIZE = 2
1407
+ )
1408
+
1409
+ return torch .cat ([out1 , out0 ], dim = 1 )
1410
+
1411
+ example_inputs = (
1412
+ torch .ones (64 , dtype = torch .int64 , device = self .device ),
1413
+ torch .ones (64 , dtype = torch .int64 , device = self .device ) * 24 ,
1414
+ torch .ones ((768 ,), dtype = torch .int64 , device = self .device ).bool (),
1415
+ torch .randn ((401 , 8 ), dtype = torch .bfloat16 , device = self .device ),
1416
+ torch .randn ((2 , 256 ), dtype = torch .bfloat16 , device = self .device ),
1417
+ torch .randn ((2 , 256 ), dtype = torch .bfloat16 , device = self .device ),
1418
+ torch .ones (758 , 758 , dtype = torch .int64 , device = self .device ),
1419
+ torch .tensor ([10 , 5 , 2 ], dtype = torch .int32 , device = self .device ),
1420
+ )
1421
+ spec = {
1422
+ "values" : (Dim .DYNAMIC ,),
1423
+ "repeats" : (Dim .DYNAMIC ,),
1424
+ "mask" : (Dim .DYNAMIC ,),
1425
+ "embeddings" : (Dim .DYNAMIC , Dim .STATIC ),
1426
+ "x" : (Dim .DYNAMIC , Dim .STATIC ),
1427
+ "y" : (Dim .DYNAMIC , Dim .STATIC ),
1428
+ "z" : (Dim .DYNAMIC , Dim .DYNAMIC ),
1429
+ "lst" : (Dim .STATIC ,),
1430
+ }
1431
+ self .check_model (Repro (), example_inputs , dynamic_shapes = spec )
1432
+
1330
1433
@skipIfXpu (msg = "_scaled_dot_product_flash_attention is not supported on XPU yet" )
1331
1434
def test_fallback_kernel_with_symexpr_output (self ):
1332
1435
if self .device != GPU_TYPE :
0 commit comments