8000 Add test_binops_dtype_precedence and todo notes · pytorch/pytorch@9debeaf · GitHub
[go: up one dir, main page]

Skip to content

Commit 9debeaf

Browse files
committed
Add test_binops_dtype_precedence and todo notes
1 parent 152a797 commit 9debeaf

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

aten/src/ATen/native/mps/operations/ConstantOps.mm

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
MPSGraph *mpsGraph = make_mps_graph();
5252
newCachedGraph = new CachedGraph(mpsGraph);
5353

54+
// TODO: Does not work for MPSDataTypeBool
5455
MPSGraphTensor* inputTensor = [mpsGraph constantWithScalar:value.toDouble()
10000 5556
shape:input_shape
5657
dataType:getMPSScalarType(self.scalar_type())];

test/test_mps.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,6 +1300,56 @@ def test_to(self):
13001300
self.assertEqual(torch.tensor(-8.34, device='cpu').to('mps', torch.int),
13011301
torch.tensor(-8.34, device='cpu').to('mps').to(torch.int))
13021302

1303+
# Test dtype precedence (casting order) in binary operations by comparing to CPU result
1304+
def test_binops_dtype_precedence(self):
1305+
sample_vals = {
1306+
torch.bool: True,
1307+
torch.int16: 8,
1308+
torch.int32: -376,
1309+
torch.int64: 123898458,
1310+
#torch.float16: -234.5, # TODO: broken
1311+
torch.float32: 111.99
1312+
}
1313+
# Test all combinations of dtypes, operations, dimensionality
1314+
for dtype1, dtype2, binop in itertools.product(
1315+
sample_vals.keys(), sample_vals.keys(), ['add', 'sub', 'mul']): # TODO: 'div' broken
1316+
if binop == 'sub' and (dtype1 == torch.bool or dtype2 == torch.bool):
1317+
# Not supported, so skip
1318+
continue
1319+
self.assertEqual(
1320+
getattr(torch.tensor(sample_vals[dtype1], dtype=dtype1, device='mps'), binop)
1321+
(torch.tensor(sample_vals[dtype2], dtype=dtype2, device='mps')),
1322+
getattr(torch.tensor(sample_vals[dtype1], dtype=dtype1, device='cpu'), binop)
1323+
(torch.tensor(sample_vals[dtype2], dtype=dtype2, device='cpu')))
1324+
self.assertEqual(
1325+
getattr(torch.tensor([sample_vals[dtype1]], dtype=dtype1, device='mps'), binop)
1326+
(torch.tensor([sample_vals[dtype2]], dtype=dtype2, device='mps')),
1327+
getattr(torch.tensor([sample_vals[dtype1]], dtype=dtype1, device='cpu'), binop)
1328+
(torch.tensor([sample_vals[dtype2]], dtype=dtype2, device='cpu')))
1329+
self.assertEqual(
1330+
getattr(torch.tensor(sample_vals[dtype1], dtype=dtype1, device='mps'), binop)
1331+
(torch.tensor([sample_vals[dtype2]], dtype=dtype2, device='mps')),
1332+
getattr(torch.tensor(sample_vals[dtype1], dtype=dtype1, device='cpu'), binop)
1333+
(torch.tensor([sample_vals[dtype2]], dtype=dtype2, device='cpu')))
1334+
self.assertEqual(
1335+
getattr(torch.tensor([sample_vals[dtype1]], dtype=dtype1, device='mps'), binop)
1336+
(torch.tensor(sample_vals[dtype2], dtype=dtype2, device='mps')),
1337+
getattr(torch.tensor([sample_vals[dtype1]], dtype=dtype1, device='cpu'), binop)
1338+
(torch.tensor(sample_vals[dtype2], dtype=dtype2, device='cpu')))
1339+
'''
1340+
# TODO: broken because [MPSGraph constantWithScalar:::] does not support MPSDataTypeBool
1341+
self.assertEqual(
1342+
getattr(torch.full((100,), sample_vals[dtype1], dtype=dtype1, device='mps'), binop)
1343+
(torch.tensor(sample_vals[dtype2], dtype=dtype2, device='mps')),
1344+
getattr(torch.full((100,), sample_vals[dtype1], dtype=dtype1, device='cpu'), binop)
1345+
(torch.tensor(sample_vals[dtype2], dtype=dtype2, device='cpu')))
1346+
self.assertEqual(
1347+
getattr(torch.tensor(sample_vals[dtype1], dtype=dtype1, device='mps'), binop)
1348+
(torch.full((100,), sample_vals[dtype2], dtype=dtype2, device='mps')),
1349+
getattr(torch.tensor(sample_vals[dtype1], dtype=dtype1, device='cpu'), binop)
1350+
(torch.full((100,), sample_vals[dtype2], dtype=dtype2, device='cpu')))
1351+
'''
1352+
13031353

13041354
class TestSmoothL1Loss(TestCase):
13051355

0 commit comments

Comments
 (0)
0