10000 fixed bernoulli1 and bernoulli2 tests · pytorch/pytorch@97001e1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 97001e1

Browse files
zero000064pytorchmergebot
authored andcommitted
fixed bernoulli1 and bernoulli2 tests
1 parent b4065a8 commit 97001e1

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

test/inductor/test_torchinductor.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7944,9 +7944,7 @@ def test_bernoulli1(self):
79447944
def fn(a):
79457945
b = a.clone()
79467946
# aten.bernoulli_() uses aten.bernoulli.p() behind the scene, so it will be decomposed.
7947-
return aten.bernoulli_(b).sum() / torch.prod(
7948-
torch.tensor(a.size(), device=a.device)
7949-
)
7947+
return aten.bernoulli_(b).sum() / torch.prod(torch.tensor(a.size()))
79507948

79517949
p = 0.3
79527950
self.common(
@@ -7961,9 +7959,7 @@ def fn(a):
79617959
@skip_if_triton_cpu
79627960
def test_bernoulli2(self):
79637961
def fn(a):
7964-
return aten.bernoulli(a).sum() / torch.prod(
7965-
torch.tensor(a.size(), device=a.device)
7966-
)
7962+
return aten.bernoulli(a).sum() / torch.prod(torch.tensor(a.size()))
79677963

79687964
p = 0.3
79697965
self.common(

0 commit comments

Comments
 (0)
0