8000 ENH Introduces set_output API for pandas output (#23734) · scikit-learn/scikit-learn@2a6703d · GitHub
[go: up one dir, main page]

Skip to content

Commit 2a6703d

Browse files
authored
ENH Introduces set_output API for pandas output (#23734)
* Introduces set_output API for all transformers * TransformerMixin inherits from _SetOutputMixin * Adds tests * Adds whatsnew * Adds example on using set_output API * Adds developer docs for set_output
1 parent 93c7306 commit 2a6703d

21 files changed

+1196
-9
lines changed

doc/developers/develop.rst

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,35 @@ instantiated with an instance of ``LogisticRegression`` (or
635635
of these two models is somewhat idiosyncratic but both should provide robust
636636
closed-form solutions.
637637

638+
.. _developer_api_set_output:
639+
640+
Developer API for `set_output`
641+
==============================
642+
643+
With
644+
`SLEP018 <https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep018/proposal.html>`__,
645+
scikit-learn introduces the `set_output` API for configuring transformers to
646+
output pandas DataFrames. The `set_output` API is automatically defined if the
647+
transformer defines :term:`get_feature_names_out` and subclasses
648+
:class:`base.TransformerMixin`. :term:`get_feature_names_out` is used to get the
649+
column names of pandas output. You can opt-out of the `set_output` API by
650+
setting `auto_wrap_output_keys=None` when defining a custom subclass::
651+
652+
class MyTransformer(TransformerMixin, BaseEstimator, auto_wrap_output_keys=None):
653+
654+
def fit(self, X, y=None):
655+
return self
656+
def transform(self, X, y=None):
657+
return X
658+
def get_feature_names_out(self, input_features=None):
659+
...
660+
661+
For transformers that return multiple arrays in `transform`, auto wrapping will
662+
only wrap the first array and not alter the other arrays.
663+
664+
See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py`
665+
for an example on how to use the API.
666+
638667
.. _coding-guidelines:
639668

640669
Coding guidelines

doc/whats_new/v1.2.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ random sampling procedures.
5252
Changes impacting all modules
5353
-----------------------------
5454

55+
- |MajorFeature| The `set_output` API has been adopted by all transformers.
56+
Meta-estimators that contain transformers such as :class:`pipeline.Pipeline`
57+
or :class:`compose.ColumnTransformer` also define a `set_output`.
58+
For details, see
59+
`SLEP018 <https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep018/proposal.html>`__.
60+
:pr:`23734` by `Thomas Fan`_.
61+
5562
- |Enhancement| Finiteness checks (detection of NaN and infinite values) in all
5663
estimators are now significantly more efficient for float32 data by leveraging
5764
NumPy's SIMD optimized primitives.
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""
2+
================================
3+
Introducing the `set_output` API
4+
================================
5+
6+
.. currentmodule:: sklearn
7+
8+
This example will demonstrate the `set_output` API to configure transformers to
9+
output pandas DataFrames. `set_output` can be configured per estimator by calling
10+
the `set_output` method or globally by setting `set_config(transform_output="pandas")`.
11+
For details, see
12+
`SLEP018 <https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep018/proposal.html>`__.
13+
""" # noqa
14+
15+
# %%
16+
# First, we load the iris dataset as a DataFrame to demonstrate the `set_output` API.
17+
from sklearn.datasets import load_iris
18+
from sklearn.model_selection import train_test_split
19+
20+
X, y = load_iris(as_frame=True, return_X_y=True)
21+
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0)
22+
X_train.head()
23+
24+
# %%
25+
# To configure an estimator such as :class:`preprocessing.StandardScalar` to return
26+
# DataFrames, call `set_output`. This feature requires pandas to be installed.
27+
28+
from sklearn.preprocessing import StandardScaler
29+
30+
scaler = StandardScaler().set_output(transform="pandas")
31+
32+
scaler.fit(X_train)
33+
X_test_scaled = scaler.transform(X_test)
34+
X_test_scaled.head()
35+
36+
# %%
37+
# `set_output` can be called after `fit` to configure `transform` after the fact.
38+
scaler2 = StandardScaler()
39+
40+
scaler2.fit(X_train)
41+
X_test_np = scaler2.transform(X_test)
42+
print(f"Default output type: {type(X_test_np).__name__}")
43+
44+
scaler2.set_output(transform="pandas")
45+
X_test_df = scaler2.transform(X_test)
46+
print(f"Configured pandas output type: {type(X_test_df).__name__}")
47+
48+
# %%
49+
# In a :class:`pipeline.Pipeline`, `set_output` configures all steps to output
50+
# DataFrames.
51+
from sklearn.pipeline import make_pipeline
52+
from sklearn.linear_model import LogisticRegression
53+
from sklearn.feature_selection import SelectPercentile
54+
55+
clf = make_pipeline(
56+
StandardScaler(), SelectPercentile(percentile=75), LogisticRegression()
57+
)
58+
clf.set_output(transform="pandas")
59+
clf.fit(X_train, y_train)
60+
61+
# %%
62+
# Each transformer in the pipeline is configured to return DataFrames. This
63+
# means that the final logistic regression step contain the feature names.
64+
clf[-1].feature_names_in_
65+
66+
# %%
67+
# Next we load the titanic dataset to demonstrate `set_output` with
68+
# :class:`compose.ColumnTransformer` and heterogenous data.
69+
from sklearn.datasets import fetch_openml
70+
71+
X, y = fetch_openml(
72+
"titanic", version=1, as_frame=True, return_X_y=True, parser="pandas"
73+
)
74+
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y)
75+
76+
# %%
77+
# The `set_output` API can be configured globally by using :func:`set_config` and
78+
# setting the `transform_output` to `"pandas"`.
79+
from sklearn.compose import ColumnTransformer
80+
from sklearn.preprocessing import OneHotEncoder, StandardScaler
81+
from sklearn.impute import SimpleImputer
82+
from sklearn import set_config
83+
84+
set_config(transform_output="pandas")
85+
86+
num_pipe = make_pipeline(SimpleImputer(), StandardScaler())
87+
ct = ColumnTransformer(
88+
(
89+
("numerical", num_pipe, ["age", "fare"]),
90+
(
91+
"categorical",
92+
OneHotEncoder(
93+
sparse_output=False, drop="if_binary", handle_unknown="ignore"
94+
),
95+
["embarked", "sex", "pclass"],
96+
),
97+
),
98+
verbose_feature_names_out=False,
99+
)
100+
clf = make_pipeline(ct, SelectPercentile(percentile=50), LogisticRegression())
101+
clf.fit(X_train, y_train)
102+
clf.score(X_test, y_test)
103+
104+
# %%
105+
# With the global configuration, all transformers output DataFrames. This allows us to
106+
# easily plot the logistic regression coefficients with the corresponding feature names.
107+
import pandas as pd
108+
109+
log_reg = clf[-1]
110+
coef = pd.Series(log_reg.coef_.ravel(), index=log_reg.feature_names_in_)
111+
_ = coef.sort_values().plot.barh()

sklearn/_config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
),
1515
"enable_cython_pairwise_dist": True,
1616
"array_api_dispatch": False,
17+
"transform_output": "default",
1718
}
1819
_threadlocal = threading.local()
1920

@@ -52,6 +53,7 @@ def set_config(
5253
pairwise_dist_chunk_size=None,
5354
enable_cython_pairwise_dist=None,
5455
array_api_dispatch=None,
56+
transform_output=None,
5557
):
5658
"""Set global scikit-learn configuration
5759
@@ -120,6 +122,11 @@ def set_config(
120122
121123
.. versionadded:: 1.2
122124
125+
transform_output : str, default=None
126+
Configure the output container for transform.
127+
128+
.. versionadded:: 1.2
129+
123130
See Also
124131
--------
125132
config_context : Context manager for global scikit-learn configuration.
@@ -141,6 +148,8 @@ def set_config(
141148
local_config["enable_cython_pairwise_dist"] = enable_cython_pairwise_dist
142149
if array_api_dispatch is not None:
143150
local_config["array_api_dispatch"] = array_api_dispatch
151+
if transform_output is not None:
152+
local_config["transform_output"] = transform_output
144153

145154

146155
@contextmanager
@@ -153,6 +162,7 @@ def config_context(
153162
pairwise_dist_chunk_size=None,
154163
enable_cython_pairwise_dist=None,
155164
array_api_dispatch=None,
165+
transform_output=None,
156166
):
157167
"""Context manager for global scikit-learn configuration.
158168
@@ -220,6 +230,11 @@ def config_context(
220230
221231
.. versionadded:: 1.2
222232
233+
transform_output : str, default=None
234+
Configure the output container for transform.
235+
236+
.. versionadded:: 1.2
237+
223238
Yields
224239
------
225240
None.
@@ -256,6 +271,7 @@ def config_context(
256271
pairwise_dist_chunk_size=pairwise_dist_chunk_size,
257272
enable_cython_pairwise_dist=enable_cython_pairwise_dist,
258273
array_api_dispatch=array_api_dispatch,
274+
transform_output=transform_output,
259275
)
260276

261277
try:

sklearn/base.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from . import __version__
1616
from ._config import get_config
1717
from .utils import _IS_32BIT
18+
from .utils._set_output import _SetOutputMixin
1819
from .utils._tags import (
1920
_DEFAULT_TAGS,
2021
)
@@ -98,6 +99,13 @@ def clone(estimator, *, safe=True):
9899
"Cannot clone object %s, as the constructor "
99100
"either does not set or modifies parameter %s" % (estimator, name)
100101
)
102+
103+
# _sklearn_output_config is used by `set_output` to configure the output
104+
# container of an estimator.
105+
if hasattr(estimator, "_sklearn_output_config"):
106+
new_object._sklearn_output_config = copy.deepcopy(
107+
estimator._sklearn_output_config
108+
)
101109
return new_object
102110

103111

@@ -798,8 +806,13 @@ def get_submatrix(self, i, data):
798806
return data[row_ind[:, np.newaxis], col_ind]
799807

800808

801-
class TransformerMixin:
802-
"""Mixin class for all transformers in scikit-learn."""
809+
class TransformerMixin(_SetOutputMixin):
810+
"""Mixin class for all transformers in scikit-learn.
811+
812+
If :term:`get_feature_names_out` is defined and `auto_wrap_output` is True,
813+
then `BaseEstimator` will automatically wrap `transform` and `fit_transform` to
814+
follow the `set_output` API. See the :ref:`developer_api_set_output` for details.
815+
"""
803816

804817
def fit_transform(self, X, y=None, **fit_params):
805818
"""

sklearn/compose/_column_transformer.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from ..utils import Bunch
2121
from ..utils import _safe_indexing
2222
from ..utils import _get_column_indices
23+
from ..utils._set_output import _get_output_config, _safe_set_output
24+
from ..utils import check_pandas_support
2325
from ..utils.metaestimators import _BaseComposition
2426
from ..utils.validation import check_array, check_is_fitted, _check_feature_names_in
2527
from ..utils.fixes import delayed
@@ -252,6 +254,35 @@ def _transformers(self, value):
252254
except (TypeError, ValueError):
253255
self.transformers = value
254256

257+
def set_output(self, transform=None):
258+
"""Set the output container when `"transform`" and `"fit_transform"` are called.
259+
260+
Calling `set_output` will set the output of all estimators in `transformers`
261+
and `transformers_`.
262+
263+
Parameters
264+
----------
265+
transform : {"default", "pandas"}, default=None
266+
Configure output of `transform` and `fit_transform`.
267+
268+
Returns
269+
-------
270+
self : estimator instance
271+
Estimator instance.
272+
"""
273+
super().set_output(transform=transform)
274+
transformers = (
275+
trans
276+
for _, trans, _ in chain(
277+
self.transformers, getattr(self, "transformers_", [])
278+
)
279+
if trans not in {"passthrough", "drop"}
280+
)
281+
for trans in transformers:
282+
_safe_set_output(trans, transform=transform)
283+
284+
return self
285+
255286
def get_params(self, deep=True):
256287
"""Get parameters for this estimator.
257288
@@ -302,7 +333,19 @@ def _iter(self, fitted=False, replace_strings=False, column_as_strings=False):
302333
303334
"""
304335
if fitted:
305-
transformers = self.transformers_
336+
if replace_strings:
337+
# Replace "passthrough" with the fitted version in
338+
# _name_to_fitted_passthrough
339+
def replace_passthrough(name, trans, columns):
340+
if name not in self._name_to_fitted_passthrough:
341+
return name, trans, columns
342+
return name, self._name_to_fitted_passthrough[name], columns
343+
344+
transformers = [
345+
replace_passthrough(*trans) for trans in self.transformers_
346+
]
347+
else:
348+
transformers = self.transformers_
306349
else:
307350
# interleave the validated column specifiers
308351
transformers = [
@@ -314,12 +357,17 @@ def _iter(self, fitted=False, replace_strings=False, column_as_strings=False):
314357
transformers = chain(transformers, [self._remainder])
315358
get_weight = (self.transformer_weights or {}).get
316359

360+
output_config = _get_output_config("transform", self)
317361
for name, trans, columns in transformers:
318362
if replace_strings:
319363
# replace 'passthrough' with identity transformer and
320364
# skip in case of 'drop'
321365
if trans == "passthrough":
322-
trans = FunctionTransformer(accept_sparse=True, check_inverse=False)
366+
trans = FunctionTransformer(
367+
accept_sparse=True,
368+
check_inverse=False,
369+
feature_names_out="one-to-one",
370+
).set_output(transform=output_config["dense"])
323371
elif trans == "drop":
324372
continue
325373
elif _is_empty_column_selection(columns):
@@ -505,15 +553,20 @@ def _update_fitted_transformers(self, transformers):
505553
# transformers are fitted; excludes 'drop' cases
506554
fitted_transformers = iter(transformers)
507555
transformers_ = []
556+
self._name_to_fitted_passthrough = {}
508557

509558
for name, old, column, _ in self._iter():
510559
if old == "drop":
511560
trans = "drop"
512561
elif old == "passthrough":
513562
# FunctionTransformer is present in list of transformers,
514563
# so get next transformer, but save original string
515-
next(fitted_transformers)
564+
func_transformer = next(fitted_transformers)
516565
trans = "passthrough"
566+
567+
# The fitted FunctionTransformer is saved in another attribute,
568+
# so it can be used during transform for set_output.
569+
self._name_to_fitted_passthrough[name] = func_transformer
517570
elif _is_empty_column_selection(column):
518571
trans = old
519572
else:
@@ -765,6 +818,10 @@ def _hstack(self, Xs):
765818
return sparse.hstack(converted_Xs).tocsr()
766819
else:
767820
Xs = [f.toarray() if sparse.issparse(f) else f for f in Xs]
821+
config = _get_output_config("transform", self)
822+
if config["dense"] == "pandas" and all(hasattr(X, "iloc") for X in Xs):
823+
pd = check_pandas_support("transform")
824+
return pd.concat(Xs, axis=1)
768825
return np.hstack(Xs)
769826

770827
def _sk_visual_block_(self):

0 commit comments

Comments
 (0)
0