From 83ab0bbdbab3eefa19660d8c1d35ae950afa91d9 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 18 Aug 2025 08:30:36 -0700 Subject: [PATCH] [MPS] Fix error check for torch.var on scalar Fixes https://github.com/pytorch/pytorch/issues/160738 [ghstack-poisoned] --- aten/src/ATen/native/mps/operations/ReduceOps.mm | 2 +- test/test_mps.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 4b209403f853..ae13504d9003 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -456,7 +456,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t, errMessage += ": reduction dim must be in the range of input shape"; for (const auto dim : dim_value) { auto wrap_dim = maybe_wrap_dim(dim, num_input_dims); - TORCH_CHECK(wrap_dim < static_cast(input_shape.size()), errMessage.c_str()) + TORCH_CHECK(wrap_dim < (num_input_dims ? num_input_dims : 1), errMessage.c_str()) } } diff --git a/test/test_mps.py b/test/test_mps.py index e0bf6a8a08ed..deaec2886d32 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -5321,6 +5321,9 @@ def helper(): helper() + # Regression test for https://github.com/pytorch/pytorch/issues/160738 + self.assertTrue(torch.var(torch.tensor(3.13, device='mps'), dim=0).isnan().item()) + # Test forward amax def test_amax(self): def helper(shape, dim, keepdim):