8000 API Change the tuple order in make_column_transformer (#12626) · scikit-learn/scikit-learn@a816af7 · GitHub
[go: up one dir, main page]

Skip to content

Commit a816af7

Browse files
adrinjalaliqinhanmin2014
authored andcommitted
API Change the tuple order in make_column_transformer (#12626)
1 parent 565262a commit a816af7

File tree

4 files changed

+92
-18
lines changed

4 files changed

+92
-18
lines changed

doc/modules/compose.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -493,8 +493,8 @@ above example would be::
493493

494494
>>> from sklearn.compose import make_column_transformer
495495
>>> column_trans = make_column_transformer(
496-
... ('city', CountVectorizer(analyzer=lambda x: [x])),
497-
... ('title', CountVectorizer()),
496+
... (CountVectorizer(analyzer=lambda x: [x]), 'city'),
497+
... (CountVectorizer(), 'title'),
498498
... remainder=MinMaxScaler())
499499
>>> column_trans # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
500500
ColumnTransformer(n_jobs=None, remainder=MinMaxScaler(copy=True, ...),

doc/whats_new/v0.20.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ Changelog
4848
even if all transformation results are sparse. :issue:`12304` by `Andreas
4949
Müller`_.
5050

51+
- |API| :func:`compose.make_column_transformer` now expects
52+
``(transformer, columns)`` instead of ``(columns, transformer)`` to keep
53+
consistent with :class:`compose.ColumnTransformer`.
54+
:issue:`12339` by :user:`Adrin Jalali <adrinjalali>`.
55+
5156
:mod:`sklearn.datasets`
5257
............................
5358

sklearn/compose/_column_transformer.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from itertools import chain
1212

1313
import numpy as np
14+
import warnings
1415
from scipy import sparse
1516

1617
from ..base import clone, TransformerMixin
@@ -681,14 +682,63 @@ def _is_empty_column_selection(column):
681682
return False
682683

683684

685+
def _validate_transformers(transformers):
686+
"""Checks if given transformers are valid.
687+
688+
This is a helper function to support the deprecated tuple order.
689+
XXX Remove in v0.22
690+
"""
691+
if not transformers:
692+
return True
693+
694+
for t in transformers:
695+
if t in ('drop', 'passthrough'):
696+
continue
697+
if (not (hasattr(t, "fit") or hasattr(t, "fit_transform")) or not
698+
hasattr(t, "transform")):
699+
return False
700+
701+
return True
702+
703+
704+
def _is_deprecated_tuple_order(tuples):
705+
"""Checks if the input follows the deprecated tuple order.
706+
707+
Returns
708+
-------
709+
Returns true if (transformer, columns) is not a valid assumption for the
710+
input, but (columns, transformer) is valid. The latter is deprecated and
711+
its support will stop in v0.22.
712+
713+
XXX Remove in v0.22
714+
"""
715+
transformers, columns = zip(*tuples)
716+
if (not _validate_transformers(transformers)
717+
and _validate_transformers(columns)):
718+
return True
719+
720+
return False
721+
722+
684723
def _get_transformer_list(estimators):
685724
"""
686725
Construct (name, trans, column) tuples from list
687726
688727
"""
689-
transformers = [trans[1] for trans in estimators]
690-
columns = [trans[0] for trans in estimators]
691-
names = [trans[0] for trans in _name_estimators(transformers)]
728+
message = ('`make_column_transformer` now expects (transformer, columns) '
729+
'as input tuples instead of (columns, transformer). This '
730+
'has been introduced in v0.20.1. `make_column_transformer` '
< F438 /code>
731+
'will stop accepting the deprecated (columns, transformer) '
732+
'order in v0.22.')
733+
734+
transformers, columns = zip(*estimators)
735+
736+
# XXX Remove in v0.22
737+
if _is_deprecated_tuple_order(estimators):
738+
transformers, columns = columns, transformers
739+
warnings.warn(message, DeprecationWarning)
740+
741+
names, _ = zip(*_name_estimators(transformers))
692742

693743
transformer_list = list(zip(names, transformers, columns))
694744
return transformer_list
@@ -704,7 +754,7 @@ def make_column_transformer(*transformers, **kwargs):
704754
705755
Parameters
706756
----------
707-
*transformers : tuples of column selections and transformers
757+
*transformers : tuples of transformers and column selections
708758
709759
remainder : {'drop', 'passthrough'} or estimator, default 'drop'
710760
By default, only the specified columns in `transformers` are
@@ -747,8 +797,8 @@ def make_column_transformer(*transformers, **kwargs):
747797
>>> from sklearn.preprocessing import StandardScaler, OneHotEncoder
748798
>>> from sklearn.compose import make_column_transformer
749799
>>> make_column_transformer(
750-
... (['numerical_column'], StandardScaler()),
751-
... (['categorical_column'], OneHotEncoder()))
800+
... (StandardScaler(), ['numerical_column']),
801+
... (OneHotEncoder(), ['categorical_column']))
752802
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
753803
ColumnTransformer(n_jobs=None, remainder='drop', sparse_threshold=0.3,
754804
transformer_weights=None,

sklearn/compose/tests/test_column_transformer.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from sklearn.utils.testing import assert_dict_equal
1414
from sklearn.utils.testing import assert_array_equal
1515
from sklearn.utils.testing import assert_allclose_dense_sparse
16+
from sklearn.utils.testing import assert_almost_equal
1617

1718
from sklearn.base import BaseEstimator
1819
from sklearn.externals import six
@@ -373,8 +374,8 @@ def test_column_transformer_mixed_cols_sparse():
373374
dtype='O')
374375

375376
ct = make_column_transformer(
376-
([0], OneHotEncoder()),
377-
([1, 2], 'passthrough'),
377+
(OneHotEncoder(), [0]),
378+
('passthrough', [1, 2]),
378379
sparse_threshold=1.0
379380
)
380381

@@ -386,8 +387,8 @@ def test_column_transformer_mixed_cols_sparse():
386387
[0, 1, 2, 0]]))
387388

388389
ct = make_column_transformer(
389-
([0], OneHotEncoder()),
390-
([0], 'passthrough'),
390+
(OneHotEncoder(), [0]),
391+
('passthrough', [0]),
391392
sparse_threshold=1.0
392393
)
393394
with pytest.raises(ValueError,
@@ -516,29 +517,47 @@ def predict(self, X):
516517
def test_make_column_transformer():
517518
scaler = StandardScaler()
518519
norm = Normalizer()
519-
ct = make_column_transformer(('first', scaler), (['second'], norm))
520+
ct = make_column_transformer((scaler, 'first'), (norm, ['second']))
520521
names, transformers, columns = zip(*ct.transformers)
521522
assert_equal(names, ("standardscaler", "normalizer"))
522523
assert_equal(transformers, (scaler, norm))
523524
assert_equal(columns, ('first', ['second']))
524525

526+
# XXX remove in v0.22
527+
with pytest.warns(DeprecationWarning,
528+
match='`make_column_transformer` now expects'):
529+
ct1 = make_column_transformer(([0], norm))
530+
ct2 = make_column_transformer((norm, [0]))
531+
X_array = np.array([[0, 1, 2], [2, 4, 6]]).T
532+
assert_almost_equal(ct1.fit_transform(X_array),
533+
ct2.fit_transform(X_array))
534+
535+
with pytest.warns(DeprecationWarning,
536+
match='`make_column_transformer` now expects'):
537+
make_column_transformer(('first', 'drop'))
538+
539+
with pytest.warns(DeprecationWarning,
540+
match='`make_column_transformer` now expects'):
541+
make_column_transformer(('passthrough', 'passthrough'),
542+
('first', 'drop'))
543+
525544

526545
def test_make_column_transformer_kwargs():
527546
scaler = StandardScaler()
528547
norm = Normalizer()
529-
ct = make_column_transformer(('first', scaler), (['second'], norm),
548+
ct = make_column_transformer((scaler, 'first'), (norm, ['second']),
530549
n_jobs=3, remainder='drop',
531550
sparse_threshold=0.5)
532551
assert_equal(ct.transformers, make_column_transformer(
533-
('first', scaler), (['second'], norm)).transformers)
552+
(scaler, 'first'), (norm, ['second'])).transformers)
534553
assert_equal(ct.n_jobs, 3)
535554
assert_equal(ct.remainder, 'drop')
536555
assert_equal(ct.sparse_threshold, 0.5)
537556
# invalid keyword parameters should raise an error message
538557
assert_raise_message(
539558
TypeError,
540559
'Unknown keyword arguments: "transformer_weights"',
541-
make_column_transformer, ('first', scaler), (['second'], norm),
560+
make_column_transformer, (scaler, 'first'), (norm, ['second']),
542561
transformer_weights={'pca': 10, 'Transf': 1}
543562
)
544563

@@ -547,7 +566,7 @@ def test_make_column_transformer_remainder_transformer():
547566
scaler = StandardScaler()
548567
norm = Normalizer()
549568
remainder = StandardScaler()
550-
ct = make_column_transformer(('first', scaler), (['second'], norm),
569+
ct = make_column_transformer((scaler, 'first'), (norm, ['second']),
551570
remainder=remainder)
552571
assert ct.remainder == remainder
553572

@@ -757,7 +776,7 @@ def test_column_transformer_remainder():
757776
"or estimator.", ct.fit_transform, X_array)
758777

759778
# check default for make_column_transformer
760-
ct = make_column_transformer(([0], Trans()))
779+
ct = make_column_transformer((Trans(), [0]))
761780
assert ct.remainder == 'drop'
762781

763782

0 commit comments

Comments
 (0)
0