@@ -3548,19 +3548,53 @@ def _adv_keys_to_combined_axis_la_keys(self, key, wildcard=False, sep='_'):
3548
3548
tuple
3549
3549
"""
3550
3550
from larray .core .array import Array
3551
+ combined_axes = self ._adv_keys_to_combined_axes (key , wildcard = wildcard , sep = sep )
3552
+ if combined_axes is None :
3553
+ return key
3554
+
3555
+ # transform all advanced non-Array keys to Array with the combined axis
3556
+ ignored_types = (int , np .integer , slice , Array )
3557
+ return tuple (axis_key if isinstance (axis_key , ignored_types ) else Array (axis_key , combined_axes )
3558
+ for axis_key in key )
3559
+
3560
+ def _adv_keys_to_combined_axes (self , key , wildcard = False , sep = '_' ):
3561
+ r"""
3562
+ Returns an AxisCollection corresponding to the combined axis of the non-Array "advanced indexing" key parts.
3563
+ Scalar, slice and Array key parts are ignored.
3564
+
3565
+ Parameters
3566
+ ----------
3567
+ key : tuple
3568
+ Complete (len(key) == self.ndim) indices-based key.
3569
+ wildcard : bool, optional
3570
+ Whether or not to produce a wildcard axis. Defaults to False.
3571
+ sep : str, optional
3572
+ Separator to use for creating combined axis name and labels (when wildcard is False). Defaults to '_'.
3573
+
3574
+ Returns
3575
+ -------
3576
+ AxisCollection or None
3577
+ """
3578
+ from larray .core .array import Array
3551
3579
3552
3580
assert isinstance (key , tuple ) and len (key ) == self .ndim
3553
3581
3554
- # 1) first compute combined axis
3555
- # ==============================
3582
+ # TODO: we should explicitly raise an error if we detect np.ndarray keys with ndim > 1 as this would
3583
+ # require more than one combined axis. Supporting that is impossible (because we cannot know what
3584
+ # the corresponding labels are) so we should either return wildcard axes in that case or raise an
3585
+ # explicit error. Given the probability our internal users ever use that is so close to 0, the easiest
3586
+ # solution should win. I am unsure which it is, but I guess an error should be easier. Note that there
3587
+ # is no such issue with ND Array keys because for those we know the labels already so nothing needs to
3588
+ # be done here.
3556
3589
3590
+ # XXX: can we use/factorize with AxisCollection._flat_lookup????
3557
3591
# TODO: use/factorize with AxisCollection.combine_axes. The problem is that it uses product(*axes_labels)
3558
3592
# while here we need zip(*axes_labels)
3559
3593
ignored_types = (int , np .integer , slice , Array )
3560
3594
adv_keys = [(axis_key , axis ) for axis_key , axis in zip (key , self )
3561
3595
if not isinstance (axis_key , ignored_types )]
3562
3596
if not adv_keys :
3563
- return key
3597
+ return None
3564
3598
3565
3599
# axes with a scalar key are not taken, since we want to kill them
3566
3600
@@ -3600,12 +3634,7 @@ def _adv_keys_to_combined_axis_la_keys(self, key, wildcard=False, sep='_'):
3600
3634
sepjoin = sep .join
3601
3635
combined_labels = [sepjoin (comb ) for comb in zip (* axes_labels )]
3602
3636
combined_axis = Axis (combined_labels , combined_name )
3603
- combined_axes = AxisCollection (combined_axis )
3604
-
3605
- # 2) transform all advanced non-Array keys to Array with the combined axis
3606
- # ========================================================================
3607
- return tuple (axis_key if isinstance (axis_key , ignored_types ) else Array (axis_key , combined_axes )
3608
- for axis_key in key )
3637
+ return AxisCollection (combined_axis )
3609
3638
3610
3639
3611
3640
class AxisReference (ABCAxisReference , ExprNode , Axis ):
<
32AE
/tr>
0 commit comments