8000 FIX multinomial logistic regression class weigths · scikit-learn/scikit-learn@6d3af2f · GitHub
[go: up one dir, main page]

Skip to content

Commit 6d3af2f

Browse files
committed
FIX multinomial logistic regression class weigths
1 parent b099a59 commit 6d3af2f

File tree

2 files changed

+59
-15
lines changed

2 files changed

+59
-15
lines changed

sklearn/linear_model/logistic.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -569,12 +569,12 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
569569
pos_class = classes[1]
570570

571571
# If class_weights is a dict (provided by the user), the weights
572-
# are assigned to the original labels. If it is "auto", then
572+
# are assigned to the original labels. If it is "balanced", then
573573
# the class_weights are assigned after masking the labels with a OvR.
574574
sample_weight = np.ones(X.shape[0])
575575
le = LabelEncoder()
576576

577-
if isinstance(class_weight, dict):
577+
if isinstance(class_weight, dict) or multi_class == 'multinomial':
578578
if solver == "liblinear":
579579
if classes.size == 2:
580580
# Reconstruct the weights with keys 1 and -1
@@ -585,8 +585,8 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
585585
raise ValueError("In LogisticRegressionCV the liblinear "
586586
"solver cannot handle multiclass with "
587587
"class_weight of type dict. Use the lbfgs, "
588-
"newton-cg or sag solvers or set "
589-
"class_weight='auto'")
588+
"newton-cg solvers or set "
589+
"class_weight='balanced'")
590590
else:
591591
class_weight_ = compute_class_weight(class_weight, classes, y)
592592
sample_weight = class_weight_[le.fit_transform(y)]
@@ -599,6 +599,12 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
599599
mask = (y == pos_class)
600600
y_bin = np.ones(y.shape, dtype=np.float64)
601601
y_bin[~mask] = -1.
602+
# for compute_class_weight
603+
604+
if class_weight in ("auto", "balanced"):
605+
class_weight_ = compute_class_weight(class_weight, mask_classes,
606+
y_bin)
607+
sample_weight = class_weight_[le.fit_transform(y_bin)]
602608

603609
else:
604610
lbin = LabelBinarizer()
@@ -607,12 +613,6 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
607613
Y_bin = np.hstack([1 - Y_bin, Y_bin])
608614
w0 = np.zeros((Y_bin.shape[1], n_features + int(fit_intercept)),
609615
order='F')
610-
mask_classes = classes
611-
612-
if class_weight == "auto":
613-
class_weight_ = compute_class_weight(class_weight, mask_classes,
614-
y_bin)
615-
sample_weight = class_weight_[le.fit_transform(y_bin)]
616616

617617
if coef is not None:
618618
# it must work both giving the bias term and not

sklearn/linear_model/tests/test_logistic.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sklearn.utils.testing import ignore_warnings
1616
from sklearn.utils.testing import assert_raise_message
1717
from sklearn.utils import ConvergenceWarning
18+
from sklearn.utils import compute_class_weight
1819

1920
from sklearn.linear_model.logistic import (
2021
LogisticRegression,
@@ -26,7 +27,6 @@
2627
from sklearn.datasets import load_iris, make_classification
2728
from sklearn.metrics import log_loss
2829

29-
3030
X = [[-1, 0], [0, 1], [1, 1]]
3131
X_sp = sp.csr_matrix(X)
3232
Y1 = [0, 1, 1]
@@ -542,12 +542,12 @@ def test_logistic_regressioncv_class_weights():
542542
X, y = make_classification(n_samples=20, n_features=20, n_informative=10,
543543
n_classes=3, random_state=0)
544544

545-
# Test the liblinear fails when class_weight of type dict is
546-
# provided, when it is multiclass. However it can handle
547-
# binary problems.
545+
msg = ("In LogisticRegressionCV the liblinear solver cannot handle "
546+
"multiclass with class_weight of type dict. Use the lbfgs, "
547+
"newton-cg solvers or set class_weight='balanced'")
548548
clf_lib = LogisticRegressionCV(class_weight={0: 0.1, 1: 0.2},
549549
solver='liblinear')
550-
assert_raises(ValueError, clf_lib.fit, X, y)
550+
assert_raise_message(ValueError, msg, clf_lib.fit, X, y)
551551
y_ = y.copy()
552552
y_[y == 2] = 1
553553
clf_lib.fit(X, y_)
@@ -570,6 +570,50 @@ def test_logistic_regressioncv_class_weights():
570570
assert_array_almost_equal(clf_lib.coef_, clf_sag.coef_, decimal=4)
571571

572572

573+
def _compute_class_weight_dictionary(y):
574+
# compute class_weight and return it as a dictionary
575+
classes = np.unique(y)
576+
class_weight = compute_class_weight("balanced", classes, y)
577+
578+
class_weight_dict = {}
579+
for (cw, cl) in zip(class_weight, classes):
580+
class_weight_dict[cl] = cw
581+
582+
return class_weight_dict
583+
584+
585+
def test_logistic_regression_class_weights():
586+
# Multinomial case: remove 90% of class 0
587+
X = iris.data[45:, :]
588+
y = iris.target[45:]
589+
solvers = ("lbfgs", "newton-cg")
590+
class_weight_dict = _compute_class_weight_dictionary(y)
591+
592+
for solver in solvers:
593+
clf1 = LogisticRegression(solver=solver, multi_class="multinomial",
594+
class_weight="balanced")
595+
clf2 = LogisticRegression(solver=solver, multi_class="multinomial",
596+
class_weight=class_weight_dict)
597+
clf1.fit(X, y)
598+
clf2.fit(X, y)
599+
assert_array_almost_equal(clf1.coef_, clf2.coef_, decimal=6)
600+
601+
# Binary case: remove 90% of class 0 and 100% of class 2
602+
X = iris.data[45:100, :]
603+
y = iris.target[45:100]
604+
solvers = ("lbfgs", "newton-cg", "liblinear")
605+
class_weight_dict = _compute_class_weight_dictionary(y)
606+
607+
for solver in solvers:
608+
clf1 = LogisticRegression(solver=solver, multi_class="ovr",
609+
class_weight="balanced")
610+
clf2 = LogisticRegression(solver=solver, multi_class="ovr",
611+
class_weight=class_weight_dict)
612+
clf1.fit(X, y)
613+
clf2.fit(X, y)
614+
assert_array_almost_equal(clf1.coef_, clf2.coef_, decimal=6)
615+
616+
573617
def test_logistic_regression_convergence_warnings():
574618
# Test that warnings are raised if model does not converge
575619

0 commit comments

Comments
 (0)
0