8000 corrected bad test in test_multiclass · satra/scikit-learn@48f0a44 · GitHub
[go: up one dir, main page]

Skip to content

Commit 48f0a44

Browse files
AWintermanamueller
authored andcommitted
corrected bad test in test_multiclass
`test_ovr_single_label_predict_proba` wasn't checking consitency between `predict_proba` and `predict` correctly. Now it is. Nose tests pass except for 1 conerning PIL.
1 parent 396a16f commit 48f0a44

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

sklearn/tests/test_multiclass.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,11 @@ def test_ovr_multilabel_dataset():
146146
def test_ovr_multilabel_predict_proba():
147147
#shamelessly coppied from test_ovr_multilable_dataset.
148148
base_clf = MultinomialNB(alpha=1)
149-
n_samples = 100
150-
n_classes = 5
151149
for au in (False, True):
152150
X, Y = datasets.make_multilabel_classification(n_samples=100,
153151
n_features=20,
154152
n_classes=5,
155-
n_labels=3
153+
n_labels=3 ,
156154
length=50,
157155
allow_unlabeled=au,
158156
random_state=0)
@@ -182,7 +180,6 @@ def test_ovr_single_label_predict_proba():
182180
base_clf = MultinomialNB(alpha=1)
183181
n_samples = 100
184182
n_classes = 5
185-
multilabel=False
186183
X,Y = iris.data, iris.target
187184
X_train, Y_train = X[:80], Y[:80]
188185
X_test, Y_test = X[80:], Y[80:]
@@ -198,8 +195,8 @@ def test_ovr_single_label_predict_proba():
198195
assert_almost_equal(Y_proba.sum(axis=1), 1.0)
199196
#predict assigns a label if the probability that the
200197
#sample has the label is greater than than 0.5.
201-
pred = [tuple(l.nonzero()[0]) for l in (Y_proba > 0.5)]
202-
assert_equal(pred, Y_pred)
198+
pred = np.array([l.argmax() for l in Y_proba])
199+
assert_true(not (pred-Y_pred).any())
203200

204201
def test_ovr_gridsearch():
205202
ovr = OneVsRestClassifier(LinearSVC(random_state=0))

0 commit comments

Comments
 (0)
0