@@ -1302,53 +1302,69 @@ def test_to(self):
1302
1302
1303
1303
# Test dtype precedence (casting order) in binary operations by comparing to CPU result
1304
1304
def test_binops_dtype_precedence (self ):
1305
+ # Example values for all dtypes supported on the MPS backend
1305
1306
sample_vals = {
1306
- torch .bool : True ,
1307
- torch .int16 : 8 ,
1308
- torch .int32 : - 376 ,
1309
- torch .int64 : 123898458 ,
1310
- #torch.float16: -234.5, # TODO: broken
1311
- torch .float32 : 111.99
1307
+ torch .bool : [ False , True ] ,
1308
+ torch .int16 : [ - 15 , 0 , 1 , 10 ] ,
1309
+ torch .int32 : [ - 376 , 0 , 1 , 13 ] ,
1310
+ torch .int64 : [ - 8 , 0 , 1 , 77 ] ,
1311
+ # torch.float16: [ -234.5] , # TODO: Broken, unknown why currently
1312
+ torch .float32 : [ - 1.0 , 0 , 0.1 , 111.99 ]
1312
1313
}
1313
1314
# Test all combinations of dtypes, operations, dimensionality
1315
+ # TODO: 'div' operation broken, needs special rules currently not implement
1316
+ # because div is the only arithmetic operation that can result in a floats result
1317
+ # with integer or bool inputs. Also infinities can occur.
1314
1318
for dtype1 , dtype2 , binop in itertools .product (
1315
- sample_vals .keys (), sample_vals .keys (), ['add' , 'sub' , 'mul' ]): # TODO: 'div' broken
1319
+ sample_vals .keys (), sample_vals .keys (), ['add' , 'sub' , 'mul' ]):
1316
1320
if binop == 'sub' and (dtype1 == torch .bool or dtype2 == torch .bool ):
1317
- # Not supported, so skip
1318
- continue
1319
- self .assertEqual (
1320
- getattr (torch .tensor (sample_vals [dtype1 ], dtype = dtype1 , device = 'mps' ), binop )
1321
- (torch .tensor (sample_vals [dtype2 ], dtype = dtype2 , device = 'mps' )),
1322
- getattr (torch .tensor (sample_vals [dtype1 ], dtype = dtype1 , device = 'cpu' ), binop )
1323
- (torch .tensor (sample_vals [dtype2 ], dtype = dtype2 , device = 'cpu' )))
1324
- self .assertEqual (
1325
- getattr (torch .tensor ([sample_vals [dtype1 ]], dtype = dtype1 , device = 'mps' ), binop )
1326
- (torch .tensor ([sample_vals [dtype2 ]], dtype = dtype2 , device = 'mps' )),
1327
- getattr (torch .tensor ([sample_vals [dtype1 ]], dtype = dtype1 , device = 'cpu' ), binop )
1328
- (torch .tensor ([sample_vals [dtype2 ]], dtype = dtype2 , device = 'cpu' )))
1329
- self .assertEqual (
1330
- getattr (torch .tensor (sample_vals [dtype1 ], dtype = dtype1 , device = 'mps' ), binop )
1331
- (torch .tensor ([sample_vals [dtype2 ]], dtype = dtype2 , device = 'mps' )),
1332
- getattr (torch .tensor (sample_vals [dtype1 ], dtype = dtype1 , device = 'cpu' ), binop )
1333
- (torch .tensor ([sample_vals [dtype2 ]], dtype = dtype2 , device = 'cpu' )))
1334
- self .assertEqual (
1335
- getattr (torch .tensor ([sample_vals [dtype1 ]], dtype = dtype1 , device = 'mps' ), binop )
1336
- (torch .tensor (sample_vals [dtype2 ], dtype = dtype2 , device = 'mps' )),
1337
- getattr (torch .tensor ([sample_vals [dtype1 ]], dtype = dtype1 , device = 'cpu' ), binop )
1338
- (torch .tensor (sample_vals [dtype2 ], dtype = dtype2 , device = 'cpu' )))
1339
- '''
1340
- # TODO: broken because [MPSGraph constantWithScalar:::] does not support MPSDataTypeBool
1341
- self.assertEqual(
1342
- getattr(torch.full((100,), sample_vals[dtype1], dtype=dtype1, device='mps'), binop)
1343
- (torch.tensor(sample_vals[dtype2], dtype=dtype2, device='mps')),
1344
- getattr(torch.full((100,), sample_vals[dtype1], dtype=dtype1, device='cpu'), binop)
1345
- (torch.tensor(sample_vals[dtype2], dtype=dtype2, device='cpu')))
1346
- self.assertEqual(
1347
- getattr(torch.tensor(sample_vals[dtype1], dtype=dtype1, device='mps'), binop)
1348
- (torch.full((100,), sample_vals[dtype2], dtype=dtype2, device='mps')),
1349
- getattr(torch.tensor(sample_vals[dtype1], dtype=dtype1, device='cpu'), binop)
1350
- (torch.full((100,), sample_vals[dtype2], dtype=dtype2, device='cpu')))
1351
- '''
1321
+ continue # Not supported, so skip
1322
+ #print(dtype1, dtype2, binop)
1323
+ full_sh = (20 ,)
1324
+ #print('assert1')
1325
+ for val1 , val2 in itertools .product (sample_vals [dtype1 ], sample_vals [dtype2 ]):
1326
+ self .assertEqual (
1327
+ getattr (torch .tensor (val1 , dtype = dtype1 , device = 'mps' ), binop )
1328
+ (torch .tensor (val2 , dtype = dtype2 , device = 'mps' )),
1329
+ getattr (torch .tensor (val1 , dtype = dtype1 , device = 'cpu' ), binop )
1330
+ (torch .tensor (val2 , dtype = dtype2 , device = 'cpu' )))
1331
+ #print('assert2')
1332
+ self .assertEqual (
1333
+ getattr (torch .tensor ([val1 ], dtype = dtype1 , device = 'mps' ), binop )
1334
+ (torch .tensor ([val2 ], dtype = dtype2 , device = 'mps' )),
1335
+ getattr (torch .tensor ([val1 ], dtype = dtype1 , device = 'cpu' ), binop )
1336
+ (torch .tensor ([val2 ], dtype = dtype2 , device = 'cpu' )))
1337
+ #print('assert3')
1338
+ self .assertEqual (
1339
+ getattr (torch .tensor (val1 , dtype = dtype1 , device = 'mps' ), binop )
1340
+ (torch .tensor ([val2 ], dtype = dtype2 , device = 'mps' )),
1341
+ getattr (torch .tensor (val1 , dtype = dtype1 , device = 'cpu' ), binop )
1342
+ (torch .tensor ([val2 ], dtype = dtype2 , device = 'cpu' )))
1343
+ #print('assert4')
1344
+ self .assertEqual (
1345
+ getattr (torch .tensor ([val1 ], dtype = dtype1 , device = 'mps' ), binop )
1346
+ (torch .tensor (val2 , dtype = dtype2 , device = 'mps' )),
1347
+ getattr (torch .tensor ([val1 ], dtype = dtype1 , device = 'cpu' ), binop )
1348
+ (torch .tensor (val2 , dtype = dtype2 , device = 'cpu' )))
1349
+ #'''
1350
+ # Multiple problems with [MPSGraph constantWithScalar:shape:dataType:] prevent
1351
+ # these tests from completing successfully currently
1352
+ # TODO: Research problem with int16, is it also related to constantWithScalar?
1353
+ # TODO: Stateful bug with False, False, add in assert5? Related to the cache key
1354
+ # or more serious problem?
1355
+ #print('assert5', val1, val2)
1356
+ self .assertEqual (
1357
+ getattr (torch .full (full_sh , val1 , dtype = dtype1 , device = 'mps' ), binop )
1358
+ (torch .tensor (val2 , dtype = dtype2 , device = 'mps' )),
1359
+ getattr (torch .full (full_sh , val1 , dtype = dtype1 , device = 'cpu' ), binop )
1360
+ (torch .tensor (val2 , dtype = dtype2 , device = 'cpu' )))
1361
+ #print('assert6')
1362
+ self .assertEqual (
1363
+ getattr (torch .tensor (val1 , dtype = dtype1 , device = 'mps' ), binop )
1364
+ (torch .full (full_sh , val2 , dtype = dtype2 , device = 'mps' )),
1365
+ getattr (torch .tensor (val1 , dtype = dtype1 , device = 'cpu' ), binop )
1366
+ (torch .full (full_sh , val2 , dtype = dtype2 , device = 'cpu' )))
1367
+ #'''
1352
1368
1353
1369
1354
1370
class TestSmoothL1Loss (TestCase ):
0 commit comments