14
14
from .sag_fast import sag , get_max_squared_sum
15
15
16
16
17
- def get_auto_step_size (max_squared_sum , alpha , loss , fit_intercept ):
17
+ def get_auto_step_size (max_squared_sum , alpha_scaled , loss , fit_intercept ):
18
18
"""Compute automatic step size for SAG solver
19
19
20
- The step size is set to 1 / (alpha + L + fit_intercept) where L is
20
+ The step size is set to 1 / (alpha_scaled + L + fit_intercept) where L is
21
21
the max sum of squares for over all samples.
22
22
23
23
Parameters
24
24
----------
25
25
max_squared_sum : float
26
26
Maximum squared sum of X over samples.
27
27
28
- alpha : float
29
- Constant that multiplies the regularization term. Defaults to 0.0001
28
+ alpha_scaled : float
29
+ Constant that multiplies the regularization term, scaled by
30
+ 1. / n_samples, the number of samples.
30
31
31
32
loss : string, in {"log", "squared"}
32
33
The loss function used in SAG solver.
@@ -43,16 +44,17 @@ def get_auto_step_size(max_squared_sum, alpha, loss, fit_intercept):
43
44
"""
44
45
if loss == 'log' :
45
46
# inverse Lipschitz constant for log loss
46
- return 4.0 / (max_squared_sum + int (fit_intercept ) + 4.0 * alpha )
47
+ return 4.0 / (max_squared_sum + int (fit_intercept )
48
+ + 4.0 * alpha_scaled )
47
49
elif loss == 'squared' :
48
50
# inverse Lipschitz constant for squared loss
49
- return 1.0 / (max_squared_sum + int (fit_intercept ) + alpha )
51
+ return 1.0 / (max_squared_sum + int (fit_intercept ) + alpha_scaled )
50
52
else :
51
53
raise ValueError ("Unknown loss function for SAG solver, got %s "
52
54
"instead of 'log' or 'squared'" % loss )
53
55
54
56
55
- def sag_solver (X , y , sample_weight = None , loss = 'log' , alpha = 1e-4 ,
57
+ def sag_solver (X , y , sample_weight = None , loss = 'log' , alpha = 1. ,
56
58
max_iter = 1000 , tol = 0.001 , verbose = 0 , random_state = None ,
57
59
check_input = True , max_squared_sum = None ,
58
60
warm_start_mem = dict ()):
@@ -91,7 +93,7 @@ def sag_solver(X, y, sample_weight=None, loss='log', alpha=1e-4,
91
93
'squared' is used for regression, like in Ridge.
92
94
93
95
alpha : float, optional
94
- Constant that multiplies the regularization term. Defaults to 0.0001
96
+ Constant that multiplies the regularization term. Defaults to 1.
95
97
96
98
max_iter: int, optional
97
99
The max number of passes over the training data if the stopping
@@ -177,7 +179,8 @@ def sag_solver(X, y, sample_weight=None, loss='log', alpha=1e-4,
177
179
y = check_array (y , dtype = np .float64 , ensure_2d = False , order = 'C' )
178
180
179
181
n_samples , n_features = X .shape [0 ], X .shape [1 ]
180
- alpha = float (alpha ) / n_samples
182
+ # As in SGD, the alpha is scaled by n_samples.
183
+ alpha_scaled = float (alpha ) / n_samples
181
184
182
185
# initialization
183
186
if sample_weight is None :
@@ -226,19 +229,19 @@ def sag_solver(X, y, sample_weight=None, loss='log', alpha=1e-4,
226
229
227
230
if max_squared_sum is None :
228
231
max_squared_sum = get_max_squared_sum (X )
229
- step_size = get_auto_step_size (max_squared_sum , alpha , loss ,
232
+ step_size = get_auto_step_size (max_squared_sum , alpha_scaled , loss ,
230
233
fit_intercept )
231
234
232
- if step_size * alpha == 1 :
235
+ if step_size * alpha_scaled == 1 :
233
236
raise ZeroDivisionError ("Current sag implementation does not handle "
234
- "the case step_size * alpha == 1" )
237
+ "the case step_size * alpha_scaled == 1" )
235
238
236
239
if loss == 'log' :
237
240
class_loss = Log ()
238
241
elif loss == 'squared' :
239
242
class_loss = SquaredLoss ()
240
243
else :
241
- raise ValueError ("Invalid sparseness parameter: got %r instead of "
244
+ raise ValueError ("Invalid loss parameter: got %r instead of "
242
245
"one of ('log', 'squared')" % loss )
243
246
244
247
intercept_ , num_seen , n_iter_ , intercept_sum_gradient = \
@@ -247,7 +250,7 @@ def sag_solver(X, y, sample_weight=None, loss='log', alpha=1e-4,
247
250
n_features , tol ,
248
251
max_iter ,
249
252
class_loss ,
250
- step_size , alpha ,
253
+ step_size , alpha_scaled ,
251
254
sum_gradient_init .ravel (),
252
255
gradient_memory_init .ravel (),
253
256
seen_init .ravel (),
0 commit comments