8000 [MPS] tril op not handling infs correctly (#150479) · pytorch/pytorch@a3cd7b0 · GitHub
[go: up one dir, main page]

Skip to content

Commit a3cd7b0

Browse files
pytorchbotIsalia20
andauthored
[MPS] tril op not handling infs correctly (#150479)
[MPS] tril op not handling infs correctly (#149866) Fixes #149813 Pull Request resolved: #149866 Approved by: https://github.com/malfet (cherry picked from commit ba46643) Co-authored-by: Isalia20 <irakli.salia854@gmail.com>
1 parent 8522972 commit a3cd7b0

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

aten/src/ATen/native/mps/operations/TriangularOps.mm

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,12 @@
107107
numLowerTensor:negDiagMinusOneTensor
108108
numUpperTensor:minusOneTensor
109109
name:nil];
110-
outputTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor secondaryTensor:complementTensor name:nil];
110+
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:getMPSDataType(self)];
111+
MPSGraphTensor* mask = [mpsGraph equalWithPrimaryTensor:complementTensor secondaryTensor:zeroTensor name:nil];
112+
outputTensor = [mpsGraph selectWithPredicateTensor:mask
113+
truePredicateTensor:inputTensor
114+
falsePredicateTensor:zeroTensor
115+
name:nil];
111116
}
112117

113118
newCachedGraph->inputTensor_ = inputTensor;

test/test_mps.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7849,13 +7849,21 @@ def helper(shape, diag=0):
78497849
self.assertEqual(tril_result, tril_result_cpu)
78507850
self.assertEqual(x.grad, cpu_x.grad)
78517851

7852-
helper((2, 8, 4, 5))
7853-
helper((2, 8, 4, 5), diag=1)
7854-
helper((2, 8, 4, 5), diag=2)
7855-
helper((2, 8, 4, 5), diag=3)
7856-
helper((2, 8, 4, 5), diag=-1)
7857-
helper((2, 8, 4, 5), diag=-2)
7858-
helper((2, 8, 4, 5), diag=-3)
7852+
for diag in [0, 1, 2, 3, -1, -2, -3]:
7853+
helper((2, 8, 4, 5), diag=diag)
7854+
7855+
def helper_nans_infs(value, diag_vals=(0, 1, -2)):
7856+
"""For nans and infs"""
7857+
mps_tensor = torch.full((2, 2, 5, 5), value, device="mps")
7858+
cpu_tensor = torch.full((2, 2, 5, 5), value, device="cpu")
7859+
for diag in diag_vals:
7860+
mps_result = torch.tril(mps_tensor, diagonal=diag)
7861+
cpu_result = torch.tril(cpu_tensor, diagonal=diag)
7862+
self.assertEqual(mps_result, cpu_result, f"Mismatch for diag={diag}")
7863+
7864+
helper_nans_infs(float("inf"))
7865+
helper_nans_infs(float("-inf"))
7866+
helper_nans_infs(float("nan"))
78597867

78607868
# test eye
78617869
def test_eye(self):

0 commit comments

Comments
 (0)
0