8000 ENH Adds polars output support to `set_output` API (#27315) · scikit-learn/scikit-learn@831c49a · GitHub
[go: up one dir, main page]

Skip to content

Commit 831c49a

Browse files
authored
ENH Adds polars output support to set_output API (#27315)
1 parent 5c4288d commit 831c49a

File tree

20 files changed

+712
-250
lines changed

20 files changed

+712
-250
lines changed

doc/whats_new/v1.4.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ random sampling procedures.
4141
Changes impacting all modules
4242
-----------------------------
4343

44+
- |MajorFeature| Transformers now support polars output with `set_output(transform="polars")`.
45+
:pr:`27315` by `Thomas Fan`_.
46+
4447
- |Enhancement| All estimators now recognizes the column names from any dataframe
4548
that adopts the
4649
`DataFrame Interchange Protocol <https://data-apis.org/dataframe-protocol/latest/purpose_and_scope.html>`__.

sklearn/_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,12 @@ def set_config(
134134
135135
- `"default"`: Default output format of a transformer
136136
- `"pandas"`: DataFrame output
137+
- `"polars"`: Polars output
137138
- `None`: Transform configuration is unchanged
138139
139140
.. versionadded:: 1.2
141+
.. versionadded:: 1.4
142+
`"polars"` option was added.
140143
141144
enable_metadata_routing : bool, default=None
142145
Enable metadata routing. By default this feature is disabled.
@@ -281,9 +284,12 @@ def config_context(
281284
282285
- `"default"`: Default output format of a transformer
283286
- `"pandas"`: DataFrame output
287+
- `"polars"`: Polars output
284288
- `None`: Transform configuration is unchanged
285289
286290
.. versionadded:: 1.2
291+
.. versionadded:: 1.4
292+
`"polars"` option was added.
287293
288294
enable_metadata_routing : bool, default=None
289295
Enable metadata routing. By default this feature is disabled.

sklearn/_min_dependencies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
"black": ("23.3.0", "tests"),
4141
"mypy": ("1.3", "tests"),
4242
"pyamg": ("4.0.0", "tests"),
43-
"polars": ("0.18.2", "tests"),
43+
"polars": ("0.19.12", "tests"),
4444
"pyarrow": ("12.0.0", "tests"),
4545
"sphinx": ("6.0.0", "docs"),
4646
"sphinx-copybutton": ("0.5.2", "docs"),

sklearn/compose/_column_transformer.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
to work with heterogeneous data and to apply different transformers to
44
different columns.
55
"""
6+
67
# Author: Andreas Mueller
78
# Joris Van den Bossche
89
# License: BSD
@@ -16,11 +17,15 @@
1617
from ..base import TransformerMixin, _fit_context, clone
1718
from ..pipeline import _fit_transform_one, _name_estimators, _transform_one
1819
from ..preprocessing import FunctionTransformer
19-
from ..utils import Bunch, _get_column_indices, _safe_indexing, check_pandas_support
20+
from ..utils import Bunch, _get_column_indices, _safe_indexing
2021
from ..utils._estimator_html_repr import _VisualBlock
2122
from ..utils._metadata_requests import METHODS
2223
from ..utils._param_validation import HasMethods, Hidden, Interval, StrOptions
23-
from ..utils._set_output import _get_output_config, _safe_set_output
24+
from ..utils._set_output import (
25+
_get_container_adapter,
26+
_get_output_config,
27+
_safe_set_output,
28+
)
2429
from ..utils.metadata_routing import (
2530
MetadataRouter,
2631
MethodMapping,
@@ -310,8 +315,12 @@ def set_output(self, *, transform=None):
310315
311316
- `"default"`: Default output format of a transformer
312317
- `"pandas"`: DataFrame output
318+
- `"polars"`: Polars output
313319
- `None`: Transform configuration is unchanged
314320
321+
.. versionadded:: 1.4
322+
`"polars"` option was added.
323+
315324
Returns
316325
-------
317326
self : estimator instance
@@ -1006,10 +1015,9 @@ def _hstack(self, Xs):
10061015
return sparse.hstack(converted_Xs).tocsr()
10071016
else:
10081017
Xs = [f.toarray() if sparse.issparse(f) else f for f in Xs]
1009-
config = _get_output_config("transform", self)
1010-
if config["dense"] == "pandas" and all(hasattr(X, "iloc") for X in Xs):
1011-
pd = check_pandas_support("transform")
1012-
output = pd.concat(Xs, axis=1)
1018+
adapter = _get_container_adapter("transform", self)
1019+
if adapter and all(adapter.is_supported_container(X) for X in Xs):
1020+
output = adapter.hstack(Xs)
10131021

10141022
output_samples = output.shape[0]
10151023
if any(_num_samples(X) != output_samples for X in Xs):
@@ -1042,8 +1050,7 @@ def _hstack(self, Xs):
10421050
names_out = self._add_prefix_for_feature_names_out(
10431051
list(zip(transformer_names, feature_names_outs))
10441052
)
1045-
output.columns = names_out
1046-
return output
1053+
return adapter.rename_columns(output, names_out)
10471054

10481055
return np.hstack(Xs)
10491056

sklearn/decomposition/_dict_learning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1227,7 +1227,7 @@ def dict_learning(
12271227
positive_code=positive_code,
12281228
positive_dict=positive_dict,
12291229
transform_max_iter=method_max_iter,
1230-
)
1230+
).set_output(transform="default")
12311231
code = estimator.fit_transform(X)
12321232
if return_n_iter:
12331233
return (

sklearn/decomposition/_sparse_pca.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,9 @@ def _fit(self, X, n_components, random_state):
546546
callback=self.callback,
547547
tol=self.tol,
548548
max_no_improvement=self.max_no_improvement,
549-
).fit(X.T)
549+
)
550+
est.set_output(transform="default")
551+
est.fit(X.T)
550552

551553
self.components_, self.n_iter_ = est.transform(X.T).T, est.n_iter_
552554

sklearn/feature_selection/_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from ..base import TransformerMixin
1414
from ..utils import (
15+
_is_pandas_df,
1516
_safe_indexing,
1617
check_array,
1718
safe_sqr,
@@ -81,7 +82,7 @@ def transform(self, X):
8182
# Preserve X when X is a dataframe and the output is configured to
8283
# be pandas.
8384
output_config_dense = _get_output_config("transform", estimator=self)["dense"]
84-
preserve_X = hasattr(X, "iloc") and output_config_dense == "pandas"
85+
preserve_X = output_config_dense != "default" and _is_pandas_df(X)
8586

8687
# note: we use _safe_tags instead of _get_tags because this is a
8788
# public Mixin.

sklearn/pipeline.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@
1818
from .base import TransformerMixin, _fit_context, clone
1919
from .exceptions import NotFittedError
2020
from .preprocessing import FunctionTransformer
21-
from .utils import Bunch, _print_elapsed_time, check_pandas_support
21+
from .utils import Bunch, _print_elapsed_time
2222
from .utils._estimator_html_repr import _VisualBlock
2323
from .utils._metadata_requests import METHODS
2424
from .utils._param_validation import HasMethods, Hidden
25-
from .utils._set_output import _get_output_config, _safe_set_output
25+
from .utils._set_output import (
26+
_get_container_adapter,
27+
_safe_set_output,
28+
)
2629
from .utils._tags import _safe_tags
2730
from .utils.metadata_routing import (
2831
MetadataRouter,
@@ -179,9 +182,12 @@ def set_output(self, *, transform=None):
179182
Configure output of `transform` and `fit_transform`.
180183
181184
- `"default"`: Default output format of a transformer
182-
- `"pandas"`: DataFrame output
185+
- `"polars"`: Polars output
183186
- `None`: Transform configuration is unchanged
184187
188+
.. versionadded:: 1.4
189+
`"polars"` option was added.
190+
185191
Returns
186192
-------
187193
self : estimator instance
@@ -1674,10 +1680,9 @@ def transform(self, X):
16741680
return self._hstack(Xs)
16751681

16761682
def _hstack(self, Xs):
1677-
config = _get_output_config("transform", self)
1678-
if config["dense"] == "pandas" and all(hasattr(X, "iloc") for X in Xs):
1679-
pd = check_pandas_support("transform")
1680-
return pd.concat(Xs, axis=1)
1683+
adapter = _get_container_adapter("transform", self)
1684+
if adapter and all(adapter.is_supported_container(X) for X in Xs):
1685+
return adapter.hstack(Xs)
16811686

16821687
if any(sparse.issparse(f) for f in Xs):
16831688
Xs = sparse.hstack(Xs).tocsr()

sklearn/preprocessing/_encoders.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,11 +1031,13 @@ def transform(self, X):
10311031
"""
10321032
check_is_fitted(self)
10331033
transform_output = _get_output_config("transform", estimator=self)["dense"]
1034-
if transform_output == "pandas" and self.sparse_output:
1034+
if transform_output != "default" and self.sparse_output:
1035+
capitalize_transform_output = transform_output.capitalize()
10351036
raise ValueError(
1036-
"Pandas output does not support sparse data. Set sparse_output=False to"
1037-
" output pandas DataFrames or disable pandas output via"
1038-
' `ohe.set_output(transform="default").'
1037+
f"{capitalize_transform_output} output does not support sparse data."
1038+
f" Set sparse_output=False to output {transform_output} dataframes or"
1039+
f" disable {capitalize_transform_output} output via"
1040+
'` ohe.set_output(transform="default").'
10391041
)
10401042

10411043
# validation of X happens in _check_X called by _transform

sklearn/preprocessing/_function_transformer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,12 @@ def set_output(self, *, transform=None):
346346
347347
- `"default"`: Default output format of a transformer
348348
- `"pandas"`: DataFrame output
349+
- `"polars"`: Polars output
349350
- `None`: Transform configuration is unchanged
350351
352+
.. versionadded:: 1.4
353+
`"polars"` option was added.
354+
351355
Returns
352356
-------
353357
self : estimator instance

0 commit comments

Comments
 (0)
0