8000 FIX seed in test_ridge_sample_weight_consistency [all random seeds] (… · scikit-learn/scikit-learn@c0d4f75 · GitHub
[go: up one dir, main page]

Skip to content

Commit c0d4f75

Browse files
authored
FIX seed in test_ridge_sample_weight_consistency [all random seeds] (#26589)
1 parent 0800fb3 commit c0d4f75

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

sklearn/linear_model/tests/test_ridge.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1978,7 +1978,9 @@ def test_lbfgs_solver_error():
19781978
@pytest.mark.parametrize("sparseX", [False, True])
19791979
@pytest.mark.parametrize("data", ["tall", "wide"])
19801980
@pytest.mark.parametrize("solver", SOLVERS + ["lbfgs"])
1981-
def test_ridge_sample_weight_consistency(fit_intercept, sparseX, data, solver):
1981+
def test_ridge_sample_weight_consistency(
1982+
fit_intercept, sparseX, data, solver, global_random_seed
1983+
):
19821984
"""Test that the impact of sample_weight is consistent.
19831985
19841986
Note that this test is stricter than the common test
@@ -1989,6 +1991,9 @@ def test_ridge_sample_weight_consistency(fit_intercept, sparseX, data, solver):
19891991
if solver == "svd" or (solver in ("cholesky", "saga") and fit_intercept):
19901992
pytest.skip("unsupported configuration")
19911993

1994+
# XXX: this test is quite sensitive to the seed used to generate the data:
1995+
# ideally we would like the test to pass for any global_random_seed but this is not
1996+
# the case at the moment.
19921997
rng = np.random.RandomState(42)
19931998
n_samples = 12
19941999
if data == "tall":
@@ -2005,6 +2010,7 @@ def test_ridge_sample_weight_consistency(fit_intercept, sparseX, data, solver):
20052010
alpha=1.0,
20062011
solver=solver,
20072012
positive=(solver == "lbfgs"),
2013+
random_state=global_random_seed, # for sag/saga
20082014
tol=1e-12,
20092015
)
20102016

0 commit comments

Comments
 (0)
0