From b70ffb69d513153efd5a123c9ea320a9084ae027 Mon Sep 17 00:00:00 2001 From: Raghav RV Date: Fri, 25 Nov 2016 17:28:09 +0100 Subject: [PATCH 1/6] ENH Parallelize by candidates first then by splits. --- sklearn/model_selection/_search.py | 103 ++++++++++++----------------- 1 file changed, 42 insertions(+), 61 deletions(-) diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index d2f5542ebd32f..8cdc143a8d8d2 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -12,7 +12,7 @@ # License: BSD 3 clause from abc import ABCMeta, abstractmethod -from collections import Mapping, namedtuple, Sized, defaultdict, Sequence +from collections import Mapping, namedtuple, defaultdict, Sequence from functools import partial, reduce from itertools import product import operator @@ -532,17 +532,44 @@ def inverse_transform(self, Xt): self._check_is_fitted('inverse_transform') return self.best_estimator_.transform(Xt) - def _fit(self, X, y, groups, parameter_iterable): - """Actual fitting, performing the search over parameters.""" + @property + def _param_iterable(self): + """To generate parameter iterables for multiple iterations""" + if hasattr(self, 'param_grid'): + return ParameterGrid(self.param_grid) + else: + return ParameterSampler( + self.param_distributions, self.n_iter, + random_state=self.random_state) + + def fit(self, X, y=None, groups=None): + """Run fit with all sets of parameters. + + Parameters + ---------- + X : array-like, shape = [n_samples, n_features] + Training vector, where n_samples is the number of samples and + n_features is the number of features. + + y : array-like, shape = [n_samples] or [n_samples, n_output], optional + Target relative to X for classification or regression; + None for unsupervised learning. + + groups : array-like, with shape (n_samples,), optional + Group labels for the samples used while splitting the dataset into + train/test set. + """ estimator = self.estimator cv = check_cv(self.cv, y, classifier=is_classifier(estimator)) self.scorer_ = check_scoring(self.estimator, scoring=self.scoring) X, y, groups = indexable(X, y, groups) n_splits = cv.get_n_splits(X, y, groups) - if self.verbose > 0 and isinstance(parameter_iterable, Sized): - n_candidates = len(parameter_iterable) + # Regenerate parameter iterable for each fit + candidate_params = list(self._param_iterable) + n_candidates = len(candidate_params) + if self.verbose > 0: print("Fitting {0} folds for each of {1} candidates, totalling" " {2} fits".format(n_splits, n_candidates, n_candidates * n_splits)) @@ -550,7 +577,6 @@ def _fit(self, X, y, groups, parameter_iterable): base_estimator = clone(self.estimator) pre_dispatch = self.pre_dispatch - cv_iter = list(cv.split(X, y, groups)) out = Parallel( n_jobs=self.n_jobs, verbose=self.verbose, pre_dispatch=pre_dispatch @@ -559,28 +585,25 @@ def _fit(self, X, y, groups, parameter_iterable): fit_params=self.fit_params, return_train_score=self.return_train_score, return_n_test_samples=True, - return_times=True, return_parameters=True, + return_times=True, return_parameters=False, error_score=self.error_score) - for parameters in parameter_iterable - for train, test in cv_iter) + for train, test in cv.split(X, y, groups) + for parameters in candidate_params) # if one choose to see train score, "out" will contain train score info if self.return_train_score: - (train_scores, test_scores, test_sample_counts, - fit_time, score_time, parameters) = zip(*out) + (train_scores, test_scores, test_sample_counts, fit_time, + score_time) = zip(*out) else: - (test_scores, test_sample_counts, - fit_time, score_time, parameters) = zip(*out) - - candidate_params = parameters[::n_splits] - n_candidates = len(candidate_params) + (test_scores, test_sample_counts, fit_time, score_time) = zip(*out) results = dict() def _store(key_name, array, weights=None, splits=False, rank=False): """A small helper to store the scores/times to the cv_results_""" - array = np.array(array, dtype=np.float64).reshape(n_candidates, - n_splits) + # When iterated first by splits, then by parameters + array = np.array(array, dtype=np.float64).reshape(n_splits, + n_candidates).T if splits: for split_i in range(n_splits): results["split%d_%s" @@ -600,7 +623,7 @@ def _store(key_name, array, weights=None, splits=False, rank=False): # Computed the (weighted) mean and std for test scores alone # NOTE test_sample counts (weights) remain the same for all candidates - test_sample_counts = np.array(test_sample_counts[:n_splits], + test_sample_counts = np.array(test_sample_counts[::n_candidates], dtype=np.int) _store('test_score', test_scores, splits=True, rank=True, @@ -924,26 +947,6 @@ def __init__(self, estimator, param_grid, scoring=None, fit_params=None, self.param_grid = param_grid _check_param_grid(param_grid) - def fit(self, X, y=None, groups=None): - """Run fit with all sets of parameters. - - Parameters - ---------- - - X : array-like, shape = [n_samples, n_features] - Training vector, where n_samples is the number of samples and - n_features is the number of features. - - y : array-like, shape = [n_samples] or [n_samples, n_output], optional - Target relative to X for classification or regression; - None for unsupervised learning. - - groups : array-like, with shape (n_samples,), optional - Group labels for the samples used while splitting the dataset into - train/test set. - """ - return self._fit(X, y, groups, ParameterGrid(self.param_grid)) - class RandomizedSearchCV(BaseSearchCV): """Randomized search on hyper parameters. @@ -1166,25 +1169,3 @@ def __init__(self, estimator, param_distributions, n_iter=10, scoring=None, n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose, pre_dispatch=pre_dispatch, error_score=error_score, return_train_score=return_train_score) - - def fit(self, X, y=None, groups=None): - """Run fit on the estimator with randomly drawn parameters. - - Parameters - ---------- - X : array-like, shape = [n_samples, n_features] - Training vector, where n_samples in the number of samples and - n_features is the number of features. - - y : array-like, shape = [n_samples] or [n_samples, n_output], optional - Target relative to X for classification or regression; - None for unsupervised learning. - - groups : array-like, with shape (n_samples,), optional - Group labels for the samples used while splitting the dataset into - train/test set. - """ - sampled_params = ParameterSampler(self.param_distributions, - self.n_iter, - random_state=self.random_state) - return self._fit(X, y, groups, sampled_params) From 740d3d214f97c5c4c9d3e455f4a41a566cb4ff1e Mon Sep 17 00:00:00 2001 From: Raghav RV Date: Fri, 25 Nov 2016 16:48:43 +0100 Subject: [PATCH 2/6] ENH do not materialize a cv iterator to avoid memory blow ups. --- sklearn/model_selection/_validation.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index 23db2a9cebc77..88c3922f99363 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -128,7 +128,6 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, X, y, groups = indexable(X, y, groups) cv = check_cv(cv, y, classifier=is_classifier(estimator)) - cv_iter = list(cv.split(X, y, groups)) scorer = check_scoring(estimator, scoring=scoring) # We clone the estimator to make sure that all the folds are # independent, and that it is pickle-able. @@ -137,7 +136,7 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y, scorer, train, test, verbose, None, fit_params) - for train, test in cv_iter) + for train, test in cv.split(X, y, groups)) return np.array(scores)[:, 0] @@ -385,7 +384,6 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1, X, y, groups = indexable(X, y, groups) cv = check_cv(cv, y, classifier=is_classifier(estimator)) - cv_iter = list(cv.split(X, y, groups)) # Ensure the estimator has implemented the passed decision function if not callable(getattr(estimator, method)): @@ -398,7 +396,7 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1, pre_dispatch=pre_dispatch) prediction_blocks = parallel(delayed(_fit_and_predict)( clone(estimator), X, y, train, test, verbose, fit_params, method) - for train, test in cv_iter) + for train, test in cv.split(X, y, groups)) # Concatenate the predictions predictions = [pred_block_i for pred_block_i, _ in prediction_blocks] @@ -752,8 +750,9 @@ def learning_curve(estimator, X, y, groups=None, X, y, groups = indexable(X, y, groups) cv = check_cv(cv, y, classifier=is_classifier(estimator)) - # Make a list since we will be iterating multiple times over the folds + # Store it as list as we will be iterating over the list multiple times cv_iter = list(cv.split(X, y, groups)) + scorer = check_scoring(estimator, scoring=scoring) n_max_training_samples = len(cv_iter[0][0]) @@ -961,8 +960,6 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None, X, y, groups = indexable(X, y, groups) cv = check_cv(cv, y, classifier=is_classifier(estimator)) - cv_iter = list(cv.split(X, y, groups)) - scorer = check_scoring(estimator, scoring=scoring) parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch, @@ -970,7 +967,8 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None, out = parallel(delayed(_fit_and_score)( estimator, X, y, scorer, train, test, verbose, parameters={param_name: v}, fit_params=None, return_train_score=True) - for train, test in cv_iter for v in param_range) + # NOTE do not change order of iteration to allow one time cv splitters + for train, test in cv.split(X, y, groups) for v in param_range) out = np.asarray(out) n_params = len(param_range) From de43c9266510835677941144f61f8e2eccf6f2e0 Mon Sep 17 00:00:00 2001 From: Raghav RV Date: Fri, 2 Dec 2016 18:13:52 +0100 Subject: [PATCH 3/6] Use real polymorphism --- sklearn/model_selection/_search.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 8cdc143a8d8d2..76bd9bc0e00cb 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -532,16 +532,6 @@ def inverse_transform(self, Xt): self._check_is_fitted('inverse_transform') return self.best_estimator_.transform(Xt) - @property - def _param_iterable(self): - """To generate parameter iterables for multiple iterations""" - if hasattr(self, 'param_grid'): - return ParameterGrid(self.param_grid) - else: - return ParameterSampler( - self.param_distributions, self.n_iter, - random_state=self.random_state) - def fit(self, X, y=None, groups=None): """Run fit with all sets of parameters. @@ -947,6 +937,11 @@ def __init__(self, estimator, param_grid, scoring=None, fit_params=None, self.param_grid = param_grid _check_param_grid(param_grid) + @property + def _param_iterable(self): + """To generate parameter iterables for multiple iterations""" + return ParameterGrid(self.param_grid) + class RandomizedSearchCV(BaseSearchCV): """Randomized search on hyper parameters. @@ -1169,3 +1164,10 @@ def __init__(self, estimator, param_distributions, n_iter=10, scoring=None, n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose, pre_dispatch=pre_dispatch, error_score=error_score, return_train_score=return_train_score) + + @property + def _param_iterable(self): + """To generate parameter iterables for multiple iterations""" + return ParameterSampler( + self.param_distributions, self.n_iter, + random_state=self.random_state) From 932891c6311c6f5e2a6440d8663a409156bb57e3 Mon Sep 17 00:00:00 2001 From: Raghav RV Date: Mon, 5 Dec 2016 15:59:29 +0100 Subject: [PATCH 4/6] Use getter instead of property --- sklearn/model_selection/_search.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 76bd9bc0e00cb..cfb0b92ec67cc 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -557,7 +557,7 @@ def fit(self, X, y=None, groups=None): X, y, groups = indexable(X, y, groups) n_splits = cv.get_n_splits(X, y, groups) # Regenerate parameter iterable for each fit - candidate_params = list(self._param_iterable) + candidate_params = list(self.get_param_iterable()) n_candidates = len(candidate_params) if self.verbose > 0: print("Fitting {0} folds for each of {1} candidates, totalling" @@ -937,10 +937,9 @@ def __init__(self, estimator, param_grid, scoring=None, fit_params=None, self.param_grid = param_grid _check_param_grid(param_grid) - @property - def _param_iterable(self): - """To generate parameter iterables for multiple iterations""" - return ParameterGrid(self.param_grid) + def get_param_iterable(self): + """Return ParameterGrid instance for the given param_grid""" + return ParameterGrid(self.param_grid) class RandomizedSearchCV(BaseSearchCV): @@ -1165,9 +1164,8 @@ def __init__(self, estimator, param_distributions, n_iter=10, scoring=None, pre_dispatch=pre_dispatch, error_score=error_score, return_train_score=return_train_score) - @property - def _param_iterable(self): - """To generate parameter iterables for multiple iterations""" - return ParameterSampler( - self.param_distributions, self.n_iter, - random_state=self.random_state) + def get_param_iterable(self): + """Return ParameterSampler instance for the given distributions""" + return ParameterSampler( + self.param_distributions, self.n_iter, + random_state=self.random_state) From 2bd348fda6c75fe5420f8c278786f235cb1ce48b Mon Sep 17 00:00:00 2001 From: Raghav RV Date: Tue, 6 Dec 2016 12:56:30 +0100 Subject: [PATCH 5/6] param_iterable --> param_iterator --- sklearn/model_selection/_search.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index cfb0b92ec67cc..d0e21c9ed87db 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -557,7 +557,7 @@ def fit(self, X, y=None, groups=None): X, y, groups = indexable(X, y, groups) n_splits = cv.get_n_splits(X, y, groups) # Regenerate parameter iterable for each fit - candidate_params = list(self.get_param_iterable()) + candidate_params = list(self.get_param_iterator()) n_candidates = len(candidate_params) if self.verbose > 0: print("Fitting {0} folds for each of {1} candidates, totalling" @@ -937,7 +937,7 @@ def __init__(self, estimator, param_grid, scoring=None, fit_params=None, self.param_grid = param_grid _check_param_grid(param_grid) - def get_param_iterable(self): + def get_param_iterator(self): """Return ParameterGrid instance for the given param_grid""" return ParameterGrid(self.param_grid) @@ -1164,7 +1164,7 @@ def __init__(self, estimator, param_distributions, n_iter=10, scoring=None, pre_dispatch=pre_dispatch, error_score=error_score, return_train_score=return_train_score) - def get_param_iterable(self): + def get_param_iterator(self): """Return ParameterSampler instance for the given distributions""" return ParameterSampler( self.param_distributions, self.n_iter, From 1efe6482c540d635f9f560a40c342888b94ef620 Mon Sep 17 00:00:00 2001 From: Raghav RV Date: Wed, 7 Dec 2016 00:06:02 +0100 Subject: [PATCH 6/6] Make get_param_iterator pvt. --- sklearn/model_selection/_search.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index d0e21c9ed87db..e1d744ceab6ca 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -557,7 +557,7 @@ def fit(self, X, y=None, groups=None): X, y, groups = indexable(X, y, groups) n_splits = cv.get_n_splits(X, y, groups) # Regenerate parameter iterable for each fit - candidate_params = list(self.get_param_iterator()) + candidate_params = list(self._get_param_iterator()) n_candidates = len(candidate_params) if self.verbose > 0: print("Fitting {0} folds for each of {1} candidates, totalling" @@ -937,7 +937,7 @@ def __init__(self, estimator, param_grid, scoring=None, fit_params=None, self.param_grid = param_grid _check_param_grid(param_grid) - def get_param_iterator(self): + def _get_param_iterator(self): """Return ParameterGrid instance for the given param_grid""" return ParameterGrid(self.param_grid) @@ -1164,7 +1164,7 @@ def __init__(self, estimator, param_distributions, n_iter=10, scoring=None, pre_dispatch=pre_dispatch, error_score=error_score, return_train_score=return_train_score) - def get_param_iterator(self): + def _get_param_iterator(self): """Return ParameterSampler instance for the given distributions""" return ParameterSampler( self.param_distributions, self.n_iter,