8000 Attempt to fix torch.full on MPS for bool values, more problem descri… · pytorch/pytorch@c84c93d · GitHub
[go: up one dir, main page]

< 8000 a href="#start-of-content" data-skip-target-assigned="false" class="px-2 py-4 color-bg-accent-emphasis color-fg-on-emphasis show-on-focus js-skip-to-content">Skip to content

Commit c84c93d

Browse files
committed
Attempt to fix torch.full on MPS for bool values, more problem descriptions
1 parent 9debeaf commit c84c93d

File tree

2 files changed

+87
-49
lines changed

2 files changed

+87
-49
lines changed

aten/src/ATen/native/mps/operations/ConstantOps.mm

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,35 @@
5151
MPSGraph *mpsGraph = make_mps_graph();
5252
newCachedGraph = new CachedGraph(mpsGraph);
5353

54-
// TODO: Does not work for MPSDataTypeBool
55-
MPSGraphTensor* inputTensor = [mpsGraph constantWithScalar:value.toDouble()
56-
shape:input_shape
57-
dataType:getMPSScalarType(self.scalar_type())];
58-
MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor
59-
name:nil];
60-
54+
MPSDataType self_dtype = getMPSScalarType(self.scalar_type());
55+
MPSGraphTensor* outputTensor;
56+
if (self_dtype == MPSDataTypeBool) {
57+
MPSGraphTensor* inputTensor;
58+
if (value.toDouble()) {
59+
// TODO: Simply using value.toDouble() (1.0 for True) does not work!
60+
// Results in outputTensor having value of 255,
61+
// which displays as "False" in Python frontend! Whats going on...?
62+
inputTensor = [mpsGraph constantWithScalar:1.1 shape:input_shape dataType:MPSDataTypeFloat32];
63+
} else {
64+
inputTensor = [mpsGraph constantWithScalar:0.0 shape:input_shape dataType:MPSDataTypeFloat32];
65+
}
66+
outputTensor = [mpsGraph castTensor:inputTensor toType:MPSDataTypeBool name:@"castToBool"];
67+
} else {
68+
// TODO: constantWithScalar output is incorrect for large integers because
69+
// it only accepts double scalars and furthermore MPS only supports single precision...
70+
// therefore bottlenecked by float32 precision even for ints, test:
71+
// >>> torch.tensor(16777217, dtype=torch.float32, device="mps")
72+
// >>> torch.full((1,), 16777217, dtype=torch.int32, device="mps")
73+
// Returning tensor([16777216.], device='mps:0') and
74+
// tensor([16777216], device='mps:0', dtype=torch.int32), respectively.
75+
// The first one is expected while the second one is not, and works on CPU as well as with
76+
// torch.tensor(16777217, device="mps"), which I think goes through CPU first and then
77+
// copies over to the MPS device.
78+
MPSGraphTensor* inputTensor = [mpsGraph constantWithScalar:value.toDouble()
79+
shape:input_shape
80+
dataType:self_dtype];
81+
outputTensor = [mpsGraph identityWithTensor:inputTensor name:nil];
82+
}
6183
newCachedGraph->outputTensor_ = outputTensor;
6284
}
6385
return newCachedGraph;

test/test_mps.py

Lines changed: 58 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,53 +1302,69 @@ def test_to(self):
13021302

13031303
# Test dtype precedence (casting order) in binary operations by comparing to CPU result
13041304
def test_binops_dtype_precedence(self):
1305+
# Example values for all dtypes supported on the MPS backend
13051306
sample_vals = {
1306-
torch.bool: True,
1307-
torch.int16: 8,
1308-
torch.int32: -376,
1309-
torch.int64: 123898458,
1310-
#torch.float16: -234.5, # TODO: broken
1311-
torch.float32: 111.99
1307+
torch.bool: [False, True],
1308+
torch.int16: [-15, 0, 1, 10],
1309+
torch.int32: [-376, 0, 1, 13],
1310+
torch.int64: [-8, 0, 1, 77],
1311+
# torch.float16: [-234.5], # TODO: Broken, unknown why currently
1312+
torch.float32: [-1.0, 0, 0.1, 111.99]
13121313
}
13131314
# Test all combinations of dtypes, operations, dimensionality
1315+
# TODO: 'div' operation broken, needs special rules currently not implement
1316+
# because div is the only arithmetic operation that can result in a floats result
1317+
# with integer or bool inputs. Also infinities can occur.
13141318
for dtype1, dtype2, binop in itertools.product(
1315-
sample_vals.keys(), sample_vals.keys(), ['add', 'sub', 'mul']): # TODO: 'div' broken
1319+
sample_vals.keys(), sample_vals.keys(), ['add', 'sub', 'mul']):
13161320
if binop == 'sub' and (dtype1 == torch.bool or dtype2 == torch.bool):
1317-
# Not supported, so skip
1318-
continue
1319-
self.assertEqual(
1320-
getattr(torch.tensor(sample_vals[dtype1], dtype=dtype1, device='mps'), binop)
1321-
(torch.tensor(sample_vals[dtype2], dtype=dtype2, device='mps')),
1322-
getattr(torch.tensor(sample_vals[dtype1], dtype=dtype1, device='cpu'), binop)
1323-
(torch.tensor(sample_vals[dtype2], dtype=dtype2, device='cpu')))
1324-
self.assertEqual(
1325-
getattr(torch.tensor([sample_vals[dtype1]], dtype=dtype1, device='mps'), binop)
1326-
(torch.tensor([sample_vals[dtype2]], dtype=dtype2, device='mps')),
1327-
getattr(torch.tensor([sample_vals[dtype1]], dtype=dtype1, device='cpu'), binop)
1328-
(torch.tensor([sample_vals[dtype2]], dtype=dtype2, device='cpu')))
1329-
self.assertEqual(
1330-
getattr(torch.tensor(sample_vals[dtype1], dtype=dtype1, device='mps'), binop)
1331-
(torch.tensor([sample_vals[dtype2]], dtype=dtype2, device='mps')),
1332-
getattr(torch.tensor(sample_vals[dtype1], dtype=dtype1, device='cpu'), binop)
1333-
(torch.tensor([sample_vals[dtype2]], dtype=dtype2, device='cpu')))
1334-
self.assertEqual(
1335-
getattr(torch.tensor([sample_vals[dtype1]], dtype=dtype1, device='mps'), binop)
1336-
(torch.tensor(sample_vals[dtype2], dtype=dtype2, device='mps')),
1337-
getattr(torch.tensor([sample_vals[dtype1]], dtype=dtype1, device='cpu'), binop)
1338-
(torch.tensor(sample_vals[dtype2], dtype=dtype2, device='cpu')))
1339-
'''
1340-
# TODO: broken because [MPSGraph constantWithScalar:::] does not support MPSDataTypeBool
1341-
self.assertEqual(
1342-
getattr(torch.full((100,), sample_vals[dtype1], dtype=dtype1, device='mps'), binop)
1343-
(torch.tensor(sample_vals[dtype2], dtype=dtype2, device='mps')),
1344-
getattr(torch.full((100,), sample_vals[dtype1], dtype=dtype1, device='cpu'), binop)
1345-
(torch.tensor(sample_vals[dtype2], dtype=dtype2, device='cpu')))
1346-
self.assertEqual(
1347-
getattr(torch.tensor(sample_vals[dtype1], dtype=dtype1, device='mps'), binop)
1348-
(torch.full((100,), sample_vals[dtype2], dtype=dtype2, device='mps')),
1349-
getattr(torch.tensor(sample_vals[dtype1], dtype=dtype1, device='cpu'), binop)
1350-
(torch.full((100,), sample_vals[dtype2], dtype=dtype2, device='cpu')))
1351-
'''
1321+
continue # Not supported, so skip
1322+
#print(dtype1, dtype2, binop)
1323+
full_sh = (20,)
1324+
#print('assert1')
1325+
for val1, val2 in itertools.product(sample_vals[dtype1], sample_vals[dtype2]):
1326+
self.assertEqual(
1327+
getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
1328+
(torch.tensor(val2, dtype=dtype2, device='mps')),
1329+
getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
1330+
(torch.tensor(val2, dtype=dtype2, device='cpu')))
1331+
#print('assert2')
1332+
self.assertEqual(
1333+
getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop)
1334+
(torch.tensor([val2], dtype=dtype2, device='mps')),
1335+
getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop)
1336+
(torch.tensor([val2], dtype=dtype2, device='cpu')))
1337+
#print('assert3')
1338+
self.assertEqual(
1339+
getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
1340+
(torch.tensor([val2], dtype=dtype2, device='mps')),
1341+
getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
1342+
(torch.tensor([val2], dtype=dtype2, device='cpu')))
1343+
#print('assert4')
1344+
self.assertEqual(
1345+
getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop)
1346+
(torch.tensor(val2, dtype=dtype2, device='mps')),
1347+
getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop)
1348+
(torch.tensor(val2, dtype=dtype2, device='cpu')))
1349+
#'''
1350+
# Multiple problems with [MPSGraph constantWithScalar:shape:dataType:] prevent
1351+
# these tests from completing successfully currently
1352+
# TODO: Research problem with int16, is it also related to constantWithScalar?
1353+
# TODO: Stateful bug with False, False, add in assert5? Related to the cache key
1354+
# or more serious problem?
1355+
#print('assert5', val1, val2)
1356+
self.assertEqual(
1357+
getattr(torch.full(full_sh, val1, dtype=dtype1, device='mps'), binop)
1358+
(torch.tensor(val2, dtype=dtype2, device='mps')),
1359+
getattr(torch.full(full_sh, val1, dtype=dtype1, device='cpu'), binop)
1360+
(torch.tensor(val2, dtype=dtype2, device='cpu')))
1361+
#print('assert6')
1362+
self.assertEqual(
1363+
getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
1364+
(torch.full(full_sh, val2, dtype=dtype2, device='mps')),
1365+
getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
1366+
(torch.full(full_sh, val2, dtype=dtype2, device='cpu')))
1367+
#'''
13521368

13531369

13541370
class TestSmoothL1Loss(TestCase):

0 commit comments

Comments
 (0)
0