8000 FIX: Allow `pipeline` to handle `fit` and `transform` and allows adding `tSNE` by ParthSolanki1 · Pull Request #3 · navn-r/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FIX: Allow pipeline to handle fit and transform and allows adding tSNE #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from collections import defaultdict
from itertools import islice

import warnings
import numpy as np
from scipy import sparse
from joblib import Parallel
Expand Down Expand Up @@ -201,15 +201,19 @@ def _validate_steps(self):
for t in transformers:
if t is None or t == "passthrough":
continue
if not (hasattr(t, "fit") or hasattr(t, "fit_transform")) or not hasattr(
t, "transform"
if not (hasattr(t, "fit") and hasattr(t, "transform")) and not hasattr(
t, "fit_transform"
):
raise TypeError(
"All intermediate steps should be "
"transformers and implement fit and transform "
"or be the string 'passthrough' "
"transformers and implement fit and transform, "
"fit_transform, or be the string 'passthrough' "
"'%s' (type %s) doesn't" % (t, type(t))
)
elif not hasattr(t, "transform"):
warnings.warn("Intermediate step '%s' (type %s) does not have "
"transform, pipeline is not reusable on "
"test data." % (t, type(t)))

# We allow last estimator to be None as an identity transformation
if (
Expand All @@ -218,8 +222,8 @@ def _validate_steps(self):
and not hasattr(estimator, "fit")
):
raise TypeError(
"Last step of Pipeline should implement fit "
"or be the string 'passthrough'. "
"Last step of Pipeline should implement fit, "
"fit_transform, or be the string 'passthrough'. "
"'%s' (type %s) doesn't" % (estimator, type(estimator))
)

Expand Down
42 changes: 34 additions & 8 deletions sklearn/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.impute import SimpleImputer
from sklearn.manifold import TSNE

iris = load_iris()

Expand Down Expand Up @@ -171,14 +172,16 @@ def test_pipeline_invalid_parameters():
with pytest.raises(TypeError):
pipeline.fit([[1]], [1])

nofit = NoFit()
# Check that we can't fit pipelines with objects without fit
# method
msg = (
"Last step of Pipeline should implement fit "
"or be the string 'passthrough'"
".*NoFit.*"
msg = re.escape(
"Last step of Pipeline should implement fit, "
"fit_transform, or be the string 'passthrough'. "
"'%s' (type %s) doesn't" % (nofit, type(nofit))
)
pipeline = Pipeline([("clf", NoFit())])

pipeline = Pipeline([("clf", nofit)])
with pytest.raises(TypeError, match=msg):
pipeline.fit([[1]], [1])

Expand Down Expand Up @@ -653,9 +656,8 @@ def test_set_pipeline_steps():

# With invalid data
pipeline.set_params(steps=[("junk", ())])
msg = re.escape(
"Last step of Pipeline should implement fit or be the string 'passthrough'."
)
msg = re.escape("Last step of Pipeline should implement fit, fit_transform, or be the string 'passthrough'. '()' (type <class 'tuple'>) doesn't")

with pytest.raises(TypeError, match=msg):
pipeline.fit([[1]], [1])

Expand Down Expand Up @@ -1545,3 +1547,27 @@ def get_feature_names_out(self, input_features=None):
feature_names_out = pipe.get_feature_names_out(input_names)

assert_array_equal(feature_names_out, [f"my_prefix_{name}" for name in input_names])

def test_pipeline_tscne():
"""Check that TSNE can be added to pipelines with warnings

Test for #16710
"""

pca = PCA(n_components=2, random_state=0)
tsne = TSNE(random_state=0)
iris = load_iris()
res = tsne.fit_transform(pca.fit_transform(iris.data))

pca_for_pipeline = PCA(n_components=2, random_state=0)
tsne_for_pipeline = TSNE(random_state=0)
msg = ("Intermediate step '%s' (type %s) does not have "
"transform, pipeline is not reusable on test data."
% (tsne_for_pipeline, type(tsne_for_pipeline)))

pipe = make_pipeline(pca_for_pipeline, tsne_for_pipeline, 'passthrough')

with pytest.warns(UserWarning, match=re.escape(msg)):
pipeline_emb = pipe.fit_transform(iris.data)

assert_array_almost_equal(pipeline_emb, res)
0