|
11 | 11 | from sklearn.utils.testing import assert_raises
|
12 | 12 | from sklearn.utils.testing import assert_true
|
13 | 13 | from sklearn.utils.testing import assert_warns
|
| 14 | +from sklearn.utils.testing import assert_warns_message |
14 | 15 | from sklearn.utils.testing import raises
|
15 | 16 | from sklearn.utils.testing import ignore_warnings
|
16 | 17 | from sklearn.utils.testing import assert_raise_message
|
17 | 18 | from sklearn.utils import ConvergenceWarning
|
| 19 | +from sklearn.utils import compute_class_weight |
18 | 20 |
|
19 | 21 | from sklearn.linear_model.logistic import (
|
20 | 22 | LogisticRegression,
|
|
26 | 28 | from sklearn.datasets import load_iris, make_classification
|
27 | 29 | from sklearn.metrics import log_loss
|
28 | 30 |
|
29 |
| - |
30 | 31 | X = [[-1, 0], [0, 1], [1, 1]]
|
31 | 32 | X_sp = sp.csr_matrix(X)
|
32 | 33 | Y1 = [0, 1, 1]
|
@@ -542,12 +543,12 @@ def test_logistic_regressioncv_class_weights():
|
542 | 543 | X, y = make_classification(n_samples=20, n_features=20, n_informative=10,
|
543 | 544 | n_classes=3, random_state=0)
|
544 | 545 |
|
545 |
| - # Test the liblinear fails when class_weight of type dict is |
546 |
| - # provided, when it is multiclass. However it can handle |
547 |
| - # binary problems. |
| 546 | + msg = ("In LogisticRegressionCV the liblinear solver cannot handle " |
| 547 | + "multiclass with class_weight of type dict. Use the lbfgs, " |
| 548 | + "newton-cg or sag solvers or set class_weight='balanced'") |
548 | 549 | clf_lib = LogisticRegressionCV(class_weight={0: 0.1, 1: 0.2},
|
549 | 550 | solver='liblinear')
|
550 |
| - assert_raises(ValueError, clf_lib.fit, X, y) |
| 551 | + assert_raise_message(ValueError, msg, clf_lib.fit, X, y) |
551 | 552 | y_ = y.copy()
|
552 | 553 | y_[y == 2] = 1
|
553 | 554 | clf_lib.fit(X, y_)
|
@@ -613,6 +614,55 @@ def test_logistic_regression_sample_weights():
|
613 | 614 | assert_array_almost_equal(clf_cw_12.coef_, clf_sw_12.coef_, decimal=4)
|
614 | 615 |
|
615 | 616 |
|
| 617 | +def _compute_class_weight_dictionary(y): |
| 618 | + # helper for returning a dictionary instead of an array |
| 619 | + classes = np.unique(y) |
| 620 | + class_weight = compute_class_weight("balanced", classes, y) |
| 621 | + class_weight_dict = {cl: cw for (cl, cw) in zip(classes, class_weight)} |
| 622 | + return class_weight_dict |
| 623 | + |
| 624 | + |
| 625 | +def test_logistic_regression_class_weights(): |
| 626 | + # Multinomial case: remove 90% of class 0 |
| 627 | + X = iris.data[45:, :] |
| 628 | + y = iris.target[45:] |
| 629 | + solvers = ("lbfgs", "newton-cg") |
| 630 | + class_weight_dict = _compute_class_weight_dictionary(y) |
| 631 | + |
| 632 | + for solver in solvers: |
| 633 | + clf1 = LogisticRegression(solver=solver, multi_class="multinomial", |
| 634 | + class_weight="balanced") |
| 635 | + clf2 = LogisticRegression(solver=solver, multi_class="multinomial", |
| 636 | + class_weight=class_weight_dict) |
| 637 | + clf1.fit(X, y) |
| 638 | + clf2.fit(X, y) |
| 639 | + assert_array_almost_equal(clf1.coef_, clf2.coef_, decimal=6) |
| 640 | + |
| 641 | + # Binary case: remove 90% of class 0 and 100% of class 2 |
| 642 | + X = iris.data[45:100, :] |
| 643 | + y = iris.target[45:100] |
| 644 | + solvers = ("lbfgs", "newton-cg", "liblinear") |
| 645 | + class_weight_dict = _compute_class_weight_dictionary(y) |
| 646 | + |
| 647 | + for solver in solvers: |
| 648 | + clf1 = LogisticRegression(solver=solver, multi_class="ovr", |
| 649 | + class_weight="balanced") |
| 650 | + clf2 = LogisticRegression(solver=solver, multi_class="ovr", |
| 651 | + class_weight=class_weight_dict) |
| 652 | + clf1.fit(X, y) |
| 653 | + clf2.fit(X, y) |
| 654 | + assert_array_almost_equal(clf1.coef_, clf2.coef_, decimal=6) |
| 655 | + |
| 656 | + |
| 657 | +def test_multinomial_logistic_regression_with_classweight_auto(): |
| 658 | + X, y = iris.data, iris.target |
| 659 | + model = LogisticRegression(multi_class='multinomial', |
| 660 | + class_weight='auto', solver='lbfgs') |
| 661 | + assert_warns_message(DeprecationWarning, |
| 662 | + "class_weight='auto' heuristic is deprecated", |
| 663 | + model.fit, X, y) |
| 664 | + |
| 665 | + |
616 | 666 | def test_logistic_regression_convergence_warnings():
|
617 | 667 | # Test that warnings are raised if model does not converge
|
618 | 668 |
|
|
0 commit comments