8000 FIX: make LinearRegression perfectly consistent across sparse or dens… · scikit-learn/scikit-learn@66899ed · GitHub
[go: up one dir, main page]

Skip to content

Commit 66899ed

Browse files
agramfortGaelVaroquaux
authored andcommitted
FIX: make LinearRegression perfectly consistent across sparse or dense (#13279)
* FIX : make LinearRegression perfectly consistent across sparse or dense * comments * review
1 parent c3415b8 commit 66899ed

File tree

3 files changed

+41
-3
lines changed

3 files changed

+41
-3
lines changed

doc/whats_new/v0.21.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,11 @@ Support for Python 3.4 and below has been officially dropped.
200200
parameter value ``copy_X=True`` in ``fit``.
201201
:issue:`12972` by :user:`Lucio Fernandez-Arjona <luk-f-a>`
202202

203+
- |Fix| Fixed a bug in :class:`linear_model.LinearRegression` that
204+
was not returning the same coeffecients and intercepts with
205+
``fit_intercept=True`` in sparse and dense case.
206+
:issue:`13279` by `Alexandre Gramfort`_
207+
203208
:mod:`sklearn.manifold`
204209
............................
205210

sklearn/linear_model/base.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -467,21 +467,34 @@ def fit(self, X, y, sample_weight=None):
467467

468468
X, y, X_offset, y_offset, X_scale = self._preprocess_data(
469469
X, y, fit_intercept=self.fit_intercept, normalize=self.normalize,
470-
copy=self.copy_X, sample_weight=sample_weight)
470+
copy=self.copy_X, sample_weight=sample_weight,
471+
return_mean=True)
471472

472473
if sample_weight is not None:
473474
# Sample weight can be implemented via a simple rescaling.
474475
X, y = _rescale_data(X, y, sample_weight)
475476

476477
if sp.issparse(X):
478+
X_offset_scale = X_offset / X_scale
479+
480+
def matvec(b):
481+
return X.dot(b) - b.dot(X_offset_scale)
482+
483+
def rmatvec(b):
484+
return X.T.dot(b) - X_offset_scale * np.sum(b)
485+
486+
X_centered = sparse.linalg.LinearOperator(shape=X.shape,
487+
matvec=matvec,
488+
rmatvec=rmatvec)
489+
477490
if y.ndim < 2:
478-
out = sparse_lsqr(X, y)
491+
out = sparse_lsqr(X_centered, y)
479492
self.coef_ = out[0]
480493
self._residues = out[3]
481494
else:
482495
# sparse_lstsq cannot handle y with shape (M, K)
483496
outs = Parallel(n_jobs=n_jobs_)(
484-
delayed(sparse_lsqr)(X, y[:, j].ravel())
497+
delayed(sparse_lsqr)(X_centered, y[:, j].ravel())
485498
for j in range(y.shape[1]))
486499
self.coef_ = np.vstack([out[0] for out in outs])
487500
self._residues = np.vstack([out[3] for out in outs])

sklearn/linear_model/tests/test_base.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,26 @@ def test_linear_regression_sparse(random_state=0):
154154
assert_array_almost_equal(ols.predict(X) - y.ravel(), 0)
155155

156156

157+
@pytest.mark.parametrize('normalize', [True, False])
158+
@pytest.mark.parametrize('fit_intercept', [True, False])
159+
def test_linear_regression_sparse_equal_dense(normalize, fit_intercept):
160+
# Test that linear regression agrees between sparse and dense
161+
rng = check_random_state(0)
162+
n_samples = 200
163+
n_features = 2
164+
X = rng.randn(n_samples, n_features)
165+
X[X < 0.1] = 0.
166+
Xcsr = sparse.csr_matrix(X)
167+
y = rng.rand(n_samples)
168+
params = dict(normalize=normalize, fit_intercept=fit_intercept)
169+
clf_dense = LinearRegression(**params)
170+
clf_sparse = LinearRegression(**params)
171+
clf_dense.fit(X, y)
172+
clf_sparse.fit(Xcsr, y)
173+
assert clf_dense.intercept_ == pytest.approx(clf_sparse.intercept_)
174+
assert_allclose(clf_dense.coef_, clf_sparse.coef_)
175+
176+
157177
def test_linear_regression_multiple_outcome(random_state=0):
158178
# Test multiple-outcome linear regressions
159179
X, y = make_regression(random_state=random_state)

0 commit comments

Comments
 (0)
0