8000 ENH add sag solver in LogisticRegression and Ridge · scikit-learn/scikit-learn@94eb619 · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 94eb619

Browse files
TomDLTamueller
authored andcommitted
ENH add sag solver in LogisticRegression and Ridge
1 parent 4ceffe0 commit 94eb619

31 files changed

+6919
-6114
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
/sklearn/feature_extraction/_hashing.c -diff
1010
/sklearn/linear_model/cd_fast.c -diff
1111
/sklearn/linear_model/sgd_fast.c -diff
12+
/sklearn/linear_model/sag_fast.c -diff
1213
/sklearn/metrics/pairwise_fast.c -diff
1314
/sklearn/neighbors/ball_tree.c -diff
1415
/sklearn/neighbors/kd_tree.c -diff

benchmarks/bench_covertype.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353

5454
from sklearn.datasets import fetch_covtype, get_data_home
5555
from sklearn.svm import LinearSVC
56-
from sklearn.linear_model import SGDClassifier
56+
from sklearn.linear_model import SGDClassifier, LogisticRegression
5757
from sklearn.naive_bayes import GaussianNB
5858
from sklearn.tree import DecisionTreeClassifier
5959
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier
@@ -105,7 +105,8 @@ def load_data(dtype=np.float32, order='C', random_state=13):
105105
'SGD': SGDClassifier(alpha=0.001, n_iter=2),
106106
'GaussianNB': GaussianNB(),
107107
'liblinear': LinearSVC(loss="l2", penalty="l2", C=1000, dual=False,
108-
tol=1e-3)
108+
tol=1e-3),
109+
'SAG': LogisticRegression(solver='sag', max_iter=2, C=1000)
109110
}
110111

111112

benchmarks/bench_mnist.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from sklearn.svm import LinearSVC
4848
from sklearn.tree import DecisionTreeClassifier
4949
from sklearn.utils import check_array
50+
from sklearn.linear_model import LogisticRegression
5051

5152
# Memoize the data extraction and memory map the resulting
5253
# train / test splits in readonly mode
@@ -86,7 +87,8 @@ def load_data(dtype=np.float32, order='F'):
8687
'Nystroem-SVM':
8788
make_pipeline(Nystroem(gamma=0.015, n_components=1000), LinearSVC(C=100)),
8889
'SampledRBF-SVM':
89-
make_pipeline(RBFSampler(gamma=0.015, n_components=1000), LinearSVC(C=100))
90+
make_pipeline(RBFSampler(gamma=0.015, n_components=1000), LinearSVC(C=100)),
91+
'LinearRegression-SAG': LogisticRegression(solver='sag', tol=1e-1, C=1e4)
9092
}
9193

9294

@@ -120,7 +122,7 @@ def load_data(dtype=np.float32, order='F'):
120122
print("%s %d (size=%dMB)" % ("number of train samples:".ljust(25),
121123
X_train.shape[0], int(X_train.nbytes / 1e6)))
122124
print("%s %d (size=%dMB)" % ("number of test samples:".ljust(25),
123-
X_test.shape[0], int(X_test.nbytes / 1e6)))
125+
X_test.shape[0], int(X_test.nbytes / 1e6)))
124126

125127
print()
126128
print("Training Classifiers")
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# Authors: Tom Dupre la Tour <tom.dupre-la-tour@m4x.org>
2+
# Olivier Grisel <olivier.grisel@ensta.org>
3+
#
4+
# License: BSD 3 clause
5+
6+
import matplotlib.pyplot as plt
7+
import numpy as np
8+
import gc
9+
import time
10+
11+
from sklearn.externals.joblib import Memory
12+
from sklearn.linear_model import (LogisticRegression, SGDClassifier)
13+
from sklearn.datasets import fetch_rcv1
14+
from sklearn.linear_model.sag import get_auto_step_size
15+
from sklearn.linear_model.sag_fast import get_max_squared_sum
16+
17+
18+
try:
19+
import lightning.classification as lightning_clf
20+
except ImportError:
21+
lightning_clf = None
22+
23+
m = Memory(cachedir='.', verbose=0)
24+
25+
26+
# compute logistic loss
27+
def get_loss(w, intercept, myX, myy, C):
28+
n_samples = myX.shape[0]
29+
w = w.ravel()
30+
p = np.mean(np.log(1. + np.exp(-myy * (myX.dot(w) + intercept))))
31+
print("%f + %f" % (p, w.dot(w) / 2. / C / n_samples))
32+
p += w.dot(w) / 2. / C / n_samples
33+
return p
34+
35+
36+
# We use joblib to cache individual fits. Note that we do not pass the dataset
37+
# as argument as the hashing would be too slow, so we assume that the dataset
38+
# never changes.
39+
@m.cache()
40+
def bench_one(name, clf_type, clf_params, n_iter):
41+
clf = clf_type(**clf_params)
42+
try:
43+
clf.set_params(max_iter=n_iter, random_state=42)
44+
except:
45+
clf.set_params(n_iter=n_iter, random_state=42)
46+
47+
st = time.time()
48+
clf.fit(X, y)
49+
end = time.time()
50+
51+
try:
52+
C = 1.0 / clf.alpha / n_samples
53+
except:
54+
C = clf.C
55+
56+
try:
57+
intercept = clf.intercept_
58+
except:
59+
intercept = 0.
60+
61+
train_loss = get_loss(clf.coef_, intercept, X, y, C)
62+
train_score = clf.score(X, y)
63+
test_score = clf.score(X_test, y_test)
64+
duration = end - st
65+
66+
return train_loss, train_score, test_score, duration
67+
68+
69+
def bench(clfs):
70+
for (name, clf, iter_range, train_losses, train_scores,
71+
test_scores, durations) in clfs:
72+
print("training %s" % name)
73+
clf_type = type(clf)
74+
clf_params = clf.get_params()
75+
76+
for n_iter in iter_range:
77+
gc.collect()
78+
79+
train_loss, train_score, test_score, duration = bench_one(
80+
name, clf_type, clf_params, n_iter)
81+
82+
train_losses.append(train_loss)
83+
train_scores.append(train_score)
84+
test_scores.append(test_score)
85+
durations.append(duration)
86+
print("classifier: %s" % name)
87+
print("train_loss: %.8f" % train_loss)
88+
print("train_score: %.8f" % train_score)
89+
print("test_score: %.8f" % test_score)
90+
print("time for fit: %.8f seconds" % duration)
91+
print("")
92+
93+
print("")
94+
return clfs
95+
96+
97+
def plot_train_losses(clfs):
98+
plt.figure()
99+
for (name, _, _, train_losses, _, _, durations) in clfs:
100+
plt.plot(durations, train_losses, '-o', label=name)
101+
plt.legend(loc=0)
102+
plt.xlabel("seconds")
103+
plt.ylabel("train loss")
104+
105+
106+
def plot_train_scores(clfs):
107+
plt.figure()
108+
for (name, _, _, _, train_scores, _, durations) in clfs:
109+
plt.plot(durations, train_scores, '-o', label=name)
110+
plt.legend(loc=0)
111+
plt.xlabel("seconds")
112+
plt.ylabel("train score")
113+
plt.ylim((0.92, 0.96))
114+
115+
116+
def plot_test_scores(clfs):
117+
plt.figure()
118+
for (name, _, _, _, _, test_scores, durations) in clfs:
119+
plt.plot(durations, test_scores, '-o', label=name)
120+
plt.legend(loc=0)
121+
plt.xlabel("seconds")
122+
plt.ylabel("test score")
123+
plt.ylim((0.92, 0.96))
124+
125+
126+
def plot_dloss(clfs):
127+
plt.figure()
128+
pobj_final = []
129+
for (name, _, _, train_losses, _, _, durations) in clfs:
130+
pobj_final.append(train_losses[-1])
131+
132+
indices = np.argsort(pobj_final)
133+
pobj_best = pobj_final[indices[0]]
134+
135+
for (name, _, _, train_losses, _, _, durations) in clfs:
136+
log_pobj = np.log(abs(np.array(train_losses) - pobj_best)) / np.log(10)
137+
138+
plt.plot(durations, log_pobj, '-o', label=name)
139+
plt.legend(loc=0)
140+
plt.xlabel("seconds")
141+
plt.ylabel("log(best - train_loss)")
142+
143+
144+
rcv1 = fetch_rcv1()
145+
X = rcv1.data
146+
n_samples, n_features = X.shape
147+
148+
# consider the binary classification problem 'CCAT' vs the rest
149+
ccat_idx = rcv1.target_names.tolist().index('CCAT')
150+
y = rcv1.target.tocsc()[:, ccat_idx].toarray().ravel().astype(np.float64)
151+
y[y == 0] = -1
152+
153+
# parameters
154+
C = 1.
155+
fit_intercept = True
156+
tol = 1.0e-14
157+
158+
# max_iter range
159+
sgd_iter_range = list(range(1, 121, 10))
160+
newton_iter_range = list(range(1, 25, 3))
161+
lbfgs_iter_range = list(range(1, 242, 12))
162+
liblinear_iter_range = list(range(1, 37, 3))
163+
liblinear_dual_iter_range = list(range(1, 85, 6))
164+
sag_iter_range = list(range(1, 37, 3))
165+
166+
clfs = [
167+
("LR-liblinear",
168+
LogisticRegression(C=C, tol=tol,
169+
solver="liblinear", fit_intercept=fit_intercept,
170+
intercept_scaling=1),
171+
liblinear_iter_range, [], [], [], []),
172+
("LR-liblinear-dual",
173+
LogisticRegression(C=C, tol=tol, dual=True,
174+
solver="liblinear", fit_intercept=fit_intercept,
175+
intercept_scaling=1),
176+
liblinear_dual_iter_range, [], [], [], []),
177+
("LR-SAG",
178+
LogisticRegression(C=C, tol=tol,
179+
solver="sag", fit_intercept=fit_intercept),
180+
sag_iter_range, [], [], [], []),
181+
("LR-newton-cg",
182+
LogisticRegression(C=C, tol=tol, solver="newton-cg",
183+
fit_intercept=fit_intercept),
184+
newton_iter_range, [], [], [], []),
185+
("LR-lbfgs",
186+
LogisticRegression(C=C, tol=tol,
187+
solver="lbfgs", fit_intercept=fit_intercept),
188+
lbfgs_iter_range, [], [], [], []),
189+
("SGD",
190+
SGDClassifier(alpha=1.0 / C / n_samples, penalty='l2', loss='log',
191+
fit_intercept=fit_intercept, verbose=0),
192+
sgd_iter_range, [], [], [], [])]
193+
194+
195+
if lightning_clf is not None and not fit_intercept:
196+
alpha = 1. / C / n_samples
197+
# compute the same step_size than in LR-sag
198+
max_squared_sum = get_max_squared_sum(X)
199+
step_size = get_auto_step_size(max_squared_sum, alpha, "log",
200+
fit_intercept)
201+
202+
clfs.append(
203+
("Lightning-SVRG",
204+
lightning_clf.SVRGClassifier(alpha=alpha, eta=step_size,
205+
tol=tol, loss="log"),
206+
sag_iter_range, [], [], [], []))
207+
clfs.append(
208+
("Lightning-SAG",
209+
lightning_clf.SAGClassifier(alpha=alpha, eta=step_size,
210+
tol=tol, loss="log"),
211+
sag_iter_range, [], [], [], []))
212+
213+
# We keep only 200 features, to have a dense dataset,
214+
# and compare to lightning SAG, which seems incorrect in the sparse case.
215+
X_csc = X.tocsc()
216+
nnz_in_each_features = X_csc.indptr[1:] - X_csc.indptr[:-1]
217+
X = X_csc[:, np.argsort(nnz_in_each_features)[-200:]]
218+
X = X.toarray()
219+
print("dataset: %.3f MB" % (X.nbytes / 1e6))
220+
221+
222+
# Split training and testing. Switch train and test subset compared to
223+
# LYRL2004 split, to have a larger training dataset.
224+
n = 23149
225+
X_test = X[:n, :]
226+
y_test = y[:n]
227+
X = X[n:, :]
228+
y = y[n:]
229+
230+
clfs = bench(clfs)
231+
232+
plot_train_scores(clfs)
233+
plot_test_scores(clfs)
234+
plot_train_losses(clfs)
235+
plot_dloss(clfs)
236+
plt.show()

doc/modules/linear_model.rst

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ its ``coef_`` member::
104104
>>> clf = linear_model.Ridge (alpha = .5)
105105
>>> clf.fit ([[0, 0], [0, 0], [1, 1]], [0, .1, 1]) # doctest: +NORMALIZE_WHITESPACE
106106
Ridge(alpha=0.5, copy_X=True, fit_intercept=True, max_iter=None,
107-
normalize=False, solver='auto', tol=0.001)
107+
normalize=False, random_state=None, solver='auto', tol=0.001)
108108
>>> clf.coef_
109109
array([ 0.34545455, 0.34545455])
110110
>>> clf.intercept_ #doctest: +ELLIPSIS
@@ -670,17 +670,12 @@ Similarly, L1 regularized logistic regression solves the following optimization
670670

671671
The solvers implemented in the class :class:`LogisticRegression`
672672
are "liblinear" (which is a wrapper around the C++ library,
673-
LIBLINEAR), "newton-cg" and "lbfgs".
673+
LIBLINEAR), "newton-cg", "lbfgs" and "sag".
674674

675-
The lbfgs and newton-cg solvers only support L2 penalization and are found
675+
The "lbfgs" and "newton-cg" solvers only support L2 penalization and are found
676676
to converge faster for some high dimensional data. L1 penalization yields
677677
sparse predicting weights.
678678

679-
Several estimators are available for logistic regression.
680-
681-
:class:`LogisticRegression` has an option of using three solvers,
682-
"liblinear", "lbfgs" and "newton-cg".
683-
684679
The solver "liblinear" uses a coordinate descent (CD) algorithm based on
685680
Liblinear. For L1 penalization :func:`sklearn.svm.l1_min_c` allows to
686681
calculate the lower bound for C in order to get a non "null" (all feature weights to
@@ -697,8 +692,23 @@ Setting `multi_class` to "multinomial" with the "lbfgs" or "newton-cg" solver
697692
in :class:`LogisticRegression` learns a true multinomial logistic
698693
regression model, which means that its probability estimates should
699694
be better calibrated than the default "one-vs-rest" setting.
700-
L-BFGS and newton-cg cannot optimize L1-penalized models, though,
701-
so the "multinomial" setting does not learn sparse models.
695+
"lbfgs", "newton-cg" and "sag" solvers cannot optimize L1-penalized models, though, so the "multinomial" setting does not learn sparse models.
696+
697+
The solver "sag" uses a Stochastic Average Gradient descent [3]_. It does not
698+
handle "multinomial" case, and is limited to L2-penalized models, yet it is
699+
often faster than other solvers for large datasets, when both the number of
700+
samples and the number of features are large.
701+
702+
In a nutshell, one may choose the solver with the following rules:
703+
704+
=========================== ======================
705+
Case Solver
706+
=========================== ======================
707+
Small dataset or L1 penalty "liblinear"
708+
Multinomial loss "lbfgs" or newton-cg"
709+
Large dataset "sag"
710+
=========================== ======================
711+
For large dataset, you may also consider using :class:`SGDClassifier` with 'log' loss.
702712

703713
.. topic:: Examples:
704714

@@ -729,13 +739,16 @@ so the "multinomial" setting does not learn sparse models.
729739

730740
:class:`LogisticRegressionCV` implements Logistic Regression with
731741
builtin cross-validation to find out the optimal C parameter.
732-
"newton-cg" and "lbfgs" solvers are found to be faster
742+
"newton-cg", "sag" and "lbfgs" solvers are found to be faster
733743
for high-dimensional dense data, due to warm-starting.
734744
For the multiclass case, if `multi_class`
735745
option is set to "ovr", an optimal C is obtained for each class and if
736746
the `multi_class` option is set to "multinomial", an optimal C is
737747
obtained that minimizes the cross-entropy loss.
738748

749+
.. topic:: References:
750+
751+
.. [3] Mark Schmidt, Nicolas Le Roux, and Francis Bach: `Minimizing Finite Sums with the Stochastic Average Gradient. <http://hal.inria.fr/hal-00860051/PDF/sag_journal.pdf>`_
739752
740753
Stochastic Gradient Descent - SGD
741754
=================================

doc/modules/sgd.rst

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ further information.
161161
- :ref:`example_linear_model_plot_sgd_separating_hyperplane.py`,
162162
- :ref:`example_linear_model_plot_sgd_iris.py`
163163
- :ref:`example_linear_model_plot_sgd_weighted_samples.py`
164+
- :ref:`example_linear_model_plot_sgd_comparison.py`
164165
- :ref:`example_svm_plot_separating_hyperplane_unbalanced.py` ( 741A See the `Note`)
165166

166167
:class:`SGDClassifier` supports averaged SGD (ASGD). Averaging can be enabled
@@ -169,6 +170,10 @@ of the plain SGD over each iteration over a sample. When using ASGD
169170
the learning rate can be larger and even constant leading on some
170171
datasets to a speed up in training time.
171172

173+
For classification with a logistic loss, another variant of SGD with an
174+
averaging strategy is available with Stochastic Average Gradient (SAG)
175+
algorithm, available as a solver in :class:`LogisticRegression`.
176+
172177
Regression
173178
==========
174179

@@ -192,7 +197,11 @@ specified via the parameter ``epsilon``. This parameter depends on the
192197
scale of the target variables.
193198

194199
:class:`SGDRegressor` supports averaged SGD as :class:`SGDClassifier`.
195-
Averaging can be enabled by setting ```average=True```
200+
Averaging can be enabled by setting ```average=True```.
201+
202+
For regression with a squared loss and a l2 penalty, another variant of
203+
SGD with an averaging strategy is available with Stochastic Average
204+
Gradient (SAG) algorithm, available as a solver in :class:`Ridge`.
196205

197206

198207
Stochastic Gradient Descent for sparse data

0 commit comments

Comments
 (0)
0