8000 Make test_sag_regressor_computed_correctly deterministic (#16003) · scikit-learn/scikit-learn@4d5407c · GitHub
[go: up one dir, main page]

Skip to content

Commit 4d5407c

Browse files
authored
Make test_sag_regressor_computed_correctly deterministic (#16003)
Fix #15818.
1 parent c8c21ae commit 4d5407c

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

sklearn/linear_model/tests/test_sag.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def sag(X, y, step_size, alpha, n_iter=1, dloss=None, sparse=False,
120120

121121
def sag_sparse(X, y, step_size, alpha, n_iter=1,
122122
dloss=None, sample_weight=None, sparse=False,
123-
fit_intercept=True, saga=False):
123+
fit_intercept=True, saga=False, random_state=0):
124124
if step_size * alpha == 1.:
125125
raise ZeroDivisionError("Sparse sag does not handle the case "
126126
"step_size * alpha == 1")
@@ -130,7 +130,7 @@ def sag_sparse(X, y, step_size, alpha, n_iter=1,
130130
sum_gradient = np.zeros(n_features)
131131
last_updated = np.zeros(n_features, dtype=np.int)
132132
gradient_memory = np.zeros(n_samples)
133-
rng = np.random.RandomState(77)
133+
rng = check_random_state(random_state)
134134
intercept = 0.0
135135
intercept_sum_gradient = 0.0
136136
wscale = 1.0
@@ -368,7 +368,7 @@ def test_sag_regressor_computed_correctly():
368368
alpha = .1
369369
n_features = 10
370370
n_samples = 40
371-
max_iter = 50
371+
max_iter = 100
372372
tol = .000001
373373
fit_intercept = True
374374
rng = np.random.RandomState(0)
@@ -378,7 +378,8 @@ def test_sag_regressor_computed_correctly():
378378
step_size = get_step_size(X, alpha, fit_intercept, classification=False)
379379

380380
clf1 = Ridge(fit_intercept=fit_intercept, tol=tol, solver='sag',
381-
alpha=alpha * n_samples, max_iter=max_iter)
381+
alpha=alpha * n_samples, max_iter=max_iter,
382+
random_state=rng)
382383
clf2 = clone(clf1)
383384

384385
clf1.fit(X, y)
@@ -387,12 +388,14 @@ def test_sag_regressor_computed_correctly():
387388
spweights1, spintercept1 = sag_sparse(X, y, step_size, alpha,
388389
n_iter=max_iter,
389390
dloss=squared_dloss,
390-
fit_intercept=fit_intercept)
391+
fit_intercept=fit_intercept,
392+
random_state=rng)
391393

392394
spweights2, spintercept2 = sag_sparse(X, y, step_size, alpha,
393395
n_iter=max_iter,
394396
dloss=squared_dloss, sparse=True,
395-
fit_intercept=fit_intercept)
397+
fit_intercept=fit_intercept,
398+
random_state=rng)
396399

397400
assert_array_almost_equal(clf1.coef_.ravel(),
398401
spweights1.ravel(),

0 commit comments

Comments
 (0)
0