8000 [MPS] Fix error check for torch.var on scalar · pytorch/pytorch@e691c86 · GitHub
[go: up one dir, main page]

Skip to content

Commit e691c86

Browse files
committed
[MPS] Fix error check for torch.var on scalar
Fixes #160738 ghstack-source-id: 8102c54 Pull Request resolved: #160889
1 parent 4751eb3 commit e691c86

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

aten/src/ATen/native/mps/operations/ReduceOps.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
456456
errMessage += ": reduction dim must be in the range of input shape";
457457
for (const auto dim : dim_value) {
458458
auto wrap_dim = maybe_wrap_dim(dim, num_input_dims);
459-
TORCH_CHECK(wrap_dim < static_cast<decltype(wrap_dim)>(input_shape.size()), errMessage.c_str())
459+
TORCH_CHECK(wrap_dim < (num_input_dims ? num_input_dims : 1), errMessage.c_str())
460460
}
461461
}
462462

test/test_mps.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5321,6 +5321,9 @@ def helper():
53215321

53225322
helper()
53235323

5324+
# Regression test for https://github.com/pytorch/pytorch/issues/160738
5325+
self.assertTrue(torch.var(torch.tensor(3.13, device='mps'), dim=0).isnan().item())
5326+
53245327
# Test forward amax
53255328
def test_amax(self):
53265329
def helper(shape, dim, keepdim):

0 commit comments

Comments
 (0)
0