8000 Update · pytorch/pytorch@391bff7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 391bff7

Browse files
committed
Update
[ghstack-poisoned]
1 parent c283ae3 commit 391bff7

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

aten/src/ATen/native/mps/operations/Distributions.mm

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,9 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, std::optional<Generator
418418
MPSGraphTensor* logTensor = [mpsGraph logarithmWithTensor:subtractTensor name:nil];
419419
return [mpsGraph divisionWithPrimaryTensor:logTensor secondaryTensor:minusLambdaTensor name:nil];
420420
};
421+
auto eps = std::numeric_limits<float>::epsilon();
421422
return mps::random_mps_impl<double>(self,
422-
0.0,
423+
eps,
423424
1.0,
424425
std::nullopt,
425426
std::nullopt,

test/test_mps.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7722,6 +7722,12 @@ def test_exponential_1(self):
77227722
self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,))
77237723
self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,))
77247724

7725+
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
7726+
def test_exponential_nonzero(self, dtype):
7727+
for _ in range(100):
7728+
a = torch.empty(32_000, device="mps", dtype=dtype).exponential_()
7729+
self.assertTrue((a != 0).all())
7730+
77257731
# Test add
77267732
def test_add_sub(self):
77277733
def helper(shape, alpha, op_name, inplace):

0 commit comments

Comments
 (0)
0