8000 ENH Smoke test for invalid parameters in __init__ and get_params (#21… · scikit-learn/scikit-learn@6e7fc0e · GitHub
[go: up one dir, main page]

Skip to content

Commit 6e7fc0e

Browse files
authored
ENH Smoke test for invalid parameters in __init__ and get_params (#21355)
1 parent 1a7eec8 commit 6e7fc0e

File tree

2 files changed

+64
-8
lines changed

2 files changed

+64
-8
lines changed

sklearn/tests/test_common.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from functools import partial
1717

1818
import pytest
19+
import numpy as np
1920

2021
from sklearn.utils import all_estimators
2122
from sklearn.utils._testing import ignore_warnings
@@ -402,3 +403,47 @@ def test_transformers_get_feature_names_out(transformer):
402403
check_transformer_get_feature_names_out_pandas(
403404
transformer.__class__.__name__, transformer
404405
)
406+
407+
408+
VALIDATE_ESTIMATOR_INIT = [
409+
"ColumnTransformer",
410+
"FactorAnalysis",
411+
"FastICA",
412+
"FeatureHasher",
413+
"FeatureUnion",
414+
"GridSearchCV",
415+
"HalvingGridSearchCV",
416+
"KernelDensity",
417+
"KernelPCA",
418+
"LabelBinarizer",
419+
"NuSVC",
420+
"NuSVR",
421+
"OneClassSVM",
422+
"Pipeline",
423+
"RadiusNeighborsClassifier",
424+
"SGDOneClassSVM",
425+
"SVC",
426+
"SVR",
427+
"TheilSenRegressor",
428+
"TweedieRegressor",
429+
]
430+
VALIDATE_ESTIMATOR_INIT = set(VALIDATE_ESTIMATOR_INIT)
431+
432+
433+
@pytest.mark.parametrize(
434+
"Estimator",
435+
[est for name, est in all_estimators() if name not in VALIDATE_ESTIMATOR_INIT],
436+
)
437+
def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
438+
"""Check that init or set_param does not raise errors."""
439+
params = signature(Estimator).parameters
440+
441+
smoke_test_values = [-1, 3.0, "helloworld", np.array([1.0, 4.0]), {}, []]
442+
for value in smoke_test_values:
443+
new_params = {key: value for key in params}
444+
445+
# Does not raise
446+
est = Estimator(**new_params)
447+
448+
# Also do does not raise
449+
est.set_params(**new_params)

sklearn/utils/metaestimators.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,18 @@ def _get_params(self, attr, deep=True):
2929
out = super().get_params(deep=deep)
3030
if not deep:
3131
return out
32+
3233
estimators = getattr(self, attr)
33-
out.update(estimators)
34+
try:
35+
out.update(estimators)
36+
except (TypeError, ValueError):
37+
# Ignore TypeError for cases where estimators is not a list of
38+
# (name, estimator) and ignore ValueError when the list is not
39+
# formated correctly. This is to prevent errors when calling
40+
# `set_params`. `BaseEstimator.set_params` calls `get_params` which
41+
# can error for invalid values for `estimators`.
42+
return out
43+
3444
for name, estimator in estimators:
3545
if hasattr(estimator, "get_params"):
3646
for key, value in estimator.get_params(deep=True).items():
@@ -42,14 +52,15 @@ def _set_params(self, attr, **params):
4252
# 1. All steps
4353
if attr in params:
4454
setattr(self, attr, params.pop(attr))
45-
# 2. Step replacement
55+
# 2. Replace items with estimators in params
4656
items = getattr(self, attr)
47-
names = []
48-
if items:
49-
names, _ = zip(*items)
50-
for name in list(params.keys()):
51-
if "__" not in name and name in names:
52-
self._replace_estimator(attr, name, params.pop(name))
57+
if isinstance(items, list) and items:
58+
# Get item names used to identify valid names in params
59+
item_names, _ = zip(*items)
60+
for name in list(params.keys()):
61+
if "__" not in name and name in item_names:
62+
self._replace_estimator(attr, name, params.pop(name))
63+
5364
# 3. Step parameters and other initialisation arguments
5465
super().set_params(**params)
5566
return self

0 commit comments

Comments
 (0)
0