@@ -212,15 +212,9 @@ def _boston_subset(n_samples=200):
212
212
BOSTON = X , y
213
213
return BOSTON
214
214
215
- def get_kwargs (estimator_class ):
216
- "Get special kwargs that might be required for an estimator."
217
215
218
- if issubclass (estimator_class , BaseSVC ):
219
- return {'decision_function_shape' :'ovo' }
220
- return {}
221
-
222
- def set_fast_parameters (estimator ):
223
- # speed up some estimators
216
+ def set_optimal_parameters (estimator ):
217
+ # speed up some estimators and avoid deprecated behaviour
224
218
params = estimator .get_params ()
225
219
if ("n_iter" in params
226
220
and estimator .__class__ .__name__ != "TSNE" ):
@@ -247,6 +241,9 @@ def set_fast_parameters(estimator):
247
241
if "n_init" in params :
248
242
# K-Means
249
243
estimator .set_params (n_init = 2 )
244
+ if "decision_function_shape" in params :
245
+ # SVC
246
+ estimator .set_params (decision_function_shape = 'ovo' )
250
247
251
248
if estimator .__class__ .__name__ == "SelectFdr" :
252
249
# be tolerant of noisy datasets (not actually speed)
@@ -297,7 +294,7 @@ def check_estimator_sparse_data(name, Estimator):
297
294
estimator = Estimator (with_mean = False )
298
295
else :
299
296
estimator = Estimator ()
300
- set_fast_parameters (estimator )
297
+ set_optimal_parameters (estimator )
301
298
# fit and predict
302
299
try :
303
300
estimator .fit (X , y )
@@ -329,7 +326,7 @@ def check_dtype_object(name, Estimator):
329
326
y = multioutput_estimator_convert_y_2d (name , y )
330
327
with warnings .catch_warnings ():
331
328
estimator = Estimator ()
332
- set_fast_parameters (estimator )
329
+ set_optimal_parameters (estimator )
333
330
334
331
estimator .fit (X , y )
335
332
if hasattr (estimator , "predict" ):
@@ -356,9 +353,8 @@ def check_fit2d_predict1d(name, Estimator):
356
353
X = 3 * rnd .uniform (size = (20 , 3 ))
357
354
y = X [:, 0 ].astype (np .int )
358
355
y = multioutput_estimator_convert_y_2d (name , y )
359
- kwargs = get_kwargs (Estimator )
360
- estimator = Estimator (** kwargs )
361
- set_fast_parameters (estimator )
356
+ estimator = Estimator ()
357
+ set_optimal_parameters (estimator )
362
358
363
359
if hasattr (estimator , "n_components" ):
364
360
estimator .n_components = 1
@@ -386,7 +382,7 @@ def check_fit2d_1sample(name, Estimator):
386
382
y = X [:, 0 ].astype (np .int )
387
383
y = multioutput_estimator_convert_y_2d (name , y )
388
384
estimator = Estimator ()
389
- set_fast_parameters (estimator )
385
+ set_optimal_parameters (estimator )
390
386
391
387
if hasattr (estimator , "n_components" ):
392
388
estimator .n_components = 1
@@ -408,7 +404,7 @@ def check_fit2d_1feature(name, Estimator):
408
404
y = X [:, 0 ].astype (np .int )
409
405
y = multioutput_estimator_convert_y_2d (name , y )
410
406
estimator = Estimator ()
411
- set_fast_parameters (estimator )
407
+ set_optimal_parameters (estimator )
412
408
413
409
if hasattr (estimator , "n_components" ):
414
410
estimator .n_components = 1
@@ -430,7 +426,7 @@ def check_fit1d_1feature(name, Estimator):
430
426
y = X .astype (np .int )
431
427
y = multioutput_estimator_convert_y_2d (name , y )
432
428
estimator = Estimator ()
433
- set_fast_parameters (estimator )
429
+ set_optimal_parameters (estimator )
434
430
435
431
if hasattr (estimator , "n_components" ):
436
432
estimator .n_components = 1
@@ -453,7 +449,7 @@ def check_fit1d_1sample(name, Estimator):
453
449
y = np .array ([1 ])
454
450
y = multioutput_estimator_convert_y_2d (name , y )
455
451
estimator = Estimator ()
456
- set_fast_parameters (estimator )
452
+ set_optimal_parameters (estimator )
457
453
458
454
if hasattr (estimator , "n_components" ):
459
455
estimator .n_components = 1
@@ -512,7 +508,7 @@ def _check_transformer(name, Transformer, X, y):
512
508
with warnings .catch_warnings (record = True ):
513
509
transformer = Transformer ()
514
510
set_random_state (transformer )
515
- set_fast_parameters (transformer )
511
+ set_optimal_parameters (transformer )
516
512
517
513
# fit
518
514
@@ -583,7 +579,7 @@ def check_pipeline_consistency(name, Estimator):
583
579
X -= X .min ()
584
580
y = multioutput_estimator_convert_y_2d (name , y )
585
581
estimator = Estimator ()
586
- set_fast_parameters (estimator )
582
+ set_optimal_parameters (estimator )
587
583
set_random_state (estimator )
588
584
pipeline = make_pipeline (estimator )
589
585
estimator .fit (X , y )
@@ -607,7 +603,7 @@ def check_fit_score_takes_y(name, Estimator):
607
603
y = np .arange (10 ) % 3
608
604
y = multioutput_estimator_convert_y_2d (name , y )
609
605
estimator = Estimator ()
610
- set_fast_parameters (estimator )
606
+ set_optimal_parameters (estimator )
611
607
set_random_state (estimator )
612
608
funcs = ["fit" , "score" , "partial_fit" , "fit_predict" , "fit_transform" ]
613
609
@@ -633,9 +629,8 @@ def check_estimators_dtypes(name, Estimator):
633
629
y = multioutput_estimator_convert_y_2d (name , y )
634
630
for X_train in [X_train_32 , X_train_64 , X_train_int_64 , X_train_int_32 ]:
635
631
with warnings .catch_warnings (record = True ):
636
- kwargs = get_kwargs (Estimator )
637
- estimator = Estimator (** kwargs )
638
- set_fast_parameters (estimator )
632
+ estimator = Estimator ()
633
+ set_optimal_parameters (estimator )
639
634
set_random_state (estimator , 1 )
640
635
estimator .fit (X_train , y )
641
636
@@ -647,7 +642,7 @@ def check_estimators_dtypes(name, Estimator):
647
642
648
643
def check_estimators_empty_data_messages (name , Estimator ):
649
644
e = Estimator ()
650
- set_fast_parameters (e )
645
+ set_optimal_parameters (e )
651
646
set_random_state (e , 1 )
652
647
653
648
X_zero_samples = np .empty (0 ).reshape (0 , 3 )
@@ -682,7 +677,7 @@ def check_estimators_nan_inf(name, Estimator):
682
677
# catch deprecation warnings
683
678
with warnings .catch_warnings (record = True ):
684
679
estimator = Estimator ()
685
- set_fast_parameters (estimator )
680
+ set_optimal_parameters (estimator )
686
681
set_random_state (estimator , 1 )
687
682
# try to fit
688
683
try :
@@ -751,7 +746,7 @@ def check_estimators_pickle(name, Estimator):
751
746
estimator = Estimator ()
752
747
753
748
set_random_state (estimator )
754
- set_fast_parameters (estimator )
749
+ set_optimal_parameters (estimator )
755
750
estimator .fit (X , y )
756
751
757
752
result = dict ()
@@ -776,7 +771,7 @@ def check_estimators_partial_fit_n_features(name, Alg):
776
771
X -= X .min ()
777
772
with warnings .catch_warnings (record = True ):
778
773
alg = Alg ()
779
- set_fast_parameters (alg )
774
+ set_optimal_parameters (alg )
780
775
if isinstance (alg , ClassifierMixin ):
781
776
classes = np .unique (y )
782
777
alg .partial_fit (X , y , classes = classes )
@@ -794,7 +789,7 @@ def check_clustering(name, Alg):
794
789
# catch deprecation and neighbors warnings
795
790
with warnings .catch_warnings (record = True ):
796
791
alg = Alg ()
797
- set_fast_parameters (alg )
792
+ set_optimal_parameters (alg )
798
793
if hasattr (alg , "n_clusters" ):
799
794
alg .set_params (n_clusters = 3 )
800
795
set_random_state (alg )
@@ -847,7 +842,7 @@ def check_classifiers_one_label(name, Classifier):
847
842
# catch deprecation warnings
848
843
with warnings .catch_warnings (record = True ):
849
844
classifier = Classifier ()
850
- set_fast_parameters (classifier )
845
+ set_optimal_parameters (classifier )
851
846
# try to fit
852
847
try :
853
848
classifier .fit (X_train , y )
@@ -883,11 +878,10 @@ def check_classifiers_train(name, Classifier):
883
878
n_classes = len (classes )
884
879
n_samples , n_features = X .shape
885
880
with warnings .catch_warnings (record = True ):
886
- kwargs = get_kwargs (Classifier )
887
- classifier = Classifier (** kwargs )
881
+ classifier = Classifier ()
888
882
if name in ['BernoulliNB' , 'MultinomialNB' ]:
889
883
X -= X .min ()
890
- set_fast_parameters (classifier )
884
+ set_optimal_parameters (classifier )
891
885
set_random_state (classifier )
892
886
# raises error on malformed input for fit
893
887
assert_raises (ValueError , classifier .fit , X , y [:- 1 ])
@@ -950,7 +944,7 @@ def check_estimators_fit_returns_self(name, Estimator):
950
944
951
945
estimator = Estimator ()
952
946
953
- set_fast_parameters (estimator )
947
+ set_optimal_parameters (estimator )
954
948
set_random_state (estimator )
955
949
956
950
assert_true (estimator .fit (X , y ) is estimator )
@@ -999,7 +993,7 @@ def check_supervised_y_2d(name, Estimator):
999
993
# catch deprecation warnings
1000
994
with warnings .catch_warnings (record = True ):
1001
995
estimator = Estimator ()
1002
- set_fast_parameters (estimator )
996
+ set_optimal_parameters (estimator )
1003
997
set_random_state (estimator )
1004
998
# fit
1005
999
estimator .fit (X , y )
@@ -1045,7 +1039,7 @@ def check_classifiers_classes(name, Classifier):
1045
1039
classifier = Classifier ()
1046
1040
if name == 'BernoulliNB' :
1047
1041
classifier .set_params (binarize = X .mean ())
1048
- set_fast_parameters (classifier )
1042
+ set_optimal_parameters (classifier )
1049
1043
set_random_state (classifier )
1050
1044
# fit
1051
1045
classifier .fit (X , y_ )
@@ -1071,8 +1065,8 @@ def check_regressors_int(name, Regressor):
1071
1065
# separate estimators to control random seeds
1072
1066
regressor_1 = Regressor ()
1073
1067
regressor_2 = Regressor ()
1074
- set_fast_parameters (regressor_1 )
1075
- set_fast_parameters (regressor_2 )
1068
+ set_optimal_parameters (regressor_1 )
1069
+ set_optimal_parameters (regressor_2 )
1076
1070
set_random_state (regressor_1 )
1077
1071
set_random_state (regressor_2 )
1078
1072
@@ -1099,7 +1093,7 @@ def check_regressors_train(name, Regressor):
1099
1093
# catch deprecation warnings
1100
1094
with warnings .catch_warnings (record = True ):
1101
1095
regressor = Regressor ()
1102
- set_fast_parameters (regressor )
1096
+ set_optimal_parameters (regressor )
1103
1097
if not hasattr (regressor , 'alphas' ) and hasattr (regressor , 'alpha' ):
1104
1098
# linear regressors need to set alpha, but not generalized CV ones
1105
1099
regressor .alpha = 0.01
@@ -1136,7 +1130,7 @@ def check_regressors_no_decision_function(name, Regressor):
1136
1130
y = multioutput_estimator_convert_y_2d (name , X [:, 0 ])
1137
1131
regressor = Regressor ()
1138
1132
1139
- set_fast_parameters (regressor )
1133
+ set_optimal_parameters (regressor )
1140
1134
if hasattr (regressor , "n_components" ):
1141
1135
# FIXME CCA, PLS is not robust to rank 1 effects
1142
1136
regressor .n_components = 1
@@ -1244,7 +1238,7 @@ def check_estimators_overwrite_params(name, Estimator):
1244
1238
# catch deprecation warnings
1245
1239
estimator = Estimator ()
1246
1240
1247
- set_fast_parameters (estimator )
1241
+ set_optimal_parameters (estimator )
1248
1242
set_random_state (estimator )
1249
1243
1250
1244
# Make a physical copy of the orginal estimator parameters before fitting.
@@ -1315,8 +1309,8 @@ def check_estimators_data_not_an_array(name, Estimator, X, y):
1315
1309
# separate estimators to control random seeds
1316
1310
estimator_1 = Estimator ()
1317
1311
estimator_2 = Estimator ()
1318
- set_fast_parameters (estimator_1 )
1319
- set_fast_parameters (estimator_2 )
1312
+ set_optimal_parameters (estimator_1 )
1313
+ set_optimal_parameters (estimator_2 )
1320
1314
set_random_state (estimator_1 )
1321
1315
set_random_state (estimator_2 )
1322
1316
0 commit comments