@@ -1268,8 +1268,16 @@ def false_fn(x):
1268
1268
1269
1269
return torch .cond (x .shape [0 ] > 5 , true_fn , false_fn , (x ,))
1270
1270
1271
- input1 = (torch .ones (3 , 3 ), torch .ones (5 ), torch .ones (3 , 3 ))
1272
- input2 = (torch .ones (10 , 3 ), torch .ones (6 ), torch .ones (10 , 3 ))
1271
+ input1 = (
1272
+ torch .ones (3 , 3 , device = self .device ),
1273
+ torch .ones (5 , device = self .device ),
1274
+ torch .ones (3 , 3 , device = self .device ),
1275
+ )
1276
+ input2 = (
1277
+ torch .ones (10 , 3 , device = self .device ),
1278
+ torch .ones (6 , device = self .device ),
1279
+ torch .ones (10 , 3 , device = self .device ),
1280
+ )
1273
1281
inputs = (input1 , input2 )
1274
1282
dynamic_shapes = {"x" : {0 : Dim ("d" )}, "y" : {0 : Dim ("d1" )}, "z" : {0 : Dim ("d" )}}
1275
1283
self .check_model_with_multiple_inputs (
@@ -1395,6 +1403,9 @@ def forward(self, x):
1395
1403
self .check_model (M (self .device ), (torch .randn (5 , 5 , device = self .device ),))
1396
1404
1397
1405
def test_zero_grid_with_backed_symbols (self ):
1406
+ if self .device != GPU_TYPE :
1407
+ raise unittest .SkipTest ("requires GPU" )
1408
+
1398
1409
class Repro (torch .nn .Module ):
1399
1410
def __init__ (self ) -> None :
1400
1411
super ().__init__ ()
@@ -1417,7 +1428,7 @@ def forward(self, x, b):
1417
1428
example_inputs ,
1418
1429
dynamic_shapes = dynamic_shapes ,
1419
1430
)
1420
- aot_inductor_module = AOTIRunnerUtil .load (GPU_TYPE , so_path )
1431
+ aot_inductor_module = AOTIRunnerUtil .load (self . device , so_path )
1421
1432
aot_inductor_module (* example_inputs )
1422
1433
1423
1434
# Re-run where dynamic dim size is 0.
@@ -1920,7 +1931,7 @@ def __init__(self) -> None:
1920
1931
def forward (self , x ):
1921
1932
return torch .ops .aten .normal_functional .default (x )
1922
1933
1923
- self .check_model (Model (), (torch .empty (4 , 1 , 4 , 4 ),))
1934
+ self .check_model (Model (), (torch .empty (4 , 1 , 4 , 4 , device = self . device ),))
1924
1935
1925
1936
def test_empty_graph (self ):
1926
1937
class Model (torch .nn .Module ):
0 commit comments