8000 rm instance variables learing_rate_type, loss_function, and penalty_t… · seckcoder/scikit-learn@1791575 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1791575

Browse files
pprettGaelVaroquaux
authored andcommitted
rm instance variables learing_rate_type, loss_function, and penalty_type; create them before plain_fit
1 parent 03d28ee commit 1791575

File tree

1 file changed

+72
-46
lines changed

1 file changed

+72
-46
lines changed

sklearn/linear_model/stochastic_gradient.py

Lines changed: 72 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,13 @@ def __init__(self, loss, penalty='l2', alpha=0.0001,
4848
eta0=0.0, power_t=0.5, warm_start=False):
4949
self.loss = str(loss)
5050
self.penalty = str(penalty).lower()
51-
self.epsilon = float(epsilon)
52-
self._set_loss_function(self.loss)
53-
self._set_penalty_type(self.penalty)
51+
self.learning_rate = str(learning_rate)
5452

53+
# raises ValueError if not registered
54+
self.get_penalty_type(self.penalty)
55+
self.get_learning_rate_type(self.learning_rate)
56+
57+
self.epsilon = float(epsilon)
5558
self.alpha = float(alpha)
5659
if self.alpha < 0.0:
5760
raise ValueError("alpha must be greater than zero")
@@ -68,8 +71,6 @@ def __init__(self, loss, penalty='l2', alpha=0.0001,
6871
self.seed = se 10000 ed
6972
self.verbose = int(verbose)
7073

71-
self.learning_rate = str(learning_rate)
72-
self._set_learning_rate(self.learning_rate)
7374
self.eta0 = float(eta0)
7475
self.power_t = float(power_t)
7576
if self.learning_rate != "optimal":
@@ -78,7 +79,10 @@ def __init__(self, loss, penalty='l2', alpha=0.0001,
7879
self.coef_ = None
7980
self.warm_start = warm_start
8081

81-
self._init_t()
82+
#self._init_t()
83+
# iteration count for learning rate schedule
84+
# must not be int (e.g. if ``learning_rate=='optimal'``)
85+
self.t_ = None
8286

8387
@abstractmethod
8488
def fit(self, X, y):
@@ -88,25 +92,39 @@ def fit(self, X, y):
8892
def predict(self, X):
8993
"""Predict using model."""
9094

91-
def _init_t(self):
95+
def _init_t(self, loss_function):
96+
"""Initialize iteration counter attr ``t_``.
97+
98+
If ``self.loss=='optimal'`` initialize ``t_`` such that ``eta`` at
99+
first sample equals ``self.eta0``.
100+
"""
92101
self.t_ = 1.0
93102
if self.learning_rate == "optimal":
94103
typw = np.sqrt(1.0 / np.sqrt(self.alpha))
95104
# computing eta0, the initial learning rate
96-
eta0 = typw / max(1.0, self.loss_function.dloss(-typw, 1.0))
97-
# initialize t such that eta at first example equals eta0
105+
eta0 = typw / max(1.0, loss_function.dloss(-typw, 1.0))
106+
# initialize t such that eta at first sample equals eta0
98107
self.t_ = 1.0 / (eta0 * self.alpha)
99108

100-
def _set_learning_rate(self, learning_rate):
109+
def get_loss_function(self, loss):
110+
"""Get concrete ``LossFunction`` object for str ``loss``. """
101111
try:
102-
self.learning_rate_type = LEARNING_RATE_TYPES[learning_rate]
112+
loss_ = self.loss_functions[loss]
113+
loss_class, args = loss_[0], loss_[1:]
114+
return loss_class(*args)
115+
except KeyError:
116+
raise ValueError("The loss %s is not supported. " % loss)
117+
118+
def get_learning_rate_type(self, learning_rate):
119+
try:
120+
return LEARNING_RATE_TYPES[learning_rate]
103121
except KeyError:
104122
raise ValueError("learning rate %s"
105123
"is not supported. " % learning_rate)
106124

107-
def _set_penalty_type(self, penalty):
125+
def get_penalty_type(self, penalty):
108126
try:
109-
self.penalty_type = PENALTY_TYPES[penalty]
127+
return PENALTY_TYPES[penalty]
110128
except KeyError:
111129
raise ValueError("Penalty %s is not supported. " % penalty)
112130

@@ -338,27 +356,23 @@ def __init__(self, loss="hinge", penalty='l2', alpha=0.0001,
338356
self.classes_ = None
339357
self.n_jobs = int(n_jobs)
340358

359+
self.loss_functions = {
360+
"hinge": (Hinge, 1.0),
361+
"perceptron": (Hinge, 0.0),
362+
"log": (Log, ),
363+
"modified_huber": (ModifiedHuber, ),
364+
"squared_loss": (SquaredLoss, ),
365+
"huber": (Huber, self.epsilon),
366+
"epsilon_insensitive": (EpsilonInsensitive, self.epsilon),
367+
}
368+
if loss not in self.loss_functions:
369+
raise ValueError("The loss %s is not supported. " % loss)
370+
341371
@property
342372
@deprecated("to be removed in v0.13; use ``classes_`` instead.")
343373
def classes(self):
344374
return self.classes_
345375

346-
def _set_loss_function(self, loss):
347-
"""Set concrete LossFunction."""
348-
loss_functions = {
349-
"hinge": Hinge(1.0),
350-
"perceptron": Hinge(0.0),
351-
"log": Log(),
352-
"modified_huber": ModifiedHuber(),
353-
"squared_loss": SquaredLoss(),
354-
"huber": Huber(self.epsilon),
355-
"epsilon_insensitive": EpsilonInsensitive(self.epsilon),
356-
}
357-
try:
358-
self.loss_function = loss_functions[loss]
359-
except KeyError:
360-
raise ValueError("The loss %s is not supported. " % loss)
361-
362376
def _set_class_weight(self, class_weight, classes, y):
363377
"""Estimate class weights for unbalanced datasets."""
364378
if class_weight is None or len(class_weight) == 0:
@@ -513,7 +527,8 @@ def fit(self, X, y, coef_init=None, intercept_init=None,
513527
self.intercept_ = None
514528

515529
# Need to re-initialize in case of multiple call to fit.
516-
self._init_t()
530+
#self._init_t()
531+
self.t_ = None
517532

518533
self._partial_fit(X, y, self.n_iter, classes,
519534
sample_weight, coef_init, intercept_init)
@@ -678,12 +693,19 @@ def fit_binary(est, i, X, y, n_iter, pos_weight, neg_weight,
678693
assert y_i.shape[0] == y.shape[0] == sample_weight.shape[0]
679694
dataset, intercept_decay = _make_dataset(X, y_i, sample_weight)
680695

681-
return plain_sgd(coef, intercept, est.loss_function,
682-
est.penalty_type, est.alpha, est.rho,
696+
loss_function = est.get_loss_function(est.loss)
697+
penalty_type = est.get_penalty_type(est.penalty)
698+
learning_rate_type = est.get_learning_rate_type(est.learning_rate)
699+
700+
if est.t_ is None:
701+
est._init_t(loss_function)
702+
703+
return plain_sgd(coef, intercept, loss_function,
704+
penalty_type, est.alpha, est.rho,
683705
dataset, n_iter, est.fit_intercept,
684706
est.verbose, est.shuffle, est.seed,
685707
pos_weight, neg_weight,
686-
est.learning_rate_type, est.eta0,
708+
learning_rate_type, est.eta0,
687709
est.power_t, est.t_, intercept_decay)
688710

689711

@@ -813,16 +835,12 @@ def __init__(self, loss="squared_loss", penalty="l2", alpha=0.0001,
813835
eta0=eta0, power_t=power_t,
814836
warm_start=False)
815837

816-
def _set_loss_function(self, loss):
817-
"""Get concrete LossFunction"""
818-
loss_functions = {
819-
"squared_loss": SquaredLoss(),
820-
"huber": Huber(self.epsilon),
821-
"epsilon_insensitive": EpsilonInsensitive(self.epsilon),
838+
self.loss_functions = {
839+
"squared_loss": (SquaredLoss, ),
840+
"huber": (Huber, self.epsilon),
841+
"epsilon_insensitive": (EpsilonInsensitive, self.epsilon)
822842
}
823-
try:
824-
self.loss_function = loss_functions[loss]
825-
except KeyError:
843+
if loss not in self.loss_functions:
826844
raise ValueError("The loss %s is not supported. " % loss)
827845

828846
def _partial_fit(self, X, y, n_iter, sample_weight=None,
@@ -903,7 +921,8 @@ def fit(self, X, y, coef_init=None, intercept_init=None,
903921
self.intercept_ = None
904922

905923
# Need to re-initialize in case of multiple call to fit.
906-
self._init_t()
924+
#self._init_t()
925+
self.t_ = None
907926

908927
return self._partial_fit(X, y, self.n_iter, sample_weight,
909928
coef_init, intercept_init)
@@ -941,10 +960,17 @@ def predict(self, X):
941960
def _fit_regressor(self, X, y, sample_weight, n_iter):
942961
dataset, intercept_decay = _make_dataset(X, y, sample_weight)
943962

963+
loss_function = self.get_loss_function(self.loss)
964+
penalty_type = self.get_penalty_type(self.penalty)
965+
learning_rate_type = self.get_learning_rate_type(self.learning_rate)
966+
967+
if self.t_ is None:
968+
self._init_t(loss_function)
969+
944970
self.coef_, intercept = plain_sgd(self.coef_,
945971
self.intercept_[0],
946-
self.loss_function,
947-
self.penalty_type,
972+
loss_function,
973+
penalty_type,
948974
self.alpha, self.rho,
949975
dataset,
950976
n_iter,
@@ -953,7 +979,7 @@ def _fit_regressor(self, X, y, sample_weight, n_iter):
953979
int(self.shuffle),
954980
self.seed,
955981
1.0, 1.0,
956-
self.learning_rate_type,
982+
learning_rate_type,
957983
self.eta0, self.power_t, self.< 3F4B span class=pl-c1>t_,
958984
intercept_decay)
959985

0 commit comments

Comments
 (0)
0