8000 Change default solver in LogisticRegression · scikit-learn/scikit-learn@d8b5fda · GitHub
[go: up one dir, main page]

Skip to content

Commit d8b5fda

Browse files
committed
Change default solver in LogisticRegression
1 parent bc07078 commit d8b5fda

File tree

6 files changed

+455
-81
lines changed

6 files changed

+455
-81
lines changed

benchmarks/bench_logistic_solvers.py

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
"""
2+
Benchmarks of sklearn solver in LogisticRegression.
3+
"""
4+
5+
# Author: Tom Dupre la Tour
6+
import time
7+
from os.path import expanduser
8+
9+
import matplotlib.pyplot as plt
10+
import scipy.sparse as sp # noqa
11+
import numpy as np
12+
import pandas as pd
13+
14+
from sklearn.datasets import fetch_mldata
15+
from sklearn.datasets import fetch_rcv1, load_iris, load_digits
16+
from sklearn.datasets import fetch_20newsgroups_vectorized
17+
from sklearn.exceptions import ConvergenceWarning
18+
from sklearn.externals.joblib import delayed, Parallel, Memory
19+
from sklearn.linear_model import LogisticRegression
20+
from sklearn.linear_model.logistic import _multinomial_loss
21+
from sklearn.model_selection import train_test_split
22+
from sklearn.preprocessing import LabelBinarizer
23+
from sklearn.preprocessing import StandardScaler # noqa
24+
from sklearn.utils.testing import ignore_warnings
25+
from sklearn.utils import shuffle
26+
27+
28+
def get_loss(coefs, intercepts, X, y, C, multi_class, penalty):
29+
if multi_class == 'ovr':
30+
loss = 0
31+
for ii, (coef, intercept) in enumerate(zip(coefs, intercepts)):
32+
y_bin = y.copy()
33+
y_bin[y == ii] = 1
34+
y_bin[y != ii] = -1
35+
loss += np.sum(
36+
np.log(1. + np.exp(-y_bin * (X.dot(coef) + intercept))))
37+
38+
if penalty == 'l2':
39+
loss += 0.5 / C * coef.dot(coef)
40+
else:
41+
loss += np.sum(np.abs(coef)) / C
42+
else:
43+
coefs_and_intercept = np.vstack((coefs.T, intercepts.T)).T.ravel()
44+
lbin = LabelBinarizer()
45+
Y_multi = lbin.fit_transform(y)
46+
if Y_multi.shape[1] == 1:
47+
Y_multi = np.hstack([1 - Y_multi, Y_multi])
48+
loss, _, _ = _multinomial_loss(coefs_and_intercept, X, Y_multi, 0,
49+
np.ones(X.shape[0]))
50+
coefs = coefs.ravel()
51+
if penalty == 'l2':
52+
loss += 0.5 * coefs.dot(coefs) / C
53+
else:
54+
loss += np.sum(np.abs(coefs)) / C
55+
56+
loss /= X.shape[0]
57+
58+
return loss
59+
60+
61+
def fit_single(solver, X, y, X_shape, dataset, penalty='l2',
62+
multi_class='multinomial', C=1, max_iter=10):
63+
assert X.shape == X_shape
64+
65+
# if not sp.issparse(X):
66+
# X = StandardScaler().fit_transform(X)
67+
68+
X_train, X_test, y_train, y_test = train_test_split(
69+
X, y, random_state=42, stratify=y)
70+
train_scores = []
71+
train_losses = []
72+
test_scores = []
73+
times = []
74+
75+
n_repeats = None
76+
max_iter_range = np.unique(np.int_(np.logspace(0, np.log10(max_iter), 10)))
77+
for this_max_iter in max_iter_range:
78+
msg = ('[%s, %s, %s, %s] Max iter: %s' %
79+
(multi_class, dataset, penalty, solver, this_max_iter))
80+
lr = LogisticRegression(solver=solver, multi_class=multi_class, C=C,
81+
F438 penalty=penalty, fit_intercept=True, tol=1e-24,
82+
max_iter=this_max_iter, random_state=42,
83+
intercept_scaling=10000)
84+
t0 = time.clock()
85+
try:
86+
with ignore_warnings(category=ConvergenceWarning):
87+
# first time for timing
88+
if n_repeats is None:
89+
t0 = time.clock()
90+
lr.fit(X_train, y_train)
91+
max_iter_duration = max_iter * (time.clock() - t0)
92+
n_repeats = max(1, int(10. / max_iter_duration))
93+
94+
t0 = time.clock()
95+
for _ in range(n_repeats):
96+
lr.fit(X_train, y_train)
97+
train_time = (time.clock() - t0) / n_repeats
98+
print('%s (repeat=%d)' % (msg, n_repeats))
99+
100+
except ValueError:
101+
train_score = np.nan
102+
train_loss = np.nan
103+
test_score = np.nan
104+
train_time = np.nan
105+
print('%s (skipped)' % (msg, ))
106+
continue
107+
108+
train_loss = get_loss(lr.coef_, lr.intercept_, X_train, y_train, C,
109+
multi_class, penalty)
110+
111+
train_score = lr.score(X_train, y_train)
112+
test_score = lr.score(X_test, y_test)
113+
114+
train_scores.append(train_score)
115+
train_losses.append(train_loss)
116+
test_scores.append(test_score)
117+
times.append(train_time)
118+
119+
return (solver, penalty, dataset, multi_class, times, train_losses,
120+
train_scores, test_scores)
121+
122+
123+
def load_dataset(dataset, n_samples_max):
124+
if dataset == 'rcv1':
125+
rcv1 = fetch_rcv1()
126+
X = rcv1.data
127+
y = rcv1.target
128+
129+
# take only 3 categories (CCAT, ECAT, MCAT)
130+
y = y[:, [1, 4, 10]].astype(np.float64)
131+
# remove samples that have more than one category
132+
mask = np.asarray(y.sum(axis=1) == 1).ravel()
133+
y = y[mask, :].indices
134+
X = X[mask, :]
135+
136+
elif dataset == 'mnist':
137+
mnist = fetch_mldata('MNIST original')
138+
X, y = shuffle(mnist.data, mnist.target, random_state=42)
139+
X = X.astype(np.float64)
140+
141+
elif dataset == 'digits':
142+
digits = load_digits()
143+
X, y = digits.data, digits.target
144+
145+
elif dataset == 'iris':
146+
iris = load_iris()
147+
X, y = iris.data, iris.target
148+
149+
elif dataset == '20news':
150+
ng = fetch_20newsgroups_vectorized()
151+
X = ng.data
152+
y = ng.target
153+
154+
X = X[:n_samples_max]
155+
y = y[:n_samples_max]
156+
157+
return X, y
158+
159+
160+
def run(solvers, penalties, multi_classes, n_samples_max, max_iter, datasets,
161+
n_jobs):
162+
mem = Memory(cachedir=expanduser('~/cache'), verbose=0)
163+
164+
results = []
165+
for dataset in datasets:
166+
for multi_class in multi_classes:
167+
X, y = load_dataset(dataset, n_samples_max)
168+
169+
cached_fit = mem.cache(fit_single, ignore=['X'])
170+
cached_fit = fit_single
171+
172+
out = Parallel(n_jobs=n_jobs, mmap_mode=None)(delayed(cached_fit)(
173+
solver, X, y, X.shape, dataset=dataset, penalty=penalty,
174+
multi_class=multi_class, C=1, max_iter=max_iter)
175+
for solver in solvers
176+
for penalty in penalties) # yapf: disable
177+
178+
results.extend(out)
179+
180+
columns = ("solver penalty dataset multi_class times "
181+
"train_losses train_scores test_scores").split()
182+
results_df = pd.DataFrame(out, columns=columns)
183+
plot(results_df)
184+
185+
186+
def plot(res):
187+
res.set_index(['dataset', 'multi_class', 'penalty'], inplace=True)
188+
189+
grouped = res.groupby(level=['dataset', 'multi_class', 'penalty'])
190+
191+
colors = {
192+
'sag': 'red',
193+
'saga': 'orange',
194+
'liblinear': 'blue',
195+
'lbfgs': 'green',
196+
'auto': 'black',
197+
}
198+
199+
for idx, group in grouped:
200+
dataset, multi_class, penalty = idx
201+
fig = plt.figure(figsize=(12, 4))
202+
203+
# -----------------------
204+
ax = fig.add_subplot(131)
205+
train_losses = group['train_losses']
206+
tmp = np.sort(np.concatenate(train_losses.values))
207+
ref = (2 * tmp[0] - tmp[1]) * 0.999999
208+
209+
for losses, times, solver in zip(group['train_losses'], group['times'],
210+
group['solver']):
211+
losses = losses - ref
212+
linestyle = '--' if solver == 'auto' else '-'
213+
ax.plot(times, losses, label=solver, color=colors[solver],
214+
linestyle=linestyle, marker='.')
215+
ax.set_xlabel('Time (s)')
216+
ax.set_ylabel('Training objective (relative to min)')
217+
ax.set_yscale('log')
218+
219+
# -----------------------
220+
ax = fig.add_subplot(132)
221+
222+
for train_score, times, solver in zip(group['train_scores'],
223+
group['times'], group['solver']):
224+
linestyle = '--' if solver == 'auto' else '-'
225+
ax.plot(times, train_score, label=solver, color=colors[solver],
226+
linestyle=linestyle, marker='.')
227+
ax.set_xlabel('Time (s)')
228+
ax.set_ylabel('Train score')
229+
230+
# -----------------------
231+
ax = fig.add_subplot(133)
232+
233+
for test_score, times, solver in zip(group['test_scores'],
234+
group['times'], group['solver']):
235+
linestyle = '--' if solver == 'auto' else '-'
236+
ax.plot(times, test_score, label=solver, color=colors[solver],
237+
linestyle=linestyle, marker='.')
238+
ax.set_xlabel('Time (s)')
239+
ax.set_ylabel('Test score')
240+
ax.legend()
241+
242+
# -----------------------
243+
name = '%s_%s_%s' % (multi_class, penalty, dataset)
244+
plt.suptitle(name)
245+
fig.tight_layout()
246+
fig.subplots_adjust(top=0.9)
247+
plt.savefig('figures/' + name + '.png')
248+
plt.close(fig)
249+
print('SAVED: ' + name)
250+
251+
252+
if __name__ == '__main__':
253+
solvers = ['liblinear', 'saga', 'sag', 'lbfgs', 'auto']
254+
penalties = ['l2', 'l1']
255+
multi_classes = ['multinomial', 'ovr']
256+
datasets = ['iris', 'digits', 'mnist', '20news', 'rcv1']
257+
258+
run(solvers, penalties, multi_classes, n_samples_max=None, n_jobs=5,
259+
datasets=datasets, max_iter=40)

doc/modules/linear_model.rst

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -773,19 +773,30 @@ The "saga" solver [7]_ is a variant of "sag" that also supports the
773773
non-smooth `penalty="l1"` option. This is therefore the solver of choice
774774
for sparse multinomial logistic regression.
775775

776-
In a nutshell, one may choose the solver with the following rules:
777-
778-
================================= =====================================
779-
Case Solver
780-
================================= =====================================
781-
L1 penalty "liblinear" or "saga"
782-
Multinomial loss "lbfgs", "sag", "saga" or "newton-cg"
783-
Very Large dataset (`n_samples`) "sag" or "saga"
784-
================================= =====================================
776+
In a nutshell, the following table summarizes the solvers characteristics:
777+
778+
============================ =========== ======= =========== ===== ======
779+
solver 'liblinear' 'lbfgs' 'newton-cg' 'sag' 'saga'
780+
============================ =========== ======= =========== ===== ======
781+
Multinomial + L2 penalty no yes yes yes yes
782+
OVR + L2 penalty yes yes yes yes yes
783+
Multinomial + L1 penalty no no no no yes
784+
OVR + L1 penalty yes no no no yes
785+
============================ =========== ======= =========== ===== ======
786+
Penalize the intercept (bad) yes no no no no
787+
Faster for large datasets no no no yes yes
788+
Robust to unscaled datasets yes yes yes no no
789+
============================ =========== ======= =========== ===== ======
785790

786791
The "saga" solver is often the best choice. The "liblinear" solver is
787792
used by default for historical reasons.
788793

794+
The default solver will change to "auto" in version 0.22. This option
795+
automatically selects a good solver based on both `penalty` and `multi_class`
796+
parameters, and on the size of the training set. Note that the "auto" behavior
797+
may change without notice in the future, leading to similar but not necessarily
798+
exact same solutions.
799+
789800
For large dataset, you may also consider using :class:`SGDClassifier`
790801
with 'log' loss.
791802

doc/tutorial/statistical_inference/supervised_learning.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,8 +372,8 @@ function or **logistic** function:
372372
>>> logistic.fit(iris_X_train, iris_y_train)
373373
LogisticRegression(C=100000.0, class_weight=None, dual=False,
374374
fit_intercept=True, intercept_scaling=1, max_iter=100,
375-
multi_class='ovr', n_jobs=1, penalty='l2', random_state=None,
376-
solver='liblinear', tol=0.0001, verbose=0, warm_start=False)
375+
multi_class='default', n_jobs=1, penalty='l2', random_state=None,
376+
solver='default', tol=0.0001, verbose=0, warm_start=False)
377377

378378
This is known as :class:`LogisticRegression`.
379379

0 commit comments

Comments
 (0)
0