50
50
check_estimators_overwrite_params ,
51
51
check_estimators_partial_fit_n_features ,
52
52
check_cluster_overwrite_params ,
53
- check_sparsify_binary_classifier ,
54
53
check_sparsify_multiclass_classifier ,
55
54
check_classifier_data_not_an_array ,
56
55
check_regressor_data_not_an_array ,
@@ -82,6 +81,25 @@ def test_all_estimators():
82
81
yield check_parameters_default_constructible , name , Estimator
83
82
84
83
84
+ def test_non_meta_estimators ():
85
+ # input validation etc for non-meta estimators
86
+ # FIXME these should be done also for non-mixin estimators!
87
+ estimators = all_estimators (type_filter = ['classifier' , 'regressor' ,
88
+ 'transformer' , 'cluster' ])
89
+ for name , Estimator in estimators :
90
+ if name not in CROSS_DECOMPOSITION + ['Imputer' ]:
91
+ # Test that all estimators check their input for NaN's and infs
92
+ yield check_estimators_nan_inf , name , Estimator
93
+
94
+ if (name not in ['CCA' , '_CCA' , 'PLSCanonical' , 'PLSRegression' ,
95
+ 'PLSSVD' , 'GaussianProcess' ]):
96
+ # FIXME!
97
+ # in particular GaussianProcess!
98
+ yield check_estimators_overwrite_params , name , Estimator
99
+ if hasattr (Estimator , 'sparsify' ):
100
+ yield check_sparsify_multiclass_classifier , name , Estimator
101
+
102
+
85
103
def test_estimators_sparse_data ():
86
104
# All estimators should either deal with sparse data or raise an
87
105
# exception with type TypeError and an intelligible error message
@@ -108,15 +126,6 @@ def test_transformers():
108
126
yield check_transformer , name , Transformer
109
127
110
128
111
- def test_estimators_nan_inf ():
112
- # Test that all estimators check their input for NaN's and infs
113
- estimators = all_estimators (type_filter = ['classifier' , 'regressor' ,
114
- 'transformer' , 'cluster' ])
115
- for name , Estimator in estimators :
116
- if name not in CROSS_DECOMPOSITION + ['Imputer' ]:
117
- yield check_estimators_nan_inf , name , Estimator
118
-
119
-
120
129
def test_clustering ():
121
130
# test if clustering algorithms do something sensible
122
131
# also test all shapes / shape errors
@@ -279,18 +288,6 @@ def test_class_weight_auto_linear_classifiers():
279
288
yield check_class_weight_auto_linear_classifier , name , Classifier
280
289
281
290
282
- def test_estimators_overwrite_params ():
283
- # test whether any classifier overwrites his init parameters during fit
284
- for est_type in ["classifier" , "regressor" , "transformer" ]:
285
- estimators = all_estimators (type_filter = est_type )
286
- for name , Estimator in estimators :
287
- if (name not in ['CCA' , '_CCA' , 'PLSCanonical' , 'PLSRegression' ,
288
- 'PLSSVD' , 'GaussianProcess' ]):
289
- # FIXME!
290
- # in particular GaussianProcess!
291
- yield check_estimators_overwrite_params , name , Estimator
292
-
293
-
294
291
@ignore_warnings
295
292
def test_import_all_consistency ():
296
293
# Smoke test to check that any name in a __all__ list is actually defined
@@ -318,29 +315,6 @@ def test_root_import_all_completeness():
318
315
assert_in (modname , sklearn .__all__ )
319
316
320
317
321
- def test_sparsify_estimators ():
322
- #Test if predict with sparsified estimators works.
323
- #Tests regression, binary classification, and multi-class classification.
324
- estimators = all_estimators ()
325
-
326
- # test regression and binary classification
327
- for name , Estimator in estimators :
328
- try :
329
- Estimator .sparsify
330
- yield check_sparsify_binary_classifier , name , Estimator
331
- except :
332
- pass
333
-
334
- # test multiclass classification
335
- classifiers = all_estimators (type_filter = 'classifier' )
336
- for name , Classifier in classifiers :
337
- try :
338
- Classifier .sparsify
339
- yield check_sparsify_multiclass_classifier , name , Classifier
340
- except :
341
- pass
342
-
343
-
344
318
def test_non_transformer_estimators_n_iter ():
345
319
# Test that all estimators of type which are non-transformer
346
320
# and which have an attribute of max_iter, return the attribute
0 commit comments