8000 Merge · pytorch/pytorch@b9d1d71 · GitHub
[go: up one dir, main page]

Skip to content

Commit b9d1d71

Browse files
committed
Merge
1 parent d028f9e commit b9d1d71

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

aten/src/ATen/native/mps/operations/BinaryOps.mm

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
5555

5656
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
5757
@autoreleasepool {
58-
string key = op_name + getTensorsStringKey({self, other, output_}, /*use_scalar_value*/ false);
58+
string key = op_name + getTensorsStringKey({self, other}, /*use_scalar_value*/ false);
59+
//std::cout << key << std::endl;
5960
BinaryOpCachedGraph* cachedGraph = static_cast<BinaryOpCachedGraph *>(cache_->LookUp(key));
6061

6162
if (!cachedGraph) {

test/test_mps.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,7 +1516,7 @@ def test_binops_dtype_precedence(self):
15161516
# MPS: True + True = False
15171517
if binop == 'add' and dtype1 == torch.bool and dtype2 == torch.bool and val1 and val2:
15181518
continue
1519-
# print(f'\n{dtype1},{dtype2}: ({val1}).{binop}({val2})')
1519+
print(f'{dtype1},{dtype2}: ({val1}).{binop}({val2})')
15201520
# print('assert1')
15211521
self.assertEqual(
15221522
getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
@@ -1548,16 +1548,16 @@ def test_binops_dtype_precedence(self):
15481548
# TODO: Stateful bug with False, False, add in assert5? Related to the cache key
15491549
# or more serious problem?
15501550
# - Cache key looks correct, behavior currently completely unexplained
1551-
'''
1551+
#'''
15521552
print('assert5')
15531553
x1 = torch.full(full_sh, val1, dtype=dtype1, device='mps')
15541554
y1 = torch.tensor(val2, dtype=dtype2, device='mps')
15551555
x2 = torch.full(full_sh, val1, dtype=dtype1, device='cpu')
15561556
y2 = torch.tensor(val2, dtype=dtype2, device='cpu')
15571557
print('x1', x1, hex(x1.data_ptr()))
15581558
print('y1', y1, hex(y1.data_ptr()))
1559-
#print('x2', x2, hex(x2.data_ptr()))
1560-
#print('y2', y2, hex(y2.data_ptr()))
1559+
print('x2', x2, hex(x2.data_ptr()))
1560+
print('y2', y2, hex(y2.data_ptr()))
15611561
self.assertEqual(getattr(x1, binop)(y1), getattr(x2, binop)(y2))
15621562
print('assert6')
15631563
x3 = torch.tensor(val1, dtype=dtype1, device='mps')
@@ -1566,8 +1566,8 @@ def test_binops_dtype_precedence(self):
15661566
y4 = torch.full(full_sh, val2, dtype=dtype2, device='cpu')
15671567
print('x3', x3, hex(x3.data_ptr()))
15681568
print('y3', y3, hex(y3.data_ptr()))
1569-
#print('x4', x4, hex(x4.data_ptr()))
1570-
#print('y4', y4, hex(y4.data_ptr()))
1569+
print('x4', x4, hex(x4.data_ptr()))
1570+
print('y4', y4, hex(y4.data_ptr()))
15711571
#breakpoint()
15721572
self.assertEqual(getattr(x3, binop)(y3), getattr(x4, binop)(y4))
15731573
#self.assertEqual(

0 commit comments

Comments
 (0)
0