8000 FIX: boolean subset with incompatible axes raises an error (closes #1… · larray-project/larray@4c0bab3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4c0bab3

Browse files
committed
FIX: boolean subset with incompatible axes raises an error (closes #1085)
1 parent afb9fc4 commit 4c0bab3

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

doc/source/changes/version_0_34_3.rst.inc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ Miscellaneous improvements
5555
Fixes
5656
^^^^^
5757

58-
* fixed something (closes :issue:`1`).
58+
* using a boolean array as a filter to take a subset of another array now raise an error when the
59+
two arrays have incompatible axes instead of producing wrong result (closes :issue:`1085`).
5960

6061
* fixed converting a scalar Array (an Array with 0 dimensions) to string with numpy 1.22+.
6162

larray/core/axis.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2878,6 +2878,22 @@ def _key_to_axis_indices_dict(self, key):
28782878
extra_key_axes = axis_key.axes - self
28792879
if extra_key_axes:
28802880
raise ValueError(f"boolean subset key contains more axes ({axis_key.axes}) than array ({self})")
2881+
2882+
# TODO: factorize with check_compatible
2883+
for i, subset_axis in enumerate(axis_key.axes):
2884+
array_axis = self.get_by_pos(subset_axis, i)
2885+
if not array_axis.iscompatible(subset_axis):
2886+
msg = f"""boolean subset array has incompatible axes with array:
2887+
array axes: {self}
2888+
subset array axes: {axis_key.axes}
2889+
incompatible axes:
2890+
array axis:
2891+
{array_axis!r}
2892+
subset array axis:
2893+
{subset_axis!r}
2894+
"""
2895+
raise ValueError(msg)
2896+
28812897
# nonzero (currently) returns a tuple of IGroups containing 1D Arrays (one IGroup per axis)
28822898
filtered_key.extend(axis_key.nonzero())
28832899
# drop slice(None) and Ellipsis since they are meaningless because of guess_axis.

larray/tests/test_array.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,8 @@ def test_getitem_bool_larray_key_arr_whout_bool_axis():
771771
raw = arr.data
772772

773773
# all dimensions
774-
res = arr[arr < 5]
774+
filter_ = arr < 5
775+
res = arr[filter_]
775776
assert isinstance(res, Array)
776777
assert res.ndim == 1
777778
assert_nparray_equal(res.data, raw[raw < 5])
@@ -786,6 +787,20 @@ def test_getitem_bool_larray_key_arr_whout_bool_axis():
786787
raw_d1, raw_d3 = raw_key.nonzero()
787788
assert_nparray_equal(res.data, raw[raw_d1, :, raw_d3])
788789

790+
# filter with smaller axis than array
791+
filter_ = arr < 10
792+
filter2 = filter_['c0,c2,c3']
793+
with must_raise(ValueError, """boolean subset array has incompatible axes with array:
794+
array axes: {a, b, c}
795+
subset array axes: {a, b, c}
796+
incompatible axes:
797+
array axis:
798+
Axis(['c0', 'c1', 'c2', 'c3'], 'c')
799+
subset array axis:
800+
Axis(['c0', 'c2', 'c3'], 'c')
801+
"""):
802+
_ = arr[filter2]
803+
789804
# using an Axis object
790805
arr = ndtest('a=a0,a1;b=0..3')
791806
raw = arr.data

0 commit comments

Comments
 (0)
0