8000 Add non regression test · scikit-learn/scikit-learn@1b6471e · GitHub
[go: up one dir, main page]

Skip to content

Commit 1b6471e

Browse files
committed
Add non regression test
1 parent a5537aa commit 1b6471e

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

sklearn/linear_model/tests/test_logistic.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424
from sklearn.cross_validation import StratifiedKFold
2525
from sklearn.datasets import load_iris, make_classification
26+
from sklearn.metrics import log_loss
2627

2728

2829
X = [[-1, 0], [0, 1], [1, 1]]
@@ -677,13 +678,14 @@ def test_logreg_cv_penalty():
677678
assert_equal(np.count_nonzero(lr_cv.coef_), np.count_nonzero(lr.coef_))
678679

679680

680-
def test_logreg_predict_proba():
681+
def test_logreg_predict_proba_multinomial():
681682
X, y = make_classification(
682683
n_samples=10, n_features=20, random_state=0, n_classes=3, n_informative=10)
683-
clf = LogisticRegression(multi_class="multinomial", solver="lbfgs")
684-
clf.fit(X, y)
685-
assert_array_almost_equal(np.sum(clf.predict_proba(X), axis=1), np.ones(10))
686-
687-
clf = LogisticRegression(multi_class="multinomial", solver="lbfgs")
688-
clf.fit(X, y)
689-
assert_array_almost_equal(np.sum(clf.predict_proba(X), axis=1), np.ones(10))
684+
clf_multi = LogisticRegression(multi_class="multinomial", solver="lbfgs")
685+
clf_multi.fit(X, y)
686+
clf_multi_loss = log_loss(y, clf_multi.predict_proba(X))
687+
688+
clf_ovr = LogisticRegression(multi_class="ovr", solver="lbfgs")
689+
clf_ovr.fit(X, y)
690+
clf_ovr_loss = log_loss(y, clf_ovr.predict_proba(X))
691+
assert_greater(clf_ovr_loss, clf_multi_loss)

0 commit comments

Comments
 (0)
0