8000 More informative error message when set_params has invalid values (#2… · scikit-learn/scikit-learn@74016ab · GitHub
[go: up one dir, main page]

Skip to content

Commit 74016ab

Browse files
ogriseljeremiedbbjjerphan
authored
More informative error message when set_params has invalid values (#21542)
* More informative error message when set_params has invalid values * Update changelog * Typo [ci skip] Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com> * Typo [ci skip] Co-authored-by: Julien Jerphanion <git@jjerphan.xyz> * Simpler, more efficient code Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com> Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
1 parent 2a36ea7 commit 74016ab

File tree

3 files changed

+25
-11
lines changed

3 files changed

+25
-11
lines changed

doc/whats_new/v1.1.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ Changelog
4545
message suggests potential solutions.
4646
:pr:`21219` by :user:`Olivier Grisel <ogrisel>`.
4747

48+
- |Enhancement| All scikit-learn models now generate a more informative
49+
error message when setting invalid hyper-parameters with `set_params`.
50+
:pr:`21542` by :user:`Olivier Grisel <ogrisel>`.
51+
4852
:mod:`sklearn.calibration`
4953
..........................
5054

sklearn/base.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,7 @@ def get_params(self, deep=True):
212212
return out
213213

214214
def set_params(self, **params):
215-
"""
216-
Set the parameters of this estimator.
215+
"""Set the parameters of this estimator.
217216
218217
The method works on simple estimators as well as on nested objects
219218
(such as :class:`~sklearn.pipeline.Pipeline`). The latter have
@@ -239,10 +238,10 @@ def set_params(self, **params):
239238
for key, value in params.items():
240239
key, delim, sub_key = key.partition("__")
241240
if key not in valid_params:
241+
local_valid_params = self._get_param_names()
242242
raise ValueError(
243-
"Invalid parameter %s for estimator %s. "
244-
"Check the list of available parameters "
245-
"with `estimator.get_params().keys()`." % (key, self)
243+
f"Invalid parameter {key!r} for estimator {self}. "
244+
f"Valid parameters are: {local_valid_params!r}."
246245
)
247246

248247
if delim:

sklearn/tests/test_pipeline.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,10 @@ def test_pipeline_init():
216216
repr(pipe)
217217

218218
# Check that params are not set when naming them wrong
219-
msg = "Invalid parameter C for estimator SelectKBest"
219+
msg = re.escape(
220+
"Invalid parameter 'C' for estimator SelectKBest(). Valid parameters are: ['k',"
221+
" 'score_func']."
222+
)
220223
with pytest.raises(ValueError, match=msg):
221224
pipe.set_params(anova__C=0.1)
222225

@@ -316,18 +319,26 @@ def test_pipeline_raise_set_params_error():
316319

317320
# expected error message
318321
error_msg = re.escape(
319-
f"Invalid parameter fake for estimator {pipe}. "
320-
"Check the list of available parameters "
321-
"with `estimator.get_params().keys()`."
322+
"Invalid parameter 'fake' for estimator Pipeline(steps=[('cls',"
323+
" LinearRegression())]). Valid parameters are: ['memory', 'steps', 'verbose']."
322324
)
323-
324325
with pytest.raises(ValueError, match=error_msg):
325326
pipe.set_params(fake="nope")
326327

327-
# nested model check
328+
# invalid outer parameter name for compound parameter: the expected error message
329+
# is the same as above.
328330
with pytest.raises(ValueError, match=error_msg):
329331
pipe.set_params(fake__estimator="nope")
330332

333+
# expected error message for invalid inner parameter
334+
error_msg = re.escape(
335+
"Invalid parameter 'invalid_param' for estimator LinearRegression(). Valid"
336+
" parameters are: ['copy_X', 'fit_intercept', 'n_jobs', 'normalize',"
337+
" 'positive']."
338+
)
339+
with pytest.raises(ValueError, match=error_msg):
340+
pipe.set_params(cls__invalid_param="nope")
341+
331342

332343
def test_pipeline_methods_pca_svm():
333344
# Test the various methods of the pipeline (pca + svm).

0 commit comments

Comments
 (0)
0