8000 ENH test more shapes, test non-consecutive classes, test accuracy on … · pfdevilliers/scikit-learn@a93e0af · GitHub
[go: up one dir, main page]

Skip to content

Commit a93e0af

Browse files
committed
ENH test more shapes, test non-consecutive classes, test accuracy on test set
1 parent 971d131 commit a93e0af

File tree

1 file changed

+70
-13
lines changed

1 file changed

+70
-13
lines changed

sklearn/tests/test_common.py

Lines changed: 70 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
"""
44
import warnings
55
import numpy as np
6-
from nose.tools import assert_raises
6+
from nose.tools import assert_raises, assert_equal
77
from numpy.testing import assert_array_equal
88

99
from sklearn.utils.testing import all_estimators
1010
from sklearn.utils.testing import assert_greater
1111
from sklearn.base import clone, ClassifierMixin, RegressorMixin
1212
from sklearn.utils import shuffle
1313
from sklearn.preprocessing import Scaler
14-
#from sklearn.datasets import load_digits
14+
from sklearn.cross_validation import train_test_split
1515
from sklearn.datasets import load_iris, load_boston
1616
from sklearn.metrics import zero_one_score
1717
from sklearn.lda import LDA
@@ -26,7 +26,6 @@
2626
OutputCodeClassifier
2727
from sklearn.feature_selection import RFE, RFECV
2828
from sklearn.naive_bayes import MultinomialNB, BernoulliNB
29-
from sklearn.linear_model import RidgeClassifier, RidgeClassifierCV
3029

3130
dont_test = [Pipeline, GridSearchCV, SparseCoder]
3231
meta_estimators = [BaseEnsemble, OneVsOneClassifier, OutputCodeClassifier,
@@ -55,15 +54,17 @@ def test_all_estimators():
5554
print(w)
5655

5756

58-
def test_classifiers():
57+
def test_classifiers_train():
58+
# test if classifiers do something sensible on training set
59+
# also test all shapes / shape errors
5960
estimators = all_estimators()
6061
classifiers = [(name, E) for name, E in estimators if issubclass(E,
6162
ClassifierMixin)]
6263
iris = load_iris()
6364
X, y = iris.data, iris.target
6465
X, y = shuffle(X, y, random_state=7)
65-
#digits = load_digits()
66-
#X, y = digits.data, digits.target
66+
n_samples, n_features = X.shape
67+
n_labels = len(np.unique(y))
6768
X = Scaler().fit_transform(X)
6869
for name, Clf in classifiers:
6970
if Clf in dont_test or Clf in meta_estimators:
@@ -75,6 +76,7 @@ def test_classifiers():
7576
# fit
7677
clf.fit(X, y)
7778
y_pred = clf.predict(X)
79+
assert_equal(y_pred.shape, (n_samples,))
7880
# training set performance
7981
assert_greater(zero_one_score(y, y_pred), 0.78)
8082
# raises error on malformed input for predict
@@ -84,24 +86,82 @@ def test_classifiers():
8486
assert_raises(ValueError, clf.predict, X.T)
8587
if hasattr(clf, "decision_function"):
10000 8688
try:
87-
#raises error on malformed input for decision_function
89+
# raises error on malformed input for decision_function
8890
assert_raises(ValueError, clf.decision_function, X.T)
89-
#decision_function agrees with predict:
91+
# decision_function agrees with predict:
9092
decision = clf.decision_function(X)
93+
assert_equal(decision.shape, (n_samples, n_labels))
9194
assert_array_equal(np.argmax(decision, axis=1), y_pred)
9295
except NotImplementedError:
9396
pass
9497
if hasattr(clf, "predict_proba"):
9598
try:
99+
# raises error on malformed input for predict_proba
96100
assert_raises(ValueError, clf.predict_proba, X.T)
97-
# decision_function agrees with predict:
101+
# predict_proba agrees with predict:
98102
y_prob = clf.predict_proba(X)
103+
assert_equal(y_prob.shape, (n_samples, n_labels))
99104
assert_array_equal(np.argmax(y_prob, axis=1), y_pred)
100105
except NotImplementedError:
101106
pass
102107

103108

104-
def test_regressors():
109+
def test_classifiers_classes():
110+
# test if classifiers can cope with non-consecutive classes
111+
estimators = all_estimators()
112+
classifiers = [(name, E) for name, E in estimators if issubclass(E,
113+
ClassifierMixin)]
114+
iris = load_iris()
115+
X, y = iris.data, iris.target
116+
X, y = shuffle(X, y, random_state=7)
117+
X = Scaler().fit_transform(X)
118+
y = 2 * y + 1
119+
# TODO: make work with next line :)
120+
#y = y.astype(np.str)
121+
for name, Clf in classifiers:
122+
if Clf in dont_test or Clf in meta_estimators:
123+
continue
124+
if Clf in [MultinomialNB, BernoulliNB]:
125+
# TODO also test these!
126+
continue
127+
clf = Clf()
128+
# fit
129+
clf.fit(X, y)
130+
y_pred = clf.predict(X)
131+
# training set performance
132+
assert_array_equal(np.unique(y), np.unique(y_pred))
133+
assert_greater(zero_one_score(y, y_pred), 0.78)
134+
135+
136+
def test_classifiers_test():
137+
# test if classifiers can cope with non-consecutive classes
138+
estimators = all_estimators()
139+
classifiers = [(name, E) for name, E in estimators if issubclass(E,
140+
ClassifierMixin)]
141+
iris = load_iris()
142+
X, y = iris.data, iris.target
143+
X, y = shuffle(X, y, random_state=7)
144+
X = Scaler().fit_transform(X)
145+
X_train, X_test, y_train, y_test = train_test_split(X, y)
146+
for name, Clf in classifiers:
147+
if Clf in dont_test or Clf in meta_estimators:
148+
continue
149+
if Clf in [MultinomialNB, BernoulliNB]:
150+
# TODO also test these!
151+
continue
152+
clf = Clf()
153+
# fit
154+
try:
155+
clf.fit(X_train, y_train)
156+
y_pred = clf.predict(X_test)
157+
# test set performance
158+
assert_greater(zero_one_score(y_test, y_pred), 0.78)
159+
except Exception as ex:
160+
print(ex)
161+
print(clf)
162+
163+
164+
def test_regressors_train():
105165
estimators = all_estimators()
106166
regressors = [(name, E) for name, E in estimators if issubclass(E,
107167
RegressorMixin)]
@@ -115,9 +175,6 @@ def test_regressors():
115175
for name, Reg in regressors:
116176
if Reg in dont_test or Reg in meta_estimators:
117177
continue
118-
if Reg in [RidgeClassifier, RidgeClassifierCV]:
119-
#TODO this is not a regressor!
120-
continue
121178
reg = Reg()
122179
if hasattr(reg, 'alpha'):
123180
reg.set_params(alpha=0.01)

0 commit comments

Comments
 (0)
0