8000 Fix LogisticRegression with warm_start, multinomial and binary data (… · wdevazelhes/scikit-learn@97a15db · GitHub
[go: up one dir, main page]

Skip to content

Commit 97a15db

Browse files
aishgrt1jnothman
authored andcommitted
Fix LogisticRegression with warm_start, multinomial and binary data (scikit-learn#10986)
1 parent 62301aa commit 97a15db

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

doc/whats_new/v0.20.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,10 @@ Classifiers and regressors
437437
- Fixed a bug in :class:`sklearn.linear_model.Lasso`
438438
where the coefficient had wrong shape when ``fit_intercept=False``.
439439
:issue:`10687` by :user:`Martin Hahn <martin-hahn>`.
440+ 10000
441+
- Fixed a bug in :func:`sklearn.linear_model.LogisticRegression` where the
442+
multi_class='multinomial' with binary output with warm_start = True
443+
:issue:`10836` by :user:`Aishwarya Srinivasan <aishgrt1>`.
440444

441445
- Fixed a bug in :class:`linear_model.RidgeCV` where using integer ``alphas``
442446
raised an error. :issue:`10393` by :user:`Mabel Villalba-Jiménez <mabelvj>`.

sklearn/linear_model/logistic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,13 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
676676
'shape (%d, %d) or (%d, %d)' % (
677677
coef.shape[0], coef.shape[1], classes.size,
678678
n_features, classes.size, n_features + 1))
679-
w0[:, :coef.shape[1]] = coef
679+
680+
if n_classes == 1:
681+
w0[0, :coef.shape[1]] = -coef
682+
w0[1, :coef.shape[1]] = coef
683+
else:
684+
w0[:, :coef.shape[1]] = coef
685+
680686

681687
if multi_class == 'multinomial':
682688
# fmin_l_bfgs_b and newton-cg accepts only ravelled parameters.

sklearn/linear_model/tests/test_logistic.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sklearn.preprocessing import LabelEncoder
1111
from sklearn.utils import compute_class_weight
1212
from sklearn.utils.testing import assert_almost_equal
13+
from sklearn.utils.testing import assert_allclose
1314
from sklearn.utils.testing import assert_array_almost_equal
1415
from sklearn.utils.testing import assert_array_equal
1516
from sklearn.utils.testing import assert_equal
@@ -1238,3 +1239,23 @@ def test_dtype_match():
12381239
lr_64.fit(X_64, y_64)
12391240
assert_equal(lr_64.coef_.dtype, X_64.dtype)
12401241
assert_almost_equal(lr_32.coef_, lr_64.coef_.astype(np.float32))
1242+
1243+
1244+
def test_warm_start_converge_LR():
1245+
# Test to see that the logistic regression converges on warm start,
1246+
# with multi_class='multinomial'. Non-regressive test for #10836
1247+
1248+
rng = np.random.RandomState(0)
1249+
X = np.concatenate((rng.randn(100, 2) + [1, 1], rng.randn(100, 2)))
1250+
y = np.array([1] * 100 + [-1] * 100)
1251+
lr_no_ws = LogisticRegression(multi_class='multinomial',
1252+
solver='sag', warm_start=False)
1253+
lr_ws = LogisticRegression(multi_class='multinomial',
1254+
solver='sag', warm_start=True)
1255+
1256+
lr_no_ws_loss = log_loss(y, lr_no_ws.fit(X, y).predict_proba(X))
1257+
lr_ws_loss = [log_loss(y, lr_ws.fit(X, y).predict_proba(X))
1258+
for _ in range(5)]
1259+
1260+
for i in range(5):
1261+
assert_allclose(lr_no_ws_loss, lr_ws_loss[i], rtol=1e-5)

0 commit comments

Comments
 (0)
0