15
15
from sklearn .utils .testing import ignore_warnings
16
16
from sklearn .utils .testing import assert_raise_message
17
17
from sklearn .utils import ConvergenceWarning
18
+ from sklearn .utils import compute_class_weight
18
19
19
20
from sklearn .linear_model .logistic import (
20
21
LogisticRegression ,
26
27
from sklearn .datasets import load_iris , make_classification
27
28
from sklearn .metrics import log_loss
28
29
29
-
30
30
X = [[- 1 , 0 ], [0 , 1 ], [1 , 1 ]]
31
31
X_sp = sp .csr_matrix (X )
32
32
Y1 = [0 , 1 , 1 ]
@@ -542,12 +542,12 @@ def test_logistic_regressioncv_class_weights():
542
542
X , y = make_classification (n_samples = 20 , n_features = 20 , n_informative = 10 ,
543
543
n_classes = 3 , random_state = 0 )
544
544
545
- # Test the liblinear fails when class_weight of type dict is
546
- # provided, when it is multiclass. However it can handle
547
- # binary problems.
545
+ msg = ( "In LogisticRegressionCV the liblinear solver cannot handle "
546
+ " multiclass with class_weight of type dict. Use the lbfgs, "
547
+ "newton-cg solvers or set class_weight='balanced'" )
548
548
clf_lib = LogisticRegressionCV (class_weight = {0 : 0.1 , 1 : 0.2 },
549
549
solver = 'liblinear' )
550
- assert_raises (ValueError , clf_lib .fit , X , y )
550
+ assert_raise_message (ValueError , msg , clf_lib .fit , X , y )
551
551
y_ = y .copy ()
552
552
y_ [y == 2 ] = 1
553
553
clf_lib .fit (X , y_ )
@@ -570,6 +570,50 @@ def test_logistic_regressioncv_class_weights():
570
570
assert_array_almost_equal (clf_lib .coef_ , clf_sag .coef_ , decimal = 4 )
571
571
572
572
573
+ def _compute_class_weight_dictionary (y ):
574
+ # compute class_weight and return it as a dictionary
575
+ classes = np .unique (y )
576
+ class_weight = compute_class_weight ("balanced" , classes , y )
577
+
578
+ class_weight_dict = {}
579
+ for (cw , cl ) in zip (class_weight , classes ):
580
+ class_weight_dict [cl ] = cw
581
+
582
+ return class_weight_dict
583
+
584
+
585
+ def test_logistic_regression_class_weights ():
586
+ # Multinomial case: remove 90% of class 0
587
+ X = iris .data [45 :, :]
588
+ y = iris .target [45 :]
589
+ solvers = ("lbfgs" , "newton-cg" )
590
+ class_weight_dict = _compute_class_weight_dictionary (y )
591
+
592
+ for solver in solvers :
593
+ clf1 = LogisticRegression (solver = solver , multi_class = "multinomial" ,
594
+ class_weight = "balanced" )
595
+ clf2 = LogisticRegression (solver = solver , multi_class = "multinomial" ,
596
+ class_weight = class_weight_dict )
597
+ clf1 .fit (X , y )
598
+ clf2 .fit (X , y )
599
+ assert_array_almost_equal (clf1 .coef_ , clf2 .coef_ , decimal = 6 )
600
+
601
+ # Binary case: remove 90% of class 0 and 100% of class 2
602
+ X = iris .data [45 :100 , :]
603
+ y = iris .target [45 :100 ]
604
+ solvers = ("lbfgs" , "newton-cg" , "liblinear" )
605
+ class_weight_dict = _compute_class_weight_dictionary (y )
606
+
607
+ for solver in solvers :
608
+ clf1 = LogisticRegression (solver = solver , multi_class = "ovr" ,
609
+ class_weight = "balanced" )
610
+ clf2 = LogisticRegression (solver = solver , multi_class = "ovr" ,
611
+ class_weight = class_weight_dict )
612
+ clf1 .fit (X , y )
613
+ clf2 .fit (X , y )
614
+ assert_array_almost_equal (clf1 .coef_ , clf2 .coef_ , decimal = 6 )
615
+
616
+
573
617
def test_logistic_regression_convergence_warnings ():
574
618
# Test that warnings are raised if model does not converge
575
619
0 commit comments