8000 FIX multinomial logistic regression class weigths · scikit-learn/scikit-learn@ed5004e · GitHub
[go: up one dir, main page]

Skip to content

Commit ed5004e

Browse files
committed
FIX multinomial logistic regression class weigths
1 parent cb591e1 commit ed5004e

File tree

2 files changed

+63
-21
lines changed

2 files changed

+63
-21
lines changed

sklearn/linear_model/logistic.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -550,12 +550,12 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
550550
pos_class = classes[1]
551551

552552
# If class_weights is a dict (provided by the user), the weights
553-
# are assigned to the original labels. If it is "auto", then
553+
# are assigned to the original labels. If it is "balanced", then
554554
# the class_weights are assigned after masking the labels with a OvR.
555555
sample_weight = np.ones(X.shape[0])
556556
le = LabelEncoder()
557557

558-
if isinstance(class_weight, dict):
558+
if isinstance(class_weight, dict) or multi_class == 'multinomial':
559559
if solver == "liblinear":
560560
if classes.size == 2:
561561
# Reconstruct the weights with keys 1 and -1
@@ -567,7 +567,7 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
567567
"solver cannot handle multiclass with "
568568
"class_weight of type dict. Use the lbfgs, "
569569
"newton-cg solvers or set "
570-
"class_weight='auto'")
570+
"class_weight='balanced'")
571571
else:
572572
class_weight_ = compute_class_weight(class_weight, classes, y)
573573
sample_weight = class_weight_[le.fit_transform(y)]
@@ -576,13 +576,16 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
576576
# multinomial case this is not necessary.
577577
if multi_class == 'ovr':
578578
w0 = np.zeros(n_features + int(fit_intercept))
579-
mask_classes = [-1, 1]
579+
mask_classes = np.array([-1, 1])
580580
mask = (y == pos_class)
581-
y[mask] = 1
582-
y[~mask] = -1
583-
# To take care of object dtypes, i.e 1 and -1 are in the form of
584-
# strings.
585-
y = as_float_array(y, copy=False)
581+
y_bin = np.ones(y.shape, dtype=np.float64)
582+
y_bin[~mask] = -1.
583+
# for compute_class_weight
584+
585+
if class_weight in ("auto", "balanced"):
586+
class_weight_ = compute_class_weight(class_weight, mask_classes,
587+
y_bin)
588+
sample_weight = class_weight_[le.fit_transform(y_bin)]
586589

587590
else:
588591
lbin = LabelBinarizer()
@@ -591,11 +594,6 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
591594
Y_bin = np.hstack([1 - Y_bin, Y_bin])
592595
w0 = np.zeros((Y_bin.shape[1], n_features + int(fit_intercept)),
593596
order='F')
594-
mask_classes = classes
595-
596-
if class_weight == "auto":
597-
class_weight_ = compute_class_weight(class_weight, mask_classes, y)
598-
sample_weight = class_weight_[le.fit_transform(y)]
599597

600598
if coef is not None:
601599
# it must work both giving the bias term and not
@@ -632,7 +630,7 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
632630
grad = lambda x, *args: _multinomial_loss_grad(x, *args)[1]
633631
hess = _multinomial_grad_hess
634632
else:
635-
target = y
633+
target = y_bin
636634
if solver == 'lbfgs':
637635
func = _logistic_loss_and_grad
638636
elif solver == 'newton-cg':
@@ -664,7 +662,7 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
664662
tol=tol)
665663
elif solver == 'liblinear':
666664
coef_, intercept_, _, = _fit_liblinear(
667-
X, y, C, fit_intercept, intercept_scaling, class_weight,
665+
X, target, C, fit_intercept, intercept_scaling, class_weight,
668666
penalty, dual, verbose, max_iter, tol, random_state)
669667
if fit_intercept:
670668
w0 = np.concatenate([coef_.ravel(), intercept_])

sklearn/linear_model/tests/test_logistic.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sklearn.utils.testing import ignore_warnings
1515
from sklearn.utils.testing import assert_raise_message
1616
from sklearn.utils import ConvergenceWarning
17+
from sklearn.utils import compute_class_weight
1718

1819
from sklearn.linear_model.logistic import (
1920
LogisticRegression,
@@ -25,7 +26,6 @@
2526
from sklearn.datasets import load_iris, make_classification
2627
from sklearn.metrics import log_loss
2728

28-
2929
X = [[-1, 0], [0, 1], [1, 1]]
F438 3030
X_sp = sp.csr_matrix(X)
3131
Y1 = [0, 1, 1]
@@ -518,12 +518,12 @@ def test_logistic_regressioncv_class_weights():
518518
X, y = make_classification(n_samples=20, n_features=20, n_informative=10,
519519
n_classes=3, random_state=0)
520520

521-
# Test the liblinear fails when class_weight of type dict is
522-
# provided, when it is multiclass. However it can handle
523-
# binary problems.
521+
msg = ("In LogisticRegressionCV the liblinear solver cannot handle "
522+
"multiclass with class_weight of type dict. Use the lbfgs, "
523+
"newton-cg solvers or set class_weight='balanced'")
524524
clf_lib = LogisticRegressionCV(class_weight={0: 0.1, 1: 0.2},
525525
solver='liblinear')
526-
assert_raises(ValueError, clf_lib.fit, X, y)
526+
assert_raise_message(ValueError, msg, clf_lib.fit, X, y)
527527
y_ = y.copy()
528528
y_[y == 2] = 1
529529
clf_lib.fit(X, y_)
@@ -541,6 +541,50 @@ def test_logistic_regressioncv_class_weights():
541541
assert_array_almost_equal(clf_lib.coef_, clf_lbf.coef_, decimal=4)
542542

543543

544+
def _compute_class_weight_dictionary(y):
545+
# compute class_weight and return it as a dictionary
546+
classes = np.unique(y)
547+
class_weight = compute_class_weight("balanced", classes, y)
548+
549+
class_weight_dict = {}
550+
for (cw, cl) in zip(class_weight, classes):
551+
class_weight_dict[cl] = cw
552+
553+
return class_weight_dict
554+
555+
556+
def test_logistic_regression_class_weights():
557+
# Multinomial case: remove 90% of class 0
558+
X = iris.data[45:, :]
559+
y = iris.target[45:]
560+
solvers = ("lbfgs", "newton-cg")
561+
class_weight_dict = _compute_class_weight_dictionary(y)
562+
563+
for solver in solvers:
564+
clf1 = LogisticRegression(solver=solver, multi_class="multinomial",
565+
class_weight="balanced")
566+
clf2 = LogisticRegression(solver=solver, multi_class="multinomial",
567+
class_weight=class_weight_dict)
568+
clf1.fit(X, y)
569+
clf2.fit(X, y)
570+
assert_array_almost_equal(clf1.coef_, clf2.coef_, decimal=6)
571+
572+
# Binary case: remove 90% of class 0 and 100% of class 2
573+
X = iris.data[45:100, :]
574+
y = iris.target[45:100]
575+
solvers = ("lbfgs", "newton-cg", "liblinear")
576+
class_weight_dict = _compute_class_weight_dictionary(y)
577+
578+
for solver in solvers:
579+
clf1 = LogisticRegression(solver=solver, multi_class="ovr",
580+
class_weight="balanced")
581+
clf2 = LogisticRegression(solver=solver, multi_class="ovr",
582+
class_weight=class_weight_dict)
583+
clf1.fit(X, y)
584+
clf2.fit(X, y)
585+
assert_array_almost_equal(clf1.coef_, clf2.coef_, decimal=6)
586+
587+
544588
def test_logistic_regression_convergence_warnings():
545589
# Test that warnings are raised if model does not converge
546590

0 commit comments

Comments
 (0)
0