8000 PERF: faster .iflat by not creating Array for the combined key · larray-project/larray@9033d2f · GitHub
[go: up one dir, main page]

Skip to content

Commit 9033d2f

Browse files
committed
PERF: faster .iflat by not creating Array for the combined key
1 parent 984ff1f commit 9033d2f

File tree

2 files changed

+39
-18
lines changed

2 files changed

+39
-18
lines changed

larray/core/array.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -516,15 +516,7 @@ def __getitem__(self, flat_key, sep='_'):
516516
flat_np_key = np.asarray(flat_key)
517517
axes = self.array.axes
518518
nd_key = np.unravel_index(flat_np_key, axes.shape)
519-
# the following lines are equivalent to (but faster than) "return array.ipoints[nd_key]"
520-
521-
# TODO: extract a function which only computes the combined axes because we do not use the actual Arrays
522-
# produced here, which is wasteful. AxisCollection._flat_lookup seems related (but not usable as-is).
523-
la_key = axes._adv_keys_to_combined_axis_la_keys(nd_key, sep=sep)
524-
first_axis_key_axes = la_key[0].axes
525-
assert all(isinstance(axis_key, ABCArray) and axis_key.axes is first_axis_key_axes
526-
for axis_key in la_key[1:])
527-
res_axes = first_axis_key_axes
519+
res_axes = axes._adv_keys_to_combined_axes(nd_key, sep=sep)
528520
return Array(self.array.data.flat[flat_np_key], res_axes)
529521

530522
def __setitem__(self, flat_key, value):

larray/core/axis.py

Lines changed: 38 additions & 9 deletions
< 32AE /tr>
Original file line numberDiff line numberDiff line change
@@ -3548,19 +3548,53 @@ def _adv_keys_to_combined_axis_la_keys(self, key, wildcard=False, sep='_'):
35483548
tuple
35493549
"""
35503550
from larray.core.array import Array
3551+
combined_axes = self._adv_keys_to_combined_axes(key, wildcard=wildcard, sep=sep)
3552+
if combined_axes is None:
3553+
return key
3554+
3555+
# transform all advanced non-Array keys to Array with the combined axis
3556+
ignored_types = (int, np.integer, slice, Array)
3557+
return tuple(axis_key if isinstance(axis_key, ignored_types) else Array(axis_key, combined_axes)
3558+
for axis_key in key)
3559+
3560+
def _adv_keys_to_combined_axes(self, key, wildcard=False, sep='_'):
3561+
r"""
3562+
Returns an AxisCollection corresponding to the combined axis of the non-Array "advanced indexing" key parts.
3563+
Scalar, slice and Array key parts are ignored.
3564+
3565+
Parameters
3566+
----------
3567+
key : tuple
3568+
Complete (len(key) == self.ndim) indices-based key.
3569+
wildcard : bool, optional
3570+
Whether or not to produce a wildcard axis. Defaults to False.
3571+
sep : str, optional
3572+
Separator to use for creating combined axis name and labels (when wildcard is False). Defaults to '_'.
3573+
3574+
Returns
3575+
-------
3576+
AxisCollection or None
3577+
"""
3578+
from larray.core.array import Array
35513579

35523580
assert isinstance(key, tuple) and len(key) == self.ndim
35533581

3554-
# 1) first compute combined axis
3555-
# ==============================
3582+
# TODO: we should explicitly raise an error if we detect np.ndarray keys with ndim > 1 as this would
3583+
# require more than one combined axis. Supporting that is impossible (because we cannot know what
3584+
# the corresponding labels are) so we should either return wildcard axes in that case or raise an
3585+
# explicit error. Given the probability our internal users ever use that is so close to 0, the easiest
3586+
# solution should win. I am unsure which it is, but I guess an error should be easier. Note that there
3587+
# is no such issue with ND Array keys because for those we know the labels already so nothing needs to
3588+
# be done here.
35563589

3590+
# XXX: can we use/factorize with AxisCollection._flat_lookup????
35573591
# TODO: use/factorize with AxisCollection.combine_axes. The problem is that it uses product(*axes_labels)
35583592
# while here we need zip(*axes_labels)
35593593
ignored_types = (int, np.integer, slice, Array)
35603594
adv_keys = [(axis_key, axis) for axis_key, axis in zip(key, self)
35613595
if not isinstance(axis_key, ignored_types)]
35623596
if not adv_keys:
3563-
return key
3597+
return None
35643598

35653599
# axes with a scalar key are not taken, since we want to kill them
35663600

@@ -3600,12 +3634,7 @@ def _adv_keys_to_combined_axis_la_keys(self, key, wildcard=False, sep='_'):
36003634
sepjoin = sep.join
36013635
combined_labels = [sepjoin(comb) for comb in zip(*axes_labels)]
36023636
combined_axis = Axis(combined_labels, combined_name)
3603-
combined_axes = AxisCollection(combined_axis)
3604-
3605-
# 2) transform all advanced non-Array keys to Array with the combined axis
3606-
# ========================================================================
3607-
return tuple(axis_key if isinstance(axis_key, ignored_types) else Array(axis_key, combined_axes)
3608-
for axis_key in key)
3637+
return AxisCollection(combined_axis)
36093638

36103639

36113640
class AxisReference(ABCAxisReference, ExprNode, Axis):

0 commit comments

Comments
 (0)
0