8000 MAINT: clean up assert_array_compare a bit further. · numpy/numpy@3ad49aa · GitHub
[go: up one dir, main page]

Skip to content

Commit 3ad49aa

Browse files
committed
MAINT: clean up assert_array_compare a bit further.
This brought to light two bugs in tests, which are fixed here, viz., that a sample ndarray subclass that tested propagation of an added parameter was incomplete, in that in propagating the parameter in __array_wrap__ it assumed it was there on self, but that assumption could be broken when a view of self was taken (as is done by x[~flagged] in the test routine), since there was no __array_finalize__ defined. The other subclass bug counted, incorrectly, on only needing to provide one type of comparison, the __lt__ being explicitly tested. But flags are compared with __eq__ and those flags will have the same subclass.
1 parent 5718b33 commit 3ad49aa

File tree

3 files changed

+23
-10
lines changed

3 files changed

+23
-10
lines changed

numpy/lib/tests/test_ufunclike.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def __array_wrap__(self, obj, context=None):
5555
obj.metadata = self.metadata
5656
return obj
5757

58+
def __array_finalize__(self, obj):
59+
self.metadata = getattr(obj, 'metadata', None)
60+
return self
61+
5862
a = nx.array([1.1, -1.1])
5963
m = MyArray(a, metadata='foo')
6064
f = ufl.fix(m)

numpy/testing/_private/utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
685685
header='', precision=6, equal_nan=True,
686686
equal_inf=True):
687687
__tracebackhide__ = True # Hide traceback for py.test
688-
from numpy.core import array, isnan, any, inf, ndim
688+
from numpy.core import array, isnan, inf, bool_
689689
x = array(x, copy=False, subok=True)
690690
y = array(y, copy=False, subok=True)
691691

@@ -698,22 +698,25 @@ def istime(x):
698698
def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
699699
"""Handling nan/inf: combine results of running func on x and y,
700700
checking that they are True at the same locations."""
701+
# Both the != True comparison here and the cast to bool_ at
702+
# the end are done to deal with `masked`, which cannot be
703+
# compared usefully, and for which .all() yields masked.
701704
x_id = func(x)
702705
y_id = func(y)
703-
if not any(x_id) and not any(y_id):
704-
return False
705-
706-
try:
707-
assert_array_equal(x_id, y_id)
708-
except AssertionError:
706+
if (x_id == y_id).all() != True:
709707
msg = build_err_msg([x, y],
710708
err_msg + '\nx and y %s location mismatch:'
711709
% (hasval), verbose=verbose, header=header,
712710
names=('x', 'y'), precision=precision)
713711
raise AssertionError(msg)
714712
# If there is a scalar, then here we know the array has the same
715713
# flag as it everywhere, so we should return the scalar flag.
716-
return x_id if x_id.ndim == 0 else y_id
714+
if x_id.ndim == 0:
715+
return bool_(x_id)
716+
elif y_id.ndim == 0:
717+
return bool_(y_id)
718+
else:
719+
return y_id
717720

718721
try:
719722
cond = (x.shape == () or y.shape == ()) or x.shape == y.shape
@@ -726,7 +729,7 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
726729
names=('x', 'y'), precision=precision)
727730
raise AssertionError(msg)
728731

729-
flagged = False
732+
flagged = bool_(False)
730733
if isnumber(x) and isnumber(y):
731734
if equal_nan:
732735
flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan')
@@ -744,7 +747,7 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
744747
if equal_nan and x.dtype.type == y.dtype.type:
745748
flagged = func_assert_same_pos(x, y, func=isnat, hasval="NaT")
746749

747-
if ndim(flagged):
750+
if flagged.ndim > 0:
748751
x, y = x[~flagged], y[~flagged]
749752
# Only do the comparison if actual values are left
750753
if x.size == 0:

numpy/testing/tests/test_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,9 @@ def test_subclass_that_cannot_be_bool(self):
401401
# comparison operators, not on them being able to store booleans
402402
# (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
403403
class MyArray(np.ndarray):
404+
def __eq__(self, other):
405+
return super(MyArray, self).__eq__(other).view(np.ndarray)
406+
404407
def __lt__(self, other):
405408
return super(MyArray, self).__lt__(other).view(np.ndarray)
406409

@@ -500,6 +503,9 @@ def test_subclass_that_cannot_be_bool(self):
500503
# comparison operators, not on them being able to store booleans
501504
# (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
502505
class MyArray(np.ndarray):
506+
def __eq__(self, other):
507+
return super(MyArray, self).__eq__(other).view(np.ndarray)
508+
503509
def __lt__(self, other):
504510
return super(MyArray, self).__lt__(other).view(np.ndarray)
505511

0 commit comments

Comments
 (0)
0