diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index 6f504a721ec75..e3e3ec9f88816 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -230,6 +230,10 @@ Support for Python 3.4 and below has been officially dropped. in version 0.21 and will be removed in version 0.23. :issue:`12821` by :user:`Nicolas Hug `. +- |Enhancement| `sparse_cg` solver in :class:`linear_model.ridge.Ridge` + now supports fitting the intercept (i.e. ``fit_intercept=True``) when + inputs are sparse . :issue:`13336` by :user:`Bartosz Telenczuk ` + :mod:`sklearn.manifold` ............................ diff --git a/sklearn/linear_model/ridge.py b/sklearn/linear_model/ridge.py index eed636622dcdc..e240db3f1cb06 100644 --- a/sklearn/linear_model/ridge.py +++ b/sklearn/linear_model/ridge.py @@ -33,9 +33,31 @@ from ..exceptions import ConvergenceWarning -def _solve_sparse_cg(X, y, alpha, max_iter=None, tol=1e-3, verbose=0): +def _solve_sparse_cg(X, y, alpha, max_iter=None, tol=1e-3, verbose=0, + X_offset=None, X_scale=None): + + def _get_rescaled_operator(X): + + X_offset_scale = X_offset / X_scale + + def matvec(b): + return X.dot(b) - b.dot(X_offset_scale) + + def rmatvec(b): + return X.T.dot(b) - X_offset_scale * np.sum(b) + + X1 = sparse.linalg.LinearOperator(shape=X.shape, + matvec=matvec, + rmatvec=rmatvec) + return X1 + n_samples, n_features = X.shape - X1 = sp_linalg.aslinearoperator(X) + + if X_offset is None or X_scale is None: + X1 = sp_linalg.aslinearoperator(X) + else: + X1 = _get_rescaled_operator(X) + coefs = np.empty((y.shape[1], n_features), dtype=X.dtype) if n_features > n_samples: @@ -326,6 +348,25 @@ def ridge_regression(X, y, alpha, sample_weight=None, solver='auto', ----- This function won't compute the intercept. """ + + return _ridge_regression(X, y, alpha, + sample_weight=sample_weight, + solver=solver, + max_iter=max_iter, + tol=tol, + verbose=verbose, + random_state=random_state, + return_n_iter=return_n_iter, + return_intercept=return_intercept, + X_scale=None, + X_offset=None) + + +def _ridge_regression(X, y, alpha, sample_weight=None, solver='auto', + max_iter=None, tol=1e-3, verbose=0, random_state=None, + return_n_iter=False, return_intercept=False, + X_scale=None, X_offset=None): + if return_intercept and sparse.issparse(X) and solver != 'sag': if solver != 'auto': warnings.warn("In Ridge, only 'sag' solver can currently fit the " @@ -395,7 +436,12 @@ def ridge_regression(X, y, alpha, sample_weight=None, solver='auto', n_iter = None if solver == 'sparse_cg': - coef = _solve_sparse_cg(X, y, alpha, max_iter, tol, verbose) + coef = _solve_sparse_cg(X, y, alpha, + max_iter=max_iter, + tol=tol, + verbose=verbose, + X_offset=X_offset, + X_scale=X_scale) elif solver == 'lsqr': coef, n_iter = _solve_lsqr(X, y, alpha, max_iter, tol) @@ -492,24 +538,35 @@ def fit(self, X, y, sample_weight=None): np.atleast_1d(sample_weight).ndim > 1): raise ValueError("Sample weights must be 1D array or scalar") + # when X is sparse we only remove offset from y X, y, X_offset, y_offset, X_scale = self._preprocess_data( X, y, self.fit_intercept, self.normalize, self.copy_X, - sample_weight=sample_weight) + sample_weight=sample_weight, return_mean=True) # temporary fix for fitting the intercept with sparse data using 'sag' - if sparse.issparse(X) and self.fit_intercept: - self.coef_, self.n_iter_, self.intercept_ = ridge_regression( + if (sparse.issparse(X) and self.fit_intercept and + self.solver != 'sparse_cg'): + self.coef_, self.n_iter_, self.intercept_ = _ridge_regression( X, y, alpha=self.alpha, sample_weight=sample_weight, max_iter=self.max_iter, tol=self.tol, solver=self.solver, random_state=self.random_state, return_n_iter=True, return_intercept=True) + # add the offset which was subtracted by _preprocess_data self.intercept_ += y_offset else: - self.coef_, self.n_iter_ = ridge_regression( + if sparse.issparse(X): + # required to fit intercept with sparse_cg solver + params = {'X_offset': X_offset, 'X_scale': X_scale} + else: + # for dense matrices or when intercept is set to 0 + params = {} + + self.coef_, self.n_iter_ = _ridge_regression( X, y, alpha=self.alpha, sample_weight=sample_weight, max_iter=self.max_iter, tol=self.tol, solver=self.solver, random_state=self.random_state, return_n_iter=True, - return_intercept=False) + return_intercept=False, **params) + self._set_intercept(X_offset, y_offset, X_scale) return self diff --git a/sklearn/linear_model/tests/test_ridge.py b/sklearn/linear_model/tests/test_ridge.py index eca4a53f4f507..a5ee524e8c557 100644 --- a/sklearn/linear_model/tests/test_ridge.py +++ b/sklearn/linear_model/tests/test_ridge.py @@ -815,21 +815,25 @@ def test_n_iter(): def test_ridge_fit_intercept_sparse(): X, y = make_regression(n_samples=1000, n_features=2, n_informative=2, bias=10., random_state=42) + X_csr = sp.csr_matrix(X) - for solver in ['saga', 'sag']: + for solver in ['sag', 'sparse_cg']: dense = Ridge(alpha=1., tol=1.e-15, solver=solver, fit_intercept=True) sparse = Ridge(alpha=1., tol=1.e-15, solver=solver, fit_intercept=True) dense.fit(X, y) - sparse.fit(X_csr, y) + with pytest.warns(None) as record: + sparse.fit(X_csr, y) + assert len(record) == 0 assert_almost_equal(dense.intercept_, sparse.intercept_) assert_array_almost_equal(dense.coef_, sparse.coef_) # test the solver switch and the corresponding warning - sparse = Ridge(alpha=1., tol=1.e-15, solver='lsqr', fit_intercept=True) - assert_warns(UserWarning, sparse.fit, X_csr, y) - assert_almost_equal(dense.intercept_, sparse.intercept_) - assert_array_almost_equal(dense.coef_, sparse.coef_) + for solver in ['saga', 'lsqr']: + sparse = Ridge(alpha=1., tol=1.e-15, solver=solver, fit_intercept=True) + assert_warns(UserWarning, sparse.fit, X_csr, y) + assert_almost_equal(dense.intercept_, sparse.intercept_) + assert_array_almost_equal(dense.coef_, sparse.coef_) def test_errors_and_values_helper():