8000 [MRG+1] fix _BaseComposition._set_params with nested parameters by amueller · Pull Request #9945 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+1] fix _BaseComposition._set_params with nested parameters #9945

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 18, 2017
35 changes: 17 additions & 18 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import copy
import warnings
from collections import defaultdict

import numpy as np
from scipy import sparse
Expand Down Expand Up @@ -248,26 +249,24 @@ def set_params(self, **params):
# Simple optimization to gain speed (inspect is slow)
return self
valid_params = self.get_params(deep=True)
for key, value in six.iteritems(params):
split = key.split('__', 1)
if len(split) > 1:
# nested objects case
name, sub_name = split
if name not in valid_params:
raise ValueError('Invalid parameter %s for estimator %s. '
'Check the list of available parameters '
'with `estimator.get_params().keys()`.' %
(name, self))
sub_object = valid_params[name]
sub_object.set_params(**{sub_name: value})

nested_params = defaultdict(dict) # grouped by prefix
for key, value in params.items():
key, delim, sub_key = key.partition('__')
if key not in valid_params:
raise ValueError('Invalid parameter %s for estimator %s. '
'Check the list of available parameters '
'with `estimator.get_params().keys()`.' %
(key, self))

if delim:
nested_params[key][sub_key] = value
else:
# simple objects case
if key not in valid_params:
raise ValueError('Invalid parameter %s for estimator %s. '
'Check the list of available parameters '
'with `estimator.get_params().keys()`.' %
(key, self.__class__.__name__))
setattr(self, key, value)

for key, sub_params in nested_params.items():
valid_params[key].set_params(**sub_params)

return self

def __repr__(self):
Expand Down
18 changes: 18 additions & 0 deletions sklearn/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,24 @@ def test_set_params():
# bad__stupid_param=True)


def test_set_params_passes_all_parameters():
# Make sure all parameters are passed together to set_params
# of nested estimator. Regression test for #9944

class TestDecisionTree(DecisionTreeClassifier):
def set_params(self, **kwargs):
super(TestDecisionTree, self).set_params(**kwargs)
# expected_kwargs is in test scope
assert kwargs == expected_kwargs
return self

expected_kwargs = {'max_depth': 5, 'min_samples_leaf': 2}
for est in [Pipeline([('estimator', TestDecisionTree())]),
GridSearchCV(TestDecisionTree(), {})]:
est.set_params(estimator__max_depth=5,
estimator__min_samples_leaf=2)


def test_score_sample_weight():

rng = np.random.RandomState(0)
Expand Down
15 changes: 13 additions & 2 deletions sklearn/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
from sklearn.base import clone, BaseEstimator
from sklearn.pipeline import Pipeline, FeatureUnion, make_pipeline, make_union
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import LogisticRegression, Lasso
from sklearn.linear_model import LinearRegression
from sklearn.cluster import KMeans
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.dummy import DummyRegressor
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
Expand Down Expand Up @@ -289,7 +290,7 @@ def test_pipeline_raise_set_params_error():
'with `estimator.get_params().keys()`.')

assert_raise_message(ValueError,
error_msg % ('fake', 'Pipeline'),
error_msg % ('fake', pipe),
pipe.set_params,
fake='nope')

Expand Down Expand Up @@ -863,6 +864,16 @@ def test_step_name_validation():
[[1]], [1])


def test_set_params_nested_pipeline():
estimator = Pipeline([
('a', Pipeline([
('b', DummyRegressor())
]))
])
estimator.set_params(a__b__alpha=0.001, a__b=Lasso())
estimator.set_params(a__steps=[('b', LogisticRegression())], a__b__C=5)


def test_pipeline_wrong_memory():
# Test that an error is raised when memory is not a string or a Memory
# instance
Expand Down
0