8000 ENH implement sag_logistic · scikit-learn/scikit-learn@578bd5b · GitHub
[go: up one dir, main page]

Skip to content

Commit 578bd5b

Browse files
committed
ENH implement sag_logistic
1 parent a7b5f8a commit 578bd5b

File tree

4 files changed

+257
-179
lines changed

4 files changed

+257
-179
lines changed

sklearn/linear_model/logistic.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from scipy import optimize, sparse
1717

1818
from .base import LinearClassifierMixin, SparseCoefMixin, BaseEstimator
19+
from .sag import sag_logistic
1920
from ..feature_selection.from_model import _LearntSelectorMixin
2021
from ..preprocessing import LabelEncoder, LabelBinarizer
2122
from ..svm.base import _fit_liblinear
@@ -395,7 +396,8 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
395396
max_iter=100, tol=1e-4, verbose=0,
396397
solver='lbfgs', coef=None, copy=True,
397398
class_weight=None, dual=False, penalty='l2',
398-
intercept_scaling=1., multi_class='ovr'):
399+
intercept_scaling=1., multi_class='ovr',
400+
random_state=None):
399401
"""Compute a Logistic Regression model for a list of regularization
400402
parameters.
401403
@@ -481,7 +483,11 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
481483
chosen is 'ovr', then a binary problem is fit for each label. Else
482484
the loss minimised is the multinomial loss fit across
483485
the entire probability distribution. Works only for the 'lbfgs'
484-
solver.
486+
and 'newton-cg' solvers.
487+
488+
random_state : int seed, RandomState instance, or None (default)
489+
The seed of the pseudo random number generator to use when
490+
shuffling the data.
485491
486492
Returns
487493
-------
@@ -505,20 +511,20 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
505511
raise ValueError("multi_class can be either 'multinomial' or 'ovr'"
506512
"got %s" % multi_class)
507513

508-
if solver not in ['liblinear', 'newton-cg', 'lbfgs']:
514+
if solver not in ['liblinear', 'newton-cg', 'lbfgs', 'sag']:
509515
raise ValueError("Logistic Regression supports only liblinear,"
510-
" newton-cg and lbfgs solvers. got %s" % solver)
516+
" newton-cg, lbfgs, and sag solvers. got %s" % solver)
511517

512518
if multi_class == 'multinomial' and solver == 'liblinear':
513519
raise ValueError("Solver %s cannot solve problems with "
514520
"a multinomial backend." % solver)
515521

516522
if solver != 'liblinear':
517523
if penalty != 'l2':
518-
raise ValueError("newton-cg and lbfgs solvers support only "
524+
raise ValueError("newton-cg, lbfgs and sag solvers support only "
519525
"l2 penalties, got %s penalty." % penalty)
520526
if dual:
521-
raise ValueError("newton-cg and lbfgs solvers support only "
527+
raise ValueError("newton-cg, lbfgs and sag solvers support only "
522528
"dual=False, got dual=%s" % dual)
523529
# Preprocessing.
524530
X = check_array(X, accept_sparse='csr', dtype=np.float64)
@@ -660,17 +666,21 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
660666
w0 = np.concatenate([coef_.ravel(), intercept_])
661667
else:
662668
w0 = coef_.ravel()
669+
elif solver == 'sag':
670+
w0 = sag_logistic(X, y, w0, 1. / C, sample_weight,
671+
max_iter=max_iter, tol=tol, verbose=verbose,
672+
random_state=random_state)
663673
else:
664674
raise ValueError("solver must be one of {'liblinear', 'lbfgs', "
665-
"'newton-cg'}, got '%s' instead" % solver)
675+
"'newton-cg', 'sag'}, got '%s' instead" % solver)
666676

667677
if multi_class == 'multinomial':
668678
multi_w0 = np.reshape(w0, (classes.size, -1))
669679
if classes.size == 2:
670680
multi_w0 = multi_w0[1][np.newaxis, :]
671681
coefs.append(multi_w0)
672682
else:
673-
coefs.append(w0)
683+
coefs.append(np.copy(w0))
674684
return coefs, np.array(Cs)
675685

676686

@@ -1015,10 +1025,10 @@ def fit(self, X, y):
10151025

10161026
X, y = check_X_y(X, y, accept_sparse='csr', dtype=np.float64, order="C")
10171027
self.classes_ = np.unique(y)
1018-
if self.solver not in ['liblinear', 'newton-cg', 'lbfgs']:
1028+
if self.solver not in ['liblinear', 'newton-cg', 'lbfgs', 'sag']:
10191029
raise ValueError(
1020-
"Logistic Regression supports only liblinear, newton-cg and "
1021-
"lbfgs solvers, Got solver=%s" % self.solver
1030+
"Logistic Regression supports only liblinear, newton-cg, "
1031+
"lbfgs and sag solvers, Got solver=%s" % self.solver
10221032
)
10231033

10241034
if self.solver == 'liblinear' and self.multi_class == 'multinomial':
@@ -1060,7 +1070,7 @@ def fit(self, X, y):
10601070
fit_intercept=self.fit_intercept, tol=self.tol,
10611071
verbose=self.verbose, solver=self.solver,
10621072
multi_class=self.multi_class, max_iter=self.max_iter,
1063-
class_weight=self.class_weight)
1073+
class_weight=self.class_weight, random_state=self.random_state)
10641074
self.coef_.append(coef_[0])
10651075

10661076
self.coef_ = np.squeeze(self.coef_)

sklearn/linear_model/sag.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,78 @@
2121
SPARSE_INTERCEPT_DECAY = 0.01
2222

2323

24+
def sag_logistic(X, y, coef_init, alpha=1e-4, sample_weight=None,
25+
max_iter=1000, tol=0.001, verbose=0, random_state=None):
26+
"""SAG solver for LogisticRegression"""
27+
28+
n_samples, n_features = X.shape[0], X.shape[1]
29+
30+
alpha = alpha / n_samples
31+
32+
# initialize all parameters if there is no init
33+
if sample_weight is None:
34+
sample_weight = np.ones(n_samples, dtype=np.float64, order='C')
35+
36+
# coef_init contains eventually the intercept_init at the end.
37+
fit_intercept = coef_init.size == (n_features + 1)
38+
if fit_intercept:
39+
intercept_init = coef_init[-1]
40+
coef_init = coef_init[:-1]
41+
else:
42+
intercept_init = 0.0
43+
44+
# TODO: *_init (with a boolean warm-start) as parameters ?
45+
intercept_sum_gradient_init = 0.0
46+
sum_gradient_init = np.zeros(n_features, dtype=np.float64, order='C')
47+
gradient_memory_init = np.zeros(n_samples, dtype=np.float64, order='C')
48+
seen_init = np.zeros(n_samples, dtype=np.int32, order='C')
49+
num_seen_init = 0
50+
weight_pos = 1
51+
weight_neg = 1
52+
53+
random_state = check_random_state(random_state)
54+
55+
# check which type of Sequential Dataset is needed
56+
if sp.issparse(X):
57+
dataset = CSRDataset(X.data, X.indptr, X.indices,
58+
y, sample_weight,
59+
seed=random_state.randint(MAX_INT))
60+
intercept_decay = SPARSE_INTERCEPT_DECAY
61+
else:
62+
dataset = ArrayDataset(X, y, sample_weight,
63+
seed=random_state.randint(MAX_INT))
64+
intercept_decay = 1.0
65+
66+
# set the eta0 at 1 / 4L where L is the max sum of
67+
# squares for over all samples
68+
step_size = get_auto_eta(dataset, alpha, n_samples, Log(), fit_intercept)
69+
70+
intercept_, num_seen, max_iter_reached, intercept_sum_gradient = \
71+
sag_sparse(dataset, coef_init.ravel(),
72+
intercept_init, n_samples,
73+
n_features, tol,
74+
max_iter,
75+
Log(),
76+
step_size, alpha,
77+
sum_gradient_init.ravel(),
78+
gradient_memory_init.ravel(),
79+
seen_init.ravel(),
80+
num_seen_init,
81+
weight_pos, weight_neg,
82+
fit_intercept,
83+
intercept_sum_gradient_init,
84+
intercept_decay,
85+
verbose)
86+
87+
if max_iter_reached:
88+
warnings.warn("The max_iter was reached which means "
89+
"the coef_ did not converge", ConvergenceWarning)
90+
if fit_intercept:
91+
return np.append(coef_init, intercept_)
92+
else:
93+
return coef_init
94+
95+
2496
# taken from http://stackoverflow.com/questions/1816958
2597
# useful for passing instance methods to Parallel
2698
def multiprocess_method(instance, name, args=()):

sklearn/linear_model/tests/test_logistic.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ def test_predict_iris():
105105
LogisticRegression(C=len(iris.data), solver='lbfgs',
106106
multi_class='multinomial'),
107107
LogisticRegression(C=len(iris.data), solver='newton-cg',
108-
multi_class='multinomial')]:
108+
multi_class='multinomial'),
109+
LogisticRegression(C=len(iris.data), solver='sag',
110+
multi_class='ovr')]:
109111
clf.fit(iris.data, target)
110112
assert_array_equal(np.unique(target), clf.classes_)
111113

@@ -216,17 +218,17 @@ def test_consistency_path():
216218
f = ignore_warnings
217219
# can't test with fit_intercept=True since LIBLINEAR
218220
# penalizes the intercept
219-
for method in ('lbfgs', 'newton-cg', 'liblinear'):
221+
for method in ('lbfgs', 'newton-cg', 'liblinear', 'sag'):
220222
coefs, Cs = f(logistic_regression_path)(
221-
X, y, Cs=Cs, fit_intercept=False, tol=1e-16, solver=method)
223+
X, y, Cs=Cs, fit_intercept=False, tol=1e-5, solver=method)
222224
for i, C in enumerate(Cs):
223-
lr = LogisticRegression(C=C, fit_intercept=False, tol=1e-16)
225+
lr = LogisticRegression(C=C, fit_intercept=False, tol=1e-5)
224226
lr.fit(X, y)
225227
lr_coef = lr.coef_.ravel()
226228
assert_array_almost_equal(lr_coef, coefs[i], decimal=4)
227229

228230
# test for fit_intercept=True
229-
for method in ('lbfgs', 'newton-cg', 'liblinear'):
231+
for method in ('lbfgs', 'newton-cg', 'liblinear', 'sag'):
230232
Cs = [1e3]
231233
coefs, Cs = f(logistic_regression_path)(
232234
X, y, Cs=Cs, fit_intercept=True, tol=1e-4, solver=method)
@@ -450,29 +452,43 @@ def test_ovr_multinomial_iris():
450452

451453
def test_logistic_regression_solvers():
452454
X, y = make_classification(n_features=10, n_informative=5, random_state=0)
453-
clf_n = LogisticRegression(solver='newton-cg', fit_intercept=False)
454-
clf_n.fit(X, y)
455+
clf_new = LogisticRegression(solver='newton-cg', fit_intercept=False)
456+
clf_new.fit(X, y)
455457
clf_lbf = LogisticRegression(solver='lbfgs', fit_intercept=False)
456458
clf_lbf.fit(X, y)
459+
clf_sag = LogisticRegression(solver='sag', fit_intercept=False)
460+
clf_sag.fit(X, y)
457461
clf_lib = LogisticRegression(fit_intercept=False)
458462
clf_lib.fit(X, y)
459-
assert_array_almost_equal(clf_n.coef_, clf_lib.coef_, decimal=3)
463+
assert_array_almost_equal(clf_new.coef_, clf_lib.coef_, decimal=3)
460464
assert_array_almost_equal(clf_lib.coef_, clf_lbf.coef_, decimal=3)
461-
assert_array_almost_equal(clf_n.coef_, clf_lbf.coef_, decimal=3)
465+
assert_array_almost_equal(clf_new.coef_, clf_lbf.coef_, decimal=3)
466+
assert_array_almost_equal(clf_sag.coef_, clf_lib.coef_, decimal=3)
467+
assert_array_almost_equal(clf_sag.coef_, clf_new.coef_, decimal=3)
468+
assert_array_almost_equal(clf_sag.coef_, clf_lbf.coef_, decimal=3)
462469

463470

464471
def test_logistic_regression_solvers_multiclass():
472+
tol = 1e-6
465473
X, y = make_classification(n_samples=20, n_features=20, n_informative=10,
466474
n_classes=3, random_state=0)
467-
clf_n = LogisticRegression(solver='newton-cg', fit_intercept=False)
468-
clf_n.fit(X, y)
469-
clf_lbf = LogisticRegression(solver='lbfgs', fit_intercept=False)
475+
clf_new = LogisticRegression(solver='newton-cg', fit_intercept=False,
476+
tol=tol)
477+
clf_new.fit(X, y)
478+
clf_lbf = LogisticRegression(solver='lbfgs', fit_intercept=False,
479+
tol=tol)
470480
clf_lbf.fit(X, y)
471-
clf_lib = LogisticRegression(fit_intercept=False)
481+
clf_sag = LogisticRegression(solver='sag', fit_intercept=False,
482+
tol=tol, max_iter=1000)
483+
clf_sag.fit(X, y)
484+
clf_lib = LogisticRegression(fit_intercept=False, tol=tol)
472485
clf_lib.fit(X, y)
473-
assert_array_almost_equal(clf_n.coef_, clf_lib.coef_, decimal=4)
486+
assert_array_almost_equal(clf_new.coef_, clf_lib.coef_, decimal=4)
474487
assert_array_almost_equal(clf_lib.coef_, clf_lbf.coef_, decimal=4)
475-
assert_array_almost_equal(clf_n.coef_, clf_lbf.coef_, decimal=4)
488+
assert_array_almost_equal(clf_new.coef_, clf_lbf.coef_, decimal=4)
489+
assert_array_almost_equal(clf_sag.coef_, clf_lib.coef_, decimal=4)
490+
assert_array_almost_equal(clf_sag.coef_, clf_new.coef_, decimal=4)
491+
assert_array_almost_equal(clf_sag.coef_, clf_lbf.coef_, decimal=4)
476492

477493

478494
def test_logistic_regressioncv_class_weights():
@@ -499,7 +515,12 @@ def test_logistic_regressioncv_class_weights():
499515
clf_lib = LogisticRegressionCV(solver='liblinear', fit_intercept=False,
500516
class_weight='auto')
501517
clf_lib.fit(X, y)
518+
clf_sag = LogisticRegressionCV(solver='sag', fit_intercept=False,
519+
class_weight='auto')
520+
clf_sag.fit(X, y)
502521
assert_array_almost_equal(clf_lib.coef_, clf_lbf.coef_, decimal=4)
522+
assert_array_almost_equal(clf_sag.coef_, clf_lbf.coef_, decimal=4)
523+
assert_array_almost_equal(clf_lib.coef_, clf_sag.coef_, decimal=4)
503524

504525

505526
def test_logistic_regression_convergence_warnings():

0 commit comments

Comments
 (0)
0