8000 FIX _safe_indexing for pyarrow (#31040) · scikit-learn/scikit-learn@81bb708 · GitHub
[go: up one dir, main page]

Skip to content

Commit 81bb708

Browse files
authored
FIX _safe_indexing for pyarrow (#31040)
1 parent 7a88bf1 commit 81bb708

File tree

6 files changed

+133
-17
lines changed

6 files changed

+133
-17
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- The private helper function :func:`utils._safe_indexing` now officially supports
2+
pyarrow data. For instance, passing a pyarrow `Table` as `X` in a
3+
:class:`compose.ColumnTransformer` is now possible.
4+
By :user:`Christian Lorentzen <lorentzenchr>`

sklearn/utils/_indexing.py

+72-6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
_is_arraylike_not_scalar,
1919< 8000 /td>
_is_pandas_df,
2020
_is_polars_df_or_series,
21+
_is_pyarrow_data,
2122
_use_interchange_protocol,
2223
check_array,
2324
check_consistent_length,
@@ -65,7 +66,7 @@ def _list_indexing(X, key, key_dtype):
6566

6667

6768
def _polars_indexing(X, key, key_dtype, axis):
68-
"""Indexing X with polars interchange protocol."""
69+
"""Index a polars dataframe or series."""
6970
# Polars behavior is more consistent with lists
7071
if isinstance(key, np.ndarray):
7172
# Convert each element of the array to a Python scalar
@@ -93,6 +94,55 @@ def _polars_indexing(X, key, key_dtype, axis):
9394
return X_indexed
9495

9596

97+
def _pyarrow_indexing(X, key, key_dtype, axis):
98+
"""Index a pyarrow data."""
99+
scalar_key = np.isscalar(key)
100+
if isinstance(key, slice):
101+
if isinstance(key.stop, str):
102+
start = X.column_names.index(key.start)
103+
stop = X.column_names.index(key.stop) + 1
104+
else:
105+
start = 0 if not key.start else key.start
106+
stop = key.stop
107+
step = 1 if not key.step else key.step
108+
key = list(range(start, stop, step))
109+
110+
if axis == 1:
111+
# Here we are certain that X is a pyarrow Table or RecordBatch.
112+
if key_dtype == "int" and not isinstance(key, list):
113+
# pyarrow's X.select behavior is more consistent with integer lists.
114+
key = np.asarray(key).tolist()
115+
if key_dtype == "bool":
116+
key = np.asarray(key).nonzero()[0].tolist()
117+
118+
if scalar_key:
119+
return X.column(key)
120+
121+
return X.select(key)
122+
123+
# axis == 0 from here on
124+
if scalar_key:
125+
if hasattr(X, "shape"):
126+
# X is a Table or RecordBatch
127+
key = [key]
128+
else:
129+
return X[key].as_py()
130+
elif not isinstance(key, list):
131+
key = np.asarray(key)
132+
133+
if key_dtype == "bool":
134+
X_indexed = X.filter(key)
135+
else:
136+
X_indexed = X.take(key)
137+
138+
if scalar_key and len(getattr(X, "shape", [0])) == 2:
139+
# X_indexed is a dataframe-like with a single row; we return a Series to be
140+
# consistent with pandas
141+
pa = sys.modules["pyarrow"]
142+
return pa.array(X_indexed.to_pylist()[0].values())
143+
return X_indexed
144+
145+
96146
def _determine_key_type(key, accept_slice=True):
97147
"""Determine the data type of key.
98148
@@ -245,11 +295,11 @@ def _safe_indexing(X, indices, *, axis=0):
245295
if axis == 1 and isinstance(X, list):
246296
raise ValueError("axis=1 is not supported for lists")
247297

248-
if axis == 1 and hasattr(X, "shape") and len(X.shape) != 2:
298+
if axis == 1 and (ndim := len(getattr(X, "shape", [0]))) != 2:
249299
raise ValueError(
250300
"'X' should be a 2D NumPy array, 2D sparse matrix or "
251301
"dataframe when indexing the columns (i.e. 'axis=1'). "
252-
"Got {} instead with {} dimension(s).".format(type(X), len(X.shape))
302+
f"Got {type(X)} instead with {ndim} dimension(s)."
253303
)
254304

255305
if (
@@ -262,12 +312,28 @@ def _safe_indexing(X, indices, *, axis=0):
262312
)
263313

264314
if hasattr(X, "iloc"):
265-
# TODO: we should probably use _is_pandas_df_or_series(X) instead but this
266-
# would require updating some tests such as test_train_test_split_mock_pandas.
315+
# TODO: we should probably use _is_pandas_df_or_series(X) instead but:
316+
# 1) Currently, it (probably) works for dataframes compliant to pandas' API.
317+
# 2) Updating would require updating some tests such as
318+
# test_train_test_split_mock_pandas.
267319
return _pandas_indexing(X, indices, indices_dtype, axis=axis)
268320
elif _is_polars_df_or_series(X):
269321
return _polars_indexing(X, indices, indices_dtype, axis=axis)
270-
elif hasattr(X, "shape"):
322+
elif _is_pyarrow_data(X):
323+
return _pyarrow_indexing(X, indices, indices_dtype, axis=axis)
324+
elif _use_interchange_protocol(X): # pragma: no cover
325+
# Once the dataframe X is converted into its dataframe interchange protocol
326+
# version by calling X.__dataframe__(), it becomes very hard to turn it back
327+
# into its original type, e.g., a pyarrow.Table, see
328+
# https://github.com/data-apis/dataframe-api/issues/85.
329+
raise warnings.warn(
330+
message="A data object with support for the dataframe interchange protocol"
331+
"was passed, but scikit-learn does currently not know how to handle this "
332+
"kind of data. Some array/list indexing will be tried.",
333+
category=UserWarning,
334+
)
335+
336+
if hasattr(X, "shape"):
271337
return _array_indexing(X, indices, indices_dtype, axis=axis)
272338
else:
273339
return _list_indexing(X, indices, indices_dtype)

sklearn/utils/_testing.py

+4
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,7 @@ def _convert_container(
10211021
elif constructor_name == "pyarrow":
10221022
pa = pytest.importorskip("pyarrow", minversion=minversion)
10231023
array = np.asarray(container)
1024+
array = array[:, None] if array.ndim == 1 else array
10241025
if columns_name is None:
10251026
columns_name = [f"col{i}" for F987 i in range(array.shape[1])]
10261027
data = {name: array[:, i] for i, name in enumerate(columns_name)}
@@ -1042,6 +1043,9 @@ def _convert_container(
10421043
elif constructor_name == "series":
10431044
pd = pytest.importorskip("pandas", minversion=minversion)
10441045
return pd.Series(container, dtype=dtype)
1046+
elif constructor_name == "pyarrow_array":
1047+
pa = pytest.importorskip("pyarrow", minversion=minversion)
1048+
return pa.array(container)
10451049
elif constructor_name == "polars_series":
10461050
pl = pytest.importorskip("polars", minversion=minversion)
10471051
return pl.Series(values=container)

sklearn/utils/tests/test_indexing.py

+29-10
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def test_determine_key_type_array_api(array_namespace, device, dtype_name):
134134

135135

136136
@pytest.mark.parametrize(
137-
"array_type", ["list", "array", "sparse", "dataframe", "polars"]
137+
"array_type", ["list", "array", "sparse", "dataframe", "polars", "pyarrow"]
138138
)
139139
@pytest.mark.parametrize("indices_type", ["list", "tuple", "array", "series", "slice"])
140140
def test_safe_indexing_2d_container_axis_0(array_type, indices_type):
@@ -149,7 +149,9 @@ def test_safe_indexing_2d_container_axis_0(array_type, indices_type):
149149
)
150150

151151

152-
@pytest.mark.parametrize("array_type", ["list", "array", "series", "polars_series"])
152+
@pytest.mark.parametrize(
153+
"array_type", ["list", "array", "series", "polars_series", "pyarrow_array"]
154+
)
153155
@pytest.mark.parametrize("indices_type", ["list", "tuple", "array", "series", "slice"])
154156
def test_safe_indexing_1d_container(array_type, indices_type):
155157
indices = [1, 2]
@@ -161,7 +163,9 @@ def test_safe_indexing_1d_container(array_type, indices_type):
161163
assert_allclose_dense_sparse(subset, _convert_container([2, 3], array_type))
162164

163165

164-
@pytest.mark.parametrize("array_type", ["array", "sparse", "dataframe", "polars"])
166+
@pytest.mark.parametrize(
167+
"array_type", ["array", "sparse", "dataframe", "polars", "pyarrow"]
168+
)
165169
@pytest.mark.parametrize("indices_type", ["list", "tuple", "array", "series", "slice"])
166170
@pytest.mark.parametrize("indices", [[1, 2], ["col_1", "col_2"]])
167171
def test_safe_indexing_2d_container_axis_1(array_type, indices_type, indices):
@@ -177,7 +181,7 @@ def test_safe_indexing_2d_container_axis_1(array_type, indices_type, indices):
177181
)
178182
indices_converted = _convert_container(indices_converted, indices_type)
179183

180-
if isinstance(indices[0], str) and array_type not in ("dataframe", "polars"):
184+
if isinstance(indices[0], str) and array_type in ("array", "sparse"):
181185
err_msg = (
182186
"Specifying the columns using strings is only supported for dataframes"
183187
)
@@ -192,7 +196,9 @@ def test_safe_indexing_2d_container_axis_1(array_type, indices_type, indices):
192196

193197
@pytest.mark.parametrize("array_read_only", [True, False])
194198
@pytest.mark.parametrize("indices_read_only", [True, False])
195-
@pytest.mark.parametrize("array_type", ["array", "sparse", "dataframe", "polars"])
199+
@pytest.mark.parametrize(
200+
"array_type", ["array", "sparse", "dataframe", "polars", "pyarrow"]
201+
)
196202
@pytest.mark.parametrize("indices_type", ["array", "series"])
197203
@pytest.mark.parametrize(
198204
"axis, expected_array", [(0, [[4, 5, 6], [7, 8, 9]]), (1, [[2, 3], [5, 6], [8, 9]])]
@@ -212,7 +218,9 @@ def test_safe_indexing_2d_read_only_axis_1(
212218
assert_allclose_dense_sparse(subset, _convert_container(expected_array, array_type))
213219

214220

215-
@pytest.mark.parametrize("array_type", ["list", "array", "series", "polars_series"])
221+
@pytest.mark.parametrize(
222+
"array_type", ["list", "array", "series", "polars_series", "pyarrow_array"]
223+
)
216224
@pytest.mark.parametrize("indices_type", ["list", "tuple", "array", "series"])
217225
def test_safe_indexing_1d_container_mask(array_type, indices_type):
218226
indices = [False] + [True] * 2 + [False] * 6
@@ -222,7 +230,9 @@ def test_safe_indexing_1d_container_mask(array_type, indices_type):
222230
assert_allclose_dense_sparse(subset, _convert_container([2, 3], array_type))
223231

224232

225-
@pytest.mark.parametrize("array_type", ["array", "sparse", "dataframe", "polars"])
233+
@pytest.mark.parametrize(
234+
"array_type", ["array", "sparse", "dataframe", "polars", "pyarrow"]
235+
)
226236
@pytest.mark.parametrize("indices_type", ["list", "tuple", "array", "series"])
227237
@pytest.mark.parametrize(
228238
"axis, expected_subset",
@@ -250,6 +260,7 @@ def test_safe_indexing_2d_mask(array_type, indices_type, axis, expected_subset):
250260
("sparse", "sparse"),
251261
("dataframe", "series"),
252262
("polars", "polars_series"),
263+
("pyarrow", "pyarrow_array"),
253264
],
254265
)
255266
def test_safe_indexing_2d_scalar_axis_0(array_type, expected_output_type):
@@ -260,7 +271,9 @@ def test_safe_indexing_2d_scalar_axis_0(array_type, expected_output_type):
260271
assert_allclose_dense_sparse(subset, expected_array)
261272

262273

263-
@pytest.mark.parametrize("array_type", ["list", "array", "series", "polars_series"])
274+
@pytest.mark.parametrize(
275+
"array_type", ["list", "array", "series", "polars_series", "pyarrow_array"]
276+
)
264277
def test_safe_indexing_1d_scalar(array_type):
265278
array = _convert_container([1, 2, 3, 4, 5, 6, 7, 8, 9], array_type)
266279
indices = 2
@@ -275,6 +288,7 @@ def test_safe_indexing_1d_scalar(array_type):
275288
("sparse", "sparse"),
276289
("dataframe", "series"),
277290
("polars", "polars_series"),
291+
("pyarrow", "pyarrow_array"),
278292
],
279293
)
280294
@pytest.mark.parametrize("indices", [2, "col_2"])
@@ -284,7 +298,7 @@ def test_safe_indexing_2d_scalar_axis_1(array_type, expected_output_type, indice
284298
[[1, 2, 3], [4, 5, 6], [7, 8, 9]], array_type, columns_name
285299
)
286300

287-
if isinstance(indices, str) and array_type not in ("dataframe", "polars"):
301+
if isinstance(indices, str) and array_type in ("array", "sparse"):
288302
err_msg = (
289303
"Specifying the columns using strings is only supported for dataframes"
290304
)
@@ -321,7 +335,9 @@ def test_safe_indexing_error_axis(axis):
321335
_safe_indexing(X_toy, [0, 1], axis=axis)
322336

323337

324-
@pytest.mark.parametrize("X_constructor", ["array", "series", "polars_series"])
338+
@pytest.mark.parametrize(
339+
"X_constructor", ["array", "series", "polars_series", "pyarrow_array"]
340+
)
325341
def test_safe_indexing_1d_array_error(X_constructor):
326342
# check that we are raising an error if the array-like passed is 1D and
327343
# we try to index on the 2nd dimension
@@ -334,6 +350,9 @@ def test_safe_indexing_1d_array_error(X_constructor):
334350
elif X_constructor == "polars_series":
335351
pl = pytest.importorskip("polars")
336352
X_constructor = pl.Series(values=X)
353+
elif X_constructor == "pyarrow_array":
354+
pa = pytest.importorskip("pyarrow")
355+
X_constructor = pa.array(X)
337356

338357
err_msg = "'X' should be a 2D NumPy array, 2D sparse matrix or dataframe"
339358
with pytest.raises(ValueError, match=err_msg):

sklearn/utils/tests/test_testing.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,10 @@ def test_create_memmap_backed_data(monkeypatch):
896896
("dataframe", lambda: pytest.importorskip("pandas").DataFrame),
897897
("series", lambda: pytest.importorskip("pandas").Series),
898898
("index", lambda: pytest.importorskip("pandas").Index),
899+
("pyarrow", lambda: pytest.importorskip("pyarrow").Table),
900+
("pyarrow_array", lambda: pytest.importorskip("pyarrow").Array),
901+
("polars", lambda: pytest.importorskip("polars").DataFrame),
902+
("polars_series", lambda: pytest.importorskip("polars").Series),
899903
("slice", slice),
900904
],
901905
)
@@ -916,7 +920,15 @@ def test_convert_container(
916920
):
917921
"""Check that we convert the container to the right type of array with the
918922
right data type."""
919-
if constructor_name in ("dataframe", "polars", "series", "polars_series", "index"):
923+
if constructor_name in (
924+
"dataframe",
925+
"index",
926+
"polars",
927+
"polars_series",
928+
"pyarrow",
929+
"pyarrow_array",
930+
"series",
931+
):
920932
# delay the import of pandas/polars within the function to only skip this test
921933
# instead of the whole file
922934
container_type = container_type()
@@ -933,6 +945,8 @@ def test_convert_container(
933945
# list and tuple will use Python class dtype: int, float
934946
# pandas index will always use high precision: np.int64 and np.float64
935947
assert np.issubdtype(type(container_converted[0]), superdtype)
948+
elif constructor_name in ("polars", "polars_series", "pyarrow", "pyarrow_array"):
949+
return
936950
elif hasattr(container_converted, "dtype"):
937951
assert container_converted.dtype == dtype
938952
elif hasattr(container_converted, "dtypes"):

sklearn/utils/validation.py

+9
Original file line numberDiff line numberDiff line change
@@ -2348,6 +2348,15 @@ def _is_pandas_df(X):
23482348
return isinstance(X, pd.DataFrame)
23492349

23502350

2351+
def _is_pyarrow_data(X):
2352+
"""Return True if the X is a pyarrow Table, RecordBatch, Array or ChunkedArray."""
2353+
try:
2354+
pa = sys.modules["pyarrow"]
2355+
except KeyError:
2356+
return False
2357+
return isinstance(X, (pa.Table, pa.RecordBatch, pa.Array, pa.ChunkedArray))
2358+
2359+
23512360
def _is_polars_df_or_series(X):
23522361
"""Return True if the X is a polars dataframe or series."""
23532362
try:

0 commit comments

Comments
 (0)
0