8000 Merge pull request #8590 from mhvk/ma/eq_ne_axis_bug · numpy/numpy@64111c5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 64111c5

Browse files
authored
Merge pull request #8590 from mhvk/ma/eq_ne_axis_bug
BUG MaskedArray __eq__ wrong for masked scalar, multi-d recarray
2 parents 2dd9125 + 3435dd9 commit 64111c5

File tree

3 files changed

+181
-90
lines changed

3 files changed

+181
-90
lines changed

doc/release/1.13.0-notes.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,15 @@ Better default repr for ``ndarray`` subclasses
179179
Subclasses of ndarray with no ``repr`` specialization now correctly indent
180180
their data and type lines.
181181

182+
More reliable comparisons of masked arrays
183+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
184+
Comparisons of masked arrays were buggy for masked scalars and failed for
185+
structured arrays with dimension higher than one. Both problems are now
186+
solved. In the process, it was ensured that in getting the result for a
187+
structured array, masked fields are properly ignored, i.e., the result is equal
188+
if all fields that are non-masked in both are equal, thus making the behaviour
189+
identical to what one gets by comparing an unstructured masked array and then
190+
doing ``.all()`` over some axis.
182191

183192
Changes
184193
=======

numpy/ma/core.py

Lines changed: 77 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from __future__ import division, absolute_import, print_function
2424

2525
import sys
26+
import operator
2627
import warnings
2728
from functools import reduce
2829

@@ -1602,21 +1603,11 @@ def make_mask(m, copy=False, shrink=True, dtype=MaskType):
16021603
"""
16031604
if m is nomask:
16041605
return nomask
1605-
elif isinstance(m, ndarray):
1606-
# We won't return after this point to make sure we can shrink the mask
1607-
# Fill the mask in case there are missing data
1608-
m = filled(m, True)
1609-
# Make sure the input dtype is valid
1610-
dtype = make_mask_descr(dtype)
1611-
if m.dtype == dtype:
1612-
if copy:
1613-
result = m.copy()
1614-
else:
1615-
result = m
1616-
else:
1617-
result = np.array(m, dtype=dtype, copy=copy)
1618-
else:
1619-
result = np.array(filled(m, True), dtype=MaskType)
1606+
1607+
# Make sure the input dtype is valid.
1608+
dtype = make_mask_descr(dtype)
1609+
# Fill the mask in case there are missing data; turn it into an ndarray.
1610+
result = np.array(filled(m, True), copy=copy, dtype=dtype, subok=True)
16201611
# Bas les masques !
16211612
if shrink and (not result.dtype.names) and (not result.any()):
16221613
return nomask
@@ -1733,7 +1724,8 @@ def _recursive_mask_or(m1, m2, newmask):
17331724
if (dtype1 != dtype2):
17341725
raise ValueError("Incompatible dtypes '%s'<>'%s'" % (dtype1, dtype2))
17351726
if dtype1.names:
1736-
newmask = np.empty_like(m1)
1727+
# Allocate an output mask array with the properly broadcast shape.
1728+
newmask = np.empty(np.broadcast(m1, m2).shape, dtype1)
17371729
_recursive_mask_or(m1, m2, newmask)
17381730
return newmask
17391731
return make_mask(umath.logical_or(m1, m2), copy=copy, shrink=shrink)
@@ -3873,81 +3865,84 @@ def _delegate_binop(self, other):
38733865
return True
38743866
return False
38753867

3876-
def __eq__(self, other):
3877-
"""
3878-
Check whether other equals self elementwise.
3868+
def _comparison(self, other, compare):
3869+
"""Compare self with other using operator.eq or operator.ne.
3870+
3871+
When either of the elements is masked, the result is masked as well,
3872+
but the underlying boolean data are still set, with self and other
3873+
considered equal if both are masked, and unequal otherwise.
38793874
3875+
For structured arrays, all fields are combined, with masked values
3876+
ignored. The result is masked if all fields were masked, with self
3877+
and other considered equal only if both were fully masked.
38803878
"""
3881-
if self is masked:
3882-
return masked
38833879
omask = getmask(other)
3884-
if omask is nomask:
3885-
check = self.filled(0).__eq__(other)
3886-
try:
3887-
check = check.view(type(self))
3888-
check._mask = self._mask
3889-
except AttributeError:
3890-
# Dang, we have a bool instead of an array: return the bool
3891-
return check
3880+
smask = self.mask
3881+
mask = mask_or(smask, omask, copy=True)
3882+
3883+
odata = getdata(other)
3884+
if mask.dtype.names:
3885+
# For possibly masked structured arrays we need to be careful,
3886+
# since the standard structured array comparison will use all
3887+
# fields, masked or not. To avoid masked fields influencing the
3888+
# outcome, we set all masked fields in self to other, so they'll
3889+
# count as equal. To prepare, we ensure we have the right shape.
3890+
broadcast_shape = np.broadcast(self, odata).shape
3891+
sbroadcast = np.broadcast_to(self, broadcast_shape, subok=True)
3892+
sbroadcast._mask = mask
3893+
sdata = sbroadcast.filled(odata)
3894+
# Now take care of the mask; the merged mask should have an item
3895+
# masked if all fields were masked (in one and/or other).
3896+
mask = (mask == np.ones((), mask.dtype))
3897+
38923898
else:
3893-
odata = filled(other, 0)
3894-
check = self.filled(0).__eq__(odata).view(type(self))
3895-
if self._mask is nomask:
3896-
check._mask = omask
3897-
else:
3898-
mask = mask_or(self._mask, omask)
3899-
if mask.dtype.names:
3900-
if mask.size > 1:
3901-
axis = 1
3902-
else:
3903-
axis = None
3904-
try:
3905-
mask = mask.view((bool_, len(self.dtype))).all(axis)
3906-
except (ValueError, np.AxisError):
3907-
# TODO: what error are we trying to catch here?
3908-
# invalid axis, or invalid view?
3909-
mask = np.all([[f[n].all() for n in mask.dtype.names]
3910-
for f in mask], axis=axis)
3911-
check._mask = mask
3899+
# For regular arrays, just use the data as they come.
3900+
sdata = self.data
3901+
3902+
check = compare(sdata, odata)
3903+
3904+
if isinstance(check, (np.bool_, bool)):
3905+
return masked if mask else check
3906+
3907+
if mask is not nomask:
3908+
# Adjust elements that were masked, which should be treated
3909+
# as equal if masked in both, unequal if masked in one.
3910+
# Note that this works automatically for structured arrays too.
3911+
check = np.where(mask, compare(smask, omask), check)
3912+
if mask.shape != check.shape:
3913+
# Guarantee consistency of the shape, making a copy since the
3914+
# the mask may need to get written to later.
3915+
mask = np.broadcast_to(mask, check.shape).copy()
3916+
3917+
check = check.view(type(self))
3918+
check._mask = mask
39123919
return check
39133920

3914-
def __ne__(self, other):
3921+
def __eq__(self, other):
3922+
"""Check whether other equals self elementwise.
3923+
3924+
When either of the elements is masked, the result is masked as well,
3925+
but the underlying boolean data are still set, with self and other
3926+
considered equal if both are masked, and unequal otherwise.
3927+
3928+
For structured arrays, all fields are combined, with masked values
3929+
ignored. The result is masked if all fields were masked, with self
3930+
and other considered equal only if both were fully masked.
39153931
"""
3916-
Check whether other doesn't equal self elementwise
3932+
return self._comparison(other, operator.eq)
39173933

3934+
def __ne__(self, other):
3935+
"""Check whether other does not equal self elementwise.
3936+
3937+
When either of the elements is masked, the result is masked as well,
3938+
but the underlying boolean data are still set, with self and other
3939+
considered equal if both are masked, and unequal otherwise.
3940+
3941+
For structured arrays, all fields are combined, with masked values
3942+
ignored. The result is masked if all fields were masked, with self
3943+
and other considered equal only if both were fully masked.
39183944
"""
3919-
if self is masked:
3920-
return masked
3921-
omask = getmask(other)
3922-
if omask is nomask:
3923-
check = self.filled(0).__ne__(other)
3924-
try:
3925-
check = check.view(type(self))
3926-
check._mask = self._mask
3927-
except AttributeError:
3928-
# In case check is a boolean (or a numpy.bool)
3929-
return check
3930-
else:
3931-
odata = filled(other, 0)
3932-
check = self.filled(0).__ne__(odata).view(type(self))
3933-
if self._mask is nomask:
3934-
check._mask = omask
3935-
else:
3936-
mask = mask_or(self._mask, omask)
3937-
if mask.dtype.names:
3938-
if mask.size > 1:
3939-
axis = 1
3940-
else:
3941-
axis = None
3942-
try:
3943-
mask = mask.view((bool_, len(self.dtype))).all(axis)
3944-
except (ValueError, np.AxisError):
3945-
# TODO: what error are we trying to catch here?
3946-
# invalid axis, or invalid view?
3947-
mask = np.all([[f[n].all() for n in mask.dtype.names]
3948-
for f in mask], axis=axis)
3949-
check._mask = mask
3950-
return check
3945+
return self._comparison(other, operator.ne)
39513946

39523947
def __add__(self, other):
39533948
"""

numpy/ma/tests/test_core.py

Lines changed: 95 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,32 +1335,96 @@ def test_eq_on_structured(self):
13351335
ndtype = [('A', int), ('B', int)]
13361336
a = array([(1, 1), (2, 2)], mask=[(0, 1), (0, 0)], dtype=ndtype)
13371337
test = (a == a)
1338-
assert_equal(test, [True, True])
1338+
assert_equal(test.data, [True, True])
1339+
assert_equal(test.mask, [False, False])
1340+
test = (a == a[0])
1341+
assert_equal(test.data, [True, False])
13391342
assert_equal(test.mask, [False, False])
13401343
b = array([(1, 1), (2, 2)], mask=[(1, 0), (0, 0)], dtype=ndtype)
13411344
test = (a == b)
1342-
assert_equal(test, [False, True])
1345+
assert_equal(test.data, [False, True])
1346+
assert_equal(test.mask, [True, False])
1347+
test = (a[0] == b)
1348+
assert_equal(test.data, [False, False])
13431349
assert_equal(test.mask, [True, False])
13441350
b = array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype)
13451351
test = (a == b)
1346-
assert_equal(test, [True, False])
1352+
assert_equal(test.data, [True, True])
13471353
assert_equal(test.mask, [False, False])
1354+
# complicated dtype, 2-dimensional array.
1355+
ndtype = [('A', int), ('B', [('BA', int), ('BB', int)])]
1356+
a = array([[(1, (1, 1)), (2, (2, 2))],
1357+
[(3, (3, 3)), (4, (4, 4))]],
1358+
mask=[[(0, (1, 0)), (0, (0, 1))],
1359+
[(1, (0, 0)), (1, (1, 1))]], dtype=ndtype)
1360+
test = (a[0, 0] == a)
1361+
assert_equal(test.data, [[True, False], [False, False]])
1362+
assert_equal(test.mask, [[False, False], [False, True]])
13481363

13491364
def test_ne_on_structured(self):
13501365
# Test the equality of structured arrays
13511366
ndtype = [('A', int), ('B', int)]
13521367
a = array([(1, 1), (2, 2)], mask=[(0, 1), (0, 0)], dtype=ndtype)
13531368
test = (a != a)
1354-
assert_equal(test, [False, False])
1369+
assert_equal(test.data, [False, False])
1370+
assert_equal(test.mask, [False, False])
1371+
test = (a != a[0])
1372+
assert_equal(test.data, [False, True])
13551373
assert_equal(test.mask, [False, False])
13561374
b = array([(1, 1), (2, 2)], mask=[(1, 0), (0, 0)], dtype=ndtype)
13571375
test = (a != b)
1358-
assert_equal(test, [True, False])
1376+
assert_equal(test.data, [True, False])
1377+
assert_equal(test.mask, [True, False])
1378+
test = (a[0] != b)
1379+
assert_equal(test.data, [True, True])
13591380
assert_equal(test.mask, [True, False])
13601381
b = array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype)
13611382
test = (a != b)
1362-
assert_equal(test, [False, True])
1383+
assert_equal(test.data, [False, False])
13631384
assert_equal(test.mask, [False, False])
1385+
# complicated dtype, 2-dimensional array.
1386+
ndtype = [('A', int), ('B', [('BA', int), ('BB', int)])]
1387+
a = array([[(1, (1, 1)), (2, (2, 2))],
1388+
[(3, (3, 3)), (4, (4, 4))]],
1389+
mask=[[(0, (1, 0)), (0, (0, 1))],
1390+
[(1, (0, 0)), (1, (1, 1))]], dtype=ndtype)
1391+
test = (a[0, 0] != a)
1392+
assert_equal(test.data, [[False, True], [True, True]])
1393+
assert_equal(test.mask, [[False, False], [False, True]])
1394+
1395+
def test_eq_ne_structured_extra(self):
1396+
# ensure simple examples are symmetric and make sense.
1397+
# from https://github.com/numpy/numpy/pull/8590#discussion_r101126465
1398+
dt = np.dtype('i4,i4')
1399+
for m1 in (mvoid((1, 2), mask=(0, 0), dtype=dt),
1400+
mvoid((1, 2), mask=(0, 1), dtype=dt),
1401+
mvoid((1, 2), mask=(1, 0), dtype=dt),
1402+
mvoid((1, 2), mask=(1, 1), dtype=dt)):
1403+
ma1 = m1.view(MaskedArray)
1404+
r1 = ma1.view('2i4')
1405+
for m2 in (np.array((1, 1), dtype=dt),
1406+
mvoid((1, 1), dtype=dt),
1407+
mvoid((1, 0), mask=(0, 1), dtype=dt),
1408+
mvoid((3, 2), mask=(0, 1), dtype=dt)):
1409+
ma2 = m2.view(MaskedArray)
1410+
r2 = ma2.view('2i4')
1411+
eq_expected = (r1 == r2).all()
1412+
assert_equal(m1 == m2, eq_expected)
1413+
assert_equal(m2 == m1, eq_expected)
1414+
assert_equal(ma1 == m2, eq_expected)
1415+
assert_equal(m1 == ma2, eq_expected)
1416+
assert_equal(ma1 == ma2, eq_expected)
1417+
# Also check it is the same if we do it element by element.
1418+
el_by_el = [m1[name] == m2[name] for name in dt.names]
1419+
assert_equal(array(el_by_el, dtype=bool).all(), eq_expected)
1420+
ne_expected = (r1 != r2).any()
1421+
assert_equal(m1 != m2, ne_expected)
1422+
assert_equal(m2 != m1, ne_expected)
1423+
assert_equal(ma1 != m2, ne_expected)
1424+
assert_equal(m1 != ma2, ne_expected)
1425+
assert_equal(ma1 != ma2, ne_expected)
1426+
el_by_el = [m1[name] != m2[name] for name in dt.names]
1427+
assert_equal(array(el_by_el, dtype=bool).any(), ne_expected)
13641428

13651429
def test_eq_with_None(self):
13661430
# Really, comparisons with None should not be done, but check them
@@ -1393,6 +1457,22 @@ def test_eq_with_scalar(self):
13931457
assert_equal(a == 0, False)
13941458
assert_equal(a != 1, False)
13951459
assert_equal(a != 0, True)
1460+
b = array(1, mask=True)
1461+
assert_equal(b == 0, masked)
1462+
assert_equal(b == 1, masked)
1463+
assert_equal(b != 0, masked)
1464+
assert_equal(b != 1, masked)
1465+
1466+
def test_eq_different_dimensions(self):
1467+
m1 = array([1, 1], mask=[0, 1])
1468+
# test comparison with both masked and regular arrays.
1469+
for m2 in (array([[0, 1], [1, 2]]),
1470+
np.array([[0, 1], [1, 2]])):
1471+
test = (m1 == m2)
1472+
assert_equal(test.data, [[False, False],
1473+
[True, False]])
1474+
assert_equal(test.mask, [[False, True],
1475+
[False, True]])
13961476

13971477
def test_numpyarithmetics(self):
13981478
# Check that the mask is not back-propagated when using numpy functions
@@ -3978,7 +4058,15 @@ def test_make_mask(self):
39784058
test = make_mask(mask, dtype=mask.dtype)
39794059
assert_equal(test.dtype, bdtype)
39804060
assert_equal(test, np.array([(0, 0), (0, 1)], dtype=bdtype))
3981-
4061+
# Ensure this also works for void
4062+
mask = np.array((False, True), dtype='?,?')[()]
4063+
assert_(isinstance(mask, np.void))
4064+
test = make_mask(mask, dtype=mask.dtype)
4065+
assert_equal(test, mask)
4066+
assert_(test is not mask)
4067+
mask = np.array((0, 1), dtype='i4,i4')[()]
4068+
test2 = make_mask(mask, dtype=mask.dtype)
4069+
assert_equal(test2, test)
39824070
# test that nomask is returned when m is nomask.
39834071
bools = [True, False]
39844072
dtypes = [MaskType, np.float]
@@ -3987,7 +4075,6 @@ 5FA9 def test_make_mask(self):
39874075
res = make_mask(nomask, copy=cpy, shrink=shr, dtype=dt)
39884076
assert_(res is nomask, msgformat % (cpy, shr, dt))
39894077

3990-
39914078
def test_mask_or(self):
39924079
# Initialize
39934080
mtype = [('a', np.bool), ('b', np.bool)]

0 commit comments

Comments
 (0)
0