diff --git a/sklearn/base.py b/sklearn/base.py index d97fe92ccdd47..b653b7149c373 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -5,6 +5,7 @@ import copy import warnings +from collections import defaultdict import numpy as np from scipy import sparse @@ -248,26 +249,24 @@ def set_params(self, **params): # Simple optimization to gain speed (inspect is slow) return self valid_params = self.get_params(deep=True) - for key, value in six.iteritems(params): - split = key.split('__', 1) - if len(split) > 1: - # nested objects case - name, sub_name = split - if name not in valid_params: - raise ValueError('Invalid parameter %s for estimator %s. ' - 'Check the list of available parameters ' - 'with `estimator.get_params().keys()`.' % - (name, self)) - sub_object = valid_params[name] - sub_object.set_params(**{sub_name: value}) + + nested_params = defaultdict(dict) # grouped by prefix + for key, value in params.items(): + key, delim, sub_key = key.partition('__') + if key not in valid_params: + raise ValueError('Invalid parameter %s for estimator %s. ' + 'Check the list of available parameters ' + 'with `estimator.get_params().keys()`.' % + (key, self)) + + if delim: + nested_params[key][sub_key] = value else: - # simple objects case - if key not in valid_params: - raise ValueError('Invalid parameter %s for estimator %s. ' - 'Check the list of available parameters ' - 'with `estimator.get_params().keys()`.' % - (key, self.__class__.__name__)) setattr(self, key, value) + + for key, sub_params in nested_params.items(): + valid_params[key].set_params(**sub_params) + return self def __repr__(self): diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index 7ad0f20382657..580a4e2ecac9f 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -228,6 +228,24 @@ def test_set_params(): # bad__stupid_param=True) +def test_set_params_passes_all_parameters(): + # Make sure all parameters are passed together to set_params + # of nested estimator. Regression test for #9944 + + class TestDecisionTree(DecisionTreeClassifier): + def set_params(self, **kwargs): + super(TestDecisionTree, self).set_params(**kwargs) + # expected_kwargs is in test scope + assert kwargs == expected_kwargs + return self + + expected_kwargs = {'max_depth': 5, 'min_samples_leaf': 2} + for est in [Pipeline([('estimator', TestDecisionTree())]), + GridSearchCV(TestDecisionTree(), {})]: + est.set_params(estimator__max_depth=5, + estimator__min_samples_leaf=2) + + def test_score_sample_weight(): rng = np.random.RandomState(0) diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index d1d62f80e51a5..ab2108ed690f2 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -24,10 +24,11 @@ from sklearn.base import clone, BaseEstimator from sklearn.pipeline import Pipeline, FeatureUnion, make_pipeline, make_union from sklearn.svm import SVC -from sklearn.linear_model import LogisticRegression +from sklearn.linear_model import LogisticRegression, Lasso from sklearn.linear_model import LinearRegression from sklearn.cluster import KMeans from sklearn.feature_selection import SelectKBest, f_classif +from sklearn.dummy import DummyRegressor from sklearn.decomposition import PCA, TruncatedSVD from sklearn.datasets import load_iris from sklearn.preprocessing import StandardScaler @@ -289,7 +290,7 @@ def test_pipeline_raise_set_params_error(): 'with `estimator.get_params().keys()`.') assert_raise_message(ValueError, - error_msg % ('fake', 'Pipeline'), + error_msg % ('fake', pipe), pipe.set_params, fake='nope') @@ -863,6 +864,16 @@ def test_step_name_validation(): [[1]], [1]) +def test_set_params_nested_pipeline(): + estimator = Pipeline([ + ('a', Pipeline([ + ('b', DummyRegressor()) + ])) + ]) + estimator.set_params(a__b__alpha=0.001, a__b=Lasso()) + estimator.set_params(a__steps=[('b', LogisticRegression())], a__b__C=5) + + def test_pipeline_wrong_memory(): # Test that an error is raised when memory is not a string or a Memory # instance