8000 [MRG] Implement fitting intercept with `sparse_cg` solver in Ridge re… · scikit-learn/scikit-learn@c8e757d · GitHub
[go: up one dir, main page]

Skip to content

Commit c8e757d

Browse files
btelGaelVaroquaux
authored andcommitted
[MRG] Implement fitting intercept with sparse_cg solver in Ridge regression (#13336)
* add skeleton for fit_intercept with sparse_cg * fix sparse_cg solver with fit_intercept=True * fix test * linting * add what's new entry * remove X_scale and X_offset from public interface of ridge_regression * reformat if clause * fixed linting issues * add comments on about the conditions of different code branches * update warning * remove whitespace * add extra checks in the test of ridge with fit_intercept * remove unused argument
1 parent b73a51b commit c8e757d

File tree

3 files changed

+79
-14
lines changed

3 files changed

+79
-14
lines changed

doc/whats_new/v0.21.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,10 @@ Support for Python 3.4 and below has been officially dropped.
230230
in version 0.21 and will be removed in version 0.23.
231231
:issue:`12821` by :user:`Nicolas Hug <NicolasHug>`.
232232

233+
- |Enhancement| `sparse_cg` solver in :class:`linear_model.ridge.Ridge`
234+
now supports fitting the intercept (i.e. ``fit_intercept=True``) when
235+
inputs are sparse . :issue:`13336` by :user:`Bartosz Telenczuk <btel>`
236+
233237
:mod:`sklearn.manifold`
234238
............................
235239

sklearn/linear_model/ridge.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,31 @@
3333
from ..exceptions import ConvergenceWarning
3434

3535

36-
def _solve_sparse_cg(X, y, alpha, max_iter=None, tol=1e-3, verbose=0):
36+
def _solve_sparse_cg(X, y, alpha, max_iter=None, tol=1e-3, verbose=0,
37+
X_offset=None, X_scale=None):
38+
39+
def _get_rescaled_operator(X):
40+
41+
X_offset_scale = X_offset / X_scale
42+
43+
def matvec(b):
44+
return X.dot(b) - b.dot(X_offset_scale)
45+
46+
def rmatvec(b):
47+
return X.T.dot(b) - X_offset_scale * np.sum(b)
48+
49+
X1 = sparse.linalg.LinearOperator(shape=X.shape,
50+
matvec=matvec,
51+
rmatvec=rmatvec)
52+
return X1
53+
3754
n_samples, n_features = X.shape
38-
X1 = sp_linalg.aslinearoperator(X)
55+
56+
if X_offset is None or X_scale is None:
57+
X1 = sp_linalg.aslinearoperator(X)
58+
else:
59+
X1 = _get_rescaled_operator(X)
60+
3961
coefs = np.empty((y.shape[1], n_features), dtype=X.dtype)
4062

4163
if n_features > n_samples:
@@ -326,6 +348,25 @@ def ridge_regression(X, y, alpha, sample_weight=None, solver='auto',
326348
-----
327349
This function won't compute the intercept.
328350
"""
351+
352+
return _ridge_regression(X, y, alpha,
353+
sample_weight=sample_weight,
354+
solver=solver,
355+
max_iter=max_iter,
356+
tol=tol,
357+
verbose=verbose,
358+
random_state=random_state,
359+
return_n_iter=return_n_iter,
360+
return_intercept=return_intercept,
361+
X_scale=None,
362+
X_offset=None)
363+
364+
365+
def _ridge_regression(X, y, alpha, sample_weight=None, solver='auto',
366+
max_iter=None, tol=1e-3, verbose=0, random_state=None,
367+
return_n_iter=False, return_intercept=False,
368+
X_scale=None, X_offset=None):
369+
329370
if return_intercept and sparse.issparse(X) and solver != 'sag':
330371
if solver != 'auto':
331372
warnings.warn("In Ridge, only 'sag' solver can currently fit the "
@@ -395,7 +436,12 @@ def ridge_regression(X, y, alpha, sample_weight=None, solver='auto',
395436

396437
n_iter = None
397438
if solver == 'sparse_cg':
398-
coef = _solve_sparse_cg(X, y, alpha, max_iter, tol, verbose)
439+
coef = _solve_sparse_cg(X, y, alpha,
440+
max_iter=max_iter,
441+
tol=tol,
442+
verbose=verbose,
443+
X_offset=X_offset,
444+
X_scale=X_scale)
399445

400446
elif solver == 'lsqr':
401447
coef, n_iter = _solve_lsqr(X, y, alpha, max_iter, tol)
@@ -492,24 +538,35 @@ def fit(self, X, y, sample_weight=None):
492538
np.atleast_1d(sample_weight).ndim > 1):
493539
raise ValueError("Sample weights must be 1D array or scalar")
494540

541+
# when X is sparse we only remove offset from y
495542
X, y, X_offset, y_offset, X_scale = self._preprocess_data(
496543
X, y, self.fit_intercept, self.normalize, self.copy_X,
497-
sample_weight=sample_weight)
544+
sample_weight=sample_weight, return_mean=True)
498545

499546
# temporary fix for fitting the intercept with sparse data using 'sag'
500-
if sparse.issparse(X) and self.fit_intercept:
501-
self.coef_, self.n_iter_, self.intercept_ = ridge_regression(
547+
if (sparse.issparse(X) and self.fit_intercept and
548+
self.solver != 'sparse_cg'):
549+
self.coef_, self.n_iter_, self.intercept_ = _ridge_regression(
502550
X, y, alpha=self.alpha, sample_weight=sample_weight,
503551
max_iter=self.max_iter, tol=self.tol, solver=self.solver,
504552
random_state=self.random_state, return_n_iter=True,
505553
return_intercept=True)
554+
# add the offset which was subtracted by _preprocess_data
506555
self.intercept_ += y_offset
507556
else:
508-
self.coef_, self.n_iter_ = ridge_regression(
557+
if sparse.issparse(X):
558+
# required to fit intercept with sparse_cg solver
559+
params = {'X_offset': X_offset, 'X_scale': X_scale}
560+
else:
561+
# for dense matrices or when intercept is set to 0
562+
params = {}
563+
564+
self.coef_, self.n_iter_ = _ridge_regression(
509565
X, y, alpha=self.alpha, sample_weight=sample_weight,
510566
max_iter=self.max_iter, tol=self.tol, solver=self.solver,
511567
random_state=self.random_state, return_n_iter=True,
512-
return_intercept=False)
568+
return_intercept=False, **params)
569+
513570
self._set_intercept(X_offset, y_offset, X_scale)
514571

515572
return self

sklearn/linear_model/tests/test_ridge.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -815,21 +815,25 @@ def test_n_iter():
815815
def test_ridge_fit_intercept_sparse():
816816
X, y = make_regression(n_samples=1000, n_features=2, n_informative=2,
817817
bias=10., random_state=42)
818+
818819
X_csr = sp.csr_matrix(X)
819820

820-
for solver in ['saga', 'sag']:
821+
for solver in ['sag', 'sparse_cg']:
821822
dense = Ridge(alpha=1., tol=1.e-15, solver=solver, fit_intercept=True)
822823
sparse = Ridge(alpha=1., tol=1.e-15, solver=solver, fit_intercept=True)
823824
dense.fit(X, y)
824-
sparse.fit(X_csr, y)
825+
with pytest.warns(None) as record:
826+
sparse.fit(X_csr, y)
827+
assert len(record) == 0
825828
assert_almost_equal(dense.intercept_, sparse.intercept_)
826829
assert_array_almost_equal(dense.coef_, sparse.coef_)
827830

828831
# test the solver switch and the corresponding warning
829-
sparse = Ridge(alpha=1., tol=1.e-15, solver='lsqr', fit_intercept=True)
830-
assert_warns(UserWarning, sparse.fit, X_csr, y)
831-
assert_almost_equal(dense.intercept_, sparse.intercept_)
832-
assert_array_almost_equal(dense.coef_, sparse.coef_)
832+
for solver in ['saga', 'lsqr']:
833+
sparse = Ridge(alpha=1., tol=1.e-15, solver=solver, fit_intercept=True)
834+
assert_warns(UserWarning, sparse.fit, X_csr, y)
835+
assert_almost_equal(dense.intercept_, sparse.intercept_)
836+
assert_array_almost_equal(dense.coef_, sparse.coef_)
833837

834838

835839
def test_errors_and_values_helper():

0 commit comments

Comments
 (0)
0