8000 [MRG + 1] ENH: allow to pass callable as column specifier in ColumnTr… · scikit-learn/scikit-learn@9caa982 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9caa982

Browse files
jorisvandenbosscheGaelVaroquaux
authored andcommitted
[MRG + 1] ENH: allow to pass callable as column specifier in ColumnTransformer (#11592)
* define select_types callable factory * include dtype column selector example * support selector function for remainder=passthrough (already worked for drop) * add docstring to example file. apparently sphinx fails if missing * remove example and select_dtype factory * generalize callable case (all specification types) + add tests * add docstring
1 parent 1e05884 commit 9caa982

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

sklearn/compose/_column_transformer.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,15 @@ class ColumnTransformer(_BaseComposition, TransformerMixin):
6161
strings 'drop' and 'passthrough' are accepted as well, to
6262
indicate to drop the columns or to pass them through untransformed,
6363
respectively.
64-
column(s) : string or int, array-like of string or int, slice or \
65-
boolean mask array
64+
column(s) : string or int, array-like of string or int, slice, \
65+
boolean mask array or callable
6666
Indexes the data on its second axis. Integers are interpreted as
6767
positional columns, while strings can reference DataFrame columns
6868
by name. A scalar string or int should be used where
6969
``transformer`` expects X to be a 1d array-like (vector),
7070
otherwise a 2d array will be passed to the transformer.
71+
A callable is passed the input data `X` and can return any of the
72+
above.
7173
7274
remainder : {'passthrough', 'drop'} or estimator, default 'passthrough'
7375
By default, all remaining columns that were not specified in
@@ -499,6 +501,7 @@ def _get_column(X, key):
499501
Supported key types (key):
500502
- scalar: output is 1D
501503
- lists, slices, boolean masks: output is 2D
504+
- callable that returns any of the above
502505
503506
Supported key data types:
504507
@@ -510,6 +513,9 @@ def _get_column(X, key):
510513
can use any hashable object as key).
511514
512515
"""
516+
if callable(key):
517+
key = key(X)
518+
513519
# check whether we have string column names or integers
514520
if _check_key_type(key, int):
515521
column_names = False
@@ -551,6 +557,9 @@ def _get_column_indices(X, key):
551557
"""
552558
n_columns = X.shape[1]
553559

560+
if callable(key):
561+
key = key(X)
562+
554563
if _check_key_type(key, int):
555564
if isinstance(key, int):
556565
return [key]

sklearn/compose/tests/test_column_transformer.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ def test_column_transformer():
9999
assert_array_equal(ct.fit_transform(X_array), res)
100100
assert_array_equal(ct.fit(X_array).transform(X_array), res)
101101

102+
# callable that returns any of the allowed specifiers
103+
ct = ColumnTransformer([('trans', Trans(), lambda x: selection)],
104+
remainder='drop')
105+
assert_array_equal(ct.fit_transform(X_array), res)
106+
assert_array_equal(ct.fit(X_array).transform(X_array), res)
107+
102108
ct = ColumnTransformer([('trans1', Trans(), [0]),
103109
('trans2', Trans(), [1])])
104110
assert_array_equal(ct.fit_transform(X_array), X_res_both)
@@ -166,6 +172,12 @@ def test_column_transformer_dataframe():
166172
assert_array_equal(ct.fit_transform(X_df), res)
167173
assert_array_equal(ct.fit(X_df).transform(X_df), res)
168174

175+
# callable that returns any of the allowed specifiers
176+
ct = ColumnTransformer([('trans', Trans(), lambda X: selection)],
177+
remainder='drop')
178+
assert_array_equal(ct.fit_transform(X_df), res)
179+
assert_array_equal(ct.fit(X_df).transform(X_df), res)
180+
169181
ct = ColumnTransformer([('trans1', Trans(), ['first']),
170182
('trans2', Trans(), ['second'])])
171183
assert_array_equal(ct.fit_transform(X_df), X_res_both)
@@ -777,3 +789,31 @@ def test_column_transformer_no_estimators():
777789
assert len(ct.transformers_) == 1
778790
assert ct.transformers_[-1][0] == 'remainder'
779791
assert ct.transformers_[-1][2] == [0, 1, 2]
792+
793+
794+
def test_column_transformer_callable_specifier():
795+
# assert that function gets the full array / dataframe
796+
X_array = np.array([[0, 1, 2], [2, 4, 6]]).T
797+
X_res_first = np.array([[0, 1, 2]]).T
798+
799+
def func(X):
800+
assert_array_equal(X, X_array)
801+
return [0]
802+
803+
ct = ColumnTransformer([('trans', Trans(), func)],
804+
remainder='drop')
805+
assert_array_equal(ct.fit_transform(X_array), X_res_first)
806+
assert_array_equal(ct.fit(X_array).transform(X_array), X_res_first)
807+
808+
pd = pytest.importorskip('pandas')
809+
X_df = pd.DataFrame(X_array, columns=['first', 'second'])
810+
811+
def func(X):
812+
assert_array_equal(X.columns, X_df.columns)
813+
assert_array_equal(X.values, X_df.values)
814+
return ['first']
815+
816+
ct = ColumnTransformer([('trans', Trans(), func)],
817+
remainder='drop')
818+
assert_array_equal(ct.fit_transform(X_df), X_res_first)
819+
assert_array_equal(ct.fit(X_df).transform(X_df), X_res_first)

0 commit comments

Comments
 (0)
0