8000 BUG: ensure nomask in comparison result is not broadcast · numpy/numpy@3499c94 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3499c94

Browse files
mhvkcharris
authored andcommitted
BUG: ensure nomask in comparison result is not broadcast
1 parent 806c829 commit 3499c94

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

numpy/ma/core.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4140,17 +4140,18 @@ def _comparison(self, other, compare):
41404140
if isinstance(check, (np.bool_, bool)):
41414141
return masked if mask else check
41424142

4143-
if mask is not nomask and compare in (operator.eq, operator.ne):
4144-
# Adjust elements that were masked, which should be treated
4145-
# as equal if masked in both, unequal if masked in one.
4146-
# Note that this works automatically for structured arrays too.
4147-
# Ignore this for operations other than `==` and `!=`
4148-
check = np.where(mask, compare(smask, omask), check)
4149-
4150-
if mask.shape != check.shape:
4151-
# Guarantee consistency of the shape, making a copy since the
4152-
# the mask may need to get written to later.
4153-
mask = np.broadcast_to(mask, check.shape).copy()
4143+
if mask is not nomask:
4144+
if compare in (operator.eq, operator.ne):
4145+
# Adjust elements that were masked, which should be treated
4146+
# as equal if masked in both, unequal if masked in one.
4147+
# Note that this works automatically for structured arrays too.
4148+
# Ignore this for operations other than `==` and `!=`
4149+
check = np.where(mask, compare(smask, omask), check)
4150+
4151+
if mask.shape != check.shape:
4152+
# Guarantee consistency of the shape, making a copy since the
4153+
# the mask may need to get written to later.
4154+
mask = np.broadcast_to(mask, check.shape).copy()
41544155

41554156
check = check.view(type(self))
41564157
check._update_from(self)

numpy/ma/tests/test_core.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1770,6 +1770,15 @@ def test_eq_broadcast_with_unmasked(self, op):
17701770
assert_(result.mask.shape == b.shape)
17711771
assert_equal(result.mask, np.zeros(b.shape, bool) | a.mask)
17721772

1773+
@pytest.mark.parametrize("op", [operator.eq, operator.gt])
1774+
def test_comp_no_mask_not_broadcast(self, op):
1775+
# Regression test for failing doctest in MaskedArray.nonzero
1776+
# after gh-24556.
1777+
a = array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
1778+
result = op(a, 3)
1779+
assert_(not result.mask.shape)
1780+
assert_(result.mask is nomask)
1781+
17731782
@pytest.mark.parametrize('dt1', num_dts, ids=num_ids)
17741783
@pytest.mark.parametrize('dt2', num_dts, ids=num_ids)
17751784
@pytest.mark.parametrize('fill', [None, 1])

0 commit comments

Comments
 (0)
0