8000 FIX class_weight in LogisticRegression and LogisticRegressionCV · scikit-learn/scikit-learn@289c0a3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 289c0a3

Browse files
committed
FIX class_weight in LogisticRegression and LogisticRegressionCV
1 parent 00996a2 commit 289c0a3

File tree

3 files changed

+82
-27
lines changed

3 files changed

+82
-27
lines changed

sklearn/linear_model/logistic.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -593,11 +593,11 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
593593
sample_weight = np.ones(X.shape[0])
594594

595595
# If class_weights is a dict (provided by the user), the weights
596-
# are assigned to the original labels. If it is "auto", then
596+
# are assigned to the original labels. If it is "balanced", then
597597
# the class_weights are assigned after masking the labels with a OvR.
598598
le = LabelEncoder()
599599

600-
if isinstance(class_weight, dict):
600+
if isinstance(class_weight, dict) or multi_class == 'multinomial':
601601
if solver == "liblinear":
602602
if classes.size == 2:
603603
# Reconstruct the weights with keys 1 and -1
@@ -609,7 +609,7 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
609609
"solver cannot handle multiclass with "
610610
"class_weight of type dict. Use the lbfgs, "
611611
"newton-cg or sag solvers or set "
612-
"class_weight='auto'")
612+
"class_weight='balanced'")
613613
else:
614614
class_weight_ = compute_class_weight(class_weight, classes, y)
615615
sample_weight *= class_weight_[le.fit_transform(y)]
@@ -622,20 +622,21 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
622622
mask = (y == pos_class)
623623
y_bin = np.ones(y.shape, dtype=np.float64)
624624
y_bin[~mask] = -1.
625+
# for compute_class_weight
626+
627+
# 'auto' is deprecated and will be removed in 0.19
628+
if class_weight in ("auto", "balanced"):
629+
class_weight_ = compute_class_weight(class_weight, mask_classes,
630+
y_bin)
631+
sample_weight *= class_weight_[le.fit_transform(y_bin)]
625632

626633
else:
627634
lbin = LabelBinarizer()
628-
Y_bin = lbin.fit_transform(y)
629-
if Y_bin.shape[1] == 1:
630-
Y_bin = np.hstack([1 - Y_bin, Y_bin])
631-
w0 = np.zeros((Y_bin.shape[1], n_features + int(fit_intercept)),
635+
Y_binarized = lbin.fit_transform(y)
636+
if Y_binarized.shape[1] == 1:
637+
Y_binarized = np.hstack([1 - Y_binarized, Y_binarized])
638+
w0 = np.zeros((Y_binarized.shape[1], n_features + int(fit_intercept)),
632639
order='F')
633-
mask_classes = classes
634-
635-
if class_weight == "auto":
636-
class_weight_ = compute_class_weight(class_weight, mask_classes,
637-
y_bin)
638-
sample_weight *= class_weight_[le.fit_transform(y_bin)]
639640

640641
if coef is not None:
641642
# it must work both giving the bias term and not
@@ -664,7 +665,7 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
664665
if multi_class == 'multinomial':
665666
# fmin_l_bfgs_b and newton-cg accepts only ravelled parameters.
666667
w0 = w0.ravel()
667-
target = Y_bin
668+
target = Y_binarized
668669
if solver == 'lbfgs':
669670
func = lambda x, *args: _multinomial_loss_grad(x, *args)[0:2]
670671
elif solver == 'newton-cg':
@@ -1534,9 +1535,18 @@ def fit(self, X, y, sample_weight=None):
15341535
if self.class_weight and not(isinstance(self.class_weight, dict) or
15351536
self.class_weight in
15361537
['balanced', 'auto']):
1538+
# 'auto' is deprecated and will be removed in 0.19
15371539
raise ValueError("class_weight provided should be a "
15381540
"dict or 'balanced'")
15391541

1542+
# compute the class weights for the entire dataset y
1543+
if self.class_weight in ("auto", "balanced"):
1544+
classes = np.unique(y)
1545+
class_weight = compute_class_weight(self.class_weight, classes, y)
1546+
class_weight = dict(zip(classes, class_weight))
1547+
else:
1548+
class_weight = self.class_weight
1549+
15401550
path_func = delayed(_log_reg_scoring_path)
15411551

15421552
# The SAG solver releases the GIL so it's more efficient to use
@@ -1548,7 +1558,7 @@ def fit(self, X, y, sample_weight=None):
15481558
fit_intercept=self.fit_intercept, penalty=self.penalty,
15491559
dual=self.dual, solver=self.solver, tol=self.tol,
15501560
max_iter=self.max_iter, verbose=self.verbose,
1551-
class_weight=self.class_weight, scoring=self.scoring,
1561+
class_weight=class_weight, scoring=self.scoring,
15521562
multi_class=self.multi_class,
15531563
intercept_scaling=self.intercept_scaling,
15541564
random_state=self.random_state,
@@ -1620,7 +1630,7 @@ def fit(self, X, y, sample_weight=None):
16201630
fit_intercept=self.fit_intercept, coef=coef_init,
16211631
max_iter=self.max_iter, tol=self.tol,
16221632
penalty=self.penalty, copy=False,
1623-
class_weight=self.class_weight,
1633+
class_weight=class_weight,
16241634
multi_class=self.multi_class,
16251635
verbose=max(0, self.verbose - 1),
16261636
random_state=self.random_state,

sklearn/linear_model/tests/test_logistic.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
from sklearn.utils.testing import assert_raises
1212
from sklearn.utils.testing import assert_true
1313
from sklearn.utils.testing import assert_warns
14+
from sklearn.utils.testing import assert_warns_message
1415
from sklearn.utils.testing import raises
1516
from sklearn.utils.testing import ignore_wa 10000 rnings
1617
from sklearn.utils.testing import assert_raise_message
1718
from sklearn.utils import ConvergenceWarning
19+
from sklearn.utils import compute_class_weight
1820

1921
from sklearn.linear_model.logistic import (
2022
LogisticRegression,
@@ -26,7 +28,6 @@
2628
from sklearn.datasets import load_iris, make_classification
2729
from sklearn.metrics import log_loss
2830

29-
3031
X = [[-1, 0], [0, 1], [1, 1]]
3132
X_sp = sp.csr_matrix(X)
3233
Y1 = [0, 1, 1]
@@ -542,12 +543,12 @@ def test_logistic_regressioncv_class_weights():
542543
X, y = make_classification(n_samples=20, n_features=20, n_informative=10,
543544
n_classes=3, random_state=0)
544545

545-
# Test the liblinear fails when class_weight of type dict is
546-
# provided, when it is multiclass. However it can handle
547-
# binary problems.
546+
msg = ("In LogisticRegressionCV the liblinear solver cannot handle "
547+
"multiclass with class_weight of type dict. Use the lbfgs, "
548+
"newton-cg or sag solvers or set class_weight='balanced'")
548549
clf_lib = LogisticRegressionCV(class_weight={0: 0.1, 1: 0.2},
549550
solver='liblinear')
550-
assert_raises(ValueError, clf_lib.fit, X, y)
551+
assert_raise_message(ValueError, msg, clf_lib.fit, X, y)
551552
y_ = y.copy()
552553
y_[y == 2] = 1
553554
clf_lib.fit(X, y_)
@@ -613,6 +614,56 @@ def test_logistic_regression_sample_weights():
613614
assert_array_almost_equal(clf_cw_12.coef_, clf_sw_12.coef_, decimal=4)
614615

615616

617+
def _compute_class_weight_dictionary(y):
618+
# helper for returning a dictionary instead of an array
619+
classes = np.unique(y)
620+
class_weight = compute_class_weight("balanced", classes, y)
621+
class_weight_dict = dict(zip(classes, class_weight))
622+
return class_weight_dict
623+
624+
625+
def test_logistic_regression_class_weights():
626+
# Multinomial case: remove 90% of class 0
627+
X = iris.data[45:, :]
628+
y = iris.target[45:]
629+
solvers = ("lbfgs", "newton-cg")
630+
class_weight_dict = _compute_class_weight_dictionary(y)
631+
632+
for solver in solvers:
633+
clf1 = LogisticRegression(solver=solver, multi_class="multinomial",
634+
class_weight="balanced")
635+
clf2 = LogisticRegression(solver=solver, multi_class="multinomial",
636+
class_weight=class_weight_dict)
637+
clf1.fit(X, y)
638+
clf2.fit(X, y)
639+
assert_array_almost_equal(clf1.coef_, clf2.coef_, decimal=6)
640+
641+
# Binary case: remove 90% of class 0 and 100% of class 2
642+
X = iris.data[45:100, :]
643+
y = iris.target[45:100]
644+
solvers = ("lbfgs", "newton-cg", "liblinear")
645+
class_weight_dict = _compute_class_weight_dictionary(y)
646+
647+
for solver in solvers:
648+
clf1 = LogisticRegression(solver=solver, multi_class="ovr",
649+
class_weight="balanced")
650+
clf2 = LogisticRegression(solver=solver, multi_class="ovr",
651+
class_weight=class_weight_dict)
652+
clf1.fit(X, y)
653+
clf2.fit(X, y)
654+
assert_array_almost_equal(clf1.coef_, clf2.coef_, decimal=6)
655+
656+
657+
def test_multinomial_logistic_regression_with_classweight_auto():
658+
X, y = iris.data, iris.target
659+
model = LogisticRegression(multi_class='multinomial',
660+
class_weight='auto', solver='lbfgs')
661+
# 'auto' is deprecated and will be removed in 0.19
662+
assert_warns_message(DeprecationWarning,
663+
"class_weight='auto' heuristic is deprecated",
664+
model.fit, X, y)
665+
666+
616667
def test_logistic_regression_convergence_warnings():
617668
# Test that warnings are raised if model does not converge
618669

sklearn/tests/test_common.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,6 @@ def test_class_weight_balanced_linear_classifiers():
113113
and issubclass(clazz, LinearClassifierMixin)]
114114

115115
for name, Classifier in linear_classifiers:
116-
if name == "LogisticRegressionCV":
117-
# Contrary to RidgeClassifierCV, LogisticRegressionCV use actual
118-
# CV folds and fit a model for each CV iteration before averaging
119-
# the coef. Therefore it is expected to not behave exactly as the
120-
# other linear model.
121-
continue
122116
yield check_class_weight_balanced_linear_classifier, name, Classifier
123117

124118

0 commit comments

Comments
 (0)
0