@@ -426,6 +426,10 @@ def _partial_fit(self, X, y, n_iter, classes=None, sample_weight=None,
426
426
self ._allocate_parameter_mem (n_classes , n_features ,
427
427
coef_init , intercept_init )
428
428
429
+ self .loss_function = self .get_loss_function (self .loss )
430
+ if self .t_ is None :
431
+ self ._init_t (self .loss_function )
432
+
429
433
# delegate to concrete training procedure
430
434
if n_classes > 2 :
431
435
self ._fit_multiclass (X , y , sample_weight , n_iter )
@@ -526,8 +530,7 @@ def fit(self, X, y, coef_init=None, intercept_init=None,
526
530
self .coef_ = None
527
531
self .intercept_ = None
528
532
529
- # Need to re-initialize in case of multiple call to fit.
530
- #self._init_t()
533
+ # Clear iteration count for multiple call to fit.
531
534
self .t_ = None
532
535
533
536
self ._partial_fit (X , y , self .n_iter , classes ,
@@ -666,7 +669,7 @@ def _fit_multiclass(self, X, y, sample_weight, n_iter):
666
669
667
670
668
671
def _prepare_fit_binary (est , y , i ):
669
- """Common initialization for _fit_binary_{dense,sparse} .
672
+ """Initialization for fit_binary .
670
673
671
674
Returns y, coef, intercept.
672
675
"""
@@ -693,14 +696,10 @@ def fit_binary(est, i, X, y, n_iter, pos_weight, neg_weight,
693
696
assert y_i .shape [0 ] == y .shape [0 ] == sample_weight .shape [0 ]
694
697
dataset , intercept_decay = _make_dataset (X , y_i , sample_weight )
695
698
696
- loss_function = est .get_loss_function (est .loss )
697
699
penalty_type = est .get_penalty_type (est .penalty )
698
700
learning_rate_type = est .get_learning_rate_type (est .learning_rate )
699
701
700
- if est .t_ is None :
701
- est ._init_t (loss_function )
702
-
703
- return plain_sgd (coef , intercept , loss_function ,
702
+ return plain_sgd (coef , intercept , est .loss_function ,
704
703
penalty_type , est .alpha , est .rho ,
705
704
dataset , n_iter , est .fit_intercept ,
706
705
est .verbose , est .shuffle , est .seed ,
0 commit comments