8000 [MRG] Fix various solver issues in ridge_regression and Ridge classes… · scikit-learn/scikit-learn@77b73d6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 77b73d6

Browse files
btelNicolasHug
authored andcommitted
[MRG] Fix various solver issues in ridge_regression and Ridge classes (#13363)
1 parent 301076e commit 77b73d6

File tree

3 files changed

+92
-22
lines changed

3 files changed

+92
-22
lines changed

doc/whats_new/v0.21.rst

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,9 @@ Support for Python 3.4 and below has been officially dropped.
184184
- |Enhancement| Minimized the validation of X in
185185
:class:`ensemble.AdaBoostClassifier` and :class:`ensemble.AdaBoostRegressor`
186186
:issue:`13174` by :user:`Christos Aridas <chkoar>`.
187-
187+
188188
- |Enhancement| :class:`ensemble.IsolationForest` now exposes ``warm_start``
189-
parameter, allowing iterative addition of trees to an isolation
189+
parameter, allowing iterative addition of trees to an isolation
190190
forest. :issue:`13496` by :user:`Peter Marko <petibear>`.
191191

192192
- |Efficiency| Make :class:`ensemble.IsolationForest` more memory efficient
@@ -369,6 +369,22 @@ Support for Python 3.4 and below has been officially dropped.
369369
deterministic when trained in a multi-class setting on several threads.
370370
:issue:`13422` by :user:`Clément Doumouro <ClemDoum>`.
371371

372+
- |Fix| Fixed bug in :func:`linear_model.ridge.ridge_regression`,
373+
:class:`linear_model.ridge.Ridge` and
374+
:class:`linear_model.ridge.ridge.RidgeClassifier` that
375+
caused unhandled exception for arguments ``return_intercept=True`` and
376+
``solver=auto`` (default) or any other solver different from ``sag``.
377+
:issue:`13363` by :user:`Bartosz Telenczuk <btel>`
378+
379+
- |Fix| :func:`linear_model.ridge.ridge_regression` will now raise an exception
380+
if ``return_intercept=True`` and solver is different from ``sag``. Previously,
381+
only warning was issued. :issue:`13363` by :user:`Bartosz Telenczuk <btel>`
382+
383+
- |API| :func:`linear_model.ridge.ridge_regression` will choose ``sparse_cg``
384+
solver for sparse inputs when ``solver=auto`` and ``sample_weight``
385+
is provided (previously `cholesky` solver was selected). :issue:`13363`
386+
by :user:`Bartosz Telenczuk <btel>`
387+
372388
:mod:`sklearn.manifold`
373389
............................
374390

sklearn/linear_model/ridge.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -368,12 +368,25 @@ def _ridge_regression(X, y, alpha, sample_weight=None, solver='auto',
368368
return_n_iter=False, return_intercept=False,
369369
X_scale=None, X_offset=None):
370370

371-
if return_intercept and sparse.issparse(X) and solver != 'sag':
372-
if solver != 'auto':
373-
warnings.warn("In Ridge, only 'sag' solver can currently fit the "
374-
"intercept when X is sparse. Solver has been "
375-
"automatically changed into 'sag'.")
376-
solver = 'sag'
371+
has_sw = sample_weight is not None
372+
373+
if solver == 'auto':
374+
if return_intercept:
375+
# only sag supports fitting intercept directly
376+
solver = "sag"
377+
elif not sparse.issparse(X):
378+
solver = "cholesky"
379+
else:
380+
solver = "sparse_cg"
381+
382+
if solver not in ('sparse_cg', 'cholesky', 'svd', 'lsqr', 'sag', 'saga'):
383+
raise ValueError("Known solvers are 'sparse_cg', 'cholesky', 'svd'"
384+
" 'lsqr', 'sag' or 'saga'. Got %s." % solver)
385+
386+
if return_intercept and solver != 'sag':
387+
raise ValueError("In Ridge, only 'sag' solver can directly fit the "
388+
"intercept. Please change solver to 'sag' or set "
389+
"return_intercept=False.")
377390

378391
_dtype = [np.float64, np.float32]
379392

@@ -404,14 +417,7 @@ def _ridge_regression(X, y, alpha, sample_weight=None, solver='auto',
404417
raise ValueError("Number of samples in X and y does not correspond:"
405418
" %d != %d" % (n_samples, n_samples_))
406419

407-
has_sw = sample_weight is not None
408420

409-
if solver == 'auto':
410-
# cholesky if it's a dense array and cg in any other case
411-
if not sparse.issparse(X) or has_sw:
412-
solver = 'cholesky'
413-
else:
414-
solver = 'sparse_cg'
415421

416422
if has_sw:
417423
if np.atleast_1d(sample_weight).ndim > 1:
@@ -432,8 +438,6 @@ def _ridge_regression(X, y, alpha, sample_weight=None, solver='auto',
432438
if alpha.size == 1 and n_targets > 1:
433439
alpha = np.repeat(alpha, n_targets)
434440

435-
if solver not in ('sparse_cg', 'cholesky', 'svd', 'lsqr', 'sag', 'saga'):
436-
raise ValueError('Solver %s not understood' % solver)
437441

438442
n_iter = None
439443
if solver == 'sparse_cg':
@@ -555,7 +559,7 @@ def fit(self, X, y, sample_weight=None):
555559
# add the offset which was subtracted by _preprocess_data
556560
self.intercept_ += y_offset
557561
else:
558-
if sparse.issparse(X):
562+
if sparse.issparse(X) and self.solver == 'sparse_cg':
559563
# required to fit intercept with sparse_cg solver
560564
params = {'X_offset': X_offset, 'X_scale': X_scale}
561565
else:

sklearn/linear_model/tests/test_ridge.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from sklearn.utils.testing import assert_almost_equal
99
from sklearn.utils.testing import assert_array_almost_equal
10+
from sklearn.utils.testing import assert_allclose
1011
from sklearn.utils.testing import assert_equal
1112
from sklearn.utils.testing import assert_array_equal
1213
from sklearn.utils.testing import assert_greater
@@ -778,7 +779,8 @@ def test_raises_value_error_if_solver_not_supported():
778779
wrong_solver = "This is not a solver (MagritteSolveCV QuantumBitcoin)"
779780

780781
exception = ValueError
781-
message = "Solver %s not understood" % wrong_solver
782+
message = ("Known solvers are 'sparse_cg', 'cholesky', 'svd'"
783+
" 'lsqr', 'sag' or 'saga'. Got %s." % wrong_solver)
782784

783785
def func():
784786
X = np.eye(3)
@@ -832,9 +834,57 @@ def test_ridge_fit_intercept_sparse():
832834
# test the solver switch and the corresponding warning
833835
for solver in ['saga', 'lsqr']:
834836
sparse = Ridge(alpha=1., tol=1.e-15, solver=solver, fit_intercept=True)
835-
assert_warns(UserWarning, sparse.fit, X_csr, y)
836-
assert_almost_equal(dense.intercept_, sparse.intercept_)
837-
assert_array_almost_equal(dense.coef_, sparse.coef_)
837+
assert_raises_regex(ValueError, "In Ridge,", sparse.fit, X_csr, y)
838+
839+
840+
@pytest.mark.parametrize('return_intercept', [False, True])
841+
@pytest.mark.parametrize('sample_weight', [None, np.ones(1000)])
842+
@pytest.mark.parametrize('arr_type', [np.array, sp.csr_matrix])
843+
@pytest.mark.parametrize('solver', ['auto', 'sparse_cg', 'cholesky', 'lsqr',
844+
'sag', 'saga'])
845+
def test_ridge_regression_check_arguments_validity(return_intercept,
846+
sample_weight, arr_type,
847+
solver):
848+
"""check if all combinations of arguments give valid estimations"""
849+
850+
# test excludes 'svd' solver because it raises exception for sparse inputs
851+
852+
rng = check_random_state(42)
853+
X = rng.rand(1000, 3)
854+
true_coefs = [1, 2, 0.1]
855+
y = np.dot(X, true_coefs)
856+
true_intercept = 0.
857+
if return_intercept:
858+
true_intercept = 10000.
859+
y += true_intercept
860+
X_testing = arr_type(X)
861+
862+
alpha, atol, tol = 1e-3, 1e-4, 1e-6
863+
864+
if solver not in ['sag', 'auto'] and return_intercept:
865+
assert_raises_regex(ValueError,
866+
"In Ridge, only 'sag' solver",
867+
ridge_regression, X_testing, y,
868+
alpha=alpha,
869+
solver=solver,
870+
sample_weight=sample_weight,
871+
return_intercept=return_intercept,
872+
tol=tol)
873+
return
874+
875+
out = ridge_regression(X_testing, y, alpha=alpha,
876+
solver=solver,
877+
sample_weight=sample_weight,
878+
return_intercept=return_intercept,
879+
tol=tol,
880+
)
881+
882+
if return_intercept:
883+
coef, intercept = out
884+
assert_allclose(coef, true_coefs, rtol=0, atol=atol)
885+
assert_allclose(intercept, true_intercept, rtol=0, atol=atol)
886+
else:
887+
assert_allclose(out, true_coefs, rtol=0, atol=atol)
838888

839889

840890
def test_errors_and_values_helper():

0 commit comments

Comments
 (0)
0