10000 Merge pull request #11122 from mhvk/assert-array-comparison-with-masked · numpy/numpy@5cbb982 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5cbb982

Browse files
authored
< 8000 div class="color-bg-default position-relative border rounded-2 color-border-default mt-2 d-flex flex-column pt-0">
Merge pull request #11122 from mhvk/assert-array-comparison-with-masked
BUG,MAINT: Ensure masked elements can be tested against nan and inf.
2 parents 41d306a + 3ad49aa commit 5cbb982

File tree

3 files changed

+57
-41
lines changed

3 files changed

+57
-41
lines changed

numpy/lib/tests/test_ufunclike.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def __array_wrap__(self, obj, context=None):
5555
obj.metadata = self.metadata
5656
return obj
5757

58+
def __array_finalize__(self, obj):
59+
self.metadata = getattr(obj, 'metadata', None)
60+
return self
61+
5862
a = nx.array([1.1, -1.1])
5963
m = MyArray(a, metadata='foo')
6064
f = ufl.fix(m)

numpy/testing/_private/utils.py

Lines changed: 35 additions & 41 deletions
Orig 8000 inal file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
686686
header='', precision=6, equal_nan=True,
687687
equal_inf=True):
688688
__tracebackhide__ = True # Hide traceback for py.test
689-
from numpy.core import array, isnan, isinf, any, inf
689+
from numpy.core import array, isnan, inf, bool_
690690
x = array(x, copy=False, subok=True)
691691
y = array(y, copy=False, subok=True)
692692

@@ -696,17 +696,28 @@ def isnumber(x):
696696
def istime(x):
697697
return x.dtype.char in "Mm"
698698

699-
def chk_same_position(x_id, y_id, hasval='nan'):
700-
"""Handling nan/inf: check that x and y have the nan/inf at the same
701-
locations."""
702-
try:
703-
assert_array_equal(x_id, y_id)
704-
except AssertionError:
699+
def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
700+
"""Handling nan/inf: combine results of running func on x and y,
701+
checking that they are True at the same locations."""
702+
# Both the != True comparison here and the cast to bool_ at
703+
# the end are done to deal with `masked`, which cannot be
704+
# compared usefully, and for which .all() yields masked.
705+
x_id = func(x)
706+
y_id = func(y)
707+
if (x_id == y_id).all() != True:
705708
msg = build_err_msg([x, y],
706709
err_msg + '\nx and y %s location mismatch:'
707710
% (hasval), verbose=verbose, header=header,
708711
names=('x', 'y'), precision=precision)
709712
raise AssertionError(msg)
713+
# If there is a scalar, then here we know the array has the same
714+
# flag as it everywhere, so we should return the scalar flag.
715+
if x_id.ndim == 0:
716+
return bool_(x_id)
717+
elif y_id.ndim == 0:
718+
return bool_(y_id)
719+
else:
720+
return y_id
710721

711722
try:
712723
cond = (x.shape == () or y.shape == ()) or x.shape == y.shape
@@ -719,49 +730,32 @@ def chk_same_position(x_id, y_id, hasval='nan'):
719730
names=('x', 'y'), precision=precision)
720731
raise AssertionError(msg)
721732

733+
flagged = bool_(False)
722734
if isnumber(x) and isnumber(y):
723-
has_nan = has_inf = False
724735
if equal_nan:
725-
x_isnan, y_isnan = isnan(x), isnan(y)
726-
# Validate that NaNs are in the same place
727-
has_nan = any(x_isnan) or any(y_isnan)
728-
if has_nan:
729-
chk_same_position(x_isnan, y_isnan, hasval='nan')
736+
flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan')
730737

731738
if equal_inf:
732-
x_isinf, y_isinf = isinf(x), isinf(y)
733-
# Validate that infinite values are in the same place
734-
has_inf = any(x_isinf) or any(y_isinf)
735-
if has_inf:
736-
# Check +inf and -inf separately, since they are different
737-
chk_same_position(x == +inf, y == +inf, hasval='+inf')
738-
chk_same_position(x == -inf, y == -inf, hasval='-inf')
739-
740-
if has_nan and has_inf:
741-
x = x[~(x_isnan | x_isinf)]
742-
y = y[~(y_isnan | y_isinf)]
743-
elif has_nan:
744-
x = x[~x_isnan]
745-
y = y[~y_isnan]
746-
elif has_inf:
747-
x = x[~x_isinf]
748-
y = y[~y_isinf]
749-
750-
# Only do the comparison if actual values are left
751-
if x.size == 0:
752-
return
739+
flagged |= func_assert_same_pos(x, y,
740+
func=lambda xy: xy == +inf,
741+
hasval='+inf')
742+
flagged |= func_assert_same_pos(x, y,
743+
func=lambda xy: xy == -inf,
744+
hasval='-inf')
753745

754746
elif istime(x) and istime(y):
755747
# If one is datetime64 and the other timedelta64 there is no point
756748
if equal_nan and x.dtype.type == y.dtype.type:
757-
x_isnat, y_isnat = isnat(x), isnat(y)
758-
759-
if any(x_isnat) or any(y_isnat):
760-
chk_same_position(x_isnat, y_isnat, hasval="NaT")
749+
flagged = func_assert_same_pos(x, y, func=isnat, hasval="NaT")
761750

762-
if any(x_isnat) or any(y_isnat):
763-
x = x[~x_isnat]
764-
y = y[~y_isnat]
751+
if flagged.ndim > 0:
752+
x, y = x[~flagged], y[~flagged]
753+
# Only do the comparison if actual values are left
754+
if x.size == 0:
755+
return
756+
elif flagged:
757+
# no sense doing comparison if everything is flagged.
758+
return
765759

766760
val = comparison(x, y)
767761

numpy/testing/tests/test_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,17 @@ def test_recarrays(self):
151151
self._test_not_equal(c, b)
152152
assert_equal(len(l), 1)
153153

154+
def test_masked_nan_inf(self):
155+
# Regression test for gh-11121
156+
a = np.ma.MaskedArray([3., 4., 6.5], mask=[False, True, False])
157+
b = np.array([3., np.nan, 6.5])
158+
self._test_equal(a, b)
159+
self._test_equal(b, a)
160+
a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, False, False])
161+
b = np.array([np.inf, 4., 6.5])
162+
self._test_equal(a, b)
163+
self._test_equal(b, a)
164+
154165

155166
class TestBuildErrorMessage(object):
156167

@@ -390,6 +401,9 @@ def test_subclass_that_cannot_be_bool(self):
390401
# comparison operators, not on them being able to store booleans
391402
# (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
392403
class MyArray(np.ndarray):
404+
def __eq__(self, other):
405+
return super(MyArray, self).__eq__(other).view(np.ndarray)
406+
393407
def __lt__(self, other):
394408
return super(MyArray, self).__lt__(other).view(np.ndarray)
395409

@@ -489,6 +503,9 @@ def test_subclass_that_cannot_be_bool(self):
489503
# comparison operators, not on them being able to store booleans
490504
# (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
491505
class MyArray(np.ndarray):
506+
def __eq__(self, other):
507+
return super(MyArray, self).__eq__(other).view(np.ndarray)
508+
492509
def __lt__(self, other):
493510
return super(MyArray, self).__lt__(other).view(np.ndarray)
494511

@@ -650,6 +667,7 @@ def test_inf_compare_array(self):
650667
assert_raises(AssertionError, lambda: self._assert_func(-ainf, -x))
651668
self._assert_func(-ainf, x)
652669

670+
653671
@pytest.mark.skip(reason="The raises decorator depends on Nose")
654672
class TestRaises(object):
655673

0 commit comments

Comments
 (0)
0