8000 moved functionality from get_kwargs to get_optimal_parameters · scikit-learn/scikit-learn@8f1ec1c · GitHub
[go: up one dir, main page]

Skip to content

Commit 8f1ec1c

Browse files
moved functionality from get_kwargs to get_optimal_parameters
1 parent 8d28c32 commit 8f1ec1c

File tree

1 file changed

+36
-42
lines changed

1 file changed

+36
-42
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 36 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -212,15 +212,9 @@ def _boston_subset(n_samples=200):
212212
BOSTON = X, y
213213
return BOSTON
214214

215-
def get_kwargs(estimator_class):
216-
"Get special kwargs that might be required for an estimator."
217215

218-
if issubclass(estimator_class, BaseSVC):
219-
return {'decision_function_shape':'ovo'}
220-
return {}
221-
222-
def set_fast_parameters(estimator):
223-
# speed up some estimators
216+
def set_optimal_parameters(estimator):
217+
# speed up some estimators and avoid deprecated behaviour
224218
params = estimator.get_params()
225219
if ("n_iter" in params
226220
and estimator.__class__.__name__ != "TSNE"):
@@ -247,6 +241,9 @@ def set_fast_parameters(estimator):
247241
if "n_init" in params:
248242
# K-Means
249243
estimator.set_params(n_init=2)
244+
if "decision_function_shape" in params:
245+
# SVC
246+
estimator.set_params(decision_function_shape='ovo')
250247

251248
if estimator.__class__.__name__ == "SelectFdr":
252249
# be tolerant of noisy datasets (not actually speed)
@@ -297,7 +294,7 @@ def check_estimator_sparse_data(name, Estimator):
297294
estimator = Estimator(with_mean=False)
298295
else:
299296
estimator = Estimator()
300-
set_fast_parameters(estimator)
297+
set_optimal_parameters(estimator)
301298
# fit and predict
302299
try:
303300
estimator.fit(X, y)
@@ -329,7 +326,7 @@ def check_dtype_object(name, Estimator):
329326
y = multioutput_estimator_convert_y_2d(name, y)
330327
with warnings.catch_warnings():
331328
estimator = Estimator()
332-
set_fast_parameters(estimator)
329+
set_optimal_parameters(estimator)
333330

334331
estimator.fit(X, y)
335332
if hasattr(estimator, "predict"):
@@ -356,9 +353,8 @@ def check_fit2d_predict1d(name, Estimator):
356353
X = 3 * rnd.uniform(size=(20, 3))
357354
y = X[:, 0].astype(np.int)
358355
y = multioutput_estimator_convert_y_2d(name, y)
359-
kwargs = get_kwargs(Estimator)
360-
estimator = Estimator(**kwargs)
361-
set_fast_parameters(estimator)
356+
estimator = Estimator()
357+
set_optimal_parameters(estimator)
362358

363359
if hasattr(estimator, "n_components"):
364360
estimator.n_components = 1
@@ -386,7 +382,7 @@ def check_fit2d_1sample(name, Estimator):
386382
y = X[:, 0].astype(np.int)
387383
y = multioutput_estimator_convert_y_2d(name, y)
388384
estimator = Estimator()
389-
set_fast_parameters(estimator)
385+
set_optimal_parameters(estimator)
390386

391387
if hasattr(estimator, "n_components"):
392388
estimator.n_components = 1
@@ -408,7 +404,7 @@ def check_fit2d_1feature(name, Estimator):
408404
y = X[:, 0].astype(np.int)
409405
y = multioutput_estimator_convert_y_2d(name, y)
410406
estimator = Estimator()
411-
set_fast_parameters(estimator)
407+
set_optimal_parameters(estimator)
412408

413409
if hasattr(estimator, "n_components"):
414410
estimator.n_components = 1
@@ -430,7 +426,7 @@ def check_fit1d_1feature(name, Estimator):
430426
y = X.astype(np.int)
431427
y = multioutput_estimator_convert_y_2d(name, y)
432428
estimator = Estimator()
433-
set_fast_parameters(estimator)
429+
set_optimal_parameters(estimator)
434430

435431
if hasattr(estimator, "n_components"):
436432
estimator.n_components = 1
@@ -453,7 +449,7 @@ def check_fit1d_1sample(name, Estimator):
453449
y = np.array([1])
454450
y = multioutput_estimator_convert_y_2d(name, y)
455451
estimator = Estimator()
456-
set_fast_parameters(estimator)
452+
set_optimal_parameters(estimator)
457453

458454
if hasattr(estimator, "n_components"):
459455
estimator.n_components = 1
@@ -512,7 +508,7 @@ def _check_transformer(name, Transformer, X, y):
512508
with warnings.catch_warnings(record=True):
513509
transformer = Transformer()
514510
set_random_state(transformer)
515-
set_fast_parameters(transformer)
511+
set_optimal_parameters(transformer)
516512

517513
# fit
518514

@@ -583,7 +579,7 @@ def check_pipeline_consistency(name, Estimator):
583579
X -= X.min()
584580
y = multioutput_estimator_convert_y_2d(name, y)
585581
estimator = Estimator()
586-
set_fast_parameters(estimator)
582+
set_optimal_parameters(estimator)
587583
set_random_state(estimator)
588584
pipeline = make_pipeline(estimator)
589585
estimator.fit(X, y)
@@ -607,7 +603,7 @@ def check_fit_score_takes_y(name, Estimator):
607603
y = np.arange(10) % 3
608604
y = multioutput_estimator_convert_y_2d(name, y)
609605
estimator = Estimator()
610-
set_fast_parameters(estimator)
606+
set_optimal_parameters(estimator)
611607
set_random_state(estimator)
612608
funcs = ["fit", "score", "partial_fit", "fit_predict", "fit_transform"]
613609

@@ -633,9 +629,8 @@ def check_estimators_dtypes(name, Estimator):
633629
y = multioutput_estimator_convert_y_2d(name, y)
634630
for X_train in [X_train_32, X_train_64, X_train_int_64, X_train_int_32]:
635631
with warnings.catch_warnings(record=True):
636-
kwargs = get_kwargs(Estimator)
637-
estimator = Estimator(**kwargs)
638-
set_fast_parameters(estimator)
632+
estimator = Estimator()
633+
set_optimal_parameters(estimator)
639634
set_random_state(estimator, 1)
640635
estimator.fit(X_train, y)
641636

@@ -647,7 +642,7 @@ def check_estimators_dtypes(name, Estimator):
647642

648643
def check_estimators_empty_data_messages(name, Estimator):
649644
e = Estimator()
650-
set_fast_parameters(e)
645+
set_optimal_parameters(e)
651646
set_random_state(e, 1)
652647

653648
X_zero_samples = np.empty(0).reshape(0, 3)
@@ -682,7 +677,7 @@ def check_estimators_nan_inf(name, Estimator):
682677
# catch deprecation warnings
683678
with warnings.catch_warnings(record=True):
684679
estimator = Estimator()
685-
set_fast_parameters(estimator)
680+
set_optimal_parameters(estimator)
686681
set_random_state(estimator, 1)
687682
# try to fit
688683
try:
@@ -751,7 +746,7 @@ def check_estimators_pickle(name, Estimator):
751746
estimator = Estimator()
752747

753748
set_random_state(estimator)
754-
set_fast_parameters(estimator)
749+
set_optimal_parameters(estimator)
755750
estimator.fit(X, y)
756751

757752
result = dict()
@@ -776,7 +771,7 @@ def check_estimators_partial_fit_n_features(name, Alg):
776771
X -= X.min()
777772
with warnings.catch_warnings(record=True):
778773
alg = Alg()
779-
set_fast_parameters(alg)
774+
set_optimal_parameters(alg)
780775
if isinstance(alg, ClassifierMixin):
781776
classes = np.unique(y)
782777
alg.partial_fit(X, y, classes=classes)
@@ -794,7 +789,7 @@ def check_clustering(name, Alg):
794789
# catch deprecation and neighbors warnings
795790
with warnings.catch_warnings(record=True):
796791
alg = Alg()
797-
set_fast_parameters(alg)
792+
set_optimal_parameters(alg)
798793
if hasattr(alg, "n_clusters"):
799794
alg.set_params(n_clusters=3)
800795
set_random_state(alg)
@@ -847,7 +842,7 @@ def check_classifiers_one_label(name, Classifier):
847842
# catch deprecation warnings
848843
with warnings.catch_warnings(record=True):
849844
classifier = Classifier()
850-
set_fast_parameters(classifier)
845+
set_optimal_parameters(classifier)
851846
# try to fit
852847
try:
853848
classifier.fit(X_train, y)
@@ -883,11 +878,10 @@ def check_classifiers_train(name, Classifier):
883878
n_classes = len(classes)
884879
n_samples, n_features = X.shape
885880
with warnings.catch_warnings(record=True):
886-
kwargs = get_kwargs(Classifier)
887-
classifier = Classifier(**kwargs)
881+
classifier = Classifier()
888882
if name in ['BernoulliNB', 'MultinomialNB']:
889883
X -= X.min()
890-
set_fast_parameters(classifier)
884+
set_optimal_parameters(classifier)
891885
set_random_state(classifier)
892886
# raises error on malformed input for fit
893887
assert_raises(ValueError, classifier.fit, X, y[:-1])
@@ -950,7 +944,7 @@ def check_estimators_fit_returns_self(name, Estimator):
950944

951945
estimator = Estimator()
952946

953-
set_fast_parameters(estimator)
947+
set_optimal_parameters(estimator)
954948
set_random_state(estimator)
955949

956950
assert_true(estimator.fit(X, y) is estimator)
@@ -999,7 +993,7 @@ def check_supervised_y_2d(name, Estimator):
999993
# catch deprecation warnings
1000994
with warnings.catch_warnings(record=True):
1001995
estimator = Estimator()
1002-
set_fast_parameters(estimator)
996+
set_optimal_parameters(estimator)
1003997
set_random_state(estimator)
1004998
# fit
1005999
estimator.fit(X, y)
@@ -1045,7 +1039,7 @@ def check_classifiers_classes(name, Classifier):
10451039
classifier = Classifier()
10461040
if name == 'BernoulliNB':
10471041
classifier.set_params(binarize=X.mean())
1048-
set_fast_parameters(classifier)
1042+
set_optimal_parameters(classifier)
10491043
set_random_state(classifier)
10501044
# fit
10511045
classifier.fit(X, y_)
@@ -1071,8 +1065,8 @@ def check_regressors_int(name, Regressor):
10711065
# separate estimators to control random seeds
10721066
regressor_1 = Regressor()
10731067
regressor_2 = Regressor()
1074-
set_fast_parameters(regressor_1)
1075-
set_fast_parameters(regressor_2)
1068+
set_optimal_parameters(regressor_1)
1069+
set_optimal_parameters(regressor_2)
10761070
set_random_state(regressor_1)
10771071
set_random_state(regressor_2)
10781072

@@ -1099,7 +1093,7 @@ def check_regressors_train(name, Regressor):
10991093
# catch deprecation warnings
11001094
with warnings.catch_warnings(record=True):
11011095
regressor = Regressor()
1102-
set_fast_parameters(regressor)
1096+
set_optimal_parameters(regressor)
11031097
if not hasattr(regressor, 'alphas') and hasattr(regressor, 'alpha'):
11041098
# linear regressors need to set alpha, but not generalized CV ones
11051099
regressor.alpha = 0.01
@@ -1136,7 +1130,7 @@ def check_regressors_no_decision_function(name, Regressor):
11361130
y = multioutput_estimator_convert_y_2d(name, X[:, 0])
11371131
regressor = Regressor()
11381132

1139-
set_fast_parameters(regressor)
1133+
set_optimal_parameters(regressor)
11401134
if hasattr(regressor, "n_components"):
11411135
# FIXME CCA, PLS is not robust to rank 1 effects
11421136
regressor.n_components = 1
@@ -1244,7 +1238,7 @@ def check_estimators_overwrite_params(name, Estimator):
12441238
# catch deprecation warnings
12451239
estimator = Estimator()
12461240

1247-
set_fast_parameters(estimator)
1241+
set_optimal_parameters(estimator)
12481242
set_random_state(estimator)
12491243

12501244
# Make a physical copy of the orginal estimator parameters before fitting.
@@ -1315,8 +1309,8 @@ def check_estimators_data_not_an_array(name, Estimator, X, y):
13151309
# separate estimators to control random seeds
13161310
estimator_1 = Estimator()
13171311
estimator_2 = Estimator()
1318-
set_fast_parameters(estimator_1)
1319-
set_fast_parameters(estimator_2)
1312+
set_optimal_parameters(estimator_1)
1313+
set_optimal_parameters(estimator_2)
13201314
set_random_state(estimator_1)
13211315
set_random_state(estimator_2)
13221316

0 commit comments

Comments
 (0)
0