8000 Merge pull request #24561 from charris/backport-24556-24559 · numpy/numpy@21f3719 · GitHub
[go: up one dir, main page]

Skip to content

Commit 21f3719

Browse files
authored
Merge pull request #24561 from charris/backport-24556-24559
BUG: fix comparisons between masked and unmasked structured arrays
2 parents 498bf30 + 3499c94 commit 21f3719

File tree

2 files changed

+48
-9
lines changed

2 files changed

+48
-9
lines changed

numpy/ma/core.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4127,6 +4127,9 @@ def _comparison(self, other, compare):
41274127
# Now take care of the mask; the merged mask should have an item
41284128
# masked if all fields were masked (in one and/or other).
41294129
mask = (mask == np.ones((), mask.dtype))
4130+
# Ensure we can compare masks below if other was not masked.
4131+
if omask is np.False_:
4132+
omask = np.zeros((), smask.dtype)
41304133

41314134
else:
41324135
# For regular arrays, just use the data as they come.
@@ -4137,12 +4140,14 @@ def _comparison(self, other, compare):
41374140
if isinstance(check, (np.bool_, bool)):
41384141
return masked if mask else check
41394142

4140-
if mask is not nomask and compare in (operator.eq, operator.ne):
4141-
# Adjust elements that were masked, which should be treated
4142-
# as equal if masked in both, unequal if masked in one.
4143-
# Note that this works automatically for structured arrays too.
4144-
# Ignore this for operations other than `==` and `!=`
4145-
check = np.where(mask, compare(smask, omask), check)
4143+
if mask is not nomask:
4144+
if compare in (operator.eq, operator.ne):
4145+
# Adjust elements that were masked, which should be treated
4146+
# as equal if masked in both, unequal if masked in one.
4147+
# Note that this works automatically for structured arrays too.
4148+
# Ignore this for operations other than `==` and `!=`
4149+
check = np.where(mask, compare(smask, omask), check)
4150+
41464151
if mask.shape != check.shape:
41474152
# Guarantee consistency of the shape, making a copy since the
41484153
# the mask may need to get written to later.

numpy/ma/tests/test_core.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,8 +1310,8 @@ def test_minmax_dtypes(self):
13101310
m1 = [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]
13111311
xm = masked_array(x, mask=m1)
13121312
xm.set_fill_value(1e+20)
1313-
float_dtypes = [np.half, np.single, np.double,
1314-
np.longdouble, np.cfloat, np.cdouble, np.clongdouble]
1313+
float_dtypes = [np.float16, np.float32, np.float64, np.longdouble,
1314+
np.complex64, np.complex128, np.clongdouble]
13151315
for float_dtype in float_dtypes:
13161316
assert_equal(masked_array(x, mask=m1, dtype=float_dtype).max(),
13171317
float_dtype(a10))
@@ -1614,6 +1614,23 @@ def test_ne_on_structured(self):
16141614
assert_equal(test.mask, [[False, False], [False, True]])
16151615
assert_(test.fill_value == True)
16161616

1617+
def test_eq_ne_structured_with_non_masked(self):
1618+
a = array([(1, 1), (2, 2), (3, 4)],
1619+
mask=[(0, 1), (0, 0), (1, 1)], dtype='i4,i4')
1620+
eq = a == a.data
1621+
ne = a.data != a
1622+
# Test the obvious.
1623+
assert_(np.all(eq))
1624+
assert_(not np.any(ne))
1625+
# Expect the mask set only for items with all fields masked.
1626+
expected_mask = a.mask == np.ones((), a.mask.dtype)
1627+
assert_array_equal(eq.mask, expected_mask)
1628+
assert_array_equal(ne.mask, expected_mask)
1629+
# The masked element will indicated not equal, because the
1630+
# masks did not match.
1631+
assert_equal(eq.data, [True, True, False])
1632+
assert_array_equal(eq.data, ~ne.data)
1633+
16171634
def test_eq_ne_structured_extra(self):
16181635
# ensure simple examples are symmetric and make sense.
16191636
# from https://github.com/numpy/numpy/pull/8590#discussion_r101126465
@@ -1745,6 +1762,23 @@ def test_eq_for_numeric(self, dt1, dt2, fill):
17451762
assert_equal(test.mask, [True, False])
17461763
assert_(test.fill_value == True)
17471764

1765+
@pytest.mark.parametrize("op", [operator.eq, operator.lt])
1766+
def test_eq_broadcast_with_unmasked(self, op):
1767+
a = array([0, 1], mask=[0, 1])
1768+
b = np.arange(10).reshape(5, 2)
1769+
result = op(a, b)
1770+
assert_(result.mask.shape == b.shape)
1771+
assert_equal(result.mask, np.zeros(b.shape, bool) | a.mask)
1772+
1773+
@pytest.mark.parametrize("op", [operator.eq, operator.gt])
1774+
def test_comp_no_mask_not_broadcast(self, op):
1775+
# Regression test for failing doctest in MaskedArray.nonzero
1776+
# after gh-24556.
1777+
a = array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
1778+
result = op(a, 3)
1779+
assert_(not result.mask.shape)
1780+
assert_(result.mask is nomask)
1781+
17481782
@pytest.mark.parametrize('dt1', num_dts, ids=num_ids)
17491783
@pytest.mark.parametrize('dt2', num_dts, ids=num_ids)
17501784
@pytest.mark.parametrize('fill', [None, 1])
@@ -3444,7 +3478,7 @@ def test_ravel_order(self, order, data_order):
34443478
raveled = x.ravel(order)
34453479
assert (raveled.filled(0) == 0).all()
34463480

3447-
# NOTE: Can be wrong if arr order is neither C nor F and `order="K"`
3481+
# NOTE: Can be wrong if arr order is neither C nor F and `order="K"`
34483482
assert_array_equal(arr.ravel(order), x.ravel(order)._data)
34493483

34503484
def test_reshape(self):

0 commit comments

Comments
 (0)
0