10000 Bug in sparse in Ridge with sample weights · Issue #15438 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Bug in sparse in Ridge with sample weights #15438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
lorentzenchr opened this issue Nov 2, 2019 · 11 comments · Fixed by #22899
Closed

Bug in sparse in Ridge with sample weights #15438

lorentzenchr opened this issue Nov 2, 2019 · 11 comments · Fixed by #22899

Comments

@lorentzenchr
Copy link
Member

Description

Ridge with sample weights gives wrong results for sparse input.

Steps/Code to Reproduce

import numpy as np
from numpy.testing import assert_array_almost_equal
from pytest import approx
import scipy.sparse as sparse
from sklearn.linear_model import Ridge

rng = np.random.RandomState(123)
n_samples, n_features = 10, 3
X = rng.rand(n_samples, n_features)
y = rng.rand(n_samples)

params = dict(alpha=0.05, fit_intercept=True)

# 1. Reference model: dense and sample_weight=None
reg = Ridge(**params, solver='svd').fit(X, y)
coef, intercept = reg.coef_.copy(), reg.intercept_

# 2. sample_weight = 2 * np.ones(..), but alpha / 2
sw = 2 * np.ones_like(y)
params['alpha'] = 2 * params['alpha']
reg = Ridge(**params, solver='svd').fit(X, y, sample_weight=sw)
assert_array_almost_equal(reg.coef_, coef)
assert reg.intercept_ == approx(intercept)

# 3. X sparse, sample_weight = 2 * np.ones(..), but alpha / 2=> ERROR
X = sparse.csr_matrix(X)
reg = Ridge(**params, solver='sparse_cg').fit(X, y, sample_weight=sw)
print('intercept true = {}, intercept sparse = {}'.format(intercept, reg.intercept_))
print('coef true = {}, coef sparse = {}'.format(coef, reg.coef_))
assert_array_almost_equal(reg.coef_, coef)
assert reg.intercept_ == approx(intercept)

Expected Results

No AssertionError is thrown.

Actual Results

Last two assertion statements throw AssertionError.

Versions

System:
python: 3.7.2
sklearn: 0.22.dev0
commit 9caf835
Author: Thomas J Fan thomasjpfan@gmail.com
Date: Sun Oct 27 03:32:45 2019 -0400

DOC Fix broken link to docs (#15364)
@lorentzenchr
Copy link
Member Author

In PR #14300 @rth implemented some good consistency tests for sample_weight. I copied them over to PR #15436. It might be a good idea to have them for (nearly) all linear models with sample_weights.

@lorentzenchr
Copy link
Member Author

Same for LinearRegression.

@thomasjpfan thomasjpfan changed the title Bug in sparse Ridge with sample weights Bug in sparse in Ridge and LinearRegression with sample weights Jun 20, 2021
@thomasjpfan thomasjpfan changed the title Bug in sparse in Ridge and LinearRegression with sample weights Bug in sparse in Ridge with sample weights Jun 20, 2021
@murphyhopfensperger
Copy link

Is anybody looking at this? It seems like linear regression not working for two years in the major python statistics package is not great.

@rth
Copy link
Member
rth commented Oct 21, 2021

@murphyhopfensperger Can you still reproduce this issue with the latest version?

@murphyhopfensperger
Copy link

Input:

import numpy as np
import pandas as pd
from sklearn.preprocessing import OneHotEncoder
from sklearn.linear_model import LinearRegression


def Xyw(sparse=True):
    rng = np.random.default_rng(seed=0)
    n, d = 2 * 1_000, 2
    df = pd.DataFrame({"x1": rng.choice(["a", "b", "c"], size=n),})
    enc = OneHotEncoder(drop=["c"])
    X = enc.fit_transform(df)
    beta = np.array([2, 3])
    y = X @ beta + 10
    w = np.zeros(shape=n)
    w[: n // 2] = 1
    if not sparse:
        X = X.toarray()
    return {"X": X, "y": y, "sample_weight": w}


ols = LinearRegression()
ols.fit(**Xyw())
print(ols.coef_, ols.intercept_)

ols = LinearRegression()
ols.fit(**Xyw(sparse=False))
print(ols.coef_, ols.intercept_)

Output:

[0.38 1.38] 11.0449
[2. 3.] 10.000000000000004

Maybe I'm misunderstanding something (very possible!), but I thought the result should be the same whether or not X happens to be sparse.

@rth
Copy link
Member
rth commented Oct 21, 2021

Yes, that certainly doesn't look good.

If you are willing to investigate, could you check if

X, y = _rescale_data(X, y, sample_weight)

produces the same output for the dense and sparse case on your example? If not the bug is there (but that's unlikely since we would have this issue for all linear models if it's the case). Otherwise it means it's an issue 8000 with the solver. It would be interesting if this happens for all Ridge solvers. For instance, we could try directly calling,

coef = _solve_sparse_cg(

or
coef, n_iter = _solve_lsqr(X, y, alpha, max_iter, tol)

on the rescaled dense/sparse data and checking the coefficients.

@rth rth added this to the 1.1 milestone Oct 21, 2021
@murphyhopfensperger
Copy link

I tried _rescale_data() to verify that it gives the same result for sparse and nonsparse.

X1, y1 = _rescale_data(**Xyw())
X1 = X1.toarray()
X2, y2 = _rescale_data(**Xyw(sparse=False))
np.array_equal(X1, X2), np.array_equal(y1, y2)

Gives the result (True, True).

I probably won't test the others right away.

@jeremiedbb
Copy link
Member

I can confirm the issue on at least LinearRegression and Ridge. It's probably wider since it comes from forgetting to scale the X_offset by the sample weights (on sparse, X is not actually centered, the centering is implicit).

@s-banach
Copy link
< B2C8 span aria-label="This user has previously committed to the scikit-learn repository." data-view-component="true" class="tooltipped tooltipped-n"> Contributor

Jérémie,
Sorry if this is off-topic, but perhaps you will consider it while you're working on sample weight code.
I believe ElasticNetCV chooses its array of alphas without taking into account sample weights.
Is this something you would consider an issue?

Thanks.

@jeremiedbb
Copy link
Member

@s-banach can you open a dedicated issue with a reproducible example and explaining what you'd expect ?

@s-banach
Copy link
Contributor

Here is the issue: #22914

Thanks for your time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants
0