8000 FIX Removes validation in __init__ for Pipeline by arisayosh · Pull Request #21888 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FIX Removes validation in __init__ for Pipeline #21888

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,10 @@ Changelog
Setting a transformer to "passthrough" will pass the features unchanged.
:pr:`20860` by :user:`Shubhraneel Pal <shubhraneel>`.

- |Fix| :class: `pipeline.Pipeline` now does not validate hyper-parameters in
`__init__` but in `.fit()`.
:pr:`21888` by :user:`iofall <iofall>` and :user: `Arisa Y. <arisayosh>`.

:mod:`sklearn.preprocessing`
............................

Expand Down
1 change: 0 additions & 1 deletion sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def __init__(self, steps, *, memory=None, verbose=False):
self.steps = steps
self.memory = memory
self.verbose = verbose
self._validate_steps()

def get_params(self, deep=True):
"""Get parameters for this estimator.
Expand Down
1 change: 0 additions & 1 deletion sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,6 @@ def test_transformers_get_feature_names_out(transformer):
"FeatureUnion",
"GridSearchCV",
"HalvingGridSearchCV",
"Pipeline",
"SGDOneClassSVM",
"TheilSenRegressor",
"TweedieRegressor",
Expand Down
20 changes: 12 additions & 8 deletions sklearn/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,20 +165,23 @@ def predict_log_proba(self, X, got_attribute=False):
return self


def test_pipeline_init():
# Test the various init parameters of the pipeline.
def test_pipeline_invalid_parameters():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed function since we are not validating parameters during instantiation. We decided not to split the function in two because it caused code repetition.

# Test the various init parameters of the pipeline in fit
# method
pipeline = Pipeline([(1, 1)])
with pytest.raises(TypeError):
Pipeline()
pipeline.fit([[1]], [1])

# Check that we can't instantiate pipelines with objects without fit
# Check that we can't fit pipelines with objects without fit
# method
msg = (
"Last step of Pipeline should implement fit "
"or be the string 'passthrough'"
".*NoFit.*"
)
pipeline = Pipeline([("clf", NoFit())])
with pytest.raises(TypeError, match=msg):
Pipeline([("clf", NoFit())])
pipeline.fit([[1]], [1])

# Smoke test with only an estimator
clf = NoTrans()
Expand All @@ -203,11 +206,12 @@ def test_pipeline_init():
assert pipe.named_steps["anova"] is filter1
assert pipe.named_steps["svc"] is clf

# Check that we can't instantiate with non-transformers on the way
# Check that we can't fit with non-transformers on the way
# Note that NoTrans implements fit, but not transform
msg = "All intermediate steps should be transformers.*\\bNoTrans\\b.*"
pipeline = Pipeline([("t", NoTrans()), ("svc", clf)])
with pytest.raises(TypeError, match=msg):
Pipeline([("t", NoTrans()), ("svc", clf)])
pipeline.fit([[1]], [1])

# Check that params are set
pipe.set_params(svc__C=0.1)
Expand Down Expand Up @@ -1086,7 +1090,7 @@ def test_step_name_validation():
# three ways to make invalid:
# - construction
with pytest.raises(ValueError, match=message):
cls(**{param: bad_steps})
cls(**{param: bad_steps}).fit([[1]], [1])

# - setattr
est = cls(**{param: [("a", Mult(1))]})
Expand Down
0