@@ -1688,6 +1688,68 @@ def test_copy_non_contiguous(self):
1688
1688
y .permute (3 , 2 , 1 , 0 )[1 ::, ::2 ] = z
1689
1689
self .assertEqual (x , y .to ('cpu' ))
1690
1690
1691
+ # See https://github.com/pytorch/pytorch/pull/84742
1692
+ # and https://github.com/pytorch/pytorch/pull/78319
1693
+ def test_binops_dtype_precedence (self ):
1694
+ # Test dtype precedence (casting order) in binary operations by comparing to CPU result
1695
+ # Example values for all dtypes supported on the MPS backend
1696
+ sample_vals = {
1697
+ torch .bool : [False , True ],
1698
+ torch .int16 : [- 15 , 0 , 1 , 10 ],
1699
+ torch .int32 : [- 376 , 0 , 1 , 13 ],
1700
+ torch .int64 : [- 8 , 0 , 1 , 77 ],
1701
+ torch .float16 : [- 234.5 , 0.0 , 1.0 , 2.0 ],
1702
+ torch .float32 : [- 1.0 , 0.0 , 0.1 , 111.99 ],
1703
+ }
1704
+ # Test all combinations of dtypes, operations, dimensionality
1705
+ for dtype1 , dtype2 , binop in itertools .product (
1706
+ sample_vals .keys (), sample_vals .keys (), ['add' , 'sub' , 'mul' , 'div' ]):
1707
+ # bool minus bool is generally unsupported, so skip
1708
+ if binop == 'sub' and (dtype1 == torch .bool or dtype2 == torch .bool ):
1709
+ continue
1710
+ full_shape = (10 ,)
1711
+ for val1 , val2 in itertools .product (sample_vals [dtype1 ], sample_vals [dtype2 ]):
1712
+ # print(f'{dtype1},{dtype2}: ({val1}).{binop}({val2})')
1713
+ # print(getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
1714
+ # (torch.tensor(val2, dtype=dtype2, device='mps')))
1715
+ # print(getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
1716
+ # (torch.tensor(val2, dtype=dtype2, device='cpu')))
1717
+ self .assertEqual (
1718
+ getattr (torch .tensor (val1 , dtype = dtype1 , device = 'mps' ), binop )
1719
+ (torch .tensor (val2 , dtype = dtype2 , device = 'mps' )),
1720
+ getattr (torch .tensor (val1 , dtype = dtype1 , device = 'cpu' ), binop )
1721
+ (torch .tensor (val2 , dtype = dtype2 , device = 'cpu' )))
1722
+ self .assertEqual (
1723
+ getattr (torch .tensor ([val1 ], dtype = dtype1 , device = 'mps' ), binop )
1724
+ (torch .tensor ([val2 ], dtype = dtype2 , device = 'mps' )),
1725
+ getattr (torch .tensor ([val1 ], dtype = dtype1 , device = 'cpu' ), binop )
1726
+ (torch .tensor ([val2 ], dtype = dtype2 , device = 'cpu' )))
1727
+ self .assertEqual (
1728
+ getattr (torch .tensor (val1 , dtype = dtype1 , device = 'mps' ), binop )
1729
+ (torch .tensor ([val2 ], dtype = dtype2 , device = 'mps' )),
1730
+ getattr (torch .tensor (val1 , dtype = dtype1 , device = 'cpu' ), binop )
1731
+ (torch .tensor ([val2 ], dtype = dtype2 , device = 'cpu' )))
1732
+ self .assertEqual (
1733
+ getattr (torch .tensor ([val1 ], dtype = dtype1 , device = 'mps' ), binop )
1734
+ (torch .tensor (val2 , dtype = dtype2 , device = 'mps' )),
1735
+ getattr (torch .tensor ([val1 ], dtype = dtype1 , device = 'cpu' ), binop )
1736
+ (torch .tensor (val2 , dtype = dtype2 , device = 'cpu' )))
1737
+ # Test tensors created with torch.full
1738
+ x1 = torch .full (full_shape , val1 , dtype = dtype1 , device = 'mps' )
1739
+ y1 = torch .tensor (val2 , dtype = dtype2 , device = 'mps' )
1740
+ x2 = torch .full (full_shape , val1 , dtype = dtype1 , device = 'cpu' )
1741
+ y2 = torch .tensor (val2 , dtype = dtype2 , device = 'cpu' )
1742
+ self .assertEqual (getattr (x1 , binop )(y1 ), getattr (x2 , binop )(y2 ))
1743
+ x3 = torch .tensor (val1 , dtype = dtype1 , device = 'mps' )
1744
+ y3 = torch .full (full_shape , val2 , dtype = dtype2 , device = 'mps' )
1745
+ x4 = torch .tensor (val1 , dtype = dtype1 , device = 'cpu' )
1746
+ y4 = torch .full (full_shape , val2 , dtype = dtype2 , device = 'cpu' )
1747
+ self .assertEqual (getattr (x3 , binop )(y3 ), getattr (x4 , binop )(y4 ))
1748
+ self .assertEqual (
1749
+ getattr (torch .tensor (val1 , dtype = dtype1 , device = 'mps' ), binop )
1750
+ (torch .full (full_shape , val2 , dtype = dtype2 , device = 'mps' )),
1751
+ getattr (torch .tensor (val1 , dtype = dtype1 , device = 'cpu' ), binop )
1752
+ (torch .full (full_shape , val2 , dtype = dtype2 , device = 'cpu' )))
1691
1753
1692
1754
1693
1755
class TestLogical (TestCase ):
0 commit comments