8000 BUG: Fix array_equal for numeric and non-numeric scalar types (#27275) · numpy/numpy@10533ca · GitHub
[go: up one dir, main page]

Skip to content

Commit 10533ca

Browse files
authored
BUG: Fix array_equal for numeric and non-numeric scalar types (#27275)
Mitigates #27271. The underlying issue (an array comparison returning a python bool instead of a numpy bool) is not addressed. The order of statements is slightly reordered, so that the if a1 is a2: check can be done before the calculation of cannot_have_nan
1 parent 3e1edef commit 10533ca

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

numpy/_core/numeric.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2554,17 +2554,17 @@ def array_equal(a1, a2, equal_nan=False):
25542554
if a1.shape != a2.shape:
25552555
return False
25562556
if not equal_nan:
2557-
return builtins.bool((a1 == a2).all())
2558-
cannot_have_nan = (_dtype_cannot_hold_nan(a1.dtype)
2559-
and _dtype_cannot_hold_nan(a2.dtype))
2560-
if cannot_have_nan:
2561-
if a1 is a2:
2562-
return True
2563-
return builtins.bool((a1 == a2).all())
2557+
return builtins.bool((asanyarray(a1 == a2)).all())
25642558

25652559
if a1 is a2:
25662560
# nan will compare equal so an array will compare equal to itself.
25672561
return True
2562+
2563+
cannot_have_nan = (_dtype_cannot_hold_nan(a1.dtype)
2564+
and _dtype_cannot_hold_nan(a2.dtype))
2565+
if cannot_have_nan:
2566+
return builtins.bool(asarray(a1 == a2).all())
2567+
25682568
# Handling NaN values if equal_nan is True
25692569
a1nan, a2nan = isnan(a1), isnan(a2)
25702570
# NaN's occur at different locations
@@ -2624,7 +2624,7 @@ def array_equiv(a1, a2):
26242624
except Exception:
26252625
return False
26262626

2627-
return builtins.bool((a1 == a2).all())
2627+
return builtins.bool(asanyarray(a1 == a2).all())
26282628

26292629

26302630
def _astype_dispatcher(x, dtype, /, *, copy=None, device=None):

numpy/_core/tests/test_numeric.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2192,6 +2192,13 @@ def test_array_equal_equal_nan(self, bx, by, equal_nan, expected):
21922192
assert_(res is expected)
21932193
assert_(type(res) is bool)
21942194

2195+
def test_array_equal_different_scalar_types(self):
2196+
# https://github.com/numpy/numpy/issues/27271
2197+
a = np.array("foo")
2198+
b = np.array(1)
2199+
assert not np.array_equal(a, b)
2200+
assert not np.array_equiv(a, b)
2201+
21952202
def test_none_compares_elementwise(self):
21962203
a = np.array([None, 1, None], dtype=object)
21972204
assert_equal(a == None, [True, False, True])

0 commit comments

Comments
 (0)
0