8000 FIX LinearRegression sparse + intercept + sample_weight (#22891) · glemaitre/scikit-learn@a4704b4 · GitHub
[go: up one dir, main page]

Skip to content

Commit a4704b4

Browse files
jeremiedbbglemaitre
authored andcommitted
FIX LinearRegression sparse + intercept + sample_weight (scikit-learn#22891)
1 parent 905aa1a commit a4704b4

File tree

3 files changed

+45
-41
lines changed

3 files changed

+45
-41
lines changed

doc/whats_new/v1.1.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,10 @@ Changelog
603603
:class:`linear_model.ARDRegression` now preserve float32 dtype. :pr:`9087` by
604604
:user:`Arthur Imbert <Henley13>` and :pr:`22525` by :user:`Meekail Zain <micky774>`.
605605

606+
- |Fix| The `intercept_` attribute of :class:`LinearRegression` is now correctly
607+
computed in the presence of sample weights when the input is sparse.
608+
:pr:`22891` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
609+
606610
:mod:`sklearn.manifold`
607611
.......................
608612

sklearn/linear_model/_base.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -325,11 +325,14 @@ def _preprocess_data(
325325
# sample_weight makes the refactoring tricky.
326326

327327

328-
def _rescale_data(X, y, sample_weight):
328+
def _rescale_data(X, y, sample_weight, sqrt_sample_weight=True):
329329
"""Rescale data sample-wise by square root of sample_weight.
330330
331331
For many linear models, this enables easy support for sample_weight.
332332
333+
Set sqrt_sample_weight=False if the square root of the sample weights has already
334+
been done prior to calling this function.
335+
333336
Returns
334337
-------
335338
X_rescaled : {array-like, sparse matrix}
@@ -340,7 +343,8 @@ def _rescale_data(X, y, sample_weight):
340343
sample_weight = np.asarray(sample_weight)
341344
if sample_weight.ndim == 0:
342345
sample_weight = np.full(n_samples, sample_weight, dtype=sample_weight.dtype)
343-
sample_weight = np.sqrt(sample_weight)
346+
if sqrt_sample_weight:
347+
sample_weight = np.sqrt(sample_weight)
344348
sw_matrix = sparse.dia_matrix((sample_weight, 0), shape=(n_samples, n_samples))
345349
X = safe_sparse_dot(sw_matrix, X)
346350
y = safe_sparse_dot(sw_matrix, y)
@@ -676,10 +680,9 @@ def fit(self, X, y, sample_weight=None):
676680
X, y, accept_sparse=accept_sparse, y_numeric=True, multi_output=True
677681
)
678682

679-
if sample_weight is not None:
680-
sample_weight = _check_sample_weight(
681-
sample_weight, X, dtype=X.dtype, only_non_negative=True
682-
)
683+
sample_weight = _check_sample_weight(
684+
sample_weight, X, dtype=X.dtype, only_non_negative=True
685+
)
683686

684687
X, y, X_offset, y_offset, X_scale = _preprocess_data(
685688
X,
@@ -691,9 +694,9 @@ def fit(self, X, y, sample_weight=None):
691694
return_mean=True,
692695
)
693696

694-
if sample_weight is not None:
695-
# Sample weight can be implemented via a simple rescaling.
696-
X, y = _rescale_data(X, y, sample_weight)
697+
# Sample weight can be implemented via a simple rescaling.
698+
sample_weight_sqrt = np.sqrt(sample_weight)
699+
X, y = _rescale_data(X, y, sample_weight_sqrt, sqrt_sample_weight=False)
697700

698701
if self.positive:
699702
if y.ndim < 2:
@@ -708,10 +711,10 @@ def fit(self, X, y, sample_weight=None):
708711
X_offset_scale = X_offset / X_scale
709712

710713
def matvec(b):
711-
return X.dot(b) - b.dot(X_offset_scale)
714+
return X.dot(b) - sample_weight_sqrt * b.dot(X_offset_scale)
712715

713716
def rmatvec(b):
714-
return X.T.dot(b) - X_offset_scale * np.sum(b)
717+
return X.T.dot(b) - X_offset_scale * b.dot(sample_weight_sqrt)
715718

716719
X_centered = sparse.linalg.LinearOperator(
717720
shape=X.shape, matvec=matvec, rmatvec=rmatvec

sklearn/linear_model/tests/test_base.py

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from sklearn.utils._testing import assert_array_almost_equal
1515
from sklearn.utils._testing import assert_array_equal
16-
from sklearn.utils._testing import assert_almost_equal
1716
from sklearn.utils._testing import assert_allclose
1817
from sklearn.utils import check_random_state
1918

@@ -26,6 +25,7 @@
2625
from sklearn.datasets import make_regression
2726
from sklearn.datasets import load_iris
2827
from sklearn.preprocessing import StandardScaler
28+
from sklearn.preprocessing import add_dummy_feature
2929

3030
rng = np.random.RandomState(0)
3131
rtol = 1e-6
@@ -55,45 +55,42 @@ def test_linear_regression():
5555
assert_array_almost_equal(reg.predict(X), [0])
5656

5757

58-
def test_linear_regression_sample_weights():
59-
# TODO: loop over sparse data as well
60-
58+
@pytest.mark.parametrize("array_constr", [np.array, sparse.csr_matrix])
59+
@pytest.mark.parametrize("fit_intercept", [True, False])
60+
def test_linear_regression_sample_weights(array_constr, fit_intercept):
6161
rng = np.random.RandomState(0)
6262

6363
# It would not work with under-determined systems
64-
for n_samples, n_features in ((6, 5),):
64+
n_samples, n_features = 6, 5
6565

66-
y = rng.randn(n_samples)
67-
X = rng.randn(n_samples, n_features)
68-
sample_weight = 1.0 + rng.rand(n_samples)
66+
X = array_constr(rng.normal(size=(n_samples, n_features)))
67+
y = rng.normal(size=n_samples)
6968

70-
for intercept in (True, False):
69+
sample_weight = 1.0 + rng.uniform(size=n_samples)
7170

72-
# LinearRegression with explicit sample_weight
73-
reg = LinearRegression(fit_intercept=intercept)
74-
reg.fit(X, y, sample_weight=sample_weight)
75-
coefs1 = reg.coef_
76-
inter1 = reg.intercept_
71+
# LinearRegression with explicit sample_weight
72+
reg = LinearRegression(fit_intercept=fit_intercept)
73+
reg.fit(X, y, sample_weight=sample_weight)
74+
coefs1 = reg.coef_
75+
inter1 = reg.intercept_
7776

78-
assert reg.coef_.shape == (X.shape[1],) # sanity checks
79-
assert reg.score(X, y) > 0.5
77+
assert reg.coef_.shape == (X.shape[1],) # sanity checks
78+
assert reg.score(X, y) > 0.5
8079

81-
# Closed form of the weighted least square
82-
# theta = (X^T W X)^(-1) * X^T W y
83-
W = np.diag(sample_weight)
84-
if intercept is False:
85-
X_aug = X
86-
else:
87-
dummy_column = np.ones(shape=(n_samples, 1))
88-
X_aug = np.concatenate((dummy_column, X), axis=1)
80+
# Closed form of the weighted least square
81+
# theta = (X^T W X)^(-1) @ X^T W y
82+
W = np.diag(sample_weight)
83+
X_aug = X if not fit_intercept else add_dummy_feature(X)
8984

90-
coefs2 = linalg.solve(X_aug.T.dot(W).dot(X_aug), X_aug.T.dot(W).dot(y))
85+
Xw = X_aug.T @ W @ X_aug
86+
yw = X_aug.T @ W @ y
87+
coefs2 = linalg.solve(Xw, yw)
9188

92-
if intercept is False:
93-
assert_array_almost_equal(coefs1, coefs2)
94-
else:
95-
assert_array_almost_equal(coefs1, coefs2[1:])
96-
assert_almost_equal(inter1, coefs2[0])
89+
if not fit_intercept:
90+
assert_allclose(coefs1, coefs2)
91+
else:
92+
assert_allclose(coefs1, coefs2[1:])
93+
assert_allclose(inter1, coefs2[0])
9794

9895

9996
def test_raises_value_error_if_positive_and_sparse():

0 commit comments

Comments
 (0)
0