8000 Merge pull request #11756 from charris/fix-testing-utils · numpy/numpy@18f338a · GitHub
[go: up one dir, main page]

Skip to content

Commit 18f338a

Browse files
authored
Merge pull request #11756 from charris/fix-testing-utils
MAINT: Make assert_array_compare more generic.
2 parents 9245def + 78efe63 commit 18f338a

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

numpy/testing/_private/utils.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,8 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
687687
equal_inf=True):
688688
__tracebackhide__ = True # Hide traceback for py.test
689689
from numpy.core import array, isnan, inf, bool_
690+
from numpy.core.fromnumeric import all as npall
691+
690692
x = array(x, copy=False, subok=True)
691693
y = array(y, copy=False, subok=True)
692694

@@ -697,14 +699,21 @@ def istime(x):
697699
return x.dtype.char in "Mm"
698700

699701
def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
700-
"""Handling nan/inf: combine results of running func on x and y,
701-
checking that they are True at the same locations."""
702-
# Both the != True comparison here and the cast to bool_ at
703-
# the end are done to deal with `masked`, which cannot be
704-
# compared usefully, and for which .all() yields masked.
702+
"""Handling nan/inf.
703+
704+
Combine results of running func on x and y, checking that they are True
705+
at the same locations.
706+
707+
"""
708+
# Both the != True comparison here and the cast to bool_ at the end are
709+
# done to deal with `masked`, which cannot be compared usefully, and
710+
# for which np.all yields masked. The use of the function np.all is
711+
# for back compatibility with ndarray subclasses that changed the
712+
# return values of the all method. We are not committed to supporting
713+
# such subclasses, but some used to work.
705714
x_id = func(x)
706715
y_id = func(y)
707-
if (x_id == y_id).all() != True:
716+
if npall(x_id == y_id) != True:
708717
msg = build_err_msg([x, y],
709718
err_msg + '\nx and y %s location mismatch:'
710719
% (hasval), verbose=verbose, header=header,

0 commit comments

Comments
 (0)
0