diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index f64a6bda6ea95..7786870dc523f 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -594,6 +594,10 @@ Changelog :class:`linear_model.ARDRegression` now preserve float32 dtype. :pr:`9087` by :user:`Arthur Imbert ` and :pr:`22525` by :user:`Meekail Zain `. +- |Fix| The `intercept_` attribute of :class:`LinearRegression` is now correctly + computed in the presence of sample weights when the input is sparse. + :pr:`22891` by :user:`Jérémie du Boisberranger `. + :mod:`sklearn.manifold` ....................... diff --git a/sklearn/linear_model/_base.py b/sklearn/linear_model/_base.py index f0322c3924426..65cdedafe1821 100644 --- a/sklearn/linear_model/_base.py +++ b/sklearn/linear_model/_base.py @@ -325,11 +325,14 @@ def _preprocess_data( # sample_weight makes the refactoring tricky. -def _rescale_data(X, y, sample_weight): +def _rescale_data(X, y, sample_weight, sqrt_sample_weight=True): """Rescale data sample-wise by square root of sample_weight. For many linear models, this enables easy support for sample_weight. + Set sqrt_sample_weight=False if the square root of the sample weights has already + been done prior to calling this function. + Returns ------- X_rescaled : {array-like, sparse matrix} @@ -340,7 +343,8 @@ def _rescale_data(X, y, sample_weight): sample_weight = np.asarray(sample_weight) if sample_weight.ndim == 0: sample_weight = np.full(n_samples, sample_weight, dtype=sample_weight.dtype) - sample_weight = np.sqrt(sample_weight) + if sqrt_sample_weight: + sample_weight = np.sqrt(sample_weight) sw_matrix = sparse.dia_matrix((sample_weight, 0), shape=(n_samples, n_samples)) X = safe_sparse_dot(sw_matrix, X) y = safe_sparse_dot(sw_matrix, y) @@ -676,10 +680,9 @@ def fit(self, X, y, sample_weight=None): X, y, accept_sparse=accept_sparse, y_numeric=True, multi_output=True ) - if sample_weight is not None: - sample_weight = _check_sample_weight( - sample_weight, X, dtype=X.dtype, only_non_negative=True - ) + sample_weight = _check_sample_weight( + sample_weight, X, dtype=X.dtype, only_non_negative=True + ) X, y, X_offset, y_offset, X_scale = _preprocess_data( X, @@ -691,9 +694,9 @@ def fit(self, X, y, sample_weight=None): return_mean=True, ) - if sample_weight is not None: - # Sample weight can be implemented via a simple rescaling. - X, y = _rescale_data(X, y, sample_weight) + # Sample weight can be implemented via a simple rescaling. + sample_weight_sqrt = np.sqrt(sample_weight) + X, y = _rescale_data(X, y, sample_weight_sqrt, sqrt_sample_weight=False) if self.positive: if y.ndim < 2: @@ -708,10 +711,10 @@ def fit(self, X, y, sample_weight=None): X_offset_scale = X_offset / X_scale def matvec(b): - return X.dot(b) - b.dot(X_offset_scale) + return X.dot(b) - sample_weight_sqrt * b.dot(X_offset_scale) def rmatvec(b): - return X.T.dot(b) - X_offset_scale * np.sum(b) + return X.T.dot(b) - X_offset_scale * b.dot(sample_weight_sqrt) X_centered = sparse.linalg.LinearOperator( shape=X.shape, matvec=matvec, rmatvec=rmatvec diff --git a/sklearn/linear_model/tests/test_base.py b/sklearn/linear_model/tests/test_base.py index 6b89b7e25e899..f59a5f6f08c5c 100644 --- a/sklearn/linear_model/tests/test_base.py +++ b/sklearn/linear_model/tests/test_base.py @@ -13,7 +13,6 @@ from sklearn.utils._testing import assert_array_almost_equal from sklearn.utils._testing import assert_array_equal -from sklearn.utils._testing import assert_almost_equal from sklearn.utils._testing import assert_allclose from sklearn.utils import check_random_state @@ -26,6 +25,7 @@ from sklearn.datasets import make_regression from sklearn.datasets import load_iris from sklearn.preprocessing import StandardScaler +from sklearn.preprocessing import add_dummy_feature rng = np.random.RandomState(0) rtol = 1e-6 @@ -55,45 +55,42 @@ def test_linear_regression(): assert_array_almost_equal(reg.predict(X), [0]) -def test_linear_regression_sample_weights(): - # TODO: loop over sparse data as well - +@pytest.mark.parametrize("array_constr", [np.array, sparse.csr_matrix]) +@pytest.mark.parametrize("fit_intercept", [True, False]) +def test_linear_regression_sample_weights(array_constr, fit_intercept): rng = np.random.RandomState(0) # It would not work with under-determined systems - for n_samples, n_features in ((6, 5),): + n_samples, n_features = 6, 5 - y = rng.randn(n_samples) - X = rng.randn(n_samples, n_features) - sample_weight = 1.0 + rng.rand(n_samples) + X = array_constr(rng.normal(size=(n_samples, n_features))) + y = rng.normal(size=n_samples) - for intercept in (True, False): + sample_weight = 1.0 + rng.uniform(size=n_samples) - # LinearRegression with explicit sample_weight - reg = LinearRegression(fit_intercept=intercept) - reg.fit(X, y, sample_weight=sample_weight) - coefs1 = reg.coef_ - inter1 = reg.intercept_ + # LinearRegression with explicit sample_weight + reg = LinearRegression(fit_intercept=fit_intercept) + reg.fit(X, y, sample_weight=sample_weight) + coefs1 = reg.coef_ + inter1 = reg.intercept_ - assert reg.coef_.shape == (X.shape[1],) # sanity checks - assert reg.score(X, y) > 0.5 + assert reg.coef_.shape == (X.shape[1],) # sanity checks + assert reg.score(X, y) > 0.5 - # Closed form of the weighted least square - # theta = (X^T W X)^(-1) * X^T W y - W = np.diag(sample_weight) - if intercept is False: - X_aug = X - else: - dummy_column = np.ones(shape=(n_samples, 1)) - X_aug = np.concatenate((dummy_column, X), axis=1) + # Closed form of the weighted least square + # theta = (X^T W X)^(-1) @ X^T W y + W = np.diag(sample_weight) + X_aug = X if not fit_intercept else add_dummy_feature(X) - coefs2 = linalg.solve(X_aug.T.dot(W).dot(X_aug), X_aug.T.dot(W).dot(y)) + Xw = X_aug.T @ W @ X_aug + yw = X_aug.T @ W @ y + coefs2 = linalg.solve(Xw, yw) - if intercept is False: - assert_array_almost_equal(coefs1, coefs2) - else: - assert_array_almost_equal(coefs1, coefs2[1:]) - assert_almost_equal(inter1, coefs2[0]) + if not fit_intercept: + assert_allclose(coefs1, coefs2) + else: + assert_allclose(coefs1, coefs2[1:]) + assert_allclose(inter1, coefs2[0]) def test_raises_value_error_if_positive_and_sparse():