8000 refactored input validation; special loss function factory for huber … · seckcoder/scikit-learn@8283126 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8283126

Browse files
pprettGaelVaroquaux
authored andcommitted
refactored input validation; special loss function factory for huber and epsilon insensitive loss
1 parent 767ec5d commit 8283126

File tree

1 file changed

+78
-62
lines changed

1 file changed

+78
-62
lines changed

sklearn/linear_model/stochastic_gradient.py

Lines changed: 78 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
"""For sparse data intercept updates are scaled by this decay factor to avoid
3737
intercept oscillation."""
3838

39+
DEFAULT_EPSILON = 0.1
40+
"""Default value of ``epsilon`` parameter. """
41+
3942

4043
class BaseSGD(BaseEstimator):
4144
"""Base class for SGD classification and regression."""
@@ -46,44 +49,32 @@ def __init__(self, loss, penalty='l2', alpha=0.0001,
4649
rho=0.85, fit_intercept=True, n_iter=5, shuffle=False,
4750
verbose=0, epsilon=0.1, seed=0, learning_rate="optimal",
4851
eta0=0.0, power_t=0.5, warm_start=False):
49-
self.loss = str(loss)
50-
self.penalty = str(penalty).lower()
51-
self.learning_rate = str(learning_rate)
52-
53-
# raises ValueError if not registered
54-
self.get_penalty_type(self.penalty)
55-
self.get_learning_rate_type(self.learning_rate)
56-
57-
self.epsilon = float(epsilon)
58-
self.alpha = float(alpha)
59-
if self.alpha < 0.0:
60-
raise ValueError("alpha must be greater than zero")
61-
self.rho = float(rho)
62-
if self.rho < 0.0 or self.rho > 1.0:
63-
raise ValueError("rho must be in [0, 1]")
64-
self.fit_intercept = bool(fit_intercept)
65-
self.n_iter = int(n_iter)
66-
if self.n_iter <= 0:
67-
raise ValueError("n_iter must be greater than zero")
68-
if not isinstance(shuffle, bool):
69-
raise ValueError("shuffle must be either True or False")
70-
self.shuffle = bool(shuffle)
52+
self.loss = loss
53+
self.penalty = penalty
54+
self.learning_rate = learning_rate
55+
self.epsilon = epsilon
56+
self.alpha = alpha
57+
self.rho = rho
58+
self.fit_intercept = fit_intercept
59+
self.n_iter = n_iter
60+
self.shuffle = shuffle
7161
self.seed = seed
72-
self.verbose = int(verbose)
73-
74-
self.eta0 = float(eta0)
75-
self.power_t = float(power_t)
76-
if self.learning_rate != "optimal":
77-
if eta0 <= 0.0:
78-
raise ValueError("eta0 must be greater than 0.0")
79-
self.coef_ = None
62+
self.verbose = verbose
63+
self.eta0 = eta0
64+
self.power_t = power_t
8065
self.warm_start = warm_start
8166

82-
#self._init_t()
67+
self._validate_params()
68+
69+
self.coef_ = None
8370
# iteration count for learning rate schedule
8471
# must not be int (e.g. if ``learning_rate=='optimal'``)
8572
self.t_ = None
8673

74+
def set_params(self, *args, **kwargs):
75+
super(BaseSGD, self).set_params(*args, **kwargs)
76+
self._validate_params()
77+
8778
@abstractmethod
8879
def fit(self, X, y):
8980
"""Fit model."""
@@ -92,6 +83,27 @@ def fit(self, X, y):
9283
def predict(self, X):
9384
"""Predict using model."""
9485

86+
def _validate_params(self):
87+
"""Validate input params. """
88+
if not isinstance(self.shuffle, bool):
89+
raise ValueError("shuffle must be either True or False")
90+
if self.n_iter <= 0:
91+
raise ValueError("n_iter must be greater than zero")
92+
if not (0.0 <= self.rho <= 1.0):
93+
raise ValueError("rho must be in [0, 1]")
94+
if self.alpha < 0.0:
95+
raise ValueError("alpha must be greater than zero")
96+
if self.learning_rate != "optimal":
97+
if self.eta0 <= 0.0:
98+
raise ValueError("eta0 must be greater than 0.0")
99+
< 9E7A /code>100+
# raises ValueError if not registered
101+
self._get_penalty_type(self.penalty)
102+
self._get_learning_rate_type(self.learning_rate)
103+
104+
if self.loss not in self.loss_functions:
105+
raise ValueError("The loss %s is not supported. " % self.loss)
106+
95107
def _init_t(self, loss_function):
96108
"""Initialize iteration counter attr ``t_``.
97109
@@ -111,18 +123,21 @@ def get_loss_function(self, loss):
111123
try:
112124
loss_ = self.loss_functions[loss]
113125
loss_class, args = loss_[0], loss_[1:]
126+
if loss in ('huber', 'epsilon_insensitive'):
127+
args = (self.epsilon, )
114128
return loss_class(*args)
115129
except KeyError:
116130
raise ValueError("The loss %s is not supported. " % loss)
117131

118-
def get_learning_rate_type(self, learning_rate):
132+
def _get_learning_rate_type(self, learning_rate):
119133
try:
120134
return LEARNING_RATE_TYPES[learning_rate]
121135
except KeyError:
122136
raise ValueError("learning rate %s"
123137
"is not supported. " % learning_rate)
124138

125-
def get_penalty_type(self, penalty):
139+
def _get_penalty_type(self, penalty):
140+
penalty = str(penalty).lower()
126141
try:
127142
return PENALTY_TYPES[penalty]
128143
except KeyError:
@@ -338,9 +353,20 @@ class SGDClassifier(BaseSGD, ClassifierMixin, SelectorMixin):
338353
LinearSVC, LogisticRegression, Perceptron
339354
340355
"""
356+
357+
loss_functions = {
358+
"hinge": (Hinge, 1.0),
359+
"perceptron": (Hinge, 0.0),
360+
"log": (Log, ),
361+
"modified_huber": (ModifiedHuber, ),
362+
"squared_loss": (SquaredLoss, ),
363+
"huber": (Huber, DEFAULT_EPSILON),
364+
"epsilon_insensitive": (EpsilonInsensitive, DEFAULT_EPSILON),
365+
}
366+
341367
def __init__(self, loss="hinge", penalty='l2', alpha=0.0001,
342368
rho=0.85, fit_intercept=True, n_iter=5, shuffle=False,
343-
verbose=0, epsilon=0.1, n_jobs=1, seed=0,
369+
verbose=0, epsilon=DEFAULT_EPSILON, n_jobs=1, seed=0,
344370
learning_rate="optimal", eta0=0.0, power_t=0.5,
345371
class_weight=None, warm_start=False):
346372
super(SGDClassifier, self).__init__(loss=loss, penalty=penalty,
@@ -356,18 +382,6 @@ def __init__(self, loss="hinge", penalty='l2', alpha=0.0001,
356382
self.classes_ = None
357383
self.n_jobs = int(n_jobs)
358384

359-
self.loss_functions = {
360-
"hinge": (Hinge, 1.0),
361-
"perceptron": (Hinge, 0.0),
362-
"log": (Log, ),
363-
"modified_huber": (ModifiedHuber, ),
364-
"squared_loss": (SquaredLoss, ),
365-
"huber": (Huber, self.epsilon),
366-
"epsilon_insensitive": (EpsilonInsensitive, self.epsilon),
367-
}
368-
if loss not in self.loss_functions:
369-
raise ValueError("The loss %s is not supported. " % loss)
370-
371385
@property
372386
@deprecated("to be removed in v0.13; use ``classes_`` instead.")
373387
def classes(self):
@@ -406,6 +420,8 @@ def _partial_fit(self, X, y, n_iter, classes=None, sample_weight=None,
406420
n_samples, n_features = X.shape
407421
_check_fit_data(X, y)
408422

423+
self._validate_params()
424+
409425
if self.classes_ is None and classes is None:
410426
raise ValueError("classes must be passed on the first call "
411427
"to partial_fit.")
@@ -696,13 +712,13 @@ def fit_binary(est, i, X, y, n_iter, pos_weight, neg_weight,
696712
assert y_i.shape[0] == y.shape[0] == sample_weight.shape[0]
697713
dataset, intercept_decay = _make_dataset(X, y_i, sample_weight)
698714

699-
penalty_type = est.get_penalty_type(est.penalty)
700-
learning_rate_type = est.get_learning_rate_type(est.learning_rate)
715+
penalty_type = est._get_penalty_type(est.penalty)
716+
learning_rate_type = est._get_learning_rate_type(est.learning_rate)
701717

702718
return plain_sgd(coef, intercept, est.loss_function,
703719
penalty_type, est.alpha, est.rho,
704-
dataset, n_iter, est.fit_intercept,
705-
est.verbose, est.shuffle, est.seed,
720+
dataset, n_iter, int(est.fit_intercept),
721+
int(est.verbose), int(est.shuffle), est.seed,
706722
pos_weight, neg_weight,
707723
learning_rate_type, est.eta0,
708724
est.power_t, est.t_, intercept_decay)
@@ -812,6 +828,13 @@ class SGDRegressor(BaseSGD, RegressorMixin, SelectorMixin):
812828
Ridge, ElasticNet, Lasso, SVR
813829
814830
"""
831+
832+
loss_functions = {
833+
"squared_loss": (SquaredLoss, ),
834+
"huber": (Huber, DEFAULT_EPSILON),
835+
"epsilon_insensitive": (EpsilonInsensitive, DEFAULT_EPSILON)
836+
}
837+
815838
def __init__(self, loss="squared_loss", penalty="l2", alpha=0.0001,
816839
rho=0.85, fit_intercept=True, n_iter=5, shuffle=False, verbose=0,
817840
epsilon=0.1, p=None, seed=0, learning_rate="invscaling", eta0=0.01,
@@ -834,14 +857,6 @@ def __init__(self, loss="squared_loss", penalty="l2", alpha=0.0001,
834857
eta0=eta0, power_t=power_t,
835858
warm_start=False)
836859

837-
self.loss_functions = {
838-
"squared_loss": (SquaredLoss, ),
839-
"huber": (Huber, self.epsilon),
840-
"epsilon_insensitive": (EpsilonInsensitive, self.epsilon)
841-
}
842-
if loss not in self.loss_functions:
843-
raise ValueError("The loss %s is not supported. " % loss)
844-
845860
def _partial_fit(self, X, y, n_iter, sample_weight=None,
846861
coef_init=None, intercept_init=None):
847862
X, y = check_arrays(X, y, sparse_format="csr", copy=False,
@@ -851,6 +866,8 @@ def _partial_fit(self, X, y, n_iter, sample_weight=None,
851866
n_samples, n_features = X.shape
852867
_check_fit_data(X, y)
853868

869+
self._validate_params()
870+
854871
# Allocate datastructures from input arguments
855872
sample_weight = self._validate_sample_weight(sample_weight, n_samples)
856873

@@ -919,8 +936,7 @@ def fit(self, X, y, coef_init=None, intercept_init=None,
919936
self.coef_ = None
920937
self.intercept_ = None
921938

922-
# Need to re-initialize in case of multiple call to fit.
923-
#self._init_t()
939+
# Clear iteration count for multiple call to fit.
924940
self.t_ = None
925941

926942
return self._partial_fit(X, y, self.n_iter, sample_weight,
@@ -960,8 +976,8 @@ def _fit_regressor(self, X, y, sample_weight, n_iter):
960976
dataset, intercept_decay = _make_dataset(X, y, sample_weight)
961977

962978
loss_function = self.get_loss_function(self.loss)
963-
penalty_type = self.get_penalty_type(self.penalty)
964-
learning_rate_type = self.get_learning_rate_type(self.learning_rate)
979+
penalty_type = self._get_penalty_type(self.penalty)
980+
learning_rate_type = self._get_learning_rate_type(self.learning_rate)
965981

966982
if self.t_ is None:
967983
self._init_t(loss_function)

0 commit comments

Comments
 (0)
0