2
2
# SPDX-License-Identifier: BSD-3-Clause
3
3
4
4
import numbers
5
- import sys
6
5
import warnings
7
6
from collections import UserList
8
7
from itertools import compress , islice
9
8
9
+ import narwhals as nw
10
10
import numpy as np
11
11
from scipy .sparse import issparse
12
12
17
17
_check_sample_weight ,
18
18
_is_arraylike_not_scalar ,
19
19
_is_pandas_df ,
20
- _is_polars_df_or_series ,
21
20
_use_interchange_protocol ,
22
21
check_array ,
23
22
check_consistent_length ,
@@ -37,6 +36,41 @@ def _array_indexing(array, key, key_dtype, axis):
37
36
return array [key , ...] if axis == 0 else array [:, key ]
38
37
39
38
39
+ def _narwhals_indexing (X , key , key_dtype , axis ):
40
+ """Index a narhals dataframe or series."""
41
+ X = nw .from_native (X , allow_series = True )
42
+ if (
43
+ not isinstance (key , (int , slice ))
44
+ and not (isinstance (key , list ) and key_dtype in ("bool" , "str" ))
45
+ and key is not None
46
+ ):
47
+ # Note that at least tuples should be converted to either list or ndarray as
48
+ # tupes in __getitem__ are special: x[(1, 2)] is equal to x[1, 2].
49
+ key = np .asarray (key )
50
+
51
+ if key_dtype in ("bool" , "str" ) and not isinstance (key , (list , slice )):
52
+ key = key .tolist ()
53
+
54
+ if axis == 1 :
55
+ return X [:, key ].to_native ()
56
+
57
+ # From here on axis == 0:
58
+ if key_dtype == "bool" :
59
+ X_indexed = X .filter (key )
60
+ else :
61
+ X_indexed = X [key ]
62
+
63
+ if np .isscalar (key ):
64
+ if len (X .shape ) <= 1 :
65
+ return X_indexed
66
+ # `X_indexed` is a DataFrame with a single row; we return a Series to be
67
+ # consistent with pandas
68
+ # Christian Lorentzen really dislikes this behaviour and favours to return a
69
+ # dataframe.
70
+ return np .array (X_indexed .row (0 ))
71
+ return X_indexed .to_native ()
72
+
73
+
40
74
def _pandas_indexing (X , key , key_dtype , axis ):
41
75
"""Index a pandas dataframe or a series."""
42
76
if _is_arraylike_not_scalar (key ):
@@ -64,35 +98,6 @@ def _list_indexing(X, key, key_dtype):
64
98
return [X [idx ] for idx in key ]
65
99
66
100
67
- def _polars_indexing (X , key , key_dtype , axis ):
68
- """Indexing X with polars interchange protocol."""
69
- # Polars behavior is more consistent with lists
70
- if isinstance (key , np .ndarray ):
71
- # Convert each element of the array to a Python scalar
72
- key = key .tolist ()
73
- elif not (np .isscalar (key ) or isinstance (key , slice )):
74
- key = list (key )
75
-
76
- if axis == 1 :
77
- # Here we are certain to have a polars DataFrame; which can be indexed with
78
- # integer and string scalar, and list of integer, string and boolean
79
- return X [:, key ]
80
-
81
- if key_dtype == "bool" :
82
- # Boolean mask can be indexed in the same way for Series and DataFrame (axis=0)
83
- return X .filter (key )
84
-
85
- # Integer scalar and list of integer can be indexed in the same way for Series and
86
- # DataFrame (axis=0)
87
- X_indexed = X [key ]
88
- if np .isscalar (key ) and len (X .shape ) == 2 :
89
- # `X_indexed` is a DataFrame with a single row; we return a Series to be
90
- # consistent with pandas
91
- pl = sys .modules ["polars" ]
92
- return pl .Series (X_indexed .row (0 ))
93
- return X_indexed
94
-
95
-
96
101
def _determine_key_type (key , accept_slice = True ):
97
102
"""Determine the data type of key.
98
103
@@ -264,9 +269,12 @@ def _safe_indexing(X, indices, *, axis=0):
264
269
if hasattr (X , "iloc" ):
265
270
# TODO: we should probably use _is_pandas_df_or_series(X) instead but this
266
271
# would require updating some tests such as test_train_test_split_mock_pandas.
272
+ # TODO: Should also work with _narwhals_indexing, but
273
+ # test_safe_indexing_pandas_no_settingwithcopy_warning
274
+ # does not pass.
267
275
return _pandas_indexing (X , indices , indices_dtype , axis = axis )
268
- elif _is_polars_df_or_series (X ):
269
- return _polars_indexing (X , indices , indices_dtype , axis = axis )
276
+ elif nw . dependencies . is_into_dataframe ( X ) or nw . dependencies . is_into_series (X ):
277
+ return _narwhals_indexing (X , indices , indices_dtype , axis = axis )
270
278
elif hasattr (X , "shape" ):
271
279
return _array_indexing (X , indices , indices_dtype , axis = axis )
272
280
else :
0 commit comments