8000 FEA Add make_column_selector for ColumnTransformer (#12371) · rasbt/scikit-learn@37ac3fd · GitHub
[go: up one dir, main page]

Skip to content

Commit 37ac3fd

Browse files
thomasjpfanNicolasHug
authored andcommitted
FEA Add make_column_selector for ColumnTransformer (scikit-learn#12371)
1 parent 35c86fc commit 37ac3fd

File tree

6 files changed

+194
-5
lines changed

6 files changed

+194
-5
lines changed

doc/modules/classes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ details.
152152
:template: function.rst
153153

154154
compose.make_column_transformer
155+
compose.make_column_selector
155156

156157
.. _covariance_ref:
157158

doc/modules/compose.rst

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,25 @@ as most of other transformers expects 2D data, therefore in that case you need
462462
to specify the column as a list of strings (``['city']``).
463463

464464
Apart from a scalar or a single item list, the column selection can be specified
465-
as a list of multiple items, an integer array, a slice, or a boolean mask.
465+
as a list of multiple items, an integer array, a slice, a boolean mask, or
466+
with a :func:`~sklearn.compose.make_column_selector`. The
467+
:func:`~sklearn.compose.make_column_selector` is used to select columns based
468+
on data type or column name::
469+
470+
>>> from sklearn.preprocessing import StandardScaler
471+
>>> from sklearn.compose import make_column_selector
472+
>>> ct = ColumnTransformer([
473+
... ('scale', StandardScaler(),
474+
... make_column_selector(dtype_include=np.number)),
475+
... ('onehot',
476+
... OneHotEncoder(),
477+
... make_column_selector(pattern='city', dtype_include=object))])
478+
>>> ct.fit_transform(X)
479+
array([[ 0.904..., 0. , 1. , 0. , 0. ],
480+
[-1.507..., 1.414..., 1. , 0. , 0. ],
481+
[-0.301..., 0. , 0. , 1. , 0. ],
482+
[ 0.904..., -1.414..., 0. , 0. , 1. ]])
483+
466484
Strings can reference columns if the input is a DataFrame, integers are always
467485
interpreted as the positional columns.
468486

doc/whats_new/v0.22.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,10 @@ Changelog
160160
:mod:`sklearn.compose`
161161
......................
162162

163+
- |Feature| Adds :func:`compose.make_column_selector` which is used with
164+
:class:`compose.ColumnTransformer` to select DataFrame columns on the basis
165+
of name and dtype. :pr:`12303` by `Thomas Fan `_.
166+
163167
- |Fix| Fixed a bug in :class:`compose.ColumnTransformer` which failed to
164168
select the proper columns when using a boolean list, with NumPy older than
165169
1.12.

sklearn/compose/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
66
"""
77

8-
from ._column_transformer import ColumnTransformer, make_column_transformer
8+
from ._column_transformer import (ColumnTransformer, make_column_transformer,
9+
make_column_selector)
910
from ._target import TransformedTargetRegressor
1011

1112

1213
__all__ = [
1314
'ColumnTransformer',
1415
'make_column_transformer',
1516
'TransformedTargetRegressor',
17+
'make_column_selector',
1618
]

sklearn/compose/_column_transformer.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
from ..utils.validation import check_array, check_is_fitted
2626

2727

28-
__all__ = ['ColumnTransformer', 'make_column_transformer']
28+
__all__ = [
29+
'ColumnTransformer', 'make_column_transformer', 'make_column_selector'
30+
]
2931

3032

3133
_ERR_MSG_1DCOLUMN = ("1D data passed to a transformer that expects 2D data. "
@@ -69,7 +71,8 @@ class ColumnTransformer(TransformerMixin, _BaseComposition):
6971
``transformer`` expects X to be a 1d array-like (vector),
7072
otherwise a 2d array will be passed to the transformer.
7173
A callable is passed the input data `X` and can return any of the
72-
above.
74+
above. To select multiple columns by name or dtype, you can use
75+
:obj:`make_column_transformer`.
7376
7477
remainder : {'drop', 'passthrough'} or estimator, default 'drop'
7578
By default, only the specified columns in `transformers` are
@@ -145,6 +148,8 @@ class ColumnTransformer(TransformerMixin, _BaseComposition):
145148
sklearn.compose.make_column_transformer : convenience function for
146149
combining the outputs of multiple transformer objects applied to
147150
column subsets of the original feature space.
151+
sklearn.compose.make_column_selector : convenience function for selecting
152+
columns based on datatype or the columns name with a regex pattern.
148153
149154
Examples
150155
--------
@@ -759,3 +764,76 @@ def is_neg(x): return isinstance(x, numbers.Integral) and x < 0
759764
elif _determine_key_type(key) == 'int':
760765
return np.any(np.asarray(key) < 0)
761766
return False
767+
768+
769+
class make_column_selector:
770+
"""Create a callable to select columns to be used with
771+
:class:`ColumnTransformer`.
772+
773+
:func:`make_column_selector` can select columns based on datatype or the
774+
columns name with a regex. When using multiple selection criteria, **all**
775+
criteria must match for a column to be selected.
776+
777+
Parameters
778+
----------
779+
pattern : str, default=None
780+
Name of columns containing this regex pattern will be included. If
781+
None, column selection will not be selected based on pattern.
782+
783+
dtype_include : column dtype or list of column dtypes, default=None
784+
A selection of dtypes to include. For more details, see
785+
:meth:`pandas.DataFrame.select_dtypes`.
786+
787+
dtype_exclude : column dtype or list of column dtypes, default=None
788+
A selection of dtypes to exclude. For more details, see
789+
:meth:`pandas.DataFrame.select_dtypes`.
790+
791+
Returns
792+
-------
793+
selector : callable
794+
Callable for column selection to be used by a
795+
:class:`ColumnTransformer`.
796+
797+
See also
798+
--------
799+
sklearn.compose.ColumnTransformer : Class that allows combining the
800+
outputs of multiple transformer objects used on column subsets
801+
of the data into a single feature space.
802+
803+
Examples
804+
--------
805+
>>> from sklearn.preprocessing import StandardScaler, OneHotEncoder
806+
>>> from sklearn.compose import make_column_transformer
807+
>>> from sklearn.compose import make_column_selector
808+
>>> import pandas as pd # doctest: +SKIP
809+
>>> X = pd.DataFrame({'city': ['London', 'London', 'Paris', 'Sallisaw'],
810+
... 'rating': [5, 3, 4, 5]}) # doctest: +SKIP
811+
>>> ct = make_column_transformer(
812+
... (StandardScaler(),
813+
... make_column_selector(dtype_include=np.number)), # rating
814+
... (OneHotEncoder(),
815+
... make_column_selector(dtype_include=object))) # city
816+
>>> ct.fit_transform(X) # doctest: +SKIP
817+
array([[ 0.90453403, 1. , 0. , 0. ],
818+ [-1.50755672, 1. , 0. , 0. ],
819+
[-0.30151134, 0. , 1. , 0. ],
820+
[ 0.90453403, 0. , 0. , 1. ]])
821+
"""
822+
823+
def __init__(self, pattern=None, dtype_include=None, dtype_exclude=None):
824+
self.pattern = pattern
825+
self.dtype_include = dtype_include
826+
self.dtype_exclude = dtype_exclude
827+
828+
def __call__(self, df):
829+
if not hasattr(df, 'iloc'):
830+
raise ValueError("make_column_selector can only be applied to "
831+
"pandas dataframes")
832+
df_row = df.iloc[:1]
833+
if self.dtype_include is not None or self.dtype_exclude is not None:
834+
df_row = df_row.select_dtypes(include=self.dtype_include,
835+
exclude=self.dtype_exclude)
836+
cols = df_row.columns
837+
if self.pattern is not None:
838+
cols = cols[cols.str.contains(self.pattern, regex=True)]
839+
return cols.tolist()

sklearn/compose/tests/test_column_transformer.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,23 @@
22
Test the ColumnTransformer.
33
"""
44
import re
5+
import pickle
56

67
import warnings
78
import numpy as np
89
from scipy import sparse
910
import pytest
1011

12+
from numpy.testing import assert_allclose
1113
from sklearn.utils._testing import assert_raise_message
1214
from sklearn.utils._testing import assert_array_equal
1315
from sklearn.utils._testing import assert_allclose_dense_sparse
1416
from sklearn.utils._testing import assert_almost_equal
1517

1618
from sklearn.base import BaseEstimator
17-
from sklearn.compose import ColumnTransformer, make_column_transformer
19+
from sklearn.compose import (
20+
ColumnTransformer, make_column_transformer, make_column_selector
21+
)
1822
from sklearn.exceptions import NotFittedError
1923
from sklearn.preprocessing import FunctionTransformer
2024
from sklearn.preprocessing import StandardScaler, Normalizer, OneHotEncoder
@@ -1180,3 +1184,85 @@ def test_column_transformer_mask_indexing(array_type):
11801184
)
11811185
X_trans = column_transformer.fit_transform(X)
11821186
assert X_trans.shape == (3, 2)
1187+
1188+
1189+
@pytest.mark.parametrize('cols, pattern, include, exclude', [
1190+
(['col_int', 'col_float'], None, np.number, None),
1191+
(['col_int', 'col_float'], None, None, object),
1192+
(['col_int', 'col_float'], None, [np.int, np.float], None),
1193+
(['col_str'], None, [np.object], None),
1194+
(['col_str'], None, np.object, None),
1195+
(['col_float'], None, float, None),
1196+
(['col_float'], 'at$', [np.number], None),
1197+
(['col_int'], None, [np.int], None),
1198+
(['col_int'], '^col_int', [np.number], None),
1199+
(['col_float', 'col_str'], 'float|str', None, None),
1200+
(['col_str'], '^col_s', None, [np.int]),
1201+
([], 'str$', np.float, None),
1202+
(['col_int', 'col_float', 'col_str'], None, [np.number, np.object], None),
1203+
])
1204+
def test_make_column_selector_with_select_dtypes(cols, pattern, include,
1205+
exclude):
1206+
pd = pytest.importorskip('pandas')
1207+
1208+
X_df = pd.DataFrame({
1209+
'col_int': np.array([0, 1, 2], dtype=np.int),
1210+
'col_float': np.array([0.0, 1.0, 2.0], dtype=np.float),
1211+
'col_str': ["one", "two", "three"],
1212+
}, columns=['col_int', 'col_float', 'col_str'])
1213+
1214+
selector = make_column_selector(
1215+
dtype_include=include, dtype_exclude=exclude, pattern=pattern)
1216+
1217+
assert_array_equal(selector(X_df), cols)
1218+
1219+
1220+
def test_column_transformer_with_make_column_selector():
1221+
# Functional test for column transformer + column selector
1222+
pd = pytest.importorskip('pandas')
1223+
X_df = pd.DataFrame({
1224+
'col_int': np.array([0, 1, 2], dtype=np.int),
1225+
'col_float': np.array([0.0, 1.0, 2.0], dtype=np.float),
1226+
'col_cat': ["one", "two", "one"],
1227+
'col_str': ["low", "middle", "high"]
1228+
}, columns=['col_int', 'col_float', 'col_cat', 'col_str'])
1229+
X_df['col_str'] = X_df['col_str'].astype('category')
1230+
1231+
cat_selector = make_column_selector(dtype_include=['category', object])
1232+
num_selector = make_column_selector(dtype_include=np.number)
1233+
1234+
ohe = OneHotEncoder()
1235+
scaler = StandardScaler()
1236+
1237+
ct_selector = make_column_transformer((ohe, cat_selector),
1238+
(scaler, num_selector))
1239+
ct_direct = make_column_transformer((ohe, ['col_cat', 'col_str']),
1240+
(scaler, ['col_float', 'col_int']))
1241+
1242+
X_selector = ct_selector.fit_transform(X_df)
1243+
X_direct = ct_direct.fit_transform(X_df)
1244+
1245+
assert_allclose(X_selector, X_direct)
1246+
1247+
1248+
def test_make_column_selector_error():
1249+
selector = make_column_selector(dtype_include=np.number)
1250+
X = np.array([[0.1, 0.2]])
1251+
msg = ("make_column_selector can only be applied to pandas dataframes")
1252+
with pytest.raises(ValueError, match=msg):
1253+
selector(X)
1254+
1255+
1256+
def test_make_column_selector_pickle():
1257+
pd = pytest.importorskip('pandas')
1258+
1259+
X_df = pd.DataFrame({
1260+
'col_int': np.array([0, 1, 2], dtype=np.int),
1261+
'col_float': np.array([0.0, 1.0, 2.0], dtype=np.float),
1262+
'col_str': ["one", "two", "three"],
1263+
}, columns=['col_int', 'col_float', 'col_str'])
1264+
1265+
selector = make_column_selector(dtype_include=[object])
1266+
selector_picked = pickle.loads(pickle.dumps(selector))
1267+
1268+
assert_array_equal(selector(X_df), selector_picked(X_df))

0 commit comments

Comments
 (0)
0