8000 [MRG] Allow only fit_transform to be present in pipeline by jnboehm · Pull Request #16714 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] Allow only fit_transform to be present in pipeline #16714

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 43 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
8a71ded
Allow only fit_transform to be present in pipeline
Mar 17, 2020
b973844
Fix linter error
Mar 17, 2020
774c10c
Check final estimator
Mar 17, 2020
3dff448
Fix faultly if condition
Mar 17, 2020
af07b9d
Fix variable typo
Mar 17, 2020
bd9232b
Add a test case for the new pipeline construction
Mar 17, 2020
5ceb7e2
Expand test cases with warning and error
Mar 17, 2020
3386e76
Fix pep8 style issue
Mar 17, 2020
2a295cf
Add blank lines
Mar 17, 2020
de0361c
Fix variable names in typos
Mar 17, 2020
81cc2c7
Fix conditional for pipeline error check
Mar 18, 2020
531250d
Fix conditional for feature union error check
Mar 18, 2020
90d3feb
Fix raised error messages
Mar 18, 2020
9eceb01
Fix linting
Mar 18, 2020
d707747
Improve warning for reusable pipe warning
Mar 18, 2020
0696f46
Remove last check in pipeline for final estimator
Mar 18, 2020
531c73f
Remove unnecessary check for fit_transform attr
Mar 18, 2020
1fa01a0
Fix over-indented line
Mar 18, 2020
370daa8
Revert changes to FeatureUnion
Mar 18, 2020
e4a5eef
Fix erroneous reverting reverting
Mar 18, 2020
a849b1d
Modify warning message
Mar 19, 2020
5a673a9
Remove unnecessary print params
Mar 19, 2020
8e32dce
Match literal strings in pytest.warns and raises
Mar 29, 2020
2c663d0
Pass the estimator to the pipeline
Mar 29, 2020
40bf2d4
Add dummy transformer to trigger correct warning
Mar 29, 2020
9ef3868
Move the correct code line into the pytest context
Mar 29, 2020
bbe900a
Fix error raising for degenerate pipeline
Mar 29, 2020
707cd3e
Remove a blank line
Mar 30, 2020
6ead4a0
Change unnecessary code format back to original
Mar 30, 2020
9db15a2
Expand FeatureUnion warning
jnboehm Apr 2, 2020
e25bf89
Fix typo
jnboehm Apr 2, 2020
69e9e83
Add entry to the changelog
jnboehm Apr 28, 2020
d188cdd
Remove change in FeatureUnion
jnboehm Apr 28, 2020
232a26b
Delay warning test until fit_transform is called
jnboehm Apr 28, 2020
667e744
Add test for nonreusable pipe with fit_predict
jnboehm Apr 28, 2020
d296366
Fix linting error
jnboehm Apr 28, 2020
91b8c48
Accept y param for dummy fit_transform class
jnboehm Apr 28, 2020
8ed1a93
Add missing point
jnboehm Apr 28, 2020
24c8a2b
Merge remote-tracking branch 'origin/master' into pipeline/fit_transform
jnboehm Apr 28, 2020
65538f0
Remove unncessary call to fit()
jnboehm Apr 28, 2020
5bf012e
Fix type in test case for warning
jnboehm Apr 28, 2020
2496865
Empty commit to restart CI pipeline
jnboehm Apr 29, 2020
d4e5de6
Merge branch 'master' into pipeline/fit_transform
jnboehm May 16, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/whats_new/v0.23.rst
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,14 @@ Changelog
- |FIX| :func:`utils.all_estimators` now only returns public estimators.
:pr:`15380` by `Thomas Fan`_.

:mod:`sklearn.pipeline`
.......................

- |Fix| :class:`pipeline.Pipeline` can now use estimators that only
expose `fit_transform` but not a separate `fit` and `transform`,
such as `manifold.TSNE`.
:pr:`16714` by :user:`Niklas Böhm <jnboehm>`.

Miscellaneous
.............

Expand Down
19 changes: 12 additions & 7 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,19 +154,24 @@ def _validate_steps(self):
for t in transformers:
if t is None or t == 'passthrough':
continue
if (not (hasattr(t, "fit") or hasattr(t, "fit_transform")) or not
hasattr(t, "transform")):
if (not hasattr(t, "fit_transform") and
(not (hasattr(t, "fit") and hasattr(t, "transform")))):
raise TypeError("All intermediate steps should be "
"transformers and implement fit and transform "
"or be the string 'passthrough' "
"'%s' (type %s) doesn't" % (t, type(t)))
"transformers and implement fit and "
"transform, fit_transform or be the string "
"'passthrough'. '%s' (type %s) doesn't"
% (t, type(t)))
elif not hasattr(t, "transform"):
warnings.warn("Intermediate step '%s' (type %s) does not have "
"transform, pipeline is not reusable on "
"test data." % (t, type(t)))

# We allow last estimator to be None as an identity transformation
if (estimator is not None and estimator != 'passthrough'
and not hasattr(estimator, "fit")):
raise TypeError(
"Last step of Pipeline should implement fit "
"or be the string 'passthrough'. "
"Last step of Pipeline should implement fit, "
"fit_transform or be the string 'passthrough'. "
"'%s' (type %s) doesn't" % (estimator, type(estimator)))

def _iter(self, with_final=True, filter_passthrough=True):
Expand Down
64 changes: 61 additions & 3 deletions sklearn/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.dummy import DummyRegressor
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.manifold import TSNE
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.feature_extraction.text import CountVectorizer
Expand Down Expand Up @@ -71,6 +72,11 @@ def set_params(self, **params):
return self


class OnlyFitTrans(NoFit):
def fit_transform(self, X, y=None):
return X


class NoInvTransf(NoTrans):
def transform(self, X):
return X
Expand Down Expand Up @@ -164,8 +170,8 @@ def test_pipeline_init():
# Check that we can't instantiate pipelines with objects without fit
# method
assert_raises_regex(TypeError,
'Last step of Pipeline should implement fit '
'or be the string \'passthrough\''
'Last step of Pipeline should implement fit, '
'fit_transform or be the string \'passthrough\''
'.*NoFit.*',
Pipeline, [('clf', NoFit())])
# Smoke test with only an estimator
Expand Down Expand Up @@ -393,6 +399,58 @@ def test_pipeline_methods_preprocessing_svm():
pipe.score(X, y)


def test_pipeline_methods_pca_tsne():
# test that only fit_transform needs to be present in order to
# run a pipeline with fit_transform.
# Don't require transform to be present, explicitly.
pca = PCA(n_components=2, random_state=0)
tsne = TSNE(random_state=0)
separate_emb = tsne.fit_transform(pca.fit_transform(iris.data))

pca_for_pipeline = PCA(n_components=2, random_state=0)
tsne_for_pipeline = TSNE(random_state=0)
msg = ("Intermediate step '%s' (type %s) does not have "
"transform, pipeline is not reusable on test data."
% (tsne_for_pipeline, type(tsne_for_pipeline)))

pipe = make_pipeline(pca_for_pipeline, tsne_for_pipeline,
'passthrough')

with pytest.warns(UserWarning, match=re.escape(msg)):
pipeline_emb = pipe.fit_transform(iris.data)

assert_array_almost_equal(pipeline_emb, separate_emb)

error_estimator = NoTrans()
msg = ("All intermediate steps should be "
"transformers and implement fit and "
"transform, fit_transform or be the string "
"'passthrough'. '%s' (type %s) doesn't"
% (error_estimator, type(error_estimator)))
with pytest.raises(TypeError, match=re.escape(msg)):
make_pipeline(error_estimator, 'passthrough')


def test_fit_predict_on_nonreusable_pipeline():
oft = OnlyFitTrans()

km = KMeans(random_state=0)
km_for_pipeline = KMeans(random_state=0)

separate_pred = km.fit_predict(iris.data)

pipe = make_pipeline(oft, km_for_pipeline)

msg = ("Intermediate step '%s' (type %s) does not have "
"transform, pipeline is not reusable on test data."
% (oft, type(oft)))

with pytest.warns(UserWarning, match=re.escape(msg)):
pipeline_pred = pipe.fit_predict(iris.data)

assert_array_almost_equal(pipeline_pred, separate_pred)


def test_fit_predict_on_pipeline():
# test that the fit_predict method is implemented on a pipeline
# test that the fit_predict on pipeline yields same results as applying
Expand Down Expand Up @@ -492,7 +550,7 @@ def test_feature_union():
# test error if some elements do not support transform
assert_raises_regex(TypeError,
'All estimators should implement fit and '
'transform.*\\bNoTrans\\b',
'transform.*\\bNoTrans\\b.*',
FeatureUnion,
[("transform", Transf()), ("no_transform", NoTrans())])

Expand Down
0