8000 [MPS] Avoid outputing zeros from `exponential_` for MPS (#159386) · pytorch/pytorch@70d2e9b · GitHub
[go: up one dir, main page]

Skip to content

Commit 70d2e9b

Browse files
kurtamohlerpytorchmergebot
authored andcommitted
[MPS] Avoid outputing zeros from exponential_ for MPS (#159386)
Fixes #159103 Pull Request resolved: #159386 Approved by: https://github.com/malfet
1 parent 62f98db commit 70d2e9b

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

aten/src/ATen/native/mps/operations/Distributions.mm

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,9 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, std::optional<Generator
418418
MPSGraphTensor* logTensor = [mpsGraph logarithmWithTensor:subtractTensor name:nil];
419419
return [mpsGraph divisionWithPrimaryTensor:logTensor secondaryTensor:minusLambdaTensor name:nil];
420420
};
421+
auto eps = std::numeric_limits<float>::epsilon();
421422
return mps::random_mps_impl<double>(self,
422-
0.0,
423+
eps,
423424
1.0,
424425
std::nullopt,
425426
std::nullopt,

test/test_mps.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7681,6 +7681,12 @@ def test_exponential_1(self):
76817681
self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,))
76827682
self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,))
76837683

7684+
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
7685+
def test_exponential_nonzero(self, dtype):
7686+
for _ in range(100):
7687+
a = torch.empty(32_000, device="mps", dtype=dtype).exponential_()
7688+
self.assertTrue((a != 0).all())
7689+
76847690
# Test add
76857691
def test_add_sub(self):
76867692
def helper(shape, alpha, op_name, inplace):

0 commit comments

Comments
 (0)
0