@@ -273,8 +273,12 @@ def _initialize(self, y, layer_units):
273
273
self .intercepts_ ]
274
274
self ._coef_velocity = [np .zeros_like (coefs ) for coefs in
275
275
self .coefs_ ]
276
+ self ._no_improvement_count = 0
276
277
if self .early_stopping :
277
278
self .validation_scores_ = []
279
+ self .best_validation_score_ = - np .inf
280
+ else :
281
+ self .best_loss_ = np .inf
278
282
279
283
def _init_coef (self , fan_in , fan_out , rng ):
280
284
if self .activation == 'logistic' :
@@ -424,7 +428,9 @@ def _fit_sgd(self, X, y, activations, deltas, coef_grads, intercept_grads,
424
428
# early_stopping in partial_fit doesn't make sense
425
429
early_stopping = self .early_stopping and not incremental
426
430
if early_stopping :
427
- X , X_val , y , y_val = train_test_split (X , y , random_state = self .random_state )
431
+ X , X_val , y , y_val = train_test_split (X , y ,
432
+ random_state = self .random_state ,
433
+ test_size = .1 )
428
434
y_val = self .label_binarizer_ .inverse_transform (y_val )
429
435
430
436
n_samples = X .shape [0 ]
@@ -476,30 +482,44 @@ def _fit_sgd(self, X, y, activations, deltas, coef_grads, intercept_grads,
476
482
if self .learning_rate == 'invscaling' :
477
483
self .learning_rate_ = (self .learning_rate_init /
478
484
(self .t_ + 1 ) ** self .power_t )
479
- # stopping criteria
485
+ # validation set evaluation
480
486
if early_stopping :
481
487
# compute validation score, use that for stopping
482
488
self .validation_scores_ .append (self .score (X_val , y_val ))
483
489
484
490
if self .verbose :
485
491
print ("Validation score: %f" % (self .validation_scores_ [- 1 ]))
492
+ # update best parameters
486
493
# use validation_scores_, not loss_curve_
487
494
# let's hope no-one overloads .score with mse
488
- sign = - 1
489
- losses = self .validation_scores_
495
+ if self .validation_scores_ [- 1 ] > self .best_validation_score_ :
496
+ self .best_validation_score_ = self .validation_scores_ [- 1 ]
497
+ self ._best_coefs = [c for c in self .coefs_ ]
498
+ self ._best_intercepts = [i for i in self .intercepts_ ]
499
+
500
+ if self .validation_scores_ [- 1 ] < self .best_validation_score_ + self .tol :
501
+ self ._no_improvement_count += 1
502
+ else :
503
+ self ._no_improvement_count = 0
504
+
490
505
else :
491
- sign = 1
492
- losses = self .loss_curve_
506
+ if self .loss_curve_ [- 1 ] < self .best_loss_ :
507
+ self .best_loss_ = self .loss_curve_ [- 1 ]
508
+ if self .loss_curve_ [- 1 ] > self .best_loss_ - self .tol :
509
+ self ._no_improvement_count += 1
510
+ else :
511
+ self ._no_improvement_count = 0
493
512
494
- if len ( losses ) > 3 and np . all ( sign * np . array ( losses [ - 6 : - 1 ])
495
- < sign * losses [ - 1 ] + self .tol ) :
513
+ # stopping criteria
514
+ if self ._no_improvement_count > 2 :
496
515
# not better than last two iterations by tol.
497
516
# stop or decreate learning rate
498
- msg = ("Training loss did not improve more than tol for five "
517
+ msg = ("Training loss did not improve more than tol for two "
499
518
" consecutive epochs." )
500
519
if self .learning_rate == 'adaptive' :
501
520
if self .learning_rate_ > 1e-6 :
502
521
self .learning_rate_ /= 5
522
+ self ._no_improvement_count = 0
503
523
if self .verbose :
504
524
print (msg + " Setting learning rate to %f" % self .learning_rate_ )
505
525
else :
@@ -522,6 +542,11 @@ def _fit_sgd(self, X, y, activations, deltas, coef_grads, intercept_grads,
522
542
except KeyboardInterrupt :
523
543
pass
524
544
545
+ if early_stopping :
546
+ # restore best weights
547
+ self .coefs_ = self ._best_coefs
548
+ self .intercepts_ = self ._best_intercepts
549
+
525
550
def fit (self , X , y ):
526
551
"""Fit the model to the data X and target y.
527
552
0 commit comments