-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG] Fix selection of solver in ridge_regression when solver=='auto' #13363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6877342
39a387b
62afd26
d3e085f
596ae7f
b0bfbe1
fdfbc50
73bbd91
70f7560
b494e76
3d44669
9848429
81573aa
a2473ff
10c116c
05ba627
90a7e56
97a5326
0619013
64d21df
1d7ca02
604f7d1
79d04d6
9e1aefb
d4cd280
4ed33b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
|
||
from sklearn.utils.testing import assert_almost_equal | ||
from sklearn.utils.testing import assert_array_almost_equal | ||
from sklearn.utils.testing import assert_allclose | ||
from sklearn.utils.testing import assert_equal | ||
from sklearn.utils.testing import assert_array_equal | ||
from sklearn.utils.testing import assert_greater | ||
|
@@ -778,7 +779,8 @@ def test_raises_value_error_if_solver_not_supported(): | |
wrong_solver = "This is not a solver (MagritteSolveCV QuantumBitcoin)" | ||
|
||
exception = ValueError | ||
message = "Solver %s not understood" % wrong_solver | ||
message = ("Known solvers are 'sparse_cg', 'cholesky', 'svd'" | ||
" 'lsqr', 'sag' or 'saga'. Got %s." % wrong_solver) | ||
|
||
def func(): | ||
X = np.eye(3) | ||
|
@@ -832,9 +834,57 @@ def test_ridge_fit_intercept_sparse(): | |
# test the solver switch and the corresponding warning | ||
for solver in ['saga', 'lsqr']: | ||
sparse = Ridge(alpha=1., tol=1.e-15, solver=solver, fit_intercept=True) | ||
assert_warns(UserWarning, sparse.fit, X_csr, y) | ||
assert_almost_equal(dense.intercept_, sparse.intercept_) | ||
assert_array_almost_equal(dense.coef_, sparse.coef_) | ||
assert_raises_regex(ValueError, "In Ridge,", sparse.fit, X_csr, y) | ||
|
||
|
||
@pytest.mark.parametrize('return_intercept', [False, True]) | ||
@pytest.mark.parametrize('sample_weight', [None, np.ones(1000)]) | ||
@pytest.mark.parametrize('arr_type', [np.array, sp.csr_matrix]) | ||
@pytest.mark.parametrize('solver', ['auto', 'sparse_cg', 'cholesky', 'lsqr', | ||
'sag', 'saga']) | ||
def test_ridge_regression_check_arguments_validity(return_intercept, | ||
sample_weight, arr_type, | ||
solver): | ||
"""check if all combinations of arguments give valid estimations""" | ||
|
||
# test excludes 'svd' solver because it raises exception for sparse inputs | ||
|
||
rng = check_random_state(42) | ||
X = rng.rand(1000, 3) | ||
true_coefs = [1, 2, 0.1] | ||
y = np.dot(X, true_coefs) | ||
true_intercept = 0. | ||
if return_intercept: | ||
true_intercept = 10000. | ||
y += true_intercept | ||
X_testing = arr_type(X) | ||
|
||
alpha, atol, tol = 1e-3, 1e-4, 1e-6 | ||
|
||
if solver not in ['sag', 'auto'] and return_intercept: | ||
assert_raises_regex(ValueError, | ||
"In Ridge, only 'sag' solver", | ||
ridge_regression, X_testing, y, | ||
alpha=alpha, | ||
solver=solver, | ||
sample_weight=sample_weight, | ||
return_intercept=return_intercept, | ||
tol=tol) | ||
return | ||
|
||
out = ridge_regression(X_testing, y, alpha=alpha, | ||
solver=solver, | ||
sample_weight=sample_weight, | ||
return_intercept=return_intercept, | ||
tol=tol, | ||
) | ||
|
||
if return_intercept: | ||
coef, intercept = out | ||
assert_allclose(coef, true_coefs, rtol=0, atol=atol) | ||
assert_allclose(intercept, true_intercept, rtol=0, atol=atol) | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is an absolute tol of 0.1 necessary ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs to be atol when comparing to 0 but 0.1 seems big for checking equality There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's true, but the differences in the estimations are around 0.02, so I can change to atol=0.03. The true coefs are: 1, 2, 0.1, intercept 0
so I can change to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess the default tol of 1e-3 might be the reason of this poor comparison. Could you try with a zero tol ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @btel I don't understand why you cannot use a lower tolerance as you have no noise added to data. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what do you mean? the data are randomly generated, so I don't get exactly the coefficients I put in. I can freeze the seed and test against the coefficients that I get after a test run, but still I might get some small differences between the solvers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What @agramfort meant is that you could pass a smaller tolerance to |
||
assert_allclose(out, true_coefs, rtol=0, atol=atol) | ||
|
||
|
||
def test_errors_and_values_helper(): | ||
|
Uh oh!
There was an error while loading. Please reload this page.