8000 Fix Ridge sparse + sample_weight + intercept (#22899) · scikit-learn/scikit-learn@d76f87c · GitHub
[go: up one dir, main page]

Skip to content

Commit d76f87c

Browse files
jeremiedbbogrisel
andauthored
Fix Ridge sparse + sample_weight + intercept (#22899)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent c6395d2 commit d76f87c

File tree

6 files changed

+90
-53
lines changed

6 files changed

+90
-53
lines changed

doc/whats_new/v1.1.rst

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -618,9 +618,14 @@ Changelog
618618
:class:`linear_model.ARDRegression` now preserve float32 dtype. :pr:`9087` by
619619
:user:`Arthur Imbert <Henley13>` and :pr:`22525` by :user:`Meekail Zain <micky774>`.
620620

621-
- |Fix| The `intercept_` attribute of :class:`LinearRegression` is now correctly
622-
computed in the presence of sample weights when the input is sparse.
623-
:pr:`22891` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
621+
- |Fix| The `coef_` and `intercept_` attributes of :class:`LinearRegression` are now
622+
correctly computed in the presence of sample weights when the input is sparse.
623+
:pr:`22891` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
624+
625+
- |Fix| The `coef_` and `intercept_` attributes of :class:`Ridge` with
626+
`solver="sparse_cg"` and `solver="lbfgs"` are now correctly computed in the presence
627+
of sample weights when the input is sparse.
628+
:pr:`22899` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
624629

625630
:mod:`sklearn.manifold`
626631
.......................

sklearn/linear_model/_base.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -325,14 +325,11 @@ def _preprocess_data(
325325
# sample_weight makes the refactoring tricky.
326326

327327

328-
def _rescale_data(X, y, sample_weight, sqrt_sample_weight=True):
328+
def _rescale_data(X, y, sample_weight):
329329
"""Rescale data sample-wise by square root of sample_weight.
330330
331331
For many linear models, this enables easy support for sample_weight.
332332
333-
Set sqrt_sample_weight=False if the square root of the sample weights has already
334-
been done prior to calling this function.
335-
336333
Returns
337334
-------
338335
X_rescaled : {array-like, sparse matrix}
@@ -343,12 +340,11 @@ def _rescale_data(X, y, sample_weight, sqrt_sample_weight=True):
343340
sample_weight = np.asarray(sample_weight)
344341
if sample_weight.ndim == 0:
345342
sample_weight = np.full(n_samples, sample_weight, dtype=sample_weight.dtype)
346-
if sqrt_sample_weight:
347-
sample_weight = np.sqrt(sample_weight)
348-
sw_matrix = sparse.dia_matrix((sample_weight, 0), shape=(n_samples, n_samples))
343+
sample_weight_sqrt = np.sqrt(sample_weight)
344+
sw_matrix = sparse.dia_matrix((sample_weight_sqrt, 0), shape=(n_samples, n_samples))
349345
X = safe_sparse_dot(sw_matrix, X)
350346
y = safe_sparse_dot(sw_matrix, y)
351-
return X, y
347+
return X, y, sample_weight_sqrt
352348

353349

354350
class LinearModel(BaseEstimator, metaclass=ABCMeta):
@@ -695,8 +691,7 @@ def fit(self, X, y, sample_weight=None):
695691
)
696692

697693
# Sample weight can be implemented via a simple rescaling.
698-
sample_weight_sqrt = np.sqrt(sample_weight)
699-
X, y = _rescale_data(X, y, sample_weight_sqrt, sqrt_sample_weight=False)
694+
X, y, sample_weight_sqrt = _rescale_data(X, y, sample_weight)
700695

701696
if self.positive:
702697
if y.ndim < 2:
@@ -844,7 +839,7 @@ def _pre_fit(
844839
sample_weight=sample_weight,
845840
)
846841
if sample_weight is not None:
847-
X, y = _rescale_data(X, y, sample_weight=sample_weight)
842+
X, y, _ = _rescale_data(X, y, sample_weight=sample_weight)
848843

849844
# FIXME: 'normalize' to be removed in 1.2
850845
if hasattr(precompute, "__array__"):

sklearn/linear_model/_bayes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def fit(self, X, y, sample_weight=None):
253253

254254
if sample_weight is not None:
255255
# Sample weight can be implemented via a simple rescaling.
256-
X, y = _rescale_data(X, y, sample_weight)
256+
X, y, _ = _rescale_data(X, y, sample_weight)
257257

258258
self.X_offset_ = X_offset_
259259
self.X_scale_ = X_scale_

sklearn/linear_model/_ridge.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,28 @@
4141

4242

4343
def _solve_sparse_cg(
44-
X, y, alpha, max_iter=None, tol=1e-3, verbose=0, X_offset=None, X_scale=None
44+
X,
45+
y,
46+
alpha,
47+
max_iter=None,
48+
tol=1e-3,
49+
verbose=0,
50+
X_offset=None,
51+
X_scale=None,
52+
sample_weight_sqrt=None,
4553
):
54+
if sample_weight_sqrt is None:
55+
sample_weight_sqrt = np.ones(X.shape[0], dtype=X.dtype)
56+
4657
def _get_rescaled_operator(X):
4758

4859
X_offset_scale = X_offset / X_scale
4960

5061
def matvec(b):
51-
return X.dot(b) - b.dot(X_offset_scale)
62+
return X.dot(b) - sample_weight_sqrt * b.dot(X_offset_scale)
5263

5364
def rmatvec(b):
54-
return X.T.dot(b) - X_offset_scale * np.sum(b)
65+
return X.T.dot(b) - X_offset_scale * b.dot(sample_weight_sqrt)
5566

5667
X1 = sparse.linalg.LinearOperator(shape=X.shape, matvec=matvec, rmatvec=rmatvec)
5768
return X1
@@ -241,7 +252,15 @@ def _solve_svd(X, y, alpha):
241252

242253

243254
def _solve_lbfgs(
244-
X, y, alpha, positive=True, max_iter=None, tol=1e-3, X_offset=None, X_scale=None
255+
X,
256+
y,
257+
alpha,
258+
positive=True,
259+
max_iter=None,
260+
tol=1e-3,
261+
X_offset=None,
262+
X_scale=None,
263+
sample_weight_sqrt=None,
245264
):
246265
"""Solve ridge regression with LBFGS.
247266
@@ -269,6 +288,9 @@ def _solve_lbfgs(
269288
else:
270289
X_offset_scale = None
271290

291+
if sample_weight_sqrt is None:
292+
sample_weight_sqrt = np.ones(X.shape[0], dtype=X.dtype)
293+
272294
coefs = np.empty((y.shape[1], n_features), dtype=X.dtype)
273295

274296
for i in range(y.shape[1]):
@@ -278,11 +300,11 @@ def _solve_lbfgs(
278300
def func(w):
279301
residual = X.dot(w) - y_column
280302
if X_offset_scale is not None:
281-
residual -= w.dot(X_offset_scale)
303+
residual -= sample_weight_sqrt * w.dot(X_offset_scale)
282304
f = 0.5 * residual.dot(residual) + 0.5 * alpha[i] * w.dot(w)
283305
grad = X.T @ residual + alpha[i] * w
284306
if X_offset_scale is not None:
285-
grad -= X_offset_scale * np.sum(residual)
307+
grad -= X_offset_scale * residual.dot(sample_weight_sqrt)
286308

287309
return f, grad
288310

@@ -568,7 +590,7 @@ def _ridge_regression(
568590
if solver not in ["sag", "saga"]:
569591
# SAG supports sample_weight directly. For other solvers,
570592
# we implement sample_weight via a simple rescaling.
571-
X, y = _rescale_data(X, y, sample_weight)
593+
X, y, sample_weight_sqrt = _rescale_data(X, y, sample_weight)
572594

573595
# Some callers of this method might pass alpha as single
574596
# element array which already has been validated.
@@ -603,6 +625,7 @@ def _ridge_regression(
603625
verbose=verbose,
604626
X_offset=X_offset,
605627
X_scale=X_scale,
628+
sample_weight_sqrt=sample_weight_sqrt if has_sw else None,
606629
)
607630

608631
elif solver == "lsqr":
@@ -673,6 +696,7 @@ def _ridge_regression(
673696
max_iter=max_iter,
674697
X_offset=X_offset,
675698
X_scale=X_scale,
699+
sample_weight_sqrt=sample_weight_sqrt if has_sw else None,
676700
)
677701

678702
if solver == "svd":
@@ -804,7 +828,7 @@ def fit(self, X, y, sample_weight=None):
804828

805829
else:
806830
if sparse.issparse(X) and self.fit_intercept:
807-
# required to fit intercept with sparse_cg solver
831+
# required to fit intercept with sparse_cg and lbfgs solver
808832
params = {"X_offset": X_offset, "X_scale": X_scale}
809833
else:
810834
# for dense matrices or when intercept is set to 0
@@ -1910,8 +1934,7 @@ def fit(self, X, y, sample_weight=None):
19101934
n_samples = X.shape[0]
19111935

19121936
if sample_weight is not None:
1913-
X, y = _rescale_data(X, y, sample_weight)
1914-
sqrt_sw = np.sqrt(sample_weight)
1937+
X, y, sqrt_sw = _rescale_data(X, y, sample_weight)
19151938
else:
19161939
sqrt_sw = np.ones(n_samples, dtype=X.dtype)
19171940

sklearn/linear_model/tests/test_base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -692,12 +692,12 @@ def test_rescale_data_dense(n_targets):
692692
y = rng.rand(n_samples)
693693
else:
694694
y = rng.rand(n_samples, n_targets)
695-
rescaled_X, rescaled_y = _rescale_data(X, y, sample_weight)
696-
rescaled_X2 = X * np.sqrt(sample_weight)[:, np.newaxis]
695+
rescaled_X, rescaled_y, sqrt_sw = _rescale_data(X, y, sample_weight)
696+
rescaled_X2 = X * sqrt_sw[:, np.newaxis]
697697
if n_targets is None:
698-
rescaled_y2 = y * np.sqrt(sample_weight)
698+
rescaled_y2 = y * sqrt_sw
699699
else:
700-
rescaled_y2 = y * np.sqrt(sample_weight)[:, np.newaxis]
700+
rescaled_y2 = y * sqrt_sw[:, np.newaxis]
701701
assert_array_almost_equal(rescaled_X, rescaled_X2)
702702
assert_array_almost_equal(rescaled_y, rescaled_y2)
703703

sklearn/linear_model/tests/test_ridge.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,33 +1363,41 @@ def test_n_iter():
13631363

13641364

13651365
@pytest.mark.parametrize("solver", ["sparse_cg", "lbfgs", "auto"])
1366-
def test_ridge_fit_intercept_sparse(solver):
1366+
@pytest.mark.parametrize("with_sample_weight", [True, False])
1367+
def test_ridge_fit_intercept_sparse(solver, with_sample_weight, global_random_seed):
1368+
"""Check that ridge finds the same coefs and intercept on dense and sparse input
1369+
in the presence of sample weights.
1370+
1371+
For now only sparse_cg and lbfgs can correctly fit an intercept
1372+
with sparse X with default tol and max_iter.
1373+
'sag' is tested separately in test_ridge_fit_intercept_sparse_sag because it
1374+
requires more iterations and should raise a warning if default max_iter is used.
1375+
Other solvers raise an exception, as checked in
1376+
test_ridge_fit_intercept_sparse_error
1377+
"""
13671378
positive = solver == "lbfgs"
13681379
X, y = _make_sparse_offset_regression(
1369-
n_features=20, random_state=0, positive=positive
1380+
n_features=20, random_state=global_random_seed, positive=positive
13701381
)
1371-
X_csr = sp.csr_matrix(X)
13721382

1373-
# for now only sparse_cg and lbfgs can correctly fit an intercept
1374-
# with sparse X with default tol and max_iter.
1375-
# sag is tested separately in test_ridge_fit_intercept_sparse_sag
1376-
# because it requires more iterations and should raise a warning if default
1377-
# max_iter is used.
1378-
# other solvers raise an exception, as checked in
1379-
# test_ridge_fit_intercept_sparse_error
1380-
#
1383+
sample_weight = None
1384+
if with_sample_weight:
1385+
rng = np.random.RandomState(global_random_seed)
1386+
sample_weight = 1.0 + rng.uniform(size=X.shape[0])
1387+
13811388
# "auto" should switch to "sparse_cg" when X is sparse
13821389
# so the reference we use for both ("auto" and "sparse_cg") is
13831390
# Ridge(solver="sparse_cg"), fitted using the dense representation (note
13841391
# that "sparse_cg" can fit sparse or dense data)
1385-
dense_ridge = Ridge(solver="sparse_cg", tol=1e-12)
1392+
dense_solver = "sparse_cg" if solver == "auto" else solver
1393+
dense_ridge = Ridge(solver=dense_solver, tol=1e-12, positive=positive)
13861394
sparse_ridge = Ridge(solver=solver, tol=1e-12, positive=positive)
1387-
dense_ridge.fit(X, y)
1388-
with warnings.catch_warnings():
1389-
warnings.simplefilter("error", UserWarning)
1390-
sparse_ridge.fit(X_csr, y)
1391-
assert np.allclose(dense_ridge.intercept_, sparse_ridge.intercept_)
1392-
assert np.allclose(dense_ridge.coef_, sparse_ridge.coef_)
1395+
1396+
dense_ridge.fit(X, y, sample_weight=sample_weight)
1397+
sparse_ridge.fit(sp.csr_matrix(X), y, sample_weight=sample_weight)
1398+
1399+
assert_allclose(dense_ridge.intercept_, sparse_ridge.intercept_)
1400+
assert_allclose(dense_ridge.coef_, sparse_ridge.coef_)
13931401

13941402

13951403
@pytest.mark.parametrize("solver", ["saga", "lsqr", "svd", "cholesky"])
@@ -1402,23 +1410,29 @@ def test_ridge_fit_intercept_sparse_error(solver):
14021410
sparse_ridge.fit(X_csr, y)
14031411

14041412

1405-
def test_ridge_fit_intercept_sparse_sag():
1413+
@pytest.mark.parametrize("with_sample_weight", [True, False])
1414+
def test_ridge_fit_intercept_sparse_sag(with_sample_weight, global_random_seed):
14061415
X, y = _make_sparse_offset_regression(
1407-
n_features=5, n_samples=20, random_state=0, X_offset=5.0
1416+
n_features=5, n_samples=20, random_state=global_random_seed, X_offset=5.0
14081417
)
1418+
if with_sample_weight:
1419+
rng = np.random.RandomState(global_random_seed)
1420+
sample_weight = 1.0 + rng.uniform(size=X.shape[0])
1421+
else:
1422+
sample_weight = None
14091423
X_csr = sp.csr_matrix(X)
14101424

14111425
params = dict(
14121426
alpha=1.0, solver="sag", fit_intercept=True, tol=1e-10, max_iter=100000
14131427
)
14141428
dense_ridge = Ridge(**params)
14151429
sparse_ridge = Ridge(**params)
1416-
dense_ridge.fit(X, y)
1430+
dense_ridge.fit(X, y, sample_weight=sample_weight)
14171431
with warnings.catch_warnings():
14181432
warnings.simplefilter("error", UserWarning)
1419-
sparse_ridge.fit(X_csr, y)
1420- assert np.allclose(dense_ridge.intercept_, sparse_ridge.intercept_, rtol=1e-4)
1421-
assert np.allclose(dense_ridge.coef_, sparse_ridge.coef_, rtol=1e-4)
1433+
sparse_ridge.fit(X_csr, y, sample_weight=sample_weight)
1434+
assert_allclose(dense_ridge.intercept_, sparse_ridge.intercept_, rtol=1e-4)
1435+
assert_allclose(dense_ridge.coef_, sparse_ridge.coef_, rtol=1e-4)
14221436
with pytest.warns(UserWarning, match='"sag" solver requires.*'):
14231437
Ridge(solver="sag").fit(X_csr, y)
14241438

0 commit comments

Comments
 (0)
0