@@ -686,7 +686,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
686
686
header = '' , precision = 6 , equal_nan = True ,
687
687
equal_inf = True ):
688
688
__tracebackhide__ = True # Hide traceback for py.test
689
- from numpy .core import array , isnan , isinf , any , inf
689
+ from numpy .core import array , isnan , inf , bool_
690
690
x = array (x , copy = False , subok = True )
691
691
y = array (y , copy = False , subok = True )
692
692
@@ -696,17 +696,28 @@ def isnumber(x):
696
696
def istime (x ):
697
697
return x .dtype .char in "Mm"
698
698
699
- def chk_same_position (x_id , y_id , hasval = 'nan' ):
700
- """Handling nan/inf: check that x and y have the nan/inf at the same
701
- locations."""
702
- try :
703
- assert_array_equal (x_id , y_id )
704
- except AssertionError :
699
+ def func_assert_same_pos (x , y , func = isnan , hasval = 'nan' ):
700
+ """Handling nan/inf: combine results of running func on x and y,
701
+ checking that they are True at the same locations."""
702
+ # Both the != True comparison here and the cast to bool_ at
703
+ # the end are done to deal with `masked`, which cannot be
704
+ # compared usefully, and for which .all() yields masked.
705
+ x_id = func (x )
706
+ y_id = func (y )
707
+ if (x_id == y_id ).all () != True :
705
708
msg = build_err_msg ([x , y ],
706
709
err_msg + '\n x and y %s location mismatch:'
707
710
% (hasval ), verbose = verbose , header = header ,
708
711
names = ('x' , 'y' ), precision = precision )
709
712
raise AssertionError (msg )
713
+ # If there is a scalar, then here we know the array has the same
714
+ # flag as it everywhere, so we should return the scalar flag.
715
+ if x_id .ndim == 0 :
716
+ return bool_ (x_id )
717
+ elif y_id .ndim == 0 :
718
+ return bool_ (y_id )
719
+ else :
720
+ return y_id
710
721
711
722
try :
712
723
cond = (x .shape == () or y .shape == ()) or x .shape == y .shape
@@ -719,49 +730,32 @@ def chk_same_position(x_id, y_id, hasval='nan'):
719
730
names = ('x' , 'y' ), precision = precision )
720
731
raise AssertionError (msg )
721
732
733
+ flagged = bool_ (False )
722
734
if isnumber (x ) and isnumber (y ):
723
- has_nan = has_inf = False
724
735
if equal_nan :
725
- x_isnan , y_isnan = isnan (x ), isnan (y )
726
- # Validate that NaNs are in the same place
727
- has_nan = any (x_isnan ) or any (y_isnan )
728
- if has_nan :
729
- chk_same_position (x_isnan , y_isnan , hasval = 'nan' )
736
+ flagged = func_assert_same_pos (x , y , func = isnan , hasval = 'nan' )
730
737
731
738
if equal_inf :
732
- x_isinf , y_isinf = isinf (x ), isinf (y )
733
- # Validate that infinite values are in the same place
734
- has_inf = any (x_isinf ) or any (y_isinf )
735
- if has_inf :
736
- # Check +inf and -inf separately, since they are different
737
- chk_same_position (x == + inf , y == + inf , hasval = '+inf' )
738
- chk_same_position (x == - inf , y == - inf , hasval = '-inf' )
739
-
740
- if has_nan and has_inf :
741
- x = x [~ (x_isnan | x_isinf )]
742
- y = y [~ (y_isnan | y_isinf )]
743
- elif has_nan :
744
- x = x [~ x_isnan ]
745
- y = y [~ y_isnan ]
746
- elif has_inf :
747
- x = x [~ x_isinf ]
748
- y = y [~ y_isinf ]
749
-
750
- # Only do the comparison if actual values are left
751
- if x .size == 0 :
752
- return
739
+ flagged |= func_assert_same_pos (x , y ,
740
+ func = lambda xy : xy == + inf ,
741
+ hasval = '+inf' )
742
+ flagged |= func_assert_same_pos (x , y ,
743
+ func = lambda xy : xy == - inf ,
744
+ hasval = '-inf' )
753
745
754
746
elif istime (x ) and istime (y ):
755
747
# If one is datetime64 and the other timedelta64 there is no point
756
748
if equal_nan and x .dtype .type == y .dtype .type :
757
- x_isnat , y_isnat = isnat (x ), isnat (y )
758
-
759
- if any (x_isnat ) or any (y_isnat ):
760
- chk_same_position (x_isnat , y_isnat , hasval = "NaT" )
749
+ flagged = func_assert_same_pos (x , y , func = isnat , hasval = "NaT" )
761
750
762
- if any (x_isnat ) or any (y_isnat ):
763
- x = x [~ x_isnat ]
764
- y = y [~ y_isnat ]
751
+ if flagged .ndim > 0 :
752
+ x , y = x [~ flagged ], y [~ flagged ]
753
+ # Only do the comparison if actual values are left
754
+ if x .size == 0 :
755
+ return
756
+ elif flagged :
757
+ # no sense doing comparison if everything is flagged.
758
+ return
765
759
766
760
val = comparison (x , y )
767
761
0 commit comments