36
36
"""For sparse data intercept updates are scaled by this decay factor to avoid
37
37
intercept oscillation."""
38
38
39
+ DEFAULT_EPSILON = 0.1
40
+ """Default value of ``epsilon`` parameter. """
41
+
39
42
40
43
class BaseSGD (BaseEstimator ):
41
44
"""Base class for SGD classification and regression."""
@@ -46,44 +49,32 @@ def __init__(self, loss, penalty='l2', alpha=0.0001,
46
49
rho = 0.85 , fit_intercept = True , n_iter = 5 , shuffle = False ,
47
50
verbose = 0 , epsilon = 0.1 , seed = 0 , learning_rate = "optimal" ,
48
51
eta0 = 0.0 , power_t = 0.5 , warm_start = False ):
49
- self .loss = str (loss )
50
- self .penalty = str (penalty ).lower ()
51
- self .learning_rate = str (learning_rate )
52
-
53
- # raises ValueError if not registered
54
- self .get_penalty_type (self .penalty )
55
- self .get_learning_rate_type (self .learning_rate )
56
-
57
- self .epsilon = float (epsilon )
58
- self .alpha = float (alpha )
59
- if self .alpha < 0.0 :
60
- raise ValueError ("alpha must be greater than zero" )
61
- self .rho = float (rho )
62
- if self .rho < 0.0 or self .rho > 1.0 :
63
- raise ValueError ("rho must be in [0, 1]" )
64
- self .fit_intercept = bool (fit_intercept )
65
- self .n_iter = int (n_iter )
66
- if self .n_iter <= 0 :
67
- raise ValueError ("n_iter must be greater than zero" )
68
- if not isinstance (shuffle , bool ):
69
- raise ValueError ("shuffle must be either True or False" )
70
- self .shuffle = bool (shuffle )
52
+ self .loss = loss
53
+ self .penalty = penalty
54
+ self .learning_rate = learning_rate
55
+ self .epsilon = epsilon
56
+ self .alpha = alpha
57
+ self .rho = rho
58
+ self .fit_intercept = fit_intercept
59
+ self .n_iter = n_iter
60
+ self .shuffle = shuffle
71
61
self .seed = seed
72
- self .verbose = int (verbose )
73
-
74
- self .eta0 = float (eta0 )
75
- self .power_t = float (power_t )
76
- if self .learning_rate != "optimal" :
77
- if eta0 <= 0.0 :
78
- raise ValueError ("eta0 must be greater than 0.0" )
79
- self .coef_ = None
62
+ self .verbose = verbose
63
+ self .eta0 = eta0
64
+ self .power_t = power_t
80
65
self .warm_start = warm_start
81
66
82
- #self._init_t()
67
+ self ._validate_params ()
68
+
69
+ self .coef_ = None
83
70
# iteration count for learning rate schedule
84
71
# must not be int (e.g. if ``learning_rate=='optimal'``)
85
72
self .t_ = None
86
73
74
+ def set_params (self , * args , ** kwargs ):
75
+ super (BaseSGD , self ).set_params (* args , ** kwargs )
76
+ self ._validate_params ()
77
+
87
78
@abstractmethod
88
79
def fit (self , X , y ):
89
80
"""Fit model."""
@@ -92,6 +83,27 @@ def fit(self, X, y):
92
83
def predict (self , X ):
93
84
"""Predict using model."""
94
85
86
+ def _validate_params (self ):
87
+ """Validate input params. """
88
+ if not isinstance (self .shuffle , bool ):
89
+ raise ValueError ("shuffle must be either True or False" )
90
+ if self .n_iter <= 0 :
91
+ raise ValueError ("n_iter must be greater than zero" )
92
+ if not (0.0 <= self .rho <= 1.0 ):
93
+ raise ValueError ("rho must be in [0, 1]" )
94
+ if self .alpha < 0.0 :
95
+ raise ValueError ("alpha must be greater than zero" )
96
+ if self .learning_rate != "optimal" :
97
+ if self .eta0 <= 0.0 :
98
+ raise ValueError ("eta0 must be greater than 0.0" )
99
+
<
9E7A
/code>
100
+ # raises ValueError if not registered
101
+ self ._get_penalty_type (self .penalty )
102
+ self ._get_learning_rate_type (self .learning_rate )
103
+
104
+ if self .loss not in self .loss_functions :
105
+ raise ValueError ("The loss %s is not supported. " % self .loss )
106
+
95
107
def _init_t (self , loss_function ):
96
108
"""Initialize iteration counter attr ``t_``.
97
109
@@ -111,18 +123,21 @@ def get_loss_function(self, loss):
111
123
try :
112
124
loss_ = self .loss_functions [loss ]
113
125
loss_class , args = loss_ [0 ], loss_ [1 :]
126
+ if loss in ('huber' , 'epsilon_insensitive' ):
127
+ args = (self .epsilon , )
114
128
return loss_class (* args )
115
129
except KeyError :
116
130
raise ValueError ("The loss %s is not supported. " % loss )
117
131
118
- def get_learning_rate_type (self , learning_rate ):
132
+ def _get_learning_rate_type (self , learning_rate ):
119
133
try :
120
134
return LEARNING_RATE_TYPES [learning_rate ]
121
135
except KeyError :
122
136
raise ValueError ("learning rate %s"
123
137
"is not supported. " % learning_rate )
124
138
125
- def get_penalty_type (self , penalty ):
139
+ def _get_penalty_type (self , penalty ):
140
+ penalty = str (penalty ).lower ()
126
141
try :
127
142
return PENALTY_TYPES [penalty ]
128
143
except KeyError :
@@ -338,9 +353,20 @@ class SGDClassifier(BaseSGD, ClassifierMixin, SelectorMixin):
338
353
LinearSVC, LogisticRegression, Perceptron
339
354
340
355
"""
356
+
357
+ loss_functions = {
358
+ "hinge" : (Hinge , 1.0 ),
359
+ "perceptron" : (Hinge , 0.0 ),
360
+ "log" : (Log , ),
361
+ "modified_huber" : (ModifiedHuber , ),
362
+ "squared_loss" : (SquaredLoss , ),
363
+ "huber" : (Huber , DEFAULT_EPSILON ),
364
+ "epsilon_insensitive" : (EpsilonInsensitive , DEFAULT_EPSILON ),
365
+ }
366
+
341
367
def __init__ (self , loss = "hinge" , penalty = 'l2' , alpha = 0.0001 ,
342
368
rho = 0.85 , fit_intercept = True , n_iter = 5 , shuffle = False ,
343
- verbose = 0 , epsilon = 0.1 , n_jobs = 1 , seed = 0 ,
369
+ verbose = 0 , epsilon = DEFAULT_EPSILON , n_jobs = 1 , seed = 0 ,
344
370
learning_rate = "optimal" , eta0 = 0.0 , power_t = 0.5 ,
345
371
class_weight = None , warm_start = False ):
346
372
super (SGDClassifier , self ).__init__ (loss = loss , penalty = penalty ,
@@ -356,18 +382,6 @@ def __init__(self, loss="hinge", penalty='l2', alpha=0.0001,
356
382
self .classes_ = None
357
383
self .n_jobs = int (n_jobs )
358
384
359
- self .loss_functions = {
360
- "hinge" : (Hinge , 1.0 ),
361
- "perceptron" : (Hinge , 0.0 ),
362
- "log" : (Log , ),
363
- "modified_huber" : (ModifiedHuber , ),
364
- "squared_loss" : (SquaredLoss , ),
365
- "huber" : (Huber , self .epsilon ),
366
- "epsilon_insensitive" : (EpsilonInsensitive , self .epsilon ),
367
- }
368
- if loss not in self .loss_functions :
369
- raise ValueError ("The loss %s is not supported. " % loss )
370
-
371
385
@property
372
386
@deprecated ("to be removed in v0.13; use ``classes_`` instead." )
373
387
def classes (self ):
@@ -406,6 +420,8 @@ def _partial_fit(self, X, y, n_iter, classes=None, sample_weight=None,
406
420
n_samples , n_features = X .shape
407
421
_check_fit_data (X , y )
408
422
423
+ self ._validate_params ()
424
+
409
425
if self .classes_ is None and classes is None :
410
426
raise ValueError ("classes must be passed on the first call "
411
427
"to partial_fit." )
@@ -696,13 +712,13 @@ def fit_binary(est, i, X, y, n_iter, pos_weight, neg_weight,
696
712
assert y_i .shape [0 ] == y .shape [0 ] == sample_weight .shape [0 ]
697
713
dataset , intercept_decay = _make_dataset (X , y_i , sample_weight )
698
714
699
- penalty_type = est .get_penalty_type (est .penalty )
700
- learning_rate_type = est .get_learning_rate_type (est .learning_rate )
715
+ penalty_type = est ._get_penalty_type (est .penalty )
716
+ learning_rate_type = est ._get_learning_rate_type (est .learning_rate )
701
717
702
718
return plain_sgd (coef , intercept , est .loss_function ,
703
719
penalty_type , est .alpha , est .rho ,
704
- dataset , n_iter , est .fit_intercept ,
705
- est .verbose , est .shuffle , est .seed ,
720
+ dataset , n_iter , int ( est .fit_intercept ) ,
721
+ int ( est .verbose ), int ( est .shuffle ) , est .seed ,
706
722
pos_weight , neg_weight ,
707
723
learning_rate_type , est .eta0 ,
708
724
est .power_t , est .t_ , intercept_decay )
@@ -812,6 +828,13 @@ class SGDRegressor(BaseSGD, RegressorMixin, SelectorMixin):
812
828
Ridge, ElasticNet, Lasso, SVR
813
829
814
830
"""
831
+
832
+ loss_functions = {
833
+ "squared_loss" : (SquaredLoss , ),
834
+ "huber" : (Huber , DEFAULT_EPSILON ),
835
+ "epsilon_insensitive" : (EpsilonInsensitive , DEFAULT_EPSILON )
836
+ }
837
+
815
838
def __init__ (self , loss = "squared_loss" , penalty = "l2" , alpha = 0.0001 ,
816
839
rho = 0.85 , fit_intercept = True , n_iter = 5 , shuffle = False , verbose = 0 ,
817
840
epsilon = 0.1 , p = None , seed = 0 , learning_rate = "invscaling" , eta0 = 0.01 ,
@@ -834,14 +857,6 @@ def __init__(self, loss="squared_loss", penalty="l2", alpha=0.0001,
834
857
eta0 = eta0 , power_t = power_t ,
835
858
warm_start = False )
836
859
837
- self .loss_functions = {
838
- "squared_loss" : (SquaredLoss , ),
839
- "huber" : (Huber , self .epsilon ),
840
- "epsilon_insensitive" : (EpsilonInsensitive , self .epsilon )
841
- }
842
- if loss not in self .loss_functions :
843
- raise ValueError ("The loss %s is not supported. " % loss )
844
-
845
860
def _partial_fit (self , X , y , n_iter , sample_weight = None ,
846
861
coef_init = None , intercept_init = None ):
847
862
X , y = check_arrays (X , y , sparse_format = "csr" , copy = False ,
@@ -851,6 +866,8 @@ def _partial_fit(self, X, y, n_iter, sample_weight=None,
851
866
n_samples , n_features = X .shape
852
867
_check_fit_data (X , y )
853
868
869
+ self ._validate_params ()
870
+
854
871
# Allocate datastructures from input arguments
855
872
sample_weight = self ._validate_sample_weight (sample_weight , n_samples )
856
873
@@ -919,8 +936,7 @@ def fit(self, X, y, coef_init=None, intercept_init=None,
919
936
self .coef_ = None
920
937
self .intercept_ = None
921
938
922
- # Need to re-initialize in case of multiple call to fit.
923
- #self._init_t()
939
+ # Clear iteration count for multiple call to fit.
924
940
self .t_ = None
925
941
926
942
return self ._partial_fit (X , y , self .n_iter , sample_weight ,
@@ -960,8 +976,8 @@ def _fit_regressor(self, X, y, sample_weight, n_iter):
960
976
dataset , intercept_decay = _make_dataset (X , y , sample_weight )
961
977
962
978
loss_function = self .get_loss_function (self .loss )
963
- penalty_type = self .get_penalty_type (self .penalty )
964
- learning_rate_type = self .get_learning_rate_type (self .learning_rate )
979
+ penalty_type = self ._get_penalty_type (self .penalty )
980
+ learning_rate_type = self ._get_learning_rate_type (self .learning_rate )
965
981
966
982
if self .t_ is None :
967
983
self ._init_t (loss_function )
0 commit comments