10000 move get_loss_function to _partial_fit · jwchennlp/scikit-learn@ee838a2 · GitHub
[go: up one dir, main page]

Skip to content

Commit ee838a2

Browse files
pprettGaelVaroquaux
authored andcommitted
move get_loss_function to _partial_fit
1 parent 1791575 commit ee838a2

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

sklearn/linear_model/stochastic_gradient.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,10 @@ def _partial_fit(self, X, y, n_iter, classes=None, sample_weight=None,
426426
self._allocate_parameter_mem(n_classes, n_features,
427427
coef_init, intercept_init)
428428

429+
self.loss_function = self.get_loss_function(self.loss)
430+
if self.t_ is None:
431+
self._init_t(self.loss_function)
432+
429433
# delegate to concrete training procedure
430434
if n_classes > 2:
431435
self._fit_multiclass(X, y, sample_weight, n_iter)
@@ -526,8 +530,7 @@ def fit(self, X, y, coef_init=None, intercept_init=None,
526530
self.coef_ = None
527531
self.intercept_ = None
528532

529-
# Need to re-initialize in case of multiple call to fit.
530-
#self._init_t()
533+
# Clear iteration count for multiple call to fit.
531534
self.t_ = None
532535

533536
self._partial_fit(X, y, self.n_iter, classes,
@@ -666,7 +669,7 @@ def _fit_multiclass(self, X, y, sample_weight, n_iter):
666669

667670

668671
def _prepare_fit_binary(est, y, i):
669-
"""Common initialization for _fit_binary_{dense,sparse}.
672+
"""Initialization for fit_binary.
670673
671674
Returns y, coef, intercept.
672675
"""
@@ -693,14 +696,10 @@ def fit_binary(est, i, X, y, n_iter, pos_weight, neg_weight,
693696
assert y_i.shape[0] == y.shape[0] == sample_weight.shape[0]
694697
dataset, intercept_decay = _make_dataset(X, y_i, sample_weight)
695698

696-
loss_function = est.get_loss_function(est.loss)
697699
penalty_type = est.get_penalty_type(est.penalty)
698700
learning_rate_type = est.get_learning_rate_type(est.learning_rate)
699701

700-
if est.t_ is None:
701-
est._init_t(loss_function)
702-
703-
return plain_sgd(coef, intercept, loss_function,
702+
return plain_sgd(coef, intercept, est.loss_function,
704703
penalty_type, est.alpha, est.rho,
705704
dataset, n_iter, est.fit_intercept,
706705
est.verbose, est.shuffle, est.seed,

0 commit comments

Comments
 (0)
0