8000 FIX: fixed all binops (except ==) between Array and unknown types (cl… · larray-project/larray@2acb45e · GitHub
[go: up one dir, main page]

Skip to content

Commit 2acb45e

Browse files
committed
FIX: fixed all binops (except ==) between Array and unknown types (closes #1064)
For == this was correct (returns False) but all binops were wrong. As an added bonus, we now let the other type handle the operation if it can.
1 parent 4d99410 commit 2acb45e

File tree

2 files changed

+56
-8
lines changed

2 files changed

+56
-8
lines changed

larray/core/array.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5791,15 +5791,10 @@ def opmethod(self, other) -> 'Array':
57915791

57925792
# we could pass scalars through asarray too but it is too costly performance-wise for only suppressing one
57935793
# isscalar test and an if statement.
5794-
# TODO: ndarray should probably be converted to larrays because that would harmonize broadcasting rules, but
5795-
# it makes some tests fail for some reason.
5794+
# TODO: ndarray should probably be converted to larrays too because that would harmonize broadcasting rules,
5795+
# but it makes some tests fail for some reason.
57965796
if isinstance(other, (list, Axis)):
57975797
other = asarray(other)
5798-
elif other is not None and not isinstance(other, (Array, np.ndarray)) and not np.isscalar(other):
5799-
# support for inspect.signature
5800-
# FIXME: this should only be the case for __eq__. For other operations, we should
5801-
# probably raise a TypeError (or return NotImplemented???)
5802-
return False
58035798

58045799
if isinstance(other, Array):
58055800
# TODO: first test if it is not already broadcastable
@@ -5809,9 +5804,13 @@ def opmethod(self, other) -> 'Array':
58095804
res_axes = self.axes
58105805
else:
58115806
(self_data, other_data), res_axes = raw_broadcastable((self, other))
5812-
else:
5807+
# We need to check for None explicitly because we consider None as a valid scalar, while numpy does not.
5808+
# i.e. we consider "arr == None" as valid code
5809+
elif isinstance(other, np.ndarray) or np.isscalar(other) or other is None:
58135810
self_data, other_data = self.data, other
58145811
res_axes = self.axes
5812+
else:
5813+
return NotImplemented
58155814
return Array(super_method(self_data, other_data), res_axes)
58165815
opmethod.__name__ = fullname
58175816
return opmethod

larray/tests/test_array.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2790,6 +2790,18 @@ def test_binary_ops(small_array):
27902790
res = arr + arr.a
27912791
assert_larray_equal(res, arr + asarray(arr.a))
27922792

2793+
# Array + <unsupported type>
2794+
with must_raise(TypeError, "unsupported operand type(s) for +: 'Array' and 'object'"):
2795+
res = arr + object()
2796+
2797+
# Array + <unsupported type which implements the reverse op>
2798+
class Test:
2799+
def __radd__(self, other):
2800+
return "success"
2801+
2802+
res = arr + Test()
2803+
assert res == 'success'
2804+
27932805

27942806
def test_binary_ops_no_name_axes(small_array):
27952807
raw = small_array.data
@@ -2881,6 +2893,43 @@ def test_binary_ops_with_scalar_group():
28812893
assert_larray_equal(arr + time.i[0], expected)
28822894

28832895

2896+
def test_comparison_ops():
2897+
# simple array equality (identity)
2898+
a = Axis('a=a0,a1,a2')
2899+
arr = ndtest(a)
2900+
res = arr == arr
2901+
expected = ones(a)
2902+
assert_larray_equal(res, expected)
2903+
2904+
# simple array equality
2905+
arr = ndtest(a)
2906+
res = arr == zeros(a)
2907+
expected = Array([True, False, False], a)
2908+
assert_larray_equal(res, expected)
2909+
2910+
# invalid types
2911+
# a) eq
2912+
arr = ndtest(3)
2913+
res = arr == object()
2914+
assert res is False
2915+
2916+
# b) ne
2917+
res = arr != object()
2918+
assert res is True
2919+
2920+
# c) others
2921+
with must_raise(TypeError, "'>' not supported between instances of 'Array' and 'object'"):
2922+
res = arr > object()
2923+
2924+
# d) other type implementing the reverse comparison
2925+
class Test:
2926+
def __lt__(self, other):
2927+
return "success"
2928+
2929+
res = arr > Test()
2930+
assert res == 'success'
2931+
2932+
28842933
def test_unary_ops(small_array):
28852934
raw = small_array.data
28862935

0 commit comments

Comments
 (0)
0