diff --git a/doc/tutorial/statistical_inference/model_selection.rst b/doc/tutorial/statistical_inference/model_selection.rst index 65ae1c431466e..ec651f3f1dd09 100644 --- a/doc/tutorial/statistical_inference/model_selection.rst +++ b/doc/tutorial/statistical_inference/model_selection.rst @@ -144,7 +144,7 @@ estimator during the construction and exposes an estimator API:: >>> clf = GridSearchCV(estimator=svc, param_grid=dict(gamma=gammas), ... n_jobs=-1) >>> clf.fit(X_digits[:1000], y_digits[:1000]) # doctest: +ELLIPSIS - GridSearchCV(cv=None,... + GridSearchCV(compute_training_score=False,... >>> clf.best_score_ 0.98899999999999999 >>> clf.best_estimator_.gamma diff --git a/examples/svm/plot_rbf_parameters.py b/examples/svm/plot_rbf_parameters.py index f298ebf01205c..664c13bb5f6fc 100644 --- a/examples/svm/plot_rbf_parameters.py +++ b/examples/svm/plot_rbf_parameters.py @@ -14,10 +14,30 @@ the decision surface smooth, while a high C aims at classifying all training examples correctly. -Two plots are generated. The first is a visualization of the -decision function for a variety of parameter values, and the second -is a heatmap of the classifier's cross-validation accuracy as -a function of `C` and `gamma`. +Two plots are generated. The first is a visualization of the decision function +for a variety of parameter values, and the second is a heatmap of the +classifier's cross-validation accuracy and training time as a function of `C` +and `gamma`. + +An interesting observation on overfitting can be made when comparing validation +and training error: higher C always result in lower training error, as it +inceases complexity of the classifier. + +For the validation set on the other hand, there is a tradeoff of goodness of +fit and generalization. + +We can observe that the lower right half of the parameters (below the diagonal +with high C and gamma values) is characteristic of parameters that yields an +overfitting model: the trainin score is very high but there is a wide gap. The +top and left parts of the parameter plots show underfitting models: the C and +gamma values can individually or in conjunction constrain the model too much +leading to low training scores (hence low validation scores too as validation +scores are on average upper bounded by training scores). + + +We can also see that the training time is quite sensitive to the parameter +setting, while the prediction time is not impacted very much. This is probably +a consequence of the small size of the data set. ''' print(__doc__) @@ -65,7 +85,8 @@ gamma_range = 10.0 ** np.arange(-5, 4) param_grid = dict(gamma=gamma_range, C=C_range) cv = StratifiedKFold(y=Y, n_folds=3) -grid = GridSearchCV(SVC(), param_grid=param_grid, cv=cv) +grid = GridSearchCV(SVC(), param_grid=param_grid, cv=cv, + compute_training_score=True) grid.fit(X, Y) print("The best classifier is: ", grid.best_estimator_) @@ -108,18 +129,28 @@ # cv_scores_ contains parameter settings and scores score_dict = grid.cv_scores_ -# We extract just the scores -scores = [x[1] for x in score_dict] -scores = np.array(scores).reshape(len(C_range), len(gamma_range)) - -# draw heatmap of accuracy as a function of gamma and C -pl.figure(figsize=(8, 6)) -pl.subplots_adjust(left=0.05, right=0.95, bottom=0.15, top=0.95) -pl.imshow(scores, interpolation='nearest', cmap=pl.cm.spectral) -pl.xlabel('gamma') -pl.ylabel('C') -pl.colorbar() -pl.xticks(np.arange(len(gamma_range)), gamma_range, rotation=45) -pl.yticks(np.arange(len(C_range)), C_range) +# We extract validation and training scores, as well as training and prediction +# times +_, val_scores, _, train_scores, train_time, pred_time = zip(*score_dict) + +arrays = [val_scores, train_scores, train_time, pred_time] +titles = ["Validation Score", "Training Score", "Training Time", + "Prediction Time"] + +# for each value draw heatmap as a function of gamma and C +pl.figure(figsize=(12, 8)) +for i, (arr, title) in enumerate(zip(arrays, titles)): + pl.subplot(2, 2, i + 1) + arr = np.array(arr).reshape(len(C_range), len(gamma_range)) + pl.title(title) + pl.imshow(arr, interpolation='nearest', cmap=pl.cm.spectral) + pl.xlabel('gamma') + pl.ylabel('C') + pl.colorbar() + pl.xticks(np.arange(len(gamma_range)), ["%.e" % g for g in gamma_range], + rotation=45) + pl.yticks(np.arange(len(C_range)), ["%.e" % C for C in C_range]) + +pl.subplots_adjust(top=.95, hspace=.35, left=.0, right=.8, wspace=.05) pl.show() diff --git a/sklearn/grid_search.py b/sklearn/grid_search.py index aa51e267ac32b..ab1fcb2e31e8f 100644 --- a/sklearn/grid_search.py +++ b/sklearn/grid_search.py @@ -8,11 +8,10 @@ # Gael Varoquaux # License: BSD Style. -import time import warnings import numbers +from time import time from itertools import product -from collections import namedtuple from abc import ABCMeta, abstractmethod import numpy as np @@ -21,10 +20,10 @@ from .base import MetaEstimatorMixin from .cross_validation import check_cv from .externals.joblib import Parallel, delayed, logger -from .externals.six import string_types +from .externals.six import string_types, iterkeys from .utils import safe_mask, check_random_state from .utils.validation import _num_samples, check_arrays -from .metrics import SCORERS, Scorer +from .metrics import SCORERS, Scorer, EstimatorScorer, WrapScorer __all__ = ['GridSearchCV', 'ParameterGrid', 'fit_grid_point', 'ParameterSampler', 'RandomizedSearchCV'] @@ -170,8 +169,8 @@ def __iter__(self): yield params -def fit_grid_point(X, y, base_clf, clf_params, train, test, scorer, - verbose, loss_func=None, **fit_params): +def fit_grid_point(X, y, base_clf, clf_params, train, test, scorer, verbose, + loss_func=None, compute_training_score=False, **fit_params): """Run fit on one set of parameters. Parameters @@ -198,6 +197,9 @@ def fit_grid_point(X, y, base_clf, clf_params, train, test, scorer, If provided must be a scoring object / function with signature ``scorer(estimator, X, y)``. + compute_training_score : bool, default=False + Whether to compute the training loss. If False, None is returned. + verbose : int Verbosity level. @@ -207,8 +209,18 @@ def fit_grid_point(X, y, base_clf, clf_params, train, test, scorer, Returns ------- - score : float - Score of this parameter setting on given training / test split. + test_score : float + Test score of this parameter setting on given training / test split. + + training_score : float or None + Training score of this parameter setting or None if + ``compute_training_score=False`` (default). + + training_time : float + Training time for this parameter setting in seconds. + + prediction_time : float + Prediction time for the given test set in seconds. estimator : estimator object Estimator object of type base_clf that was fitted using clf_params @@ -218,7 +230,7 @@ def fit_grid_point(X, y, base_clf, clf_params, train, test, scorer, Number of test samples in this split. """ if verbose > 1: - start_time = time.time() + start_time = time() msg = '%s' % (', '.join('%s=%s' % (k, v) for k, v in clf_params.items())) print("[GridSearchCV] %s %s" % (msg, (64 - len(msg)) * '.')) @@ -249,34 +261,51 @@ def fit_grid_point(X, y, base_clf, clf_params, train, test, scorer, X_train = X[safe_mask(X, train)] X_test = X[safe_mask(X, test)] + results = {'test_n_samples': _num_samples(X_test)} + if scorer is None: + scorer = EstimatorScorer(clf.score) + elif not hasattr(scorer, 'calc_scores'): + scorer = WrapScorer(scorer) + if y is not None: y_test = y[safe_mask(y, test)] y_train = y[safe_mask(y, train)] - clf.fit(X_train, y_train, **fit_params) - - if scorer is not None: - this_score = scorer(clf, X_test, y_test) - else: - this_score = clf.score(X_test, y_test) + fit_args = (X_train, y_train) + score_args = (X_test, y_test) else: - clf.fit(X_train, **fit_params) - if scorer is not None: - this_score = scorer(clf, X_test) - else: - this_score = clf.score(X_test) - - if not isinstance(this_score, numbers.Number): + fit_args = (X_train,) + score_args = (X_test,) + + start = time() + # do actual fitting + clf.fit(*fit_args, **fit_params) + results['train_time'] = time() - start + start = time() + results.update(('test_' + name, score) + for name, score in scorer.calc_scores(clf, *score_args)) + results['test_time'] = time() - start + + if compute_training_score: + results.update(('train_' + name, score) + for name, score in scorer.calc_scores(clf, *fit_args)) + + try: + test_score = results['test_score'] + except KeyError: + raise ValueError("Scorer.calc_scores must return a score named 'score'." + " Got %s instead." % (results)) + if not isinstance(test_score, numbers.Number): raise ValueError("scoring must return a number, got %s (%s)" - " instead." % (str(this_score), type(this_score))) + " instead." % (str(test_score), type(test_score))) if verbose > 2: - msg += ", score=%f" % this_score + msg += ", score=%f" % test_score if verbose > 1: end_msg = "%s -%s" % (msg, - logger.short_format_time(time.time() - + logger.short_format_time(time() - start_time)) print("[GridSearchCV] %s %s" % ((64 - len(end_msg)) * '.', end_msg)) - return this_score, clf_params, _num_samples(X_test) + return clf_params, results def _check_param_grid(param_grid): @@ -317,8 +346,10 @@ class BaseSearchCV(BaseEstimator, MetaEstimatorMixin): @abstractmethod def __init__(self, estimator, scoring=None, loss_func=None, score_func=None, fit_params=None, n_jobs=1, iid=True, - refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs'): + refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs', + compute_training_score=False): + self.compute_training_score = compute_training_score self.scoring = scoring self.estimator = estimator self.loss_func = loss_func @@ -357,6 +388,26 @@ def score(self, X, y=None): y_predicted = self.predict(X) return self.scorer(y, y_predicted) + @property + def grid_scores_(self): + warnings.warn("grid_scores_ is deprecated and will be removed in 0.15." + " Use grid_results_ and fold_results_ instead.", DeprecationWarning) + return zip(self.grid_results_['parameters'], + self.grid_results_['test_score'], + self.fold_results_['test_score']) + + @property + def best_score_(self): + if not hasattr(self, 'best_index_'): + raise AttributeError('Call fit() to calculate best_score_') + return self.grid_results_['test_score'][self.best_index_] + + @property + def best_params_(self): + if not hasattr(self, 'best_index_'): + raise AttributeError('Call fit() to calculate best_params_') + return self.grid_results_['parameters'][self.best_index_] + def _check_estimator(self): """Check that estimator can be fitted and score can be computed.""" if (not hasattr(self.estimator, 'fit') or @@ -381,6 +432,29 @@ def _set_methods(self): if hasattr(self.best_estimator_, 'predict_proba'): self.predict_proba = self.best_estimator_.predict_proba + def _aggregate_scores(self, scores, n_samples): + """Take 2d arrays of scores and samples and calculate weighted + means/sums of each row""" + if self.iid: + scores = scores * n_samples + scores = scores.sum(axis=1) / n_samples.sum(axis=1) + else: + scores = scores.sum(axis=1) + return scores + + def _merge_result_dicts(self, result_dicts): + """ + From a result dict for each fold, produce a single dict with an array + for each key. + For example [[{'score': 1}, {'score': 2}], [{'score': 3}, {'score': 4}]] + -> {'score': np.array([[1, 2], [3, 4]])}""" + # assume keys are same throughout + result_keys = list(iterkeys(result_dicts[0][0])) + arrays = ([[fold_results[key] for fold_results in point] + for point in result_dicts] + for key in result_keys) + return np.rec.fromarrays(arrays, names=result_keys) + def _fit(self, X, y, parameter_iterator, **params): """Actual fitting, performing the search over parameters.""" estimator = self.estimator @@ -425,33 +499,31 @@ def _fit(self, X, y, parameter_iterator, **params): pre_dispatch=pre_dispatch)( delayed(fit_grid_point)( X, y, base_clf, clf_params, train, test, scorer, - self.verbose, **self.fit_params) for clf_params in + self.verbose, + compute_training_score=self.compute_training_score, + **self.fit_params) for clf_params in parameter_iterator for train, test in cv) - # Out is a list of triplet: score, estimator, n_test_samples n_param_points = len(list(parameter_iterator)) n_fits = len(out) n_folds = n_fits // n_param_points - scores = list() - cv_scores = list() - for start in range(0, n_fits, n_folds): - n_test_samples = 0 - mean_validation_score = 0 - these_points = list() - for this_score, clf_params, this_n_test_samples in \ - out[start:start + n_folds]: - these_points.append(this_score) - if self.iid: - this_score *= this_n_test_samples - mean_validation_score += this_score - n_test_samples += this_n_test_samples - if self.iid: - mean_validation_score /= float(n_test_samples) - scores.append((mean_validation_score, clf_params)) - cv_scores.append(these_points) - - cv_scores = np.asarray(cv_scores) + cv_results = self._merge_result_dicts([ + [fold_results for clf_params, fold_results in out[start:start + n_folds]] + for start in range(0, n_fits, n_folds) + ]) + + field_defs = [('parameters', 'object'), ('test_score', cv_results['test_score'].dtype)] + if self.compute_training_score: + field_defs.append(('train_score', cv_results['train_score'].dtype)) + grid_results = np.zeros(n_param_points, dtype=field_defs) + grid_results['parameters'] = list(parameter_iterator) + grid_results['test_score'] = self._aggregate_scores( + cv_results['test_score'], cv_results['test_n_samples']) + if self.compute_training_score: + grid_results['train_score'] = self._aggregate_scores( + cv_results['train_score'], + n_samples - cv_results['test_n_samples']) # Note: we do not use max(out) to make ties deterministic even if # comparison on estimator instances is not deterministic @@ -465,19 +537,21 @@ def _fit(self, X, y, parameter_iterator, **params): else: best_score = np.inf - for score, params in scores: + for i, score in enumerate(grid_results['test_score']): if ((score > best_score and greater_is_better) - or (score < best_score and not greater_is_better)): + or (score < best_score + and not greater_is_better)): best_score = score - best_params = params + best_index = i - self.best_params_ = best_params - self.best_score_ = best_score + self.best_index_ = best_index + self.fold_results_ = cv_results + self.grid_results_ = grid_results if self.refit: # fit the best estimator using the entire dataset # clone first to work around broken estimators - best_estimator = clone(base_clf).set_params(**best_params) + best_estimator = clone(base_clf).set_params(**self.best_params_) if y is not None: best_estimator.fit(X, y, **self.fit_params) else: @@ -485,14 +559,6 @@ def _fit(self, X, y, parameter_iterator, **params): self.best_estimator_ = best_estimator self._set_methods() - # Store the computed scores - CVScoreTuple = namedtuple('CVScoreTuple', ('parameters', - 'mean_validation_score', - 'cv_validation_scores')) - self.cv_scores_ = [ - CVScoreTuple(clf_params, score, all_scores) - for clf_params, (score, _), all_scores - in zip(parameter_iterator, scores, cv_scores)] return self @@ -572,7 +638,7 @@ class GridSearchCV(BaseSearchCV): >>> clf = grid_search.GridSearchCV(svr, parameters) >>> clf.fit(iris.data, iris.target) ... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS - GridSearchCV(cv=None, + GridSearchCV(compute_training_score=False, cv=None, estimator=SVC(C=1.0, cache_size=..., coef0=..., degree=..., gamma=..., kernel='rbf', max_iter=-1, probability=False, shrinking=True, tol=...), @@ -582,20 +648,34 @@ class GridSearchCV(BaseSearchCV): Attributes ---------- - `cv_scores_` : list of named tuples - Contains scores for all parameter combinations in param_grid. - Each entry corresponds to one parameter setting. - Each named tuple has the attributes: + `grid_results_` : structured array of shape [# param combinations] + For each parameter combination in ``param_grid`` includes these fields: - * ``parameters``, a dict of parameter settings - * ``mean_validation_score``, the mean score over the + * ``parameters``, dict of parameter settings + * ``test_score``, the mean score over the cross-validation folds - * ``cv_validation_scores``, the list of scores for each fold + * ``train_score``, the mean training score over the + cross-validation folds, if ``compute_training_score`` + + `fold_results_` : structured array of shape [# param combinations, # folds] + For each cross-validation fold includes these fields: + + * ``test_time``, the elapsed prediction and scoring time + * ``train_time``, the elapsed training time + * ``test_score``, the score for this fold + * ``train_score``, the training score for this fold + * ``test_n_samples``, the number of samples in testing + * ``test_*``, other scores from `scorer.calc_scores()` + * ``train_*``, other training scores from `scorer.calc_scores()` `best_estimator_` : estimator - Estimator that was choosen by grid search, i.e. estimator + Estimator that was chosen by grid search, i.e. estimator which gave highest score (or smallest loss if specified) - on the left out data. + on the left out data. Available only if refit=True. + + `best_index_` : int + The index of the best parameter setting into ``grid_results_`` and + ``fold_results_`` data. `best_score_` : float score of best_estimator on the left out data. @@ -603,6 +683,10 @@ class GridSearchCV(BaseSearchCV): `best_params_` : dict Parameter setting that gave the best results on the hold out data. + `grid_scores_` : list of tuples (deprecated) + Contains scores for all parameter combinations in ``param_grid``: + each tuple is (parameters, mean score, fold scores). + Notes ------ The parameters selected are those that maximize the score of the left out @@ -630,19 +714,14 @@ class GridSearchCV(BaseSearchCV): def __init__(self, estimator, param_grid, scoring=None, loss_func=None, score_func=None, fit_params=None, n_jobs=1, iid=True, - refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs'): + refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs', + compute_training_score=False): super(GridSearchCV, self).__init__( estimator, scoring, loss_func, score_func, fit_params, n_jobs, iid, - refit, cv, verbose, pre_dispatch) + refit, cv, verbose, pre_dispatch, compute_training_score) self.param_grid = param_grid _check_param_grid(param_grid) - @property - def grid_scores_(self): - warnings.warn("grid_scores_ is deprecated and will be removed in 0.15." - " Use estimator_scores_ instead.", DeprecationWarning) - return self.cv_scores_ - def fit(self, X, y=None, **params): """Run fit with all sets of parameters. @@ -754,30 +833,47 @@ class RandomizedSearchCV(BaseSearchCV): verbose : integer Controls the verbosity: the higher, the more messages. - Attributes ---------- - `cv_scores_` : list of named tuples - Contains scores for all parameter combinations in param_grid. - Each entry corresponds to one parameter setting. - Each named tuple has the attributes: + `grid_results_` : structured array of shape [# param combinations] + For each parameter combination in ``param_grid`` includes these fields: - * ``parameters``, a dict of parameter settings - * ``mean_validation_score``, the mean score over the + * ``parameters``, dict of parameter settings + * ``test_score``, the mean score over the cross-validation folds - * ``cv_validation_scores``, the list of scores for each fold + * ``train_score``, the mean training score over the + cross-validation folds, if ``compute_training_score`` + + `fold_results_` : structured array of shape [# param combinations, # folds] + For each cross-validation fold includes these fields: + + * ``test_time``, the elapsed prediction and scoring time + * ``train_time``, the elapsed training time + * ``test_score``, the score for this fold + * ``train_score``, the training score for this fold + * ``test_n_samples``, the number of samples in testing + * ``test_*``, other scores from `scorer.calc_scores()` + * ``train_*``, other training scores from `scorer.calc_scores()` `best_estimator_` : estimator - Estimator that was choosen by search, i.e. estimator + Estimator that was chosen by grid search, i.e. estimator which gave highest score (or smallest loss if specified) - on the left out data. + on the left out data. Available only if refit=True. + + `best_index_` : int + The index of the best parameter setting into ``grid_results_`` and + ``fold_results_`` data. `best_score_` : float - Score of best_estimator on the left out data. + score of best_estimator on the left out data. `best_params_` : dict Parameter setting that gave the best results on the hold out data. + `grid_scores_` : list of tuples (deprecated) + Contains scores for all parameter combinations in ``param_grid``: + each tuple is (parameters, mean score, fold scores). + Notes ----- The parameters selected are those that maximize the score of the left out @@ -807,13 +903,13 @@ class RandomizedSearchCV(BaseSearchCV): def __init__(self, estimator, param_distributions, n_iter=10, scoring=None, loss_func=None, score_func=None, fit_params=None, n_jobs=1, iid=True, refit=True, cv=None, verbose=0, - pre_dispatch='2*n_jobs'): + pre_dispatch='2*n_jobs', compute_training_score=False): self.param_distributions = param_distributions self.n_iter = n_iter super(RandomizedSearchCV, self).__init__( estimator, scoring, loss_func, score_func, fit_params, n_jobs, iid, - refit, cv, verbose, pre_dispatch) + refit, cv, verbose, pre_dispatch, compute_training_score) def fit(self, X, y=None, **params): """Run fit on the estimator with randomly drawn parameters. diff --git a/sklearn/manifold/locally_linear.py b/sklearn/manifold/locally_linear.py index dcc29548a6640..fa1e3cfcdd7cc 100644 --- a/sklearn/manifold/locally_linear.py +++ b/sklearn/manifold/locally_linear.py @@ -533,11 +533,11 @@ class LocallyLinearEmbedding(BaseEstimator, TransformerMixin): maximum number of iterations for the arpack solver. Not used if eigen_solver=='dense'. - method : string ('standard', 'hessian', 'modified' or 'ltsa') - standard : use the standard locally linear embedding algorithm. see - reference [1] - hessian : use the Hessian eigenmap method. This method requires - ``n_neighbors > n_components * (1 + (n_components + 1) / 2`` + method : string ['standard' | 'hessian' | 'modified' |'ltsa'] + standard : use the standard locally linear embedding algorithm. + see reference [1] + hessian : use the Hessian eigenmap method. This method requires + n_neighbors > n_components * (1 + (n_components + 1) / 2. see reference [2] modified : use the modified locally linear embedding algorithm. see reference [3] @@ -546,11 +546,11 @@ class LocallyLinearEmbedding(BaseEstimator, TransformerMixin): hessian_tol : float, optional Tolerance for Hessian eigenmapping method. - Only used if ``method == 'hessian'`` + Only used if method == 'hessian' modified_tol : float, optional Tolerance for modified LLE method. - Only used if ``method == 'modified'`` + Only used if method == 'modified' neighbors_algorithm : string ['auto'|'brute'|'kd_tree'|'ball_tree'] algorithm to use for nearest neighbors search, diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index 9527636d04ff1..6abbb38c600e0 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -29,7 +29,7 @@ from .metrics import zero_one from .metrics import zero_one_score -from .scorer import Scorer, SCORERS +from .scorer import Scorer, PRFScorer, EstimatorScorer, WrapScorer, SCORERS from . import cluster from .cluster import (adjusted_rand_score, @@ -83,5 +83,6 @@ 'silhouette_samples', 'v_measure_score', 'zero_one_loss', + 'PRFScorer', 'Scorer', 'SCORERS'] diff --git a/sklearn/metrics/scorer.py b/sklearn/metrics/scorer.py index 0f01a0b8439d0..476bf779247aa 100644 --- a/sklearn/metrics/scorer.py +++ b/sklearn/metrics/scorer.py @@ -16,16 +16,77 @@ # Authors: Andreas Mueller # Liscence: Simplified BSD +from abc import ABCMeta, abstractmethod + import numpy as np -from . import (r2_score, mean_squared_error, accuracy_score, f1_score, +from . import (r2_score, mean_squared_error, accuracy_score, auc_score, average_precision_score, precision_score, - recall_score) + recall_score, precision_recall_fscore_support) from .cluster import adjusted_rand_score +class BaseScorer(object): + __metaclass__ = ABCMeta + + def __init__(self, greater_is_better=True): + self.greater_is_better = greater_is_better + + def calc_scores(self, estimator, X, y=None): + """Calculate one or more scores for X against y using the provided + estimator. While __call__ calculates a single score, this may return + multiple. + + Parameters + ---------- + estimator : object + Trained estimator to use for scoring. + If ``needs_threshold`` is True, estimator needs + to provide ``decision_function`` or ``predict_proba``. + Otherwise, estimator needs to provide ``predict``. + + X : array-like or sparse matrix + Test data that will be scored by the estimator. + + y : array-like + True prediction for X. + + Returns + ------- + scores : iterable of (name, score) pairs + Scores of the estimator's predictions of X with respect to y. + Names must be distinct, and exactly one name must be 'score', whose + score corresponds to the result of `__call__`. + """ + yield ('score', self(estimator, X, y)) + + @abstractmethod + def __call__(self, estimator, X, y=None): + """Score X and y using the provided estimator. + + Parameters + ---------- + estimator : object + Trained estimator to use for scoring. + If ``needs_threshold`` is True, estimator needs + to provide ``decision_function`` or ``predict_proba``. + Otherwise, estimator needs to provide ``predict``. + + X : array-like or sparse matrix + Test data that will be scored by the estimator. + + y : array-like + True prediction for X. + + Returns + ------- + score : float + The score of estimator's prediction of X. + """ + pass + -class Scorer(object): +class Scorer(BaseScorer): """Flexible scores for any estimator. This class wraps estimator scoring functions for the use in GridSearchCV @@ -63,8 +124,8 @@ class Scorer(object): """ def __init__(self, score_func, greater_is_better=True, needs_threshold=False, **kwargs): + super(Scorer, self).__init__(greater_is_better) self.score_func = score_func - self.greater_is_better = greater_is_better self.needs_threshold = needs_threshold self.kwargs = kwargs @@ -75,7 +136,7 @@ def __repr__(self): "%s%s)" % (self.score_func.__name__, self.greater_is_better, self.needs_threshold, kwargs_string)) - def __call__(self, estimator, X, y): + def __call__(self, estimator, X, y=None): """Score X and y using the provided estimator. Parameters @@ -105,10 +166,100 @@ def __call__(self, estimator, X, y): y_pred = estimator.decision_function(X).ravel() except (NotImplementedError, AttributeError): y_pred = estimator.predict_proba(X)[:, 1] - return self.score_func(y, y_pred, **self.kwargs) else: y_pred = estimator.predict(X) + if y is not None: return self.score_func(y, y_pred, **self.kwargs) + else: + return self.score_func(y_pred, **self.kwargs) + + +class PRFScorer(Scorer): + """Scorer to optimise F score while also providing precision and recall. + + Parameters + ---------- + **kwargs : additional arguments + Additional parameters to be passed to + `metrics.precision_recall_fscore_support`. + """ + + def __init__(self, **kwargs): + if 'average' not in kwargs: + kwargs['average'] = 'weighted' + super(PRFScorer, self).__init__(precision_recall_fscore_support, **kwargs) + + def __repr__(self): + kwargs_string = "".join([", %s=%s" % (str(k), str(v)) + for k, v in self.kwargs.items()]) + return 'PRFScorer(%s)' % kwargs_string + + def calc_scores(self, estimator, X, y): + """ + Calculates F score, precision and recall + + Parameters + ---------- + estimator : object + Trained estimator to use for scoring. + If ``needs_threshold`` is True, estimator needs + to provide ``decision_function`` or ``predict_proba``. + Otherwise, estimator needs to provide ``predict``. + + X : array-like or sparse matrix + Test data that will be scored by the estimator. + + y : array-like + True prediction for X. + + Returns + ------- + scores : list of (name, score) pairs + providing names 'score', 'precision' and 'recall' + """ + p, r, f, support = super(PRFScorer, self).__call__(estimator, X, y) + return [ + ('score', f), + ('precision', p), + ('recall', r), + ] + + def __call__(self, estimator, X, y): + p, r, f, support = super(PRFScorer, self).__call__(estimator, X, y) + return f + + +class WrapScorer(BaseScorer): + """Scores by passing the estimator and data to a given function + + Parameters + ---------- + score_fn : function with signature of `Scorer.__call__` + A function which returns a score given an estimator, instances and + ground truth if available. + + greater_is_better : boolean, default=True + Whether score_func is a score function (default), meaning high is good, + or a loss function, meaning low is good. + """ + + def __init__(self, score_fn, greater_is_better=True): + super(WrapScorer, self).__init__(greater_is_better) + self.score_fn = score_fn + + def __call__(self, estimator, X, y=None): + if y is None: + return self.score_fn(estimator, X) + return self.score_fn(estimator, X, y) + + +class EstimatorScorer(BaseScorer): + """Scores by calling the estimator's score method.""" + + def __call__(self, estimator, X, y=None): + if y is None: + return estimator.score(X) + return estimator.score(X, y) # Standard regression scores @@ -117,7 +268,7 @@ def __call__(self, estimator, X, y): # Standard Classification Scores accuracy_scorer = Scorer(accuracy_score) -f1_scorer = Scorer(f1_score) +f1_scorer = PRFScorer() # Score functions that need decision values auc_scorer = Scorer(auc_score, greater_is_better=True, needs_threshold=True) @@ -130,7 +281,7 @@ def __call__(self, estimator, X, y): ari_scorer = Scorer(adjusted_rand_score) SCORERS = dict(r2=r2_scorer, mse=mse_scorer, accuracy=accuracy_scorer, - f1=f1_scorer, roc_auc=auc_scorer, + f1=PRFScorer(), roc_auc=auc_scorer, average_precision=average_precision_scorer, precision=precision_scorer, recall=recall_scorer, ari=ari_scorer) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 7777f15de1a7e..58c7edb6bd509 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -1,11 +1,13 @@ import pickle from sklearn.utils.testing import assert_almost_equal +from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_raises +from sklearn.utils.testing import assert_true from sklearn.metrics import f1_score, r2_score, auc_score, fbeta_score from sklearn.metrics.cluster import adjusted_rand_score -from sklearn.metrics import SCORERS, Scorer +from sklearn.metrics import SCORERS, Scorer, PRFScorer, WrapScorer, EstimatorScorer from sklearn.svm import LinearSVC from sklearn.cluster import KMeans from sklearn.linear_model import Ridge, LogisticRegression @@ -14,6 +16,8 @@ from sklearn.cross_validation import train_test_split, cross_val_score from sklearn.grid_search import GridSearchCV +# TODO: test scorers without ground truth + def test_classification_scores(): X, y = make_blobs(random_state=0) @@ -97,3 +101,61 @@ def test_raises_on_score_list(): grid_search = GridSearchCV(clf, scoring=f1_scorer_no_average, param_grid={'max_depth': [1, 2]}) assert_raises(ValueError, grid_search.fit, X, y) + + +def test_calc_scores(): + """Test that the score returned by __call__ is named 'score' by calc_scores""" + scorer = SCORERS['roc_auc'] + X, y = make_blobs(random_state=0, centers=2) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + clf = LinearSVC(random_state=0) + clf.fit(X_train, y_train) + score = scorer(clf, X_test, y_test) + scores = dict(scorer.calc_scores(clf, X_test, y_test)) + assert_true('score' in scores) + assert_equal(score, scores['score']) + + +def test_prf_scorer(): + X, y = make_blobs(random_state=0, centers=2) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + clf = LinearSVC(random_state=0) + clf.fit(X_train, y_train) + + f1_scorer = PRFScorer() + f1_score = f1_scorer(clf, X_test, y_test) + f1_scores = dict(f1_scorer.calc_scores(clf, X_test, y_test)) + + f2_scorer = PRFScorer(beta=2.) + f2_score = f2_scorer(clf, X_test, y_test) + f2_scores = dict(f2_scorer.calc_scores(clf, X_test, y_test)) + + def F(p, r, beta): + return (1 + beta * beta) * p * r / (beta * beta * p + r) + + assert_equal(f1_score, f1_scores['score']) + assert_equal(f2_score, f2_scores['score']) + assert_almost_equal(f1_score, F(f1_scores['precision'], f1_scores['recall'], 1.)) + assert_almost_equal(f2_score, F(f2_scores['precision'], f2_scores['recall'], 2.)) + + +def test_estimator_scorer(): + scorer = EstimatorScorer() + X, y = make_blobs(random_state=0) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + clf = LinearSVC(random_state=0) + clf.fit(X_train, y_train) + score = scorer(clf, X_test, y_test) + assert_equal(clf.score(X_test, y_test), score) + assert_equal(score, dict(scorer.calc_scores(clf, X_test, y_test))['score']) + + +def test_wrap_scorer(): + scorer = WrapScorer(lambda clf, X, y: clf.score(X, y) * 100) + X, y = make_blobs(random_state=0) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + clf = LinearSVC(random_state=0) + clf.fit(X_train, y_train) + score = scorer(clf, X_test, y_test) + assert_equal(clf.score(X_test, y_test) * 100, score) + assert_equal(score, dict(scorer.calc_scores(clf, X_test, y_test))['score']) diff --git a/sklearn/tests/test_grid_search.py b/sklearn/tests/test_grid_search.py index ce3252022558b..13f0b50980437 100644 --- a/sklearn/tests/test_grid_search.py +++ b/sklearn/tests/test_grid_search.py @@ -12,6 +12,7 @@ from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_raises +from sklearn.utils.testing import assert_greater from sklearn.utils.testing import assert_true from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_almost_equal @@ -24,7 +25,7 @@ ParameterSampler) from sklearn.svm import LinearSVC, SVC from sklearn.cluster import KMeans, MeanShift -from sklearn.metrics import f1_score +from sklearn.metrics import f1_score, precision_score, recall_score from sklearn.metrics import Scorer from sklearn.cross_validation import KFold, StratifiedKFold @@ -87,14 +88,40 @@ def test_grid_search(): grid_search.fit(X, y) sys.stdout = old_stdout assert_equal(grid_search.best_estimator_.foo_param, 2) + assert_equal(grid_search.best_params_, {'foo_param': 2}) + assert_equal(grid_search.best_score_, 1.) for i, foo_i in enumerate([1, 2, 3]): - assert_true(grid_search.cv_scores_[i][0] + assert_true(grid_search.grid_results_['parameters'][i] == {'foo_param': foo_i}) # Smoke test the score: grid_search.score(X, y) +def test_grid_scores(): + """Test that GridSearchCV.grid_scores_ is filled in the correct format""" + clf = MockClassifier() + grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, verbose=3) + # make sure it selects the smallest parameter in case of ties + old_stdout = sys.stdout + sys.stdout = StringIO() + grid_search.fit(X, y) + sys.stdout = old_stdout + assert_equal(grid_search.best_estimator_.foo_param, 2) + + n_folds = 3 + with warnings.catch_warnings(record=True): + for i, foo_i in enumerate([1, 2, 3]): + assert_true(grid_search.grid_scores_[i][0] + == {'foo_param': foo_i}) + # mean score + assert_almost_equal(grid_search.grid_scores_[i][1], + (1. if foo_i > 1 else 0.)) + # all fold scores + assert_array_equal(grid_search.grid_scores_[i][2], + [1. if foo_i > 1 else 0.] * n_folds) + + def test_no_refit(): """Test that grid search can be used for model selection only""" clf = MockClassifier() @@ -281,6 +308,20 @@ def test_grid_search_precomputed_kernel_error_kernel_function(): assert_raises(ValueError, cv.fit, X_, y_) +def test_grid_search_training_score(): + # test that the training score contains sensible numbers + X, y = make_classification(n_samples=200, n_features=100, random_state=0) + clf = LinearSVC(random_state=0) + cv = GridSearchCV(clf, {'C': [0.1, 1.0]}, compute_training_score=True) + cv.fit(X, y) + for i, (grid_data, fold_data) in enumerate(zip(cv.grid_results_, cv.fold_results_)): + assert_greater(grid_data['train_score'], grid_data['test_score']) + # hacky greater-equal + assert_greater(1 + 1e-10, grid_data['train_score']) + assert_greater(fold_data['train_time'].mean(), 0) + assert_greater(fold_data['test_time'].mean(), 0) + + class BrokenClassifier(BaseEstimator): """Broken classifier that cannot be fit twice""" @@ -365,7 +406,7 @@ def test_randomized_search(): params = dict(C=distributions.expon()) search = RandomizedSearchCV(LinearSVC(), param_distributions=params) search.fit(X, y) - assert_equal(len(search.cv_scores_), 10) + assert_equal(len(search.grid_results_['test_score']), 10) def test_grid_search_score_consistency(): @@ -378,9 +419,8 @@ def test_grid_search_score_consistency(): grid_search = GridSearchCV(clf, {'C': Cs}, scoring=score) grid_search.fit(X, y) cv = StratifiedKFold(n_folds=3, y=y) - for C, scores in zip(Cs, grid_search.cv_scores_): + for C, scores in zip(Cs, grid_search.fold_results_['test_score']): clf.set_params(C=C) - scores = scores[2] # get the separate runs from grid scores i = 0 for train, test in cv: clf.fit(X[train], y[train]) @@ -391,3 +431,24 @@ def test_grid_search_score_consistency(): clf.decision_function(X[test])) assert_almost_equal(correct_score, scores[i]) i += 1 + +def test_composite_scores(): + """Test that precision and recall are output when using f1""" + clf = LinearSVC(random_state=0) + X, y = make_blobs(random_state=0, centers=2) + Cs = [.1, 1, 10] + grid_search = GridSearchCV(clf, {'C': Cs}, scoring='f1', compute_training_score=True) + grid_search.fit(X, y) + cv = StratifiedKFold(n_folds=3, y=y) + for C, scores in zip(Cs, grid_search.fold_results_): + clf.set_params(C=C) + for fold, (train, test) in enumerate(cv): + clf.fit(X[train], y[train]) + for prefix, mask in [('test_', test), ('train_', train)]: + fold_scores = scores[fold] + correct_score = f1_score(y[mask], clf.predict(X[mask])) + correct_precision = precision_score(y[mask], clf.predict(X[mask])) + correct_recall = recall_score(y[mask], clf.predict(X[mask])) + assert_almost_equal(correct_score, fold_scores[prefix + 'score']) + assert_almost_equal(correct_precision, fold_scores[prefix + 'precision']) + assert_almost_equal(correct_recall, fold_scores[prefix + 'recall']) diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index 1f6d06f322ae4..106219b6efb72 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -438,9 +438,9 @@ def predict(self, X): def feature_importances_(self): """Return the feature importances. - The importance of a feature is computed as the (normalized) total - reduction of the criterion brought by that feature. - It is also known as the Gini importance [4]_. + The importance of a feature is computed as the + (normalized) total reduction of the criterion brought by that + feature. It is also known as the Gini importance [4]_. Returns ------- @@ -515,10 +515,10 @@ class DecisionTreeClassifier(BaseDecisionTree, ClassifierMixin): output (for multi-output problems). `feature_importances_` : array of shape = [n_features] - The feature importances. The higher, the more important the - feature. The importance of a feature is computed as the (normalized) - total reduction of the criterion brought by that feature. It is also - known as the Gini importance [4]_. + The feature importances. The higher, the more important the feature. + The importance of a feature is computed as the + (normalized) total reduction of the criterion brought by that + feature. It is also known as the Gini importance [4]_. See also -------- @@ -702,11 +702,10 @@ class DecisionTreeRegressor(BaseDecisionTree, RegressorMixin): The underlying Tree object. `feature_importances_` : array of shape = [n_features] - The feature importances. - The higher, the more important the feature. + The feature importances. The higher, the more important the feature. The importance of a feature is computed as the - (normalized)total reduction of the criterion brought - by that feature. It is also known as the Gini importance [4]_. + (normalized) total reduction of the criterion brought by that + feature. It is also known as the Gini importance [4]_. See also --------