8000 ENH implement sag_logistic · scikit-learn/scikit-learn@38af560 · GitHub
[go: up one dir, main page]

Skip to content

Commit 38af560

Browse files
committed
ENH implement sag_logistic
1 parent 8d7526c commit 38af560

File tree

4 files changed

+263
-182
lines changed

4 files changed

+263
-182
lines changed

sklearn/linear_model/logistic.py

Lines changed: 22 additions & 12 deletions
X, y = check_X_y(X, y, accept_sparse='csr', dtype=np.float64, order="C")
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from scipy import optimize, sparse
1717

1818
from .base import LinearClassifierMixin, SparseCoefMixin, BaseEstimator
19+
from .sag import sag_logistic
1920
from ..feature_selection.from_model import _LearntSelectorMixin
2021
from ..preprocessing import LabelEncoder, LabelBinarizer
2122
from ..svm.base import _fit_liblinear
@@ -402,7 +403,8 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
402403
max_iter=100, tol=1e-4, verbose=0,
403404
solver='lbfgs', coef=None, copy=True,
404405
class_weight=None, dual=False, penalty='l2',
405-
intercept_scaling=1., multi_class='ovr'):
406+
intercept_scaling=1., multi_class='ovr',
407+
random_state=None):
406408
"""Compute a Logistic Regression model for a list of regularization
407409
parameters.
408410
@@ -488,7 +490,11 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
488490
chosen is 'ovr', then a binary problem is fit for each label. Else
489491
the loss minimised is the multinomial loss fit across
490492
the entire probability distribution. Works only for the 'lbfgs'
491-
solver.
493+
and 'newton-cg' solvers.
494+
495+
random_state : int seed, RandomState instance, or None (default)
496+
The seed of the pseudo random number generator to use when
497+
shuffling the data.
492498
493499
Returns
494500
-------
@@ -512,20 +518,20 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
512518
raise ValueError("multi_class can be either 'multinomial' or 'ovr'"
513519
"got %s" % multi_class)
514520

515-
if solver not in ['liblinear', 'newton-cg', 'lbfgs']:
521+
if solver not in ['liblinear', 'newton-cg', 'lbfgs', 'sag']:
516522
raise ValueError("Logistic Regression supports only liblinear,"
517-
" newton-cg and lbfgs solvers. got %s" % solver)
523+
" newton-cg, lbfgs, and sag solvers. got %s" % solver)
518524

519525
if multi_class == 'multinomial' and solver == 'liblinear':
520526
raise ValueError("Solver %s cannot solve problems with "
521527
"a multinomial backend." % solver)
522528

523529
if solver != 'liblinear':
524530
if penalty != 'l2':
525-
raise ValueError("newton-cg and lbfgs solvers support only "
531+
raise ValueError("newton-cg, lbfgs and sag solvers support only "
526532
"l2 penalties, got %s penalty." % penalty)
527533
if dual:
528-
raise ValueError("newton-cg and lbfgs solvers support only "
534+
raise ValueError("newton-cg, lbfgs and sag solvers support only "
529535
"dual=False, got dual=%s" % dual)
530536
# Preprocessing.
531537
X = check_array(X, accept_sparse=& 9E88 #39;csr', dtype=np.float64)
@@ -667,17 +673,21 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
667673
w0 = np.concatenate([coef_.ravel(), intercept_])
668674
else:
669675
w0 = coef_.ravel()
676+
elif solver == 'sag':
677+
w0 = sag_logistic(X, y, w0, 1. / C, sample_weight,
678+
max_iter=max_iter, tol=tol, verbose=verbose,
679+
random_state=random_state)
670680
else:
671681
raise ValueError("solver must be one of {'liblinear', 'lbfgs', "
672-
"'newton-cg'}, got '%s' instead" % solver)
682+
"'newton-cg', 'sag'}, got '%s' instead" % solver)
673683

674684
if multi_class == 'multinomial':
675685
multi_w0 = np.reshape(w0, (classes.size, -1))
676686
if classes.size == 2:
677687
multi_w0 = multi_w0[1][np.newaxis, :]
678688
coefs.append(multi_w0)
679689
else:
680-
coefs.append(w0)
690+
coefs.append(np.copy(w0))
681691
return coefs, np.array(Cs)
682692

683693

@@ -1016,10 +1026,10 @@ def fit(self, X, y):
10161026

10171027
10181028
self.classes_ = np.unique(y)
1019-
if self.solver not in ['liblinear', 'newton-cg', 'lbfgs']:
1029+
if self.solver not in ['liblinear', 'newton-cg', 'lbfgs', 'sag']:
10201030
raise ValueError(
1021-
"Logistic Regression supports only liblinear, newton-cg and "
1022-
"lbfgs solvers, Got solver=%s" % self.solver
1031+
"Logistic Regression supports only liblinear, newton-cg, "
1032+
"lbfgs and sag solvers, Got solver=%s" % self.solver
10231033
)
10241034

10251035
if self.solver == 'liblinear' and self.multi_class == 'multinomial':
@@ -1061,7 +1071,7 @@ def fit(self, X, y):
10611071
fit_intercept=self.fit_intercept, tol=self.tol,
10621072
verbose=self.verbose, solver=self.solver,
10631073
multi_class=self.multi_class, max_iter=self.max_iter,
1064-
class_weight=self.class_weight)
1074+
class_weight=self.class_weight, random_state=self.random_state)
10651075
self.coef_.append(coef_[0])
10661076

10671077
self.coef_ = np.squeeze(self.coef_)

sklearn/linear_model/sag.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
import numpy as np
22
import scipy.sparse as sp
3+
import warnings
34

45
from abc import ABCMeta
5-
import warnings
66

77
from .base import LinearClassifierMixin, LinearModel, SparseCoefMixin
88
from ..base import RegressorMixin, BaseEstimator
99
from ..utils import check_X_y, compute_class_weight, check_random_state
1010
from ..utils import ConvergenceWarning
1111
from ..utils.seq_dataset import ArrayDataset, CSRDataset
12-
from ..externals import six
13-
from ..externals.joblib import Parallel, delayed
1412
from .sag_fast import Log, SquaredLoss
1513
from .sag_fast import sag_sparse, get_auto_eta
14+
from ..externals import six
15+
from ..externals.joblib import Parallel, delayed
1616

1717
MAX_INT = np.iinfo(np.int32).max
1818

@@ -21,6 +21,78 @@
2121
SPARSE_INTERCEPT_DECAY = 0.01
2222

2323

24+
def sag_logistic(X, y, coef_init, alpha=1e-4, sample_weight=None,
25+
max_iter=1000, tol=0.001, verbose=0, random_state=None):
26+
"""SAG solver for LogisticRegression"""
27+
28+
n_samples, n_features = X.shape[0], X.shape[1]
29+
30+
alpha = alpha / n_samples
31+
32+
# initialize all parameters if there is no init
33+
if sample_weight is None:
34+
sample_weight = np.ones(n_samples, dtype=np.float64, order='C')
35+
36+
# coef_init contains eventually the intercept_init at the end.
37+
fit_intercept = coef_init.size == (n_features + 1)
38+
if fit_intercept:
39+
intercept_init = coef_init[-1]
40+
coef_init = coef_init[:-1]
41+
else:
42+
intercept_init = 0.0
43+
44+
# TODO: *_init (with a boolean warm-start) as parameters ?
45+
intercept_sum_gradient_init = 0.0
46+
sum_gradient_init = np.zeros(n_features, dtype=np.float64, order='C')
47+
gradient_memory_init = np.zeros(n_samples, dtype=np.float64, order='C')
48+
seen_init = np.zeros(n_samples, dtype=np.int32, order='C')
49+
num_seen_init = 0
50+
weight_pos = 1
51+
weight_neg = 1
52+
53+
random_state = check_random_state(random_state)
54+
55+
# check which type of Sequential Dataset is needed
56+
if sp.issparse(X):
57+
dataset = CSRDataset(X.data, X.indptr, X.indices,
58+
y, sample_weight,
59+
seed=random_state.randint(MAX_INT))
60+
intercept_decay = SPARSE_INTERCEPT_DECAY
61+
else:
62+
dataset = ArrayDataset(X, y, sample_weight,
63+
seed=random_state.randint(MAX_INT))
64+
intercept_decay = 1.0
65+
66+
# set the eta0 at 1 / 4L where L is the max sum of
67+
# squares for over all samples
68+
step_size = get_auto_eta(dataset, alpha, n_samples, Log(), fit_intercept)
69+
70+
intercept_, num_seen, max_iter_reached, intercept_sum_gradient = \
71+
sag_sparse(dataset, coef_init.ravel(),
72+
intercept_init, n_samples,
73+
n_features, tol,
74+
max_iter,
75+
Log(),
76+
step_size, alpha,
77+
sum_gradient_init.ravel(),
78+
gradient_memory_init.ravel(),
79+
seen_init.ravel(),
80+
num_seen_init,
81+
weight_pos, weight_neg,
82+
fit_intercept,
83+
intercept_sum_gradient_init,
84+
intercept_decay,
85+
verbose)
86+
87+
if max_iter_reached:
88+
warnings.warn("The max_iter was reached which means "
89+
"the coef_ did not converge", ConvergenceWarning)
90+
if fit_intercept:
91+
return np.append(coef_init, intercept_)
92+
else:
93+
return coef_init
94+
95+
2496
# taken from http://stackoverflow.com/questions/1816958
2597
# useful for passing instance methods to Parallel
2698
def multiprocess_method(instance, name, args=()):
@@ -549,3 +621,6 @@ def fit(self, X, y, sample_weight=None):
549621
intercept_sum_gradient_init)
550622

551623
return self
624+
625+
626+

sklearn/linear_model/tests/test_logistic.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ def test_predict_iris():
8888
LogisticRegression(C=len(iris.data), solver='lbfgs',
8989
multi_class='multinomial'),
9090
LogisticRegression(C=len(iris.data), solver='newton-cg',
91-
multi_class='multinomial')]:
91+
multi_class='multinomial'),
92+
LogisticRegression(C=len(iris.data), solver='sag',
93+
multi_class='ovr')]:
9294
clf.fit(iris.data, target)
9395
assert_array_equal(np.unique(target), clf.classes_)
9496

@@ -199,17 +201,17 @@ def test_consistency_path():
199201
f = ignore_warnings
200202
# can't test with fit_intercept=True since LIBLINEAR
201203
# penalizes the intercept
202-
for method in ('lbfgs', 'newton-cg', 'liblinear'):
204+
for method in ('lbfgs', 'newton-cg', 'liblinear', 'sag'):
203205
coefs, Cs = f(logistic_regression_path)(
204-
X, y, Cs=Cs, fit_intercept=False, tol=1e-16, solver=method)
206+
X, y, Cs=Cs, fit_intercept=False, tol=1e-5, solver=method)
205207
for i, C in enumerate(Cs):
206-
lr = LogisticRegression(C=C, fit_intercept=False, tol=1e-16)
208+
lr = LogisticRegression(C=C, fit_intercept=False, tol=1e-5)
207209
lr.fit(X, y)
208210
lr_coef = lr.coef_.ravel()
209211
assert_array_almost_equal(lr_coef, coefs[i], decimal=4)
210212

211213
# test for fit_intercept=True
212-
for method in ('lbfgs', 'newton-cg', 'liblinear'):
214+
for method in ('lbfgs', 'newton-cg', 'liblinear', 'sag'):
213215
Cs = [1e3]
214216
coefs, Cs = f(logistic_regression_path)(
215217
X, y, Cs=Cs, fit_intercept=True, tol=1e-4, solver=method)
@@ -434,29 +436,43 @@ def test_ovr_multinomial_iris():
434436

435437
def test_logistic_regression_solvers():
436438
X, y = make_classification(n_features=10, n_informative=5, random_state=0)
437-
clf_n = LogisticRegression(solver='newton-cg', fit_intercept=False)
438-
clf_n.fit(X, y)
439+
clf_new = LogisticRegression(solver='newton-cg', fit_intercept=False)
440+
clf_new.fit(X, y)
439441
clf_lbf = LogisticRegression(solver='lbfgs', fit_intercept=False)
440442
clf_lbf.fit(X, y)
443+
clf_sag = LogisticRegression(solver='sag', fit_intercept=False)
444+
clf_sag.fit(X, y)
441445
clf_lib = LogisticRegression(fit_intercept=False)
442446
clf_lib.fit(X, y)
443-
assert_array_almost_equal(clf_n.coef_, clf_lib.coef_, decimal=3)
447+
assert_array_almost_equal(clf_new.coef_, clf_lib.coef_, decimal=3)
444448
assert_array_almost_equal(clf_lib.coef_, clf_lbf.coef_, decimal=3)
445-
assert_array_almost_equal(clf_n.coef_, clf_lbf.coef_, decimal=3)
449+
assert_array_almost_equal(clf_new.coef_, clf_lbf.coef_, decimal=3)
450+
assert_array_almost_equal(clf_sag.coef_, clf_lib.coef_, decimal=3)
451+
assert_array_almost_equal(clf_sag.coef_, clf_new.coef_, decimal=3)
452+
assert_array_almost_equal(clf_sag.coef_, clf_lbf.coef_, decimal=3)
446453

447454

448455
def test_logistic_regression_solvers_multiclass():
456+
tol = 1e-6
449457
X, y = make_classification(n_samples=20, n_features=20, n_informative=10,
450458
n_classes=3, random_state=0)
451-
clf_n = LogisticRegression(solver='newton-cg', fit_intercept=False)
452-
clf_n.fit(X, y)
453-
clf_lbf = LogisticRegression(solver='lbfgs', fit_intercept=False)
459+
clf_new = LogisticRegression(solver='newton-cg', fit_intercept=False,
460+
tol=tol)
461+
clf_new.fit(X, y)
462+
clf_lbf = LogisticRegression(solver='lbfgs', fit_intercept=False,
463+
tol=tol)
454464
clf_lbf.fit(X, y)
455-
clf_lib = LogisticRegression(fit_intercept=False)
465+
clf_sag = LogisticRegression(solver='sag', fit_intercept=False,
466+
tol=tol, max_iter=1000)
467+
clf_sag.fit(X, y)
468+
clf_lib = LogisticRegression(fit_intercept=False, tol=tol)
456469
clf_lib.fit(X, y)
457-
assert_array_almost_equal(clf_n.coef_, clf_lib.coef_, decimal=4)
470+
assert_array_almost_equal(clf_new.coef_, clf_lib.coef_, decimal=4)
458471
assert_array_almost_equal(clf_lib.coef_, clf_lbf.coef_, decimal=4)
459-
assert_array_almost_equal(clf_n.coef_, clf_lbf.coef_, decimal=4)
472+
assert_array_almost_equal(clf_new.coef_, clf_lbf.coef_, decimal=4)
473+
assert_array_almost_equal(clf_sag.coef_, clf_lib.coef_, decimal=4)
474+
assert_array_almost_equal(clf_sag.coef_, clf_new.coef_, decimal=4)
475+
assert_array_almost_equal(clf_sag.coef_, clf_lbf.coef_, decimal=4)
460476

461477

462478
def test_logistic_regressioncv_class_weights():
@@ -483,7 +499,12 @@ def test_logistic_regressioncv_class_weights():
483499
clf_lib = LogisticRegressionCV(solver='liblinear', fit_intercept=False,
484500
class_weight='auto')
485501
clf_lib.fit(X, y)
502+
clf_sag = LogisticRegressionCV(solver='sag', fit_intercept=False,
503+
class_weight='auto')
504+
clf_sag.fit(X, y)
486505
assert_array_almost_equal(clf_lib.coef_, clf_lbf.coef_, decimal=4)
506+
assert_array_almost_equal(clf_sag.coef_, clf_lbf.coef_, decimal=4)
507+
assert_array_almost_equal(clf_lib.coef_, clf_sag.coef_, decimal=4)
487508

488509

489510
def test_logistic_regression_convergence_warnings():

0 commit comments

Comments
 (0)
0