8000 FIX random failure in `test_solver_consistency` with SAG/SAGA solvers… · scikit-learn/scikit-learn@b257fbb · GitHub
[go: up one dir, main page]

Skip to content

Commit b257fbb

Browse files
authored
FIX random failure in test_solver_consistency with SAG/SAGA solvers (#31434)
1 parent cbe8648 commit b257fbb

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

sklearn/linear_model/tests/test_ridge.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -750,9 +750,8 @@ def _make_sparse_offset_regression(
750750
"n_samples,dtype,proportion_nonzero",
751751
[(20, "float32", 0.1), (40, "float32", 1.0), (20, "float64", 0.2)],
752752
)
753-
@pytest.mark.parametrize("seed", np.arange(3))
754753
def test_solver_consistency(
755-
solver, proportion_nonzero, n_samples, dtype, sparse_container, seed
754+
solver, proportion_nonzero, n_samples, dtype, sparse_container, global_random_seed
756755
):
757756
alpha = 1.0
758757
noise = 50.0 if proportion_nonzero > 0.9 else 500.0
@@ -761,10 +760,9 @@ def test_solver_consistency(
761760
n_features=30,
762761
proportion_nonzero=proportion_nonzero,
763762
noise=noise,
764-
random_state=seed,
763+
random_state=global_random_seed,
765764
n_samples=n_samples,
766765
)
767-
768766
# Manually scale the data to avoid pathological cases. We use
769767
# minmax_scale to deal with the sparse case without breaking
770768
# the sparsity pattern.
@@ -778,7 +776,21 @@ def test_solver_consistency(
778776
if solver == "ridgecv":
779777
ridge = RidgeCV(alphas=[alpha])
780778
else:
781-
ridge = Ridge(solver=solver, tol=1e-10, alpha=alpha)
779+
if solver.startswith("sag"):
780+
# Avoid ConvergenceWarning for sag and saga solvers.
781+
tol = 1e-7
782+
max_iter = 100_000
783+
else:
784+
tol = 1e-10
785+
max_iter = None
786+
787+
ridge = Ridge(
788+
alpha=alpha,
789+
solver=solver,
790+
max_iter=max_iter,
791+
tol=tol,
792+
random_state=global_random_seed,
793+
)
782794
ridge.fit(X, y)
783795
assert_allclose(ridge.coef_, svd_ridge.coef_, atol=1e-3, rtol=1e-3)
784796
assert_allclose(ridge.intercept_, svd_ridge.intercept_, atol=1e-3, rtol=1e-3)

0 commit comments

Comments
 (0)
0