1212# License: BSD 3 clause
1313
1414from abc import ABCMeta , abstractmethod
15- from collections import Mapping , namedtuple , Sized , defaultdict , Sequence
15+ from collections import Mapping , namedtuple , defaultdict , Sequence
1616from functools import partial , reduce
1717from itertools import product
1818import operator
@@ -532,25 +532,41 @@ def inverse_transform(self, Xt):
532532 self ._check_is_fitted ('inverse_transform' )
533533 return self .best_estimator_ .transform (Xt )
534534
535- def _fit (self , X , y , groups , parameter_iterable ):
536- """Actual fitting, performing the search over parameters."""
535+ def fit (self , X , y = None , groups = None ):
536+ """Run fit with all sets of parameters.
537+
538+ Parameters
539+ ----------
540+
541+ X : array-like, shape = [n_samples, n_features]
542+ Training vector, where n_samples is the number of samples and
543+ n_features is the number of features.
537544
545+ y : array-like, shape = [n_samples] or [n_samples, n_output], optional
546+ Target relative to X for classification or regression;
547+ None for unsupervised learning.
548+
549+ groups : array-like, with shape (n_samples,), optional
550+ Group labels for the samples used while splitting the dataset into
551+ train/test set.
552+ """
538553 estimator = self .estimator
539554 cv = check_cv (self .cv , y , classifier = is_classifier (estimator ))
540555 self .scorer_ = check_scoring (self .estimator , scoring = self .scoring )
541556
542557 X , y , groups = indexable (X , y , groups )
543558 n_splits = cv .get_n_splits (X , y , groups )
544- if self .verbose > 0 and isinstance (parameter_iterable , Sized ):
545- n_candidates = len (parameter_iterable )
559+ # Regenerate parameter iterable for each fit
560+ candidate_params = list (self ._get_param_iterator ())
561+ n_candidates = len (candidate_params )
562+ if self .verbose > 0 :
546563 print ("Fitting {0} folds for each of {1} candidates, totalling"
547564 " {2} fits" .format (n_splits , n_candidates ,
548565 n_candidates * n_splits ))
549566
550567 base_estimator = clone (self .estimator )
551568 pre_dispatch = self .pre_dispatch
552569
553- cv_iter = list (cv .split (X , y , groups ))
554570 out = Parallel (
555571 n_jobs = self .n_jobs , verbose = self .verbose ,
556572 pre_dispatch = pre_dispatch
@@ -559,28 +575,25 @@ def _fit(self, X, y, groups, parameter_iterable):
559575 fit_params = self .fit_params ,
560576 return_train_score = self .return_train_score ,
561577 return_n_test_samples = True ,
562- return_times = True , return_parameters = True ,
578+ return_times = True , return_parameters = False ,
563579 error_score = self .error_score )
564- for parameters in parameter_iterable
565- for train , test in cv_iter )
580+ for train , test in cv . split ( X , y , groups )
581+ for parameters in candidate_params )
566582
567583 # if one choose to see train score, "out" will contain train score info
568584 if self .return_train_score :
569- (train_scores , test_scores , test_sample_counts ,
570- fit_time , score_time , parameters ) = zip (* out )
585+ (train_scores , test_scores , test_sample_counts , fit_time ,
586+ score_time ) = zip (* out )
571587 else :
572- (test_scores , test_sample_counts ,
573- fit_time , score_time , parameters ) = zip (* out )
574-
575- candidate_params = parameters [::n_splits ]
576- n_candidates = len (candidate_params )
588+ (test_scores , test_sample_counts , fit_time , score_time ) = zip (* out )
577589
578590 results = dict ()
579591
580592 def _store (key_name , array , weights = None , splits = False , rank = False ):
581593 """A small helper to store the scores/times to the cv_results_"""
582- array = np .array (array , dtype = np .float64 ).reshape (n_candidates ,
583- n_splits )
594+ # When iterated first by splits, then by parameters
595+ array = np .array (array , dtype = np .float64 ).reshape (n_splits ,
596+ n_candidates ).T
584597 if splits :
585598 for split_i in range (n_splits ):
586599 results ["split%d_%s"
@@ -600,7 +613,7 @@ def _store(key_name, array, weights=None, splits=False, rank=False):
600613
601614 # Computed the (weighted) mean and std for test scores alone
602615 # NOTE test_sample counts (weights) remain the same for all candidates
603- test_sample_counts = np .array (test_sample_counts [:n_splits ],
616+ test_sample_counts = np .array (test_sample_counts [:: n_candidates ],
604617 dtype = np .int )
605618
606619 _store ('test_score' , test_scores , splits = True , rank = True ,
@@ -924,25 +937,9 @@ def __init__(self, estimator, param_grid, scoring=None, fit_params=None,
924937 self .param_grid = param_grid
925938 _check_param_grid (param_grid )
926939
927- def fit (self , X , y = None , groups = None ):
928- """Run fit with all sets of parameters.
929-
930- Parameters
931- ----------
932-
933- X : array-like, shape = [n_samples, n_features]
934- Training vector, where n_samples is the number of samples and
935- n_features is the number of features.
936-
937- y : array-like, shape = [n_samples] or [n_samples, n_output], optional
938- Target relative to X for classification or regression;
939- None for unsupervised learning.
940-
941- groups : array-like, with shape (n_samples,), optional
942- Group labels for the samples used while splitting the dataset into
943- train/test set.
944- """
945- return self ._fit (X , y , groups , ParameterGrid (self .param_grid ))
940+ def _get_param_iterator (self ):
941+ """Return ParameterGrid instance for the given param_grid"""
942+ return ParameterGrid (self .param_grid )
946943
947944
948945class RandomizedSearchCV (BaseSearchCV ):
@@ -1167,24 +1164,8 @@ def __init__(self, estimator, param_distributions, n_iter=10, scoring=None,
11671164 pre_dispatch = pre_dispatch , error_score = error_score ,
11681165 return_train_score = return_train_score )
11691166
1170- def fit (self , X , y = None , groups = None ):
1171- """Run fit on the estimator with randomly drawn parameters.
1172-
1173- Parameters
1174- ----------
1175- X : array-like, shape = [n_samples, n_features]
1176- Training vector, where n_samples in the number of samples and
1177- n_features is the number of features.
1178-
1179- y : array-like, shape = [n_samples] or [n_samples, n_output], optional
1180- Target relative to X for classification or regression;
1181- None for unsupervised learning.
1182-
1183- groups : array-like, with shape (n_samples,), optional
1184- Group labels for the samples used while splitting the dataset into
1185- train/test set.
1186- """
1187- sampled_params = ParameterSampler (self .param_distributions ,
1188- self .n_iter ,
1189- random_state = self .random_state )
1190- return self ._fit (X , y , groups , sampled_params )
1167+ def _get_param_iterator (self ):
1168+ """Return ParameterSampler instance for the given distributions"""
1169+ return ParameterSampler (
1170+ self .param_distributions, self .n_iter ,
1171+ random_state = self .random_state )
0 commit comments