8000 BUG: Fix array_equal for numeric and non-numeric scalar types · charris/numpy@da9f9c3 · GitHub
[go: up one dir, main page]

Skip to content

Commit da9f9c3

Browse files
eendebakptcharris
authored andcommitted
BUG: Fix array_equal for numeric and non-numeric scalar types
Backport of numpy#27275 Mitigates numpy#27271. The underlying issue (an array comparison returning a python bool instead of a numpy bool) is not addressed. The order of statements is slightly reordered, so that the if `a1 is a2:` check can be done before the calculation of `cannot_have_nan` Closes numpygh-27271
1 parent ee1cf96 commit da9f9c3

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

numpy/_core/numeric.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2503,17 +2503,17 @@ def array_equal(a1, a2, equal_nan=False):
25032503
if a1.shape != a2.shape:
25042504
return False
25052505
if not equal_nan:
2506-
return builtins.bool((a1 == a2).all())
2507-
cannot_have_nan = (_dtype_cannot_hold_nan(a1.dtype)
2508-
and _dtype_cannot_hold_nan(a2.dtype))
2509-
if cannot_have_nan:
2510-
if a1 is a2:
2511-
return True
2512-
return builtins.bool((a1 == a2).all())
2506+
return builtins.bool((asanyarray(a1 == a2)).all())
25132507

25142508
if a1 is a2:
25152509
# nan will compare equal so an array will compare equal to itself.
25162510
return True
2511+
2512+
cannot_have_nan = (_dtype_cannot_hold_nan(a1.dtype)
2513+
and _dtype_cannot_hold_nan(a2.dtype))
2514+
if cannot_have_nan:
2515+
return builtins.bool(asarray(a1 == a2).all())
2516+
25172517
# Handling NaN values if equal_nan is True
25182518
a1nan, a2nan = isnan(a1), isnan(a2)
25192519
# NaN's occur at different locations
@@ -2572,7 +2572,7 @@ def array_equiv(a1, a2):
25722572
except Exception:
25732573
return False
25742574

2575-
return builtins.bool((a1 == a2).all())
2575+
return builtins.bool(asanyarray(a1 == a2).all())
25762576

25772577

25782578
def _astype_dispatcher(x, dtype, /, *, copy=None):

numpy/_core/tests/test_numeric.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2131,6 +2131,13 @@ def test_array_equal_equal_nan(self, bx, by, equal_nan, expected):
21312131
assert_(res is expected)
21322132
assert_(type(res) is bool)
21332133

2134+
def test_array_equal_different_scalar_types(self):
2135+
# https://github.com/numpy/numpy/issues/27271
2136+
a = np.array("foo")
2137+
b = np.array(1)
2138+
assert not np.array_equal(a, b)
2139+
assert not np.array_equiv(a, b)
2140+
21342141
def test_none_compares_elementwise(self):
21352142
a = np.array([None, 1, None], dtype=object)
21362143
assert_equal(a == None, [True, False, True])

0 commit comments

Comments
 (0)
0