@@ -373,7 +373,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
373
373
def __init__ (self , estimator , scoring = None ,
374
374
fit_params = None , n_jobs = 1 , iid = True ,
375
375
refit = True , cv = None , verbose = 0 , pre_dispatch = '2*n_jobs' ,
376
- error_score = 'raise' ):
376
+ error_score = 'raise' , return_train_score = False ):
377
377
378
378
self .scoring = scoring
379
379
self .estimator = estimator
@@ -385,6 +385,7 @@ def __init__(self, estimator, scoring=None,
385
385
self .verbose = verbose
386
386
self .pre_dispatch = pre_dispatch
387
387
self .error_score = error_score
388
+ self .return_train_score = return_train_score
388
389
389
390
@property
390
391
def _estimator_type (self ):
@@ -550,16 +551,28 @@ def _fit(self, X, y, labels, parameter_iterable):
550
551
pre_dispatch = pre_dispatch
551
552
)(delayed (_fit_and_score )(clone (base_estimator ), X , y , self .scorer_ ,
552
553
train , test , self .verbose , parameters ,
553
- self .fit_params , return_parameters = True ,
554
+ self .fit_params ,
555
+ return_train_score = self .return_train_score ,
556
+ return_parameters = True ,
554
557
error_score = self .error_score )
555
558
for parameters in parameter_iterable
556
559
for train , test in cv .split (X , y , labels ))
557
560
558
- test_scores , test_sample_counts , _ , parameters = zip (* out )
561
+ # if one choose to see train score, out will have train score info.
562
+ if self .return_train_score :
563
+ train_scores , test_scores , test_sample_counts , _ , parameters = \
564
+ zip (* out )
565
+ else :
566
+ test_scores , test_sample_counts , _ , parameters = zip (* out )
559
567
560
568
candidate_params = parameters [::n_splits ]
561
569
n_candidates = len (candidate_params )
562
570
571
+ # if one choose to return train score, reshape the train_scores array
572
+ if self .return_train_score :
573
+ train_scores = np .array (train_scores ,
574
+ dtype = np .float64 ).reshape (n_candidates ,
575
+ n_splits )
563
576
test_scores = np .array (test_scores ,
564
577
dtype = np .float64 ).reshape (n_candidates ,
565
578
n_splits )
@@ -569,17 +582,39 @@ def _fit(self, X, y, labels, parameter_iterable):
569
582
570
583
# Computed the (weighted) mean and std for all the candidates
571
584
weights = test_sample_counts if self .iid else None
572
- means = np .average (test_scores , axis = 1 , weights = weights )
573
- stds = np .sqrt (np .average ((test_scores - means [:, np .newaxis ]) ** 2 ,
574
- axis = 1 , weights = weights ))
585
+
586
+ time = np .array (_ , dtype = np .float64 ).reshape (n_candidates , n_splits )
587
+ time_means = np .average (time , axis = 1 , weights = weights )
588
+ time_stds = np .sqrt (
589
+ np .average ((time - time_means [:, np .newaxis ]) ** 2 ,
590
+ axis = 1 , weights = weights ))
591
+ if self .return_train_score :
592
+ train_means = np .average (train_scores , axis = 1 , weights = weights )
593
+ train_stds = np .sqrt (
594
+ np .average ((train_scores - train_means [:, np .newaxis ]) ** 2 ,
595
+ axis = 1 , weights = weights ))
596
+ test_means = np .average (test_scores , axis = 1 , weights = weights )
597
+ test_stds = np .sqrt (
598
+ np .average ((test_scores - test_means [:, np .newaxis ]) ** 2 , axis = 1 ,
599
+ weights = weights ))
575
600
576
601
results = dict ()
577
602
for split_i in range (n_splits ):
578
603
results ["test_split%d_score" % split_i ] = test_scores [:, split_i ]
579
- results ["test_mean_score" ] = means
580
- results ["test_std_score" ] = stds
604
+ results ["test_mean_score" ] = test_means
605
+ results ["test_std_score" ] = test_stds
606
+
607
+ if self .return_train_score :
608
+ for split_i in range (n_splits ):
609
+ results ["train_split%d_score" % split_i ] = \
610
+ train_scores [:, split_i ]
611
+ results ["train_mean_score" ] = train_means
612
+ results ["train_std_score" ] = train_stds
613
+
614
+ results ["test_mean_time" ] = time_means
615
+ results ["test_std_time" ] = time_stds
581
616
582
- ranks = np .asarray (rankdata (- means , method = 'min' ), dtype = np .int32 )
617
+ ranks = np .asarray(rankdata (- test_means , method = 'min' ), dtype = np .int32 )
583
618
584
619
best_index = np .flatnonzero (ranks == 1 )[0 ]
585
620
best_parameters = candidate_params [best_index ]
@@ -865,11 +900,13 @@ class GridSearchCV(BaseSearchCV):
865
900
866
901
def __init__ (self , estimator , param_grid , scoring = None , fit_params = None ,
867
902
n_jobs = 1 , iid = True , refit = True , cv = None , verbose = 0 ,
868
- pre_dispatch = '2*n_jobs' , error_score = 'raise' ):
903
+ pre_dispatch = '2*n_jobs' , error_score = 'raise' ,
904
+ return_train_score = False ):
869
905
super (GridSearchCV , self ).__init__ (
870
906
estimator = estimator , scoring = scoring , fit_params = fit_params ,
871
907
n_jobs = n_jobs , iid = iid , refit = refit , cv = cv , verbose = verbose ,
872
- pre_dispatch = pre_dispatch , error_score = error_score )
908
+ pre_dispatch = pre_dispatch , error_score = error_score ,
909
+ return_train_score = return_train_score )
873
910
self .param_grid = param_grid
874
911
_check_param_grid (param_grid )
875
912
@@ -1091,15 +1128,15 @@ class RandomizedSearchCV(BaseSearchCV):
1091
1128
def __init__ (self , estimator , param_distributions , n_iter = 10 , scoring = None ,
1092
1129
fit_params = None , n_jobs = 1 , iid = True , refit = True , cv = None ,
1093
1130
verbose = 0 , pre_dispatch = '2*n_jobs' , random_state = None ,
1094
- error_score = 'raise' ):
1095
-
1131
+ error_score = 'raise' , return_train_score = False ):
1096
1132
self .param_distributions = param_distributions
1097
1133
self .n_iter = n_iter
1098
1134
self .random_state = random_state
1099
1135
super (RandomizedSearchCV , self ).__init__ (
1100
- estimator = estimator , scoring = scoring , fit_params = fit_params ,
1101
- n_jobs = n_jobs , iid = iid , refit = refit , cv = cv , verbose = verbose ,
1102
- pre_dispatch = pre_dispatch , error_score = error_score )
1136
+ estimator = estimator , scoring = scoring , fit_params = fit_params ,
1137
+ n_jobs = n_jobs , iid = iid , refit = refit , cv = cv , verbose = verbose ,
1138
+ pre_dispatch = pre_dispatch , error_score = error_score ,
1139
+ return_train_score = return_train_score )
1103
1140
1104
1141
def fit (self , X , y = None , labels = None ):
1105
1142
"""Run fit on the estimator with randomly drawn parameters.
@@ -1121,4 +1158,4 @@ def fit(self, X, y=None, labels=None):
1121
1158
sampled_params = ParameterSampler (self .param_distributions ,
1122
1159
self .n_iter ,
1123
1160
random_state = self .random_state )
1124
- return self ._fit (X , y , labels , sampled_params )
1161
+ return self ._fit (X , y , labels , sampled_params )
0 commit comments