8000 FIX Removes validation in __init__ for Pipeline (#21888) · scikit-learn/scikit-learn@0110921 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0110921

Browse files
arisayoshiofallogrisel
authored
FIX Removes validation in __init__ for Pipeline (#21888)
Co-authored-by: arisayosh <15692997+arisayosh@users.noreply.github.com> Co-authored-by: iofall <50991099+iofall@users.noreply.github.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent d72bd02 commit 0110921

File tree

4 files changed

+16
-10
lines changed

4 files changed

+16
-10
lines changed

doc/whats_new/v1.1.rst

Copy file name to clipboard
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,10 @@ Changelog
266266
Setting a transformer to "passthrough" will pass the features unchanged.
267267
:pr:`20860` by :user:`Shubhraneel Pal <shubhraneel>`.
268268

269+
- |Fix| :class: `pipeline.Pipeline` now does not validate hyper-parameters in
270+
`__init__` but in `.fit()`.
271+
:pr:`21888` by :user:`iofall <iofall>` and :user: `Arisa Y. <arisayosh>`.
272+
269273
:mod:`sklearn.preprocessing`
270274
............................
271275

sklearn/pipeline.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ def __init__(self, steps, *, memory=None, verbose=False):
146146
self.steps = steps
147147
self.memory = memory
148148
self.verbose = verbose
149-
self._validate_steps()
150149

151150
def get_params(self, deep=True):
152151
"""Get parameters for this estimator.

sklearn/tests/test_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,6 @@ def test_transformers_get_feature_names_out(transformer):
428428
"FeatureUnion",
429429
"GridSearchCV",
430430
"HalvingGridSearchCV",
431-
"Pipeline",
432431
"SGDOneClassSVM",
433432
"TheilSenRegressor",
434433
"TweedieRegressor",

sklearn/tests/test_pipeline.py

Lines changed: 12 additions & 8 deletions
# three ways to make invalid:
Original file line numberDiff line numberDiff line change
@@ -165,20 +165,23 @@ def predict_log_proba(self, X, got_attribute=False):
165165
return self
166166

167167

168-
def test_pipeline_init():
169-
# Test the various init parameters of the pipeline.
168+
def test_pipeline_invalid_parameters():
169+
# Test the various init parameters of the pipeline in fit
170+
# method
171+
pipeline = Pipeline([(1, 1)])
170172
with pytest.raises(TypeError):
171-
Pipeline()
173+
pipeline.fit([[1]], [1])
172174

173-
# Check that we can't instantiate pipelines with objects without fit
175+
# Check that we can't fit pipelines with objects without fit
174176
# method
175177
msg = (
176178
"Last step of Pipeline should implement fit "
177179
"or be the string 'passthrough'"
178180
".*NoFit.*"
179181
)
182+
pipeline = Pipeline([("clf", NoFit())])
180183
with pytest.raises(TypeError, match=msg):
181-
Pipeline([("clf", NoFit())])
184+
pipeline.fit([[1]], [1])
182185

183186
# Smoke test with only an estimator
184187
clf = NoTrans()
@@ -203,11 +206,12 @@ def test_pipeline_init():
203206
assert pipe.named_steps["anova"] is filter1
204207
assert pipe.named_steps["svc"] is clf
205208

206-
# Check that we can't instantiate with non-transformers on the way
209+
# Check that we can't fit with non-transformers on the way
207210
# Note that NoTrans implements fit, but not transform
208211
msg = "All intermediate steps should be transformers.*\\bNoTrans\\b.*"
212+
pipeline = Pipeline([("t", NoTrans()), ("svc", clf)])
209213
with pytest.raises(TypeError, match=msg):
210-
Pipeline([("t", NoTrans()), ("svc", clf)])
214+
pipeline.fit([[1]], [1])
211215

212216
# Check that params are set
213217
pipe.set_params(svc__C=0.1)
@@ -1086,7 +1090,7 @@ def test_step_name_validation():
10861090
10871091
# - construction
10881092
with pytest.raises(ValueError, match=message):
1089-
cls(**{param: bad_steps})
1093+
cls(**{param: bad_steps}).fit([[1]], [1])
10901094

10911095
# - setattr
10921096
est = cls(**{param: [("a", Mult(1))]})

0 commit comments

Comments
 (0)
0