@@ -1300,6 +1300,56 @@ def test_to(self):
1300
1300
self .assertEqual (torch .tensor (- 8.34 , device = 'cpu' ).to ('mps' , torch .int ),
1301
1301
torch .tensor (- 8.34 , device = 'cpu' ).to ('mps' ).to (torch .int ))
1302
1302
1303
+ # Test dtype precedence (casting order) in binary operations by comparing to CPU result
1304
+ def test_binops_dtype_precedence (self ):
1305
+ sample_vals = {
1306
+ torch .bool : True ,
1307
+ torch .int16 : 8 ,
1308
+ torch .int32 : - 376 ,
1309
+ torch .int64 : 123898458 ,
1310
+ #torch.float16: -234.5, # TODO: broken
1311
+ torch .float32 : 111.99
1312
+ }
1313
+ # Test all combinations of dtypes, operations, dimensionality
1314
+ for dtype1 , dtype2 , binop in itertools .product (
1315
+ sample_vals .keys (), sample_vals .keys (), ['add' , 'sub' , 'mul' ]): # TODO: 'div' broken
1316
+ if binop == 'sub' and (dtype1 == torch .bool or dtype2 == torch .bool ):
1317
+ # Not supported, so skip
1318
+ continue
1319
+ self .assertEqual (
1320
+ getattr (torch .tensor (sample_vals [dtype1 ], dtype = dtype1 , device = 'mps' ), binop )
1321
+ (torch .tensor (sample_vals [dtype2 ], dtype = dtype2 , device = 'mps' )),
1322
+ getattr (torch .tensor (sample_vals [dtype1 ], dtype = dtype1 , device = 'cpu' ), binop )
1323
+ (torch .tensor (sample_vals [dtype2 ], dtype = dtype2 , device = 'cpu' )))
1324
+ self .assertEqual (
1325
+ getattr (torch .tensor ([sample_vals [dtype1 ]], dtype = dtype1 , device = 'mps' ), binop )
1326
+ (torch .tensor ([sample_vals [dtype2 ]], dtype = dtype2 , device = 'mps' )),
1327
+ getattr (torch .tensor ([sample_vals [dtype1 ]], dtype = dtype1 , device = 'cpu' ), binop )
1328
+ (torch .tensor ([sample_vals [dtype2 ]], dtype = dtype2 , device = 'cpu' )))
1329
+ self .assertEqual (
1330
+ getattr (torch .tensor (sample_vals [dtype1 ], dtype = dtype1 , device = 'mps' ), binop )
1331
+ (torch .tensor ([sample_vals [dtype2 ]], dtype = dtype2 , device = 'mps' )),
1332
+ getattr (torch .tensor (sample_vals [dtype1 ], dtype = dtype1 , device = 'cpu' ), binop )
1333
+ (torch .tensor ([sample_vals [dtype2 ]], dtype = dtype2 , device = 'cpu' )))
1334
+ self .assertEqual (
1335
+ getattr (torch .tensor ([sample_vals [dtype1 ]], dtype = dtype1 , device = 'mps' ), binop )
1336
+ (torch .tensor (sample_vals [dtype2 ], dtype = dtype2 , device = 'mps' )),
1337
+ getattr (torch .tensor ([sample_vals [dtype1 ]], dtype = dtype1 , device = 'cpu' ), binop )
1338
+ (torch .tensor (sample_vals [dtype2 ], dtype = dtype2 , device = 'cpu' )))
1339
+ '''
1340
+ # TODO: broken because [MPSGraph constantWithScalar:::] does not support MPSDataTypeBool
1341
+ self.assertEqual(
1342
+ getattr(torch.full((100,), sample_vals[dtype1], dtype=dtype1, device='mps'), binop)
1343
+ (torch.tensor(sample_vals[dtype2], dtype=dtype2, device='mps')),
1344
+ getattr(torch.full((100,), sample_vals[dtype1], dtype=dtype1, device='cpu'), binop)
1345
+ (torch.tensor(sample_vals[dtype2], dtype=dtype2, device='cpu')))
1346
+ self.assertEqual(
1347
+ getattr(torch.tensor(sample_vals[dtype1], dtype=dtype1, device='mps'), binop)
1348
+ (torch.full((100,), sample_vals[dtype2], dtype=dtype2, device='mps')),
1349
+ getattr(torch.tensor(sample_vals[dtype1], dtype=dtype1, device='cpu'), binop)
1350
+ (torch.full((100,), sample_vals[dtype2], dtype=dtype2, device='cpu')))
1351
+ '''
1352
+
1303
1353
1304
1354
class TestSmoothL1Loss (TestCase ):
1305
1355
0 commit comments