@@ -120,7 +120,7 @@ def sag(X, y, step_size, alpha, n_iter=1, dloss=None, sparse=False,
120
120
121
121
def sag_sparse (X , y , step_size , alpha , n_iter = 1 ,
122
122
dloss = None , sample_weight = None , sparse = False ,
123
- fit_intercept = True , saga = False ):
123
+ fit_intercept = True , saga = False , random_state = 0 ):
124
124
if step_size * alpha == 1. :
125
125
raise ZeroDivisionError ("Sparse sag does not handle the case "
126
126
"step_size * alpha == 1" )
@@ -130,7 +130,7 @@ def sag_sparse(X, y, step_size, alpha, n_iter=1,
130
130
sum_gradient = np .zeros (n_features )
131
131
last_updated = np .zeros (n_features , dtype = np .int )
132
132
gradient_memory = np .zeros (n_samples )
133
- rng = np . random . RandomState ( 77 )
133
+ rng = check_random_state ( random_state )
134
134
intercept = 0.0
135
135
intercept_sum_gradient = 0.0
136
136
wscale = 1.0
@@ -368,7 +368,7 @@ def test_sag_regressor_computed_correctly():
368
368
alpha = .1
369
369
n_features = 10
370
370
n_samples = 40
371
- max_iter = 50
371
+ max_iter = 100
372
372
tol = .000001
373
373
fit_intercept = True
374
374
rng = np .random .RandomState (0 )
@@ -378,7 +378,8 @@ def test_sag_regressor_computed_correctly():
378
378
step_size = get_step_size (X , alpha , fit_intercept , classification = False )
379
379
380
380
clf1 = Ridge (fit_intercept = fit_intercept , tol = tol , solver = 'sag' ,
381
- alpha = alpha * n_samples , max_iter = max_iter )
381
+ alpha = alpha * n_samples , max_iter = max_iter ,
382
+ random_state = rng )
382
383
clf2 = clone (clf1 )
383
384
384
385
clf1 .fit (X , y )
@@ -387,12 +388,14 @@ def test_sag_regressor_computed_correctly():
387
388
spweights1 , spintercept1 = sag_sparse (X , y , step_size , alpha ,
388
389
n_iter = max_iter ,
389
390
dloss = squared_dloss ,
390
- fit_intercept = fit_intercept )
391
+ fit_intercept = fit_intercept ,
392
+ random_state = rng )
391
393
392
394
spweights2 , spintercept2 = sag_sparse (X , y , step_size , alpha ,
393
395
n_iter = max_iter ,
394
396
dloss = squared_dloss , sparse = True ,
395
- fit_intercept = fit_intercept )
397
+ fit_intercept = fit_intercept ,
398
+ random_state = rng )
396
399
397
400
assert_array_almost_equal (clf1 .coef_ .ravel (),
398
401
spweights1 .ravel (),
0 commit comments