From 17440d692dd648159c304676772ac1d5c6e40288 Mon Sep 17 00:00:00 2001 From: Marten van Kerkwijk Date: Sun, 27 Aug 2023 12:46:44 +0200 Subject: [PATCH] BUG: ensure nomask in comparison result is not broadcast --- numpy/ma/core.py | 23 ++++++++++++----------- numpy/ma/tests/test_core.py | 9 +++++++++ 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 545cd43381ef..5caffe90576e 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -4153,17 +4153,18 @@ def _comparison(self, other, compare): if isinstance(check, (np.bool_, bool)): return masked if mask else check - if mask is not nomask and compare in (operator.eq, operator.ne): - # Adjust elements that were masked, which should be treated - # as equal if masked in both, unequal if masked in one. - # Note that this works automatically for structured arrays too. - # Ignore this for operations other than `==` and `!=` - check = np.where(mask, compare(smask, omask), check) - - if mask.shape != check.shape: - # Guarantee consistency of the shape, making a copy since the - # the mask may need to get written to later. - mask = np.broadcast_to(mask, check.shape).copy() + if mask is not nomask: + if compare in (operator.eq, operator.ne): + # Adjust elements that were masked, which should be treated + # as equal if masked in both, unequal if masked in one. + # Note that this works automatically for structured arrays too. + # Ignore this for operations other than `==` and `!=` + check = np.where(mask, compare(smask, omask), check) + + if mask.shape != check.shape: + # Guarantee consistency of the shape, making a copy since the + # the mask may need to get written to later. + mask = np.broadcast_to(mask, check.shape).copy() check = check.view(type(self)) check._update_from(self) diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 98eda3325a51..a4f8deb93ab5 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -1770,6 +1770,15 @@ def test_eq_broadcast_with_unmasked(self, op): assert_(result.mask.shape == b.shape) assert_equal(result.mask, np.zeros(b.shape, bool) | a.mask) + @pytest.mark.parametrize("op", [operator.eq, operator.gt]) + def test_comp_no_mask_not_broadcast(self, op): + # Regression test for failing doctest in MaskedArray.nonzero + # after gh-24556. + a = array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + result = op(a, 3) + assert_(not result.mask.shape) + assert_(result.mask is nomask) + @pytest.mark.parametrize('dt1', num_dts, ids=num_ids) @pytest.mark.parametrize('dt2', num_dts, ids=num_ids) @pytest.mark.parametrize('fill', [None, 1])