8000 LogisticRegression convert to float64 (for SAG solver) (#13243) · scikit-learn/scikit-learn@f02ef9f · GitHub
[go: up one dir, main page]

Skip to content

Commit f02ef9f

Browse files
massichGaelVaroquaux
authored andcommitted
LogisticRegression convert to float64 (for SAG solver) (#13243)
* Remove unused code * Squash all the PR 9040 commits initial PR commit seq_dataset.pyx generated from template seq_dataset.pyx generated from template #2 rename variables fused types consistency test for seq_dataset a sklearn/utils/tests/test_seq_dataset.py new if statement add doc sklearn/utils/seq_dataset.pyx.tp minor changes minor changes typo fix check numeric accuracy only up 5th decimal Address oliver's request for changing test name add test for make_dataset and rename a variable in test_seq_dataset * FIX tests * TST more numerically stable test_sgd.test_tol_parameter * Added benchmarks to compare SAGA 32b and 64b * Fixing gael's comments * fix * solve some issues * PEP8 * Address lesteve comments * fix merging * avoid using assert_equal * use all_close * use explicit ArrayDataset64 and CSRDataset64 * fix: remove unused import * Use parametrized to cover ArrayDaset-CSRDataset-32-64 matrix * for consistency use 32 first then 64 + add 64 suffix to variables * it would be cool if this worked !!! * more verbose version * revert SGD changes as much as possible. * Add solvers back to bench_saga * make 64 explicit in the naming * remove checking native python type + add comparison between 32 64 * Add whatsnew with everyone with commits * simplify a bit the testing * simplify the parametrize * update whatsnew * fix pep8
1 parent adc1e59 commit f02ef9f

17 files changed

+789
-396
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,8 @@ _configtest.o.d
7171

7272
# Used by mypy
7373
.mypy_cache/
74+
75+
# files generated from a template
76+
sklearn/utils/seq_dataset.pyx
77+
sklearn/utils/seq_dataset.pxd
78+
sklearn/linear_model/sag_fast.pyx

benchmarks/bench_saga.py

Lines changed: 106 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
"""Author: Arthur Mensch
1+
"""Author: Arthur Mensch, Nelle Varoquaux
22
33
Benchmarks of sklearn SAGA vs lightning SAGA vs Liblinear. Shows the gain
44
in using multinomial logistic regression in term of learning time.
55
"""
66
import json
77
import time
8-
from os.path import expanduser
8+
import os
99

1010
from joblib import delayed, Parallel, Memory
1111
import matplotlib.pyplot as plt
@@ -21,7 +21,7 @@
2121

2222

2323
def fit_single(solver, X, y, penalty='l2', single_target=True, C=1,
24-
max_iter=10, skip_slow=False):
24+
max_iter=10, skip_slow=False, dtype=np.float64):
2525
if skip_slow and solver == 'lightning' and penalty == 'l1':
2626
print('skip_slowping l1 logistic regression with solver lightning.')
2727
return
@@ -37,7 +37,8 @@ def fit_single(solver, X, y, penalty='l2', single_target=True, C=1,
3737
multi_class = 'ovr'
3838
else:
3939
multi_class = 'multinomial'
40-
40+
X = X.astype(dtype)
41+
y = y.astype(dtype)
4142
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42,
4243
stratify=y)
4344
n_samples = X_train.shape[0]
@@ -69,11 +70,15 @@ def fit_single(solver, X, y, penalty='l2', single_target=True, C=1,
6970
multi_class=multi_class,
7071
C=C,
7172
penalty=penalty,
72-
fit_intercept=False, tol=1e-24,
73+
fit_intercept=False, tol=0,
7374
max_iter=this_max_iter,
7475
random_state=42,
7576
)
77+
78+
# Makes cpu cache even for all fit calls
79+
X_train.max()
7680
t0 = time.clock()
81+
7782
lr.fit(X_train, y_train)
7883
train_time = time.clock() - t0
7984

@@ -106,9 +111,13 @@ def _predict_proba(lr, X):
106111
return softmax(pred)
107112

108113

109-
def exp(solvers, penalties, single_target, n_samples=30000, max_iter=20,
114+
def exp(solvers, penalty, single_target,
115+
n_samples=30000, max_iter=20,
110116
dataset='rcv1', n_jobs=1, skip_slow=False):
111-
mem = Memory(cachedir=expanduser('~/cache'), verbose=0)
117+
dtypes_mapping = {
118+
"float64": np.float64,
119+
"float32": np.float32,
120+
}
112121

113122
if dataset == 'rcv1':
114123
rcv1 = fetch_rcv1()
@@ -151,21 +160,24 @@ def exp(solvers, penalties, single_target, n_samples=30000, max_iter=20,
151160
X = X[:n_samples]
152161
y = y[:n_samples]
153162

154-
cached_fit = mem.cache(fit_single)
155163
out = Parallel(n_jobs=n_jobs, mmap_mode=None)(
156-
delayed(cached_fit)(solver, X, y,
164+
delayed(fit_single)(solver, X, y,
157165
penalty=penalty, single_target=single_target,
166+
dtype=dtype,
158167
C=1, max_iter=max_iter, skip_slow=skip_slow)
159168
for solver in solvers
160-
for penalty in penalties)
169+
for dtype in dtypes_mapping.values())
161170

162171
res = []
163172
idx = 0
164-
for solver in solvers:
165-
for penalty in penalties:
166-
if not (skip_slow and solver == 'lightning' and penalty == 'l1'):
173+
for dtype_name in dtypes_mapping.keys():
174+
for solver in solvers:
175+
if not (skip_slow and
176+
solver == 'lightning' and
177+
penalty == 'l1'):
167178
lr, times, train_scores, test_scores, accuracies = out[idx]
168179
this_res = dict(solver=solver, penalty=penalty,
180+
dtype=dtype_name,
169181
single_target=single_target,
170182
times=times, train_scores=train_scores,
171183
test_scores=test_scores,
@@ -177,68 +189,117 @@ def exp(solvers, penalties, single_target, n_samples=30000, max_iter=20,
177189
json.dump(res, f)
178190

179191

180-
def plot():
192+
def plot(outname=None):
181193
import pandas as pd
182194
with open('bench_saga.json', 'r') as f:
183195
f = json.load(f)
184196
res = pd.DataFrame(f)
185-
res.set_index(['single_target', 'penalty'], inplace=True)
197+
res.set_index(['single_target'], inplace=True)
186198

187-
grouped = res.groupby(level=['single_target', 'penalty'])
199+
grouped = res.groupby(level=['single_target'])
188200

189-
colors = {'saga': 'blue', 'liblinear': 'orange', 'lightning': 'green'}
201+
colors = {'saga': 'C0', 'liblinear': 'C1', 'lightning': 'C2'}
202+
linestyles = {"float32": "--", "float64": "-"}
203+
alpha = {"float64": 0.5, "float32": 1}
190204

191205
for idx, group in grouped:
192-
single_target, penalty = idx
193-
fig = plt.figure(figsize=(12, 4))
194-
ax = fig.add_subplot(131)
195-
196-
train_scores = group['train_scores'].values
197-
ref = np.min(np.concatenate(train_scores)) * 0.999
198-
199-
for scores, times, solver in zip(group['train_scores'], group['times'],
200-
group['solver']):
201-
scores = scores / ref - 1
202-
ax.plot(times, scores, label=solver, color=colors[solver])
206+
single_target = idx
207+
fig, axes = plt.subplots(figsize=(12, 4), ncols=4)
208+
ax = axes[0]
209+
210+
for scores, times, solver, dtype in zip(group['train_scores'],
211+
group['times'],
212+
group['solver'],
213+
group["dtype"]):
214+
ax.plot(times, scores, label="%s - %s" % (solver, dtype),
215+
color=colors[solver],
216+
alpha=alpha[dtype],
217+
marker=".",
218+
linestyle=linestyles[dtype])
219+
ax.axvline(times[-1], color=colors[solver],
220+
alpha=alpha[dtype],
221+
linestyle=linestyles[dtype])
203222
ax.set_xlabel('Time (s)')
204223
ax.set_ylabel('Training objective (relative to min)')
205224
ax.set_yscale('log')
206225

207-
ax = fig.add_subplot(132)
226+
ax = axes[1]
208227

209-
test_scores = group['test_scores'].values
210-
ref = np.min(np.concatenate(test_scores)) * 0.999
228+
for scores, times, solver, dtype in zip(group['test_scores'],
229+
group['times'],
230+
group['solver'],
231+
group["dtype"]):
232+
ax.plot(times, scores, label=solver, color=colors[solver],
233+
linestyle=linestyles[dtype],
234+
marker=".",
235+
alpha=alpha[dtype])
236+
ax.axvline(times[-1], color=colors[solver],
237+
alpha=alpha[dtype],
238+
linestyle=linestyles[dtype])
211239

212-
for scores, times, solver in zip(group['test_scores'], group['times'],
213-
group['solver']):
214-
scores = scores / ref - 1
215-
ax.plot(times, scores, label=solver, color=colors[solver])
216240
ax.set_xlabel('Time (s)')
217241
ax.set_ylabel('Test objective (relative to min)')
218242
ax.set_yscale('log')
219243

220-
ax = fig.add_subplot(133)
244+
ax = axes[2]
245+
for accuracy, times, solver, dtype in zip(group['accuracies'],
246+
group['times'],
247+
group['solver'],
248+
group["dtype"]):
249+
ax.plot(times, accuracy, label="%s - %s" % (solver, dtype),
250+
alpha=alpha[dtype],
251+
marker=".",
252+
color=colors[solver], linestyle=linestyles[dtype])
253+
ax.axvline(times[-1], color=colors[solver],
254+
alpha=alpha[dtype],
255+
linestyle=linestyles[dtype])
221256

222-
for accuracy, times, solver in zip(group['accuracies'], group['times'],
223-
group['solver']):
224-
ax.plot(times, accuracy, label=solver, color=colors[solver])
225257
ax.set_xlabel('Time (s)')
226258
ax.set_ylabel('Test accuracy')
227259
ax.legend()
228260
name = 'single_target' if single_target else 'multi_target'
229261
name += '_%s' % penalty
230262
plt.suptitle(name)
231-
name += '.png'
263+
if outname is None:
264+
outname = name + '.png'
232265
fig.tight_layout()
233266
fig.subplots_adjust(top=0.9)
234-
plt.savefig(name)
235-
plt.close(fig)
267+
268+
ax = axes[3]
269+
for scores, times, solver, dtype in zip(group['train_scores'],
270+
group['times'],
271+
group['solver'],
272+
group["dtype"]):
273+
ax.plot(np.arange(len(scores)),
274+
scores, label="%s - %s" % (solver, dtype),
275+
marker=".",
276+
alpha=alpha[dtype],
277+
color=colors[solver], linestyle=linestyles[dtype])
278+
279+
ax.set_yscale("log")
280+
ax.set_xlabel('# iterations')
281+
ax.set_ylabel('Objective function')
282+
ax.legend()
283+
284+
plt.savefig(outname)
236285

237286

238287
if __name__ == '__main__':
239288
solvers = ['saga', 'liblinear', 'lightning']
240289
penalties = ['l1', 'l2']
290+
n_samples = [100000, 300000, 500000, 800000, None]
241291
single_target = True
242-
exp(solvers, penalties, single_target, n_samples=None, n_jobs=1,
243-
dataset='20newspaper', max_iter=20)
244-
plot()
292+
for penalty in penalties:
293+
for n_sample in n_samples:
294+
exp(solvers, penalty, single_target,
295+
n_samples=n_sample, n_jobs=1,
296+
dataset='rcv1', max_iter=10)
297+
if n_sample is not None:
298+
outname = "figures/saga_%s_%d.png" % (penalty, n_sample)
299+
else:
300+
outname = "figures/saga_%s_all.png" % (penalty,)
301+
try:
302+
os.makedirs("figures")
303+
except OSError:
304+
pass
305+
plot(outname)

doc/whats_new/v0.21.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@ Support for Python 3.4 and below has been officially dropped.
162162
:mod:`sklearn.linear_model`
163163
...........................
164164

165+
- |Enhancement| :class:`linear_model.make_dataset` now preserves
166+
``float32`` and ``float64`` dtypes. :issues:`8769` and :issues:`11000` by
167+
:user:`Nelle Varoquaux`_, :user:`Arthur Imbert <Henley13>`,
168+
:user:`Guillaume Lemaitre <glemaitre>`, and :user:`Joan Massich <massich>`
169+
165170
- |Feature| :class:`linear_model.LogisticRegression` and
166171
:class:`linear_model.LogisticRegressionCV` now support Elastic-Net penalty,
167172
with the 'saga' solver. :issue:`11646` by :user:`Nicolas Hug <NicolasHug>`.

sklearn/linear_model/base.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
from ..utils.extmath import safe_sparse_dot
3333
from ..utils.sparsefuncs import mean_variance_axis, inplace_column_scale
3434
from ..utils.fixes import sparse_lsqr
35-
from ..utils.seq_dataset import ArrayDataset, CSRDataset
35+
from ..utils.seq_dataset import ArrayDataset32, CSRDataset32
36+
from ..utils.seq_dataset import ArrayDataset64, CSRDataset64
3637
from ..utils.validation import check_is_fitted
3738
from ..exceptions import NotFittedError
3839
from ..preprocessing.data import normalize as f_normalize
@@ -76,15 +77,22 @@ def make_dataset(X, y, sample_weight, random_state=None):
7677
"""
7778

7879
rng = check_random_state(random_state)
79-
# seed should never be 0 in SequentialDataset
80+
# seed should never be 0 in SequentialDataset64
8081
seed = rng.randint(1, np.iinfo(np.int32).max)
8182

83+
if X.dtype == np.float32:
84+
CSRData = CSRDataset32
85+
ArrayData = ArrayDataset32
86+
else:
87+
CSRData = CSRDataset64
88+
ArrayData = ArrayDataset64
89+
8290
if sp.issparse(X):
83-
dataset = CSRDataset(X.data, X.indptr, X.indices, y, sample_weight,
84-
seed=seed)
91+
dataset = CSRData(X.data, X.indptr, X.indices, y, sample_weight,
92+
seed=seed)
8593
intercept_decay = SPARSE_INTERCEPT_DECAY
8694
else:
87-
dataset = ArrayDataset(X, y, sample_weight, seed=seed)
95+
dataset = ArrayData(X, y, sample_weight, seed=seed)
8896
intercept_decay = 1.0
8997

9098
return dataset, intercept_decay

sklearn/linear_model/logistic.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ def _logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
964964

965965
elif solver in ['sag', 'saga']:
966966
if multi_class == 'multinomial':
967-
target = target.astype(np.float64)
967+
target = target.astype(X.dtype, copy=False)
968968
loss = 'multinomial'
969969
else:
970970
loss = 'log'
@@ -1486,6 +1486,10 @@ def fit(self, X, y, sample_weight=None):
14861486
Returns
14871487
-------
14881488
self : object
1489+
1490+
Notes
1491+
-----
1492+
The SAGA solver supports both float64 and float32 bit arrays.
14891493
"""
14901494
solver = _check_solver(self.solver, self.penalty, self.dual)
14911495

@@ -1520,10 +1524,10 @@ def fit(self, X, y, sample_weight=None):
15201524
raise ValueError("Tolerance for stopping criteria must be "
15211525
"positive; got (tol=%r)" % self.tol)
15221526

1523-
if solver in ['newton-cg']:
1524-
_dtype = [np.float64, np.float32]
1525-
else:
1527+
if solver in ['lbfgs', 'liblinear']:
15261528
_dtype = np.float64
1529+
else:
1530+
_dtype = [np.float64, np.float32]
15271531

15281532
X, y = check_X_y(X, y, accept_sparse='csr', dtype=_dtype, order="C",
15291533
accept_large_sparse=solver != 'liblinear')

0 commit comments

Comments
 (0)
0