8000 MNT use narwhals in _safe_indexing · scikit-learn/scikit-learn@fa5ed58 · GitHub
[go: up one dir, main page]

Skip to content

Commit fa5ed58

Browse files
committed
MNT use narwhals in _safe_indexing
1 parent 5b136b9 commit fa5ed58

File tree

3 files changed

+42
-43
lines changed

3 files changed

+42
-43
lines changed

sklearn/utils/_indexing.py

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
import numbers
5-
import sys
65
import warnings
76
from collections import UserList
87
from itertools import compress, islice
98

9+
import narwhals as nw
1010
import numpy as np
1111
from scipy.sparse import issparse
1212

@@ -17,7 +17,6 @@
1717
_check_sample_weight,
1818
_is_arraylike_not_scalar,
1919
_is_pandas_df,
20-
_is_polars_df_or_series,
2120
_use_interchange_protocol,
2221
check_array,
2322
check_consistent_length,
@@ -37,6 +36,41 @@ def _array_indexing(array, key, key_dtype, axis):
3736
return array[key, ...] if axis == 0 else array[:, key]
3837

3938

39+
def _narwhals_indexing(X, key, key_dtype, axis):
40+
"""Index a narhals dataframe or series."""
41+
X = nw.from_native(X, allow_series=True)
42+
if (
43+
not isinstance(key, (int, slice))
44+
and not (isinstance(key, list) and key_dtype in ("bool", "str"))
45+
and key is not None
46+
):
47+
# Note that at least tuples should be converted to either list or ndarray as
48+
# tupes in __getitem__ are special: x[(1, 2)] is equal to x[1, 2].
49+
key = np.asarray(key)
50+
51+
if key_dtype in ("bool", "str") and not isinstance(key, (list, slice)):
52+
key = key.tolist()
53+
54+
if axis == 1:
55+
return X[:, key].to_native()
56+
57+
# From here on axis == 0:
58+
if key_dtype == "bool":
59+
X_indexed = X.filter(key)
60+
else:
61+
X_indexed = X[key]
62+
63+
if np.isscalar(key):
64+
if len(X.shape) <= 1:
65+
return X_indexed
66+
# `X_indexed` is a DataFrame with a single row; we return a Series to be
67+
# consistent with pandas
68+
# Christian Lorentzen really dislikes this behaviour and favours to return a
69+
# dataframe.
70+
return np.array(X_indexed.row(0))
71+
return X_indexed.to_native()
72+
73+
4074
def _pandas_indexing(X, key, key_dtype, axis):
4175
"""Index a pandas dataframe or a series."""
4276
if _is_arraylike_not_scalar(key):
@@ -64,35 +98,6 @@ def _list_indexing(X, key, key_dtype):
6498
return [X[idx] for idx in key]
6599

66100

67-
def _polars_indexing(X, key, key_dtype, axis):
68-
"""Indexing X with polars interchange protocol."""
69-
# Polars behavior is more consistent with lists
70-
if isinstance(key, np.ndarray):
71-
# Convert each element of the array to a Python scalar
72-
key = key.tolist()
73-
elif not (np.isscalar(key) or isinstance(key, slice)):
74-
key = list(key)
75-
76-
if axis == 1:
77-
# Here we are certain to have a polars DataFrame; which can be indexed with
78-
# integer and string scalar, and list of integer, string and boolean
79-
return X[:, key]
80-
81-
if key_dtype == "bool":
82-
# Boolean mask can be indexed in the same way for Series and DataFrame (axis=0)
83-
return X.filter(key)
84-
85-
# Integer scalar and list of integer can be indexed in the same way for Series and
86-
# DataFrame (axis=0)
87-
X_indexed = X[key]
88-
if np.isscalar(key) and len(X.shape) == 2:
89-
# `X_indexed` is a DataFrame with a single row; we return a Series to be
90-
# consistent with pandas
91-
pl = sys.modules["polars"]
92-
return pl.Series(X_indexed.row(0))
93-
return X_indexed
94-
95-
96101
def _determine_key_type(key, accept_slice=True):
97102
"""Determine the data type of key.
98103
@@ -264,9 +269,12 @@ def _safe_indexing(X, indices, *, axis=0):
264269
if hasattr(X, "iloc"):
265270
# TODO: we should probably use _is_pandas_df_or_series(X) instead but this
266271
# would require updating some tests such as test_train_test_split_mock_pandas.
272+
# TODO: Should also work with _narwhals_indexing, but
273+
# test_safe_indexing_pandas_no_settingwithcopy_warning
274+
# does not pass.
267275
return _pandas_indexing(X, indices, indices_dtype, axis=axis)
268-
elif _is_polars_df_or_series(X):
269-
return _polars_indexing(X, indices, indices_dtype, axis=axis)
276+
elif nw.dependencies.is_into_dataframe(X) or nw.dependencies.is_into_series(X):
277+
return _narwhals_indexing(X, indices, indices_dtype, axis=axis)
270278
elif hasattr(X, "shape"):
271279
return _array_indexing(X, indices, indices_dtype, axis=axis)
272280
else:

sklearn/utils/tests/test_indexing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def test_get_column_indices_interchange():
452452
pl = pytest.importorskip("polars")
453453

454454
# Polars dataframes go down the interchange path.
455-
df = pl.DataFrame([[1, 2, 3], [4, 5, 6]], schema=["a", "b", "c"])
455+
df = pl.DataFrame([[1, 2, 3], [4, 5, 6]], schema=["a", "b", "c"], orient="row")
456456

457457
key_results = [
458458
(slice(1, None), [1, 2]),

sklearn/utils/validation.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2349,15 +2349,6 @@ def _is_pandas_df(X):
23492349
return isinstance(X, pd.DataFrame)
23502350

23512351

2352-
def _is_polars_df_or_series(X):
2353-
"""Return True if the X is a polars dataframe or series."""
2354-
try:
2355-
pl = sys.modules["polars"]
2356-
except KeyError:
2357-
return False
2358-
return isinstance(X, (pl.DataFrame, pl.Series))
2359-
2360-
23612352
def _is_polars_df(X):
23622353
"""Return True if the X is a polars dataframe."""
23632354
try:

0 commit comments

Comments
 (0)
0