@@ -374,7 +374,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
374
374
def __init__ (self , estimator , scoring = None ,
375
375
fit_params = None , n_jobs = 1 , iid = True ,
376
376
refit = True , cv = None , verbose = 0 , pre_dispatch = '2*n_jobs' ,
377
- error_score = 'raise' ):
377
+ error_score = 'raise' , return_train_score = False ):
378
378
379
379
self .scoring = scoring
380
380
self .estimator = estimator
@@ -386,6 +386,7 @@ def __init__(self, estimator, scoring=None,
386
386
self .verbose = verbose
387
387
self .pre_dispatch = pre_dispatch
388
388
self .error_score = error_score
389
+ self .return_train_score = return_train_score
389
390
390
391
@property
391
392
def _estimator_type (self ):
@@ -551,16 +552,28 @@ def _fit(self, X, y, groups, parameter_iterable):
551
552
pre_dispatch = pre_dispatch
552
553
)(delayed (_fit_and_score )(clone (base_estimator ), X , y , self .scorer_ ,
553
554
train , test , self .verbose , parameters ,
554
- self .fit_params , return_parameters = True ,
555
+ self .fit_params ,
556
+ return_train_score = self .return_train_score ,
557
+ return_parameters = True ,
555
558
error_score = self .error_score )
556
559
for parameters in parameter_iterable
557
560
for train , test in cv .split (X , y , groups ))
558
561
559
- test_scores , test_sample_counts , _ , parameters = zip (* out )
562
+ # if one choose to see train score, out will have train score info.
563
+ if self .return_train_score :
564
+ train_scores , test_scores , test_sample_counts , _ , parameters = \
565
+ zip (* out )
566
+ else :
567
+ test_scores , test_sample_counts , _ , parameters = zip (* out )
560
568
561
569
candidate_params = parameters [::n_splits ]
562
570
n_candidates = len (candidate_params )
563
571
572
+ # if one choose to return train score, reshape the train_scores array
573
+ if self .return_train_score :
574
+ train_scores = np .array (train_scores ,
575
+ dtype = np .float64 ).reshape (n_candidates ,
576
+ n_splits )
564
577
test_scores = np .array (test_scores ,
565
578
dtype = np .float64 ).reshape (n_candidates ,
566
579
n_splits )
@@ -570,9 +583,21 @@ def _fit(self, X, y, groups, parameter_iterable):
570
583
571
584
# Computed the (weighted) mean and std for all the candidates
572
585
weights = test_sample_counts if self .iid else None
573
- means = np .average (test_scores , axis = 1 , weights = weights )
574
- stds = np .sqrt (np .average ((test_scores - means [:, np .newaxis ]) ** 2 ,
575
- axis = 1 , weights = weights ))
586
+
587
+ time = np .array (_ , dtype = np .float64 ).reshape (n_candidates , n_splits )
588
+ time_means = np .average (time , axis = 1 , weights = weights )
589
+ time_stds = np .sqrt (
590
+ np .average ((time - time_means [:, np .newaxis ]) ** 2 ,
591
+ axis = 1 , weights = weights ))
592
+ if self .return_train_score :
593
+ train_means = np .average (train_scores , axis = 1 , weights = weights )
594
+ train_stds = np .sqrt (
595
+ np .average ((train_scores - train_means [:, np .newaxis ]) ** 2 ,
596
+ axis = 1 , weights = weights ))
597
+ test_means = np .average (test_scores , axis = 1 , weights = weights )
598
+ test_stds = np .sqrt (
599
+ np .average ((test_scores - test_means [:, np .newaxis ]) ** 2 , axis = 1 ,
600
+ weights = weights ))
576
601
577
602
cv_results = dict ()
578
603
for split_i in range (n_splits ):
@@ -581,7 +606,17 @@ def _fit(self, X, y, groups, parameter_iterable):
581
606
cv_results ["mean_test_score" ] = means
582
607
cv_results ["std_test_score" ] = stds
583
608
584
- ranks = np .asarray (rankdata (- means , method = 'min' ), dtype = np .int32 )
609
+ if self .return_train_score :
610
+ for split_i in range (n_splits ):
611
+ results ["train_split%d_score" % split_i ] = (
612
+ train_scores [:, split_i ])
613
+ results ["mean_train_score" ] = train_means
614
+ results ["std_train_scores" ] = train_stds
615
+
616
+ results ["mean_test_time" ] = time_means
617
+ results ["std_test_time" ] = time_stds
618
+
619
+ ranks = np .asarray (rankdata (- test_means , method = 'min' ), dtype = np .int32 )
585
620
586
621
best_index = np .flatnonzero (ranks == 1 )[0 ]
587
622
best_parameters = candidate_params [best_index ]
@@ -868,11 +903,13 @@ class GridSearchCV(BaseSearchCV):
868
903
869
904
def __init__ (self , estimator , param_grid , scoring = None , fit_params = None ,
870
905
n_jobs = 1 , iid = True , refit = True , cv = None , verbose = 0 ,
871
- pre_dispatch = '2*n_jobs' , error_score = 'raise' ):
906
+ pre_dispatch = '2*n_jobs' , error_score = 'raise' ,
907
+ return_train_score = False ):
872
908
super (GridSearchCV , self ).__init__ (
873
909
estimator = estimator , scoring = scoring , fit_params = fit_params ,
874
910
n_jobs = n_jobs , iid = iid , refit = refit , cv = cv , verbose = verbose ,
875
- pre_dispatch = pre_dispatch , error_score = error_score )
911
+ pre_dispatch = pre_dispatch , error_score = error_score ,
912
+ return_train_score = return_train_score )
876
913
self .param_grid = param_grid
877
914
_check_param_grid (param_grid )
878
915
@@ -1094,15 +1131,15 @@ class RandomizedSearchCV(BaseSearchCV):
1094
1131
def __init__ (self , estimator , param_distributions , n_iter = 10 , scoring = None ,
1095
1132
fit_params = None , n_jobs = 1 , iid = True , refit = True , cv = None ,
1096
1133
verbose = 0 , pre_dispatch = '2*n_jobs' , random_state = None ,
1097
- error_score = 'raise' ):
1098
-
1134
+ error_score = 'raise' , return_train_score = False ):
1099
1135
self .param_distributions = param_distributions
1100
1136
self .n_iter = n_iter
1101
1137
self .random_state = random_state
1102
1138
super (RandomizedSearchCV , self ).__init__ (
1103
- estimator = estimator , scoring = scoring , fit_params = fit_params ,
1104
- n_jobs = n_jobs , iid = iid , refit = refit , cv = cv , verbose = verbose ,
1105
- pre_dispatch = pre_dispatch , error_score = error_score )
1139
+ estimator = estimator , scoring = scoring , fit_params = fit_params ,
1140
+ n_jobs = n_jobs , iid = iid , refit = refit , cv = cv , verbose = verbose ,
1141
+ pre_dispatch = pre_dispatch , error_score = error_score ,
1142
+ return_train_score = return_train_score )
1106
1143
1107
1144
def fit (self , X , y = None , groups = None ):
1108
1145
"""Run fit on the estimator with randomly drawn parameters.
0 commit comments