8000 [MPS] Add binary operations dtype precedence test case (#87545) · pytorch/pytorch@81a8fdc · GitHub
[go: up one dir, main page]

Skip to content

Commit 81a8fdc

Browse files
lhoenigpytorchmergebot
authored andcommitted
[MPS] Add binary operations dtype precedence test case (#87545)
See #84742 and #78319. The test case tests that - for the binary operations (add, sub, mul, div), - for all data types (dtypes), - for a range of representative values and their combinations, - for various shapes and ways of creating the test tensors, the contents and dtype of the result tensor is identical for the MPS and CPU backends. It adds about 15-18s runtime to `test_mps.py`. Pull Request resolved: #87545 Approved by: https://github.com/kit1980
1 parent 44c9185 commit 81a8fdc

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

test/test_mps.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,6 +1688,68 @@ def test_copy_non_contiguous(self):
16881688
y.permute(3, 2, 1, 0)[1::, ::2] = z
16891689
self.assertEqual(x, y.to('cpu'))
16901690

1691+
# See https://github.com/pytorch/pytorch/pull/84742
1692+
# and https://github.com/pytorch/pytorch/pull/78319
1693+
def test_binops_dtype_precedence(self):
1694+
# Test dtype precedence (casting order) in binary operations by comparing to CPU result
1695+
# Example values for all dtypes supported on the MPS backend
1696+
sample_vals = {
1697+
torch.bool: [False, True],
1698+
torch.int16: [-15, 0, 1, 10],
1699+
torch.int32: [-376, 0, 1, 13],
1700+
torch.int64: [-8, 0, 1, 77],
1701+
torch.float16: [-234.5, 0.0, 1.0, 2.0],
1702+
torch.float32: [-1.0, 0.0, 0.1, 111.99],
1703+
}
1704+
# Test all combinations of dtypes, operations, dimensionality
1705+
for dtype1, dtype2, binop in itertools.product(
1706+
sample_vals.keys(), sample_vals.keys(), ['add', 'sub', 'mul', 'div']):
1707+
# bool minus bool is generally unsupported, so skip
1708+
if binop == 'sub' and (dtype1 == torch.bool or dtype2 == torch.bool):
1709+
continue
1710+
full_shape = (10,)
1711+
for val1, val2 in itertools.product(sample_vals[dtype1], sample_vals[dtype2]):
1712+
# print(f'{dtype1},{dtype2}: ({val1}).{binop}({val2})')
1713+
# print(getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
1714+
# (torch.tensor(val2, dtype=dtype2, device='mps')))
1715+
# print(getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
1716+
# (torch.tensor(val2, dtype=dtype2, device='cpu')))
1717+
self.assertEqual(
1718+
getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
1719+
(torch.tensor(val2, dtype=dtype2, device='mps')),
1720+
getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
1721+
(torch.tensor(val2, dtype=dtype2, device='cpu')))
1722+
self.assertEqual(
1723+
getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop)
1724+
(torch.tensor([val2], dtype=dtype2, device='mps')),
1725+
getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop)
1726+
(torch.tensor([val2], dtype=dtype2, device='cpu')))
1727+
self.assertEqual(
1728+
getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
1729+
(torch.tensor([val2], dtype=dtype2, device='mps')),
1730+
getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
1731+
(torch.tensor([val2], dtype=dtype2, device='cpu')))
1732+
self.assertEqual(
1733+
getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop)
1734+
(torch.tensor(val2, dtype=dtype2, device='mps')),
1735+
getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop)
1736+
(torch.tensor(val2, dtype=dtype2, device='cpu')))
1737+
# Test tensors created with torch.full
1738+
x1 = torch.full(full_shape, val1, dtype=dtype1, device='mps')
1739+
y1 = torch.tensor(val2, dtype=dtype2, device='mps')
1740+
x2 = torch.full(full_shape, val1, dtype=dtype1, device='cpu')
1741+
y2 = torch.tensor(val2, dtype=dtype2, device='cpu')
1742+
self.assertEqual(getattr(x1, binop)(y1), getattr(x2, binop)(y2))
1743+
x3 = torch.tensor(val1, dtype=dtype1, device='mps')
1744+
y3 = torch.full(full_shape, val2, dtype=dtype2, device='mps')
1745+
x4 = torch.tensor(val1, dtype=dtype1, device='cpu')
1746+
y4 = torch.full(full_shape, val2, dtype=dtype2, device='cpu')
1747+
self.assertEqual(getattr(x3, binop)(y3), getattr(x4, binop)(y4))
1748+
self.assertEqual(
1749+
getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
1750+
(torch.full(full_shape, val2, dtype=dtype2, device='mps')),
1751+
getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
1752+
(torch.full(full_shape, val2, dtype=dtype2, device='cpu')))
16911753

16921754

16931755
class TestLogical(TestCase):

0 commit comments

Comments
 (0)
0