8000 TST slight cleanup of common tests. · scikit-learn/scikit-learn@d5db44f · GitHub
[go: up one dir, main page]

Skip to content

Commit d5db44f

Browse files
committed
TST slight cleanup of common tests.
1 parent 7044728 commit d5db44f

File tree

2 files changed

+35
-71
lines changed

sklearn/tests/test_common.py

Lines changed: 19 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
check_estimators_overwrite_params,
5151
check_estimators_partial_fit_n_features,
5252
check_cluster_overwrite_params,
53-
check_sparsify_binary_classifier,
5453
check_sparsify_multiclass_classifier,
5554
check_classifier_data_not_an_array,
5655
check_regressor_data_not_an_array,
@@ -82,6 +81,25 @@ def test_all_estimators():
8281
yield check_parameters_default_constructible, name, Estimator
8382

8483

84+
def test_non_meta_estimators():
85+
# input validation etc for non-meta estimators
86+
# FIXME these should be done also for non-mixin estimators!
87+
estimators = all_estimators(type_filter=['classifier', 'regressor',
88+
'transformer', 'cluster'])
89+
for name, Estimator in estimators:
90+
if name not in CROSS_DECOMPOSITION + ['Imputer']:
91+
# Test that all estimators check their input for NaN's and infs
92+
yield check_estimators_nan_inf, name, Estimator
93+
94+
if (name not in ['CCA', '_CCA', 'PLSCanonical', 'PLSRegression',
95+
'PLSSVD', 'GaussianProcess']):
96+
# FIXME!
97+
# in particular GaussianProcess!
98+
yield check_estimators_overwrite_params, name, Estimator
99+
if hasattr(Estimator, 'sparsify'):
100+
yield check_sparsify_multiclass_classifier, name, Estimator
101+
102+
85103
def test_estimators_sparse_data():
86104
# All estimators should either deal with sparse data or raise an
87105
# exception with type TypeError and an intelligible error message
@@ -108,15 +126,6 @@ def test_transformers():
108126
yield check_transformer, name, Transformer
109127

110128

111-
def test_estimators_nan_inf():
112-
# Test that all estimators check their input for NaN's and infs
113-
estimators = all_estimators(type_filter=['classifier', 'regressor',
114-
'transformer', 'cluster'])
115-
for name, Estimator in estimators:
116-
if name not in CROSS_DECOMPOSITION + ['Imputer']:
117-
yield check_estimators_nan_inf, name, Estimator
118-
119-
120129
def test_clustering():
121130
# test if clustering algorithms do something sensible
122131
# also test all shapes / shape errors
@@ -279,18 +288,6 @@ def test_class_weight_auto_linear_classifiers():
279288
yield check_class_weight_auto_linear_classifier, name, Classifier
280289

281290

282-
def test_estimators_overwrite_params():
283-
# test whether any classifier overwrites his init parameters during fit
284-
for est_type in ["classifier", "regressor", "transformer"]:
285-
estimators = all_estimators(type_filter=est_type)
286-
for name, Estimator in estimators:
287-
if (name not in ['CCA', '_CCA', 'PLSCanonical', 'PLSRegression',
288-
'PLSSVD', 'GaussianProcess']):
289-
# FIXME!
290-
# in particular GaussianProcess!
291-
yield check_estimators_overwrite_params, name, Estimator
292-
293-
294291
@ignore_warnings
295292
def test_import_all_consistency():
296293
# Smoke test to check that any name in a __all__ list is actually defined
@@ -318,29 +315,6 @@ def test_root_import_all_completeness():
318315
assert_in(modname, sklearn.__all__)
319316

320317

321-
def test_sparsify_estimators():
322-
#Test if predict with sparsified estimators works.
323-
#Tests regression, binary classification, and multi-class classification.
324-
estimators = all_estimators()
325-
326-
# test regression and binary classification
327-
for name, Estimator in estimators:
328-
try:
329-
Estimator.sparsify
330-
yield check_sparsify_binary_classifier, name, Estimator
331-
except:
332-
pass
333-
334-
# test multiclass classification
335-
classifiers = all_estimators(type_filter='classifier')
336-
for name, Classifier in classifiers:
337-
try:
338-
Classifier.sparsify
339-
yield check_sparsify_multiclass_classifier, name, Classifier
340-
except:
341-
pass
342-
343-
344318
def test_non_transformer_estimators_n_iter():
345319
# Test that all estimators of type which are non-transformer
346320
# and which have an attribute of max_iter, return the attribute

sklearn/utils/estimator_checks.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from sklearn.utils.testing import SkipTest
2424
from sklearn.utils.testing import check_skip_travis
2525

26-
from sklearn.base import (clone, ClusterMixin, ClassifierMixin)
26+
from sklearn.base import (clone, ClusterMixin, ClassifierMixin, RegressorMixin,
27+
TransformerMixin)
2728
from sklearn.metrics import accuracy_score, adjusted_rand_score, f1_score
2829

2930
from sklearn.lda import LDA
@@ -43,6 +44,13 @@
4344
CROSS_DECOMPOSITION = ['PLSCanonical', 'PLSRegression', 'CCA', 'PLSSVD']
4445

4546

47+
def is_supervised(estimator):
48+
return (isinstance(estimator, ClassifierMixin)
49+
or isinstance(estimator, RegressorMixin)
50+
# transformers can all take a y
51+
or isinstance(estimator, TransformerMixin))
52+
53+
4654
def _boston_subset(n_samples=200):
4755
global BOSTON
4856
if BOSTON is None:
@@ -685,9 +693,9 @@ def check_regressors_train(name, Regressor):
685693
regressor.fit(X.tolist(), y_.tolist())
686694
regressor.predict(X)
687695

688-
# TODO: find out why PLS and CCA fail. RANSAC is random
689-
# and furthermore assumes the presence of outliers, hence
690-
# skipped
696+
# TODO: find out why PLS and CCA fail. RANSAC is random
697+
# and furthermore assumes the presence of outliers, hence
698+
# skipped
691699
if name not in ('PLSCanonical', 'CCA', 'RANSACRegressor'):
692700
assert_greater(regressor.score(X, y_), 0.5)
693701

@@ -813,7 +821,10 @@ def check_estimators_overwrite_params(name, Estimator):
813821
set_random_state(estimator)
814822

815823
params = estimator.get_params()
816-
estimator.fit(X, y)
824+
if is_supervised(estimator):
825+
estimator.fit(X, y)
826+
else:
827+
estimator.fit(X)
817828
new_params = estimator.get_params()
818829
for k, v in params.items():
819830
assert_false(np.any(new_params[k] != v),
@@ -860,27 +871,6 @@ def check_sparsify_multiclass_classifier(name, Classifier):
860871
assert_array_equal(pred, pred_orig)
861872

862873

863-
def check_sparsify_binary_classifier(name, Estimator):
864-
X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]])
865-
y = [1, 1, 1, 2, 2, 2]
866-
est = Estimator()
867-
868-
est.fit(X, y)
869-
pred_orig = est.predict(X)
870-
871-
# test sparsify with dense inputs
872-
est.sparsify()
873-
assert_true(sparse.issparse(est.coef_))
874-
pred = est.predict(X)
875-
assert_array_equal(pred, pred_orig)
876-
877-
# pickle and unpickle with sparse coef_
878-
est = pickle.loads(pickle.dumps(est))
879-
assert_true(sparse.issparse(est.coef_))
880-
pred = est.predict(X)
881-
assert_array_equal(pred, pred_orig)
882-
883-
884874
def check_classifier_data_not_an_array(name, Estimator):
885875
X = np.array([[3, 0], [0, 1], [0, 2], [1, 1], [1, 2], [2, 1]])
886876
y = [1, 1, 1, 2, 2, 2]

0 commit comments

Comments
 (0)
0