diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index c72ef69a7d918..aaa70031aa075 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -662,6 +662,14 @@ user guide for further details. .. currentmodule:: sklearn +Model Selection Interface +------------------------- +.. autosummary:: + :toctree: generated/ + :template: class_with_call.rst + + metrics.Scorer + Classification metrics ---------------------- diff --git a/doc/modules/cross_validation.rst b/doc/modules/cross_validation.rst index f499f2f8475a7..71671162e6097 100644 --- a/doc/modules/cross_validation.rst +++ b/doc/modules/cross_validation.rst @@ -83,14 +83,15 @@ by:: By default, the score computed at each CV iteration is the ``score`` method of the estimator. It is possible to change this by passing a custom -scoring function, e.g. from the metrics module:: +scoring function:: >>> from sklearn import metrics >>> cross_validation.cross_val_score(clf, iris.data, iris.target, cv=5, - ... score_func=metrics.f1_score) + ... scoring='f1') ... # doctest: +ELLIPSIS array([ 1. ..., 0.96..., 0.89..., 0.96..., 1. ]) +See :ref:`score_func_objects` for details. In the case of the Iris dataset, the samples are balanced across target classes hence the accuracy and the F1-score are almost equal. diff --git a/doc/modules/grid_search.rst b/doc/modules/grid_search.rst index ea5a3a83a20d7..6f7d90d1a48b6 100644 --- a/doc/modules/grid_search.rst +++ b/doc/modules/grid_search.rst @@ -49,6 +49,20 @@ combinations is retained. This can be done by using the :func:`cross_validation.train_test_split` utility function. +.. currentmodule:: sklearn.grid_search + +.. _gridsearch_scoring: + +Scoring functions for GridSearchCV +---------------------------------- +By default, :class:`GridSearchCV` uses the ``score`` function of the estimator +to evaluate a parameter setting. These are the :func:`sklearn.metrics.accuracy_score` for classification +and :func:`sklearn.metrics.r2_score` for regression. +For some applications, other scoring function are better suited (for example in +unbalanced classification, the accuracy score is often non-informative). An +alternative scoring function can be specified via the ``scoring`` parameter to +:class:`GridSearchCV`. +See :ref:`score_func_objects` for more details. Examples ======== diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 8e2a9633ec0d6..5b8ca5dbf9ebf 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -297,7 +297,7 @@ In this context, we can define the notions of precision, recall and F-measure: F_\beta = (1 + \beta^2) \frac{\text{precision} \times \text{recall}}{\beta^2 \text{precision} + \text{recall}}. -Here some small examples in binary classification: +Here some small examples in binary classification:: >>> from sklearn import metrics >>> y_pred = [0, 1, 0, 0] @@ -411,7 +411,7 @@ their support \texttt{weighted\_{}F\_{}beta}(y,\hat{y}) &= \frac{1}{n_\text{samples}} \sum_{i=0}^{n_\text{samples} - 1} (1 + \beta^2)\frac{|y_i \cap \hat{y}_i|}{\beta^2 |\hat{y}_i| + |y_i|}. -Here an example where ``average`` is set to ``average`` to ``macro``: +Here an example where ``average`` is set to ``average`` to ``macro``:: >>> from sklearn import metrics >>> y_true = [0, 1, 2, 0, 1, 2] @@ -427,7 +427,7 @@ Here an example where ``average`` is set to ``average`` to ``macro``: >>> metrics.precision_recall_fscore_support(y_true, y_pred, average='macro') # doctest: +ELLIPSIS (0.22..., 0.33..., 0.26..., None) -Here an example where ``average`` is set to to ``micro``: +Here an example where ``average`` is set to to ``micro``:: >>> from sklearn import metrics >>> y_true = [0, 1, 2, 0, 1, 2] @@ -443,7 +443,7 @@ Here an example where ``average`` is set to to ``micro``: >>> metrics.precision_recall_fscore_support(y_true, y_pred, average='micro') # doctest: +ELLIPSIS (0.33..., 0.33..., 0.33..., None) -Here an example where ``average`` is set to to ``weighted``: +Here an example where ``average`` is set to to ``weighted``:: >>> from sklearn import metrics >>> y_true = [0, 1, 2, 0, 1, 2] @@ -459,7 +459,7 @@ Here an example where ``average`` is set to to ``weighted``: >>> metrics.precision_recall_fscore_support(y_true, y_pred, average='weighted') # doctest: +ELLIPSIS (0.22..., 0.33..., 0.26..., None) -Here an example where ``average`` is set to ``None``: +Here an example where ``average`` is set to ``None``:: >>> from sklearn import metrics >>> y_true = [0, 1, 2, 0, 1, 2] @@ -492,7 +492,7 @@ value and :math:`w` is the predicted decisions as output by L_\text{Hinge}(y, w) = \max\left\{1 - wy, 0\right\} = \left|1 - wy\right|_+ Here a small example demonstrating the use of the :func:`hinge_loss` function -with a svm classifier: +with a svm classifier:: >>> from sklearn import svm >>> from sklearn.metrics import hinge_loss @@ -653,7 +653,8 @@ variance is estimated as follow: The best possible score is 1.0, lower values are worse. -Here a small example of usage of the :func:`explained_variance_scoreé` function: +Here a small example of usage of the :func:`explained_variance_score` +function:: >>> from sklearn.metrics import explained_variance_score >>> y_true = [3, -0.5, 2, 7] @@ -676,7 +677,7 @@ and :math:`y_i` is the corresponding true value, then the mean absolute error \text{MAE}(y, \hat{y}) = \frac{1}{n_{\text{samples}}} \sum_{i=0}^{n_{\text{samples}}-1} \left| y_i - \hat{y}_i \right|. -Here a small example of usage of the :func:`mean_absolute_error` function: +Here a small example of usage of the :func:`mean_absolute_error` function:: >>> from sklearn.metrics import mean_absolute_error >>> y_true = [3, -0.5, 2, 7] @@ -705,7 +706,8 @@ and :math:`y_i` is the corresponding true value, then the mean squared error \text{MSE}(y, \hat{y}) = \frac{1}{n_\text{samples}} \sum_{i=0}^{n_\text{samples} - 1} (y_i - \hat{y}_i)^2. -Here a small example of usage of the :func:`mean_squared_error` function: +Here a small example of usage of the :func:`mean_squared_error` +function:: >>> from sklearn.metrics import mean_squared_error >>> y_true = [3, -0.5, 2, 7] @@ -740,7 +742,7 @@ over :math:`n_{\text{samples}}` is defined as where :math:`\bar{y} = \frac{1}{n_{\text{samples}}} \sum_{i=0}^{n_{\text{samples}} - 1} y_i`. -Here a small example of usage of the :func:`r2_score` function: +Here a small example of usage of the :func:`r2_score` function:: >>> from sklearn.metrics import r2_score >>> y_true = [3, -0.5, 2, 7] @@ -765,6 +767,111 @@ Clustering metrics The :mod:`sklearn.metrics` implements several losses, scores and utility function for more information see the :ref:`clustering_evaluation` section. + +.. _score_func_objects: + +.. currentmodule:: sklearn + +`Scoring` objects: defining your scoring rules +=============================================== +While the above functions provide a simple interface for most use-cases, they +can not directly be used for model selection and evaluation using +:class:`grid_search.GridSearchCV` and +:func:`cross_validation.cross_val_score`, as scoring functions have different +signatures and might require additional parameters. + +Instead, :class:`grid_search.GridSearchCV` and +:func:`cross_validation.cross_val_score` both take callables that implement +estimator dependent functions. That allows for very flexible evaluation of +models, for example taking complexity of the model into account. + +For scoring functions that take no additional parameters (which are most of +them), you can simply provide a string as the ``scoring`` parameter. Possible +values are: + + +=================== =============================================== +Scoring Function +=================== =============================================== +**Classification** +'accuracy' :func:`sklearn.metrics.accuracy_score` +'average_precision' :func:`sklearn.metrics.average_precision_score` +'f1' :func:`sklearn.metrics.f1_score` +'precision' :func:`sklearn.metrics.precision_score` +'recall' :func:`sklearn.metrics.recall_score` +'roc_auc' :func:`sklearn.metrics.auc_score` + +**Clustering** +'ari'` :func:`sklearn.metrics.adjusted_rand_score` + +**Regression** +'mse' :func:`sklearn.metrics.mean_squared_error` +'r2' :func:`sklearn.metrics.r2_score` +=================== =============================================== + +The corresponding scorer objects are stored in the dictionary +``sklearn.metrics.SCORERS``. + +.. currentmodule:: sklearn.metrics + +Creating scoring objects from score functions +--------------------------------------------- +If you want to use a scoring function that takes additional parameters, such as +:func:`fbeta_score`, you need to generate an appropriate scoring object. The +simplest way to generate a callable object for scoring is by using +:class:`Scorer`. +:class:`Scorer` converts score functions as above into callables that can be +used for model evaluation. + +One typical use case is to wrap an existing scoring function from the library +with non default value for its parameters such as the beta parameter for the +:func:`fbeta_score` function:: + + >>> from sklearn.metrics import fbeta_score, Scorer + >>> ftwo_scorer = Scorer(fbeta_score, beta=2) + >>> from sklearn.grid_search import GridSearchCV + >>> from sklearn.svm import LinearSVC + >>> grid = GridSearchCV(LinearSVC(), param_grid={'C': [1, 10]}, scoring=ftwo_scorer) + +The second use case is to help build a completely new and custom scorer object +from a simple python function:: + + >>> def my_custom_loss_func(ground_truth, predictions): + ... diff = np.abs(ground_truth - predictions).max() + ... return np.log(1 + diff) + ... + >>> my_custom_scorer = Scorer(my_custom_loss_func, greater_is_better=False) + >>> grid = GridSearchCV(LinearSVC(), param_grid={'C': [1, 10]}, scoring=my_custom_scorer) + +:class:`Scorer` takes as parameters the function you want to use, whether it is +a score (``greater_is_better=True``) or a loss (``greater_is_better=False``), +whether the function you provided takes predictions as input +(``needs_threshold=False``) or needs confidence scores +(``needs_threshold=True``) and any additional parameters, such as ``beta`` in +the example above. + + +Implementing your own scoring object +------------------------------------ +You can generate even more flexible model scores by constructing your own +scoring object from scratch, without using the :class:`Scorer` helper class. +The requirements that a callable can be used for model selection are as +follows: + +- It can be called with parameters ``(estimator, X, y)``, where ``estimator`` + it the model that should be evaluated, ``X`` is validation data and ``y`` is + the ground truth target for ``X`` (in the supervised case) or ``None`` in the + unsupervised case. + +- The call returns a number indicating the quality of estimator. + +- The callable has a boolean attribute ``greater_is_better`` which indicates whether + high or low values correspond to a better estimator. + +Objects that meet those conditions as said to implement the sklearn Scorer +protocol. + + .. _dummy_estimators: Dummy estimators @@ -772,10 +879,9 @@ Dummy estimators .. currentmodule:: sklearn.dummy -When doing supervised learning, a simple sanity check consists in comparing one's -estimator against simple rules of thumb. -:class:`DummyClassifier` implements three such simple strategies for -classification: +When doing supervised learning, a simple sanity check consists in comparing +one's estimator against simple rules of thumb. :class:`DummyClassifier` +implements three such simple strategies for classification: - `stratified` generates randomly predictions by respecting the training set's class distribution, diff --git a/doc/templates/class_with_call.rst b/doc/templates/class_with_call.rst new file mode 100644 index 0000000000000..753579c15ba94 --- /dev/null +++ b/doc/templates/class_with_call.rst @@ -0,0 +1,13 @@ +{{ fullname }} +{{ underline }} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + + {% block methods %} + .. automethod:: __init__ + .. automethod:: __call__ + {% endblock %} + + diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 4dc0da7f5556c..67610a321a93e 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -11,6 +11,13 @@ Changelog - Hyperlinks to documentation in example code on the website by `Martin Luessi`_. + - :class:`grid_search.GridSearchCV` and + :func:`cross_validation.cross_val_score` now support the use of advanced + scoring function such as area under the ROC curve and f-beta scores. + See :ref:`score_func_objects` for details. By `Andreas Müller`_. + Passing a function from :mod:`sklearn.metrics` as ``score_func`` is + deprecated. + .. _changes_0_13: diff --git a/examples/grid_search_digits.py b/examples/grid_search_digits.py index 7c6269f6e09ef..de4f8994196ca 100644 --- a/examples/grid_search_digits.py +++ b/examples/grid_search_digits.py @@ -22,8 +22,6 @@ from sklearn.cross_validation import train_test_split from sklearn.grid_search import GridSearchCV from sklearn.metrics import classification_report -from sklearn.metrics import precision_score -from sklearn.metrics import recall_score from sklearn.svm import SVC print(__doc__) @@ -46,16 +44,13 @@ 'C': [1, 10, 100, 1000]}, {'kernel': ['linear'], 'C': [1, 10, 100, 1000]}] -scores = [ - ('precision', precision_score), - ('recall', recall_score), -] +scores = ['precision', 'recall'] -for score_name, score_func in scores: - print("# Tuning hyper-parameters for %s" % score_name) +for score in scores: + print("# Tuning hyper-parameters for %s" % score) print() - clf = GridSearchCV(SVC(C=1), tuned_parameters, score_func=score_func) + clf = GridSearchCV(SVC(C=1), tuned_parameters, scoring=score) clf.fit(X_train, y_train, cv=5) print("Best parameters set found on development set:") diff --git a/examples/plot_permutation_test_for_classification.py b/examples/plot_permutation_test_for_classification.py index 1c955c4e968cc..1a26b13f6c4b5 100644 --- a/examples/plot_permutation_test_for_classification.py +++ b/examples/plot_permutation_test_for_classification.py @@ -22,7 +22,6 @@ from sklearn.svm import SVC from sklearn.cross_validation import StratifiedKFold, permutation_test_score from sklearn import datasets -from sklearn.metrics import accuracy_score ############################################################################## @@ -43,7 +42,7 @@ cv = StratifiedKFold(y, 2) score, permutation_scores, pvalue = permutation_test_score( - svm, X, y, accuracy_score, cv=cv, n_permutations=100, n_jobs=1) + svm, X, y, scoring="accuracy", cv=cv, n_permutations=100, n_jobs=1) print("Classification score %s (pvalue : %s)" % (score, pvalue)) diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index a5422b8fd6638..1593e5e9f1b1e 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -22,6 +22,7 @@ from .utils import check_arrays, check_random_state, safe_mask from .utils.fixes import unique from .externals.joblib import Parallel, delayed +from .metrics import SCORERS, Scorer __all__ = ['Bootstrap', 'KFold', @@ -1031,13 +1032,12 @@ def __len__(self): ############################################################################## -def _cross_val_score(estimator, X, y, score_func, train, test, verbose, +def _cross_val_score(estimator, X, y, scorer, train, test, verbose, fit_params): """Inner loop for cross validation""" n_samples = X.shape[0] if sp.issparse(X) else len(X) fit_params = dict([(k, np.asarray(v)[train] - if hasattr(v, '__len__') - and len(v) == n_samples else v) + if hasattr(v, '__len__') and len(v) == n_samples else v) for k, v in fit_params.items()]) if not hasattr(X, "shape"): if getattr(estimator, "_pairwise", False): @@ -1057,24 +1057,26 @@ def _cross_val_score(estimator, X, y, score_func, train, test, verbose, X_test = X[safe_mask(X, test)] if y is None: - estimator.fit(X_train, **fit_params) - if score_func is None: - score = estimator.score(X_test) - else: - score = score_func(X_test) + y_train = None + y_test = None else: - estimator.fit(X_train, y[train], **fit_params) - if score_func is None: - score = estimator.score(X_test, y[test]) - else: - score = score_func(y[test], estimator.predict(X_test)) + y_train = y[train] + y_test = y[test] + estimator.fit(X_train, y_train, **fit_params) + if scorer is None: + score = estimator.score(X_test, y_test) + else: + score = scorer(estimator, X_test, y_test) + if not isinstance(score, numbers.Number): + raise ValueError("scoring must return a number, got %s (%s)" + " instead." % (str(score), type(score))) if verbose > 1: print("score: %f" % score) return score -def cross_val_score(estimator, X, y=None, score_func=None, cv=None, n_jobs=1, - verbose=0, fit_params=None): +def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, + verbose=0, fit_params=None, score_func=None): """Evaluate a score by cross-validation Parameters @@ -1089,12 +1091,11 @@ def cross_val_score(estimator, X, y=None, score_func=None, cv=None, n_jobs=1, The target variable to try to predict in the case of supervised learning. - score_func : callable, optional - Score function to use for evaluation. - Has priority over the score function in the estimator. - In a non-supervised setting, where y is None, it takes the test - data (X_test) as its only argument. In a supervised setting it takes - the test target (y_true) and the test prediction (y_pred) as arguments. + scoring : string or callable, optional + Either one of either a string ("zero_one", "f1", "roc_auc", ... for + classification, "mse", "r2", ... for regression) or a callable. + See 'Scoring objects' in the model evaluation section of the user guide + for details. cv : cross-validation generator, optional A cross-validation generator. If None, a 3-fold cross @@ -1118,10 +1119,18 @@ def cross_val_score(estimator, X, y=None, score_func=None, cv=None, n_jobs=1, """ X, y = check_arrays(X, y, sparse_format='csr', allow_lists=True) cv = check_cv(cv, X, y, classifier=is_classifier(estimator)) - if score_func is None: - if not hasattr(estimator, 'score'): + if score_func is not None: + warnings.warn("Passing function as ``score_func`` is " + "deprecated and will be removed in 0.15. " + "Either use strings or score objects.", stacklevel=2) + scorer = Scorer(score_func) + elif isinstance(scoring, basestring): + scorer = SCORERS[scoring] + else: + scorer = scoring + if scorer is None and not hasattr(estimator, 'score'): raise TypeError( - "If no score_func is specified, the estimator passed " + "If no scoring is specified, the estimator passed " "should have a 'score' method. The estimator %s " "does not." % estimator) # We clone the estimator to make sure that all the folds are @@ -1129,19 +1138,17 @@ def cross_val_score(estimator, X, y=None, score_func=None, cv=None, n_jobs=1, fit_params = fit_params if fit_params is not None else {} scores = Parallel(n_jobs=n_jobs, verbose=verbose)( delayed(_cross_val_score)( - clone(estimator), X, y, score_func, - train, test, verbose, fit_params) + clone(estimator), X, y, scorer, train, test, verbose, fit_params) for train, test in cv) return np.array(scores) -def _permutation_test_score(estimator, X, y, cv, score_func): +def _permutation_test_score(estimator, X, y, cv, scorer): """Auxilary function for permutation_test_score""" avg_score = [] for train, test in cv: - avg_score.append(score_func(y[test], - estimator.fit(X[train], - y[train]).predict(X[test]))) + estimator.fit(X[train], y[train]) + avg_score.append(scorer(estimator, X[test], y[test])) return np.mean(avg_score) @@ -1197,9 +1204,9 @@ def check_cv(cv, X=None, y=None, classifier=False): return cv -def permutation_test_score(estimator, X, y, score_func, cv=None, +def permutation_test_score(estimator, X, y, scoring=None, cv=None, n_permutations=100, n_jobs=1, labels=None, - random_state=0, verbose=0): + random_state=0, verbose=0, score_func=None): """Evaluate the significance of a cross-validated score with permutations Parameters @@ -1214,12 +1221,11 @@ def permutation_test_score(estimator, X, y, score_func, cv=None, The target variable to try to predict in the case of supervised learning. - score_func : callable - Callable taking as arguments the test targets (y_test) and - the predicted targets (y_pred) and returns a float. The score - functions are expected to return a bigger value for a better result - otherwise the returned value does not correspond to a p-value (see - Returns below for further details). + scoring : string or object, optional + Either one of either a string ("zero_one", "f1", "roc_auc", ... for + classification, "mse", "r2", ... for regression) or a callable. + See 'Scoring objects' in the model evaluation section of the user guide + for details. cv : integer or crossvalidation generator, optional If an integer is passed, it is the number of fold (default 3). @@ -1268,15 +1274,28 @@ def permutation_test_score(estimator, X, y, score_func, cv=None, X, y = check_arrays(X, y, sparse_format='csr') cv = check_cv(cv, X, y, classifier=is_classifier(estimator)) + if score_func is not None: + warnings.warn("Passing function as ``score_func`` is " + "deprecated and will be removed in 0.15. " + "Either use strings or score objects.") + scorer = Scorer(score_func) + elif isinstance(scoring, basestring): + scorer = SCORERS[scoring] + else: + scorer = scoring + + if scorer is None: + raise ValueError("No valid scoring provided.") + random_state = check_random_state(random_state) # We clone the estimator to make sure that all the folds are # independent, and that it is pickle-able. - score = _permutation_test_score(clone(estimator), X, y, cv, score_func) + score = _permutation_test_score(clone(estimator), X, y, cv, scorer) permutation_scores = Parallel(n_jobs=n_jobs, verbose=verbose)( - delayed(_permutation_test_score)(clone(estimator), X, - _shuffle(y, labels, random_state), - cv, score_func) + delayed(_permutation_test_score)( + clone(estimator), X, _shuffle(y, labels, random_state), cv, + scorer) for _ in range(n_permutations)) permutation_scores = np.array(permutation_scores) pvalue = (np.sum(permutation_scores >= score) + 1.0) / (n_permutations + 1) diff --git a/sklearn/grid_search.py b/sklearn/grid_search.py index c6ec066960464..e9b07b3356dc4 100644 --- a/sklearn/grid_search.py +++ b/sklearn/grid_search.py @@ -9,6 +9,8 @@ from itertools import product import time +import warnings +import numbers import numpy as np @@ -16,8 +18,9 @@ from .base import MetaEstimatorMixin from .cross_validation import check_cv from .externals.joblib import Parallel, delayed, logger -from .utils import safe_mask, check_arrays from .utils.validation import _num_samples +from .utils import check_arrays, safe_mask +from .metrics import SCORERS, Scorer __all__ = ['GridSearchCV', 'IterGrid', 'fit_grid_point'] @@ -68,8 +71,8 @@ def __iter__(self): yield params -def fit_grid_point(X, y, base_clf, clf_params, train, test, loss_func, - score_func, verbose, **fit_params): +def fit_grid_point(X, y, base_clf, clf_params, train, test, scorer, + verbose, loss_func=None, **fit_params): """Run fit on one set of parameters Returns the score and the instance of the classifier @@ -77,7 +80,7 @@ def fit_grid_point(X, y, base_clf, clf_params, train, test, loss_func, if verbose > 1: start_time = time.time() msg = '%s' % (', '.join('%s=%s' % (k, v) - for k, v in clf_params.iteritems())) + for k, v in clf_params.iteritems())) print "[GridSearchCV] %s %s" % (msg, (64 - len(msg)) * '.') # update parameters of the classifier after a copy of its base structure clf = clone(base_clf) @@ -109,17 +112,21 @@ def fit_grid_point(X, y, base_clf, clf_params, train, test, loss_func, y_test = y[safe_mask(y, test)] y_train = y[safe_mask(y, train)] clf.fit(X_train, y_train, **fit_params) - if loss_func is not None: - y_pred = clf.predict(X_test) - this_score = -loss_func(y_test, y_pred) - elif score_func is not None: - y_pred = clf.predict(X_test) - this_score = score_func(y_test, y_pred) + + if scorer is not None: + this_score = scorer(clf, X_test, y_test) else: this_score = clf.score(X_test, y_test) else: clf.fit(X_train, **fit_params) - this_score = clf.score(X_test) + if scorer is not None: + this_score = scorer(clf, X_test) + else: + this_score = clf.score(X_test) + + if not isinstance(this_score, numbers.Number): + raise ValueError("scoring must return a number, got %s (%s)" + " instead." % (str(this_score), type(this_score))) if verbose > 2: msg += ", score=%f" % this_score @@ -172,24 +179,20 @@ class GridSearchCV(BaseEstimator, MetaEstimatorMixin): Parameters ---------- - estimator: object type that implements the "fit" and "predict" methods + estimator : object type that implements the "fit" and "predict" methods A object of that type is instantiated for each grid point. - param_grid: dict or list of dictionaries + param_grid : dict or list of dictionaries Dictionary with parameters names (string) as keys and lists of parameter settings to try as values, or a list of such dictionaries, in which case the grids spanned by each dictionary in the list are explored. - loss_func: callable, optional - function that takes 2 arguments and compares them in - order to evaluate the performance of prediciton (small is good) - if None is passed, the score of the estimator is maximized - - score_func: callable, optional - A function that takes 2 arguments and compares them in - order to evaluate the performance of prediction (high is good). - If None is passed, the score of the estimator is maximized. + scoring : string or callable, optional + Either one of either a string ("zero_one", "f1", "roc_auc", ... for + classification, "mse", "r2",... for regression) or a callable. + See 'Scoring objects' in the model evaluation section of the user guide + for details. fit_params : dict, optional parameters to pass to the fit method @@ -268,9 +271,7 @@ class GridSearchCV(BaseEstimator, MetaEstimatorMixin): Notes ------ The parameters selected are those that maximize the score of the left out - data, unless an explicit score_func is passed in which case it is used - instead. If a loss function loss_func is passed, it overrides the score - functions and is minimized. + data, unless an explicit score is passed in which case it is used instead. If `n_jobs` was set to a value higher than one, the data is copied for each point in the grid (and not `n_jobs` times). This is done for efficiency @@ -292,12 +293,13 @@ class GridSearchCV(BaseEstimator, MetaEstimatorMixin): """ - def __init__(self, estimator, param_grid, loss_func=None, score_func=None, - fit_params=None, n_jobs=1, iid=True, refit=True, cv=None, - verbose=0, pre_dispatch='2*n_jobs'): + def __init__(self, estimator, param_grid, scoring=None, loss_func=None, + score_func=None, fit_params=None, n_jobs=1, iid=True, + refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs'): if (not hasattr(estimator, 'score') and (not hasattr(estimator, 'predict') - or (loss_func is None and score_func is None))): + or (scoring is None and loss_func is None + and score_func is None))): raise TypeError("The provided estimator %s does not implement a " "score function. In this case, it needs to " "implement a predict fuction and you have to " @@ -310,6 +312,7 @@ def __init__(self, estimator, param_grid, loss_func=None, score_func=None, self.param_grid = param_grid self.loss_func = loss_func self.score_func = score_func + self.scoring = scoring self.n_jobs = n_jobs self.fit_params = fit_params if fit_params is not None else {} self.iid = iid @@ -362,14 +365,31 @@ def fit(self, X, y=None, **params): self._set_methods() return self + if self.loss_func is not None: + warnings.warn("Passing a loss function is " + "deprecated and will be removed in 0.15. " + "Either use strings or score objects.") + scorer = Scorer(self.loss_func, greater_is_better=False) + elif self.score_func is not None: + warnings.warn("Passing function as ``score_func`` is " + "deprecated and will be removed in 0.15. " + "Either use strings or score objects.") + scorer = Scorer(self.score_func) + elif isinstance(self.scoring, basestring): + scorer = SCORERS[self.scoring] + else: + scorer = self.scoring + + self.scorer_ = scorer + pre_dispatch = self.pre_dispatch out = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, pre_dispatch=pre_dispatch)( - delayed(fit_grid_point)( - X, y, base_clf, clf_params, train, test, - self.loss_func, self.score_func, self.verbose, - **self.fit_params) - for clf_params in grid for train, test in cv) + delayed(fit_grid_point)(X, y, base_clf, clf_params, + train, test, scorer, + self.verbose, + **self.fit_params) for + clf_params in grid for train, test in cv) # Out is a list of triplet: score, estimator, n_test_samples n_grid_points = len(list(grid)) @@ -398,9 +418,19 @@ def fit(self, X, y=None, **params): # Note: we do not use max(out) to make ties deterministic even if # comparison on estimator instances is not deterministic - best_score = -np.inf + if scorer is not None: + greater_is_better = scorer.greater_is_better + else: + greater_is_better = True + + if greater_is_better: + best_score = -np.inf + else: + best_score = np.inf + for score, params in scores: - if score > best_score: + if ((score > best_score and greater_is_better) + or (score < best_score and not greater_is_better)): best_score = score best_params = params @@ -429,9 +459,9 @@ def fit(self, X, y=None, **params): def score(self, X, y=None): if hasattr(self.best_estimator_, 'score'): return self.best_estimator_.score(X, y) - if self.score_func is None: + if self.scorer_ is None: raise ValueError("No score function explicitly defined, " "and the estimator doesn't provide one %s" % self.best_estimator_) y_predicted = self.predict(X) - return self.score_func(y, y_predicted) + return self.scorer(y, y_predicted) diff --git a/sklearn/kernel_approximation.py b/sklearn/kernel_approximation.py index 16ee09a92809e..1365656d20c5e 100644 --- a/sklearn/kernel_approximation.py +++ b/sklearn/kernel_approximation.py @@ -388,7 +388,7 @@ class Nystroem(BaseEstimator, TransformerMixin): RBFSampler : An approximation to the RBF kernel using random Fourier features. - sklearn.metric.pairwise.kernel_metrics : List of build-in kernels. + sklearn.metric.pairwise.kernel_metrics : List of built-in kernels. """ def __init__(self, kernel="rbf", gamma=None, coef0=1, degree=3, n_components=100, random_state=None): diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index 0e34751682970..f1a2762b7c957 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -28,6 +28,8 @@ from .metrics import zero_one from .metrics import zero_one_score +from .scorer import Scorer, SCORERS + from . import cluster from .cluster import (adjusted_rand_score, adjusted_mutual_info_score, @@ -78,4 +80,6 @@ 'silhouette_score', 'silhouette_samples', 'v_measure_score', - 'zero_one_loss'] + 'zero_one_loss', + 'Scorer', + 'SCORERS'] diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 81fa9e3d969cf..dec2ec7aa8902 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -526,10 +526,10 @@ def chi2_kernel(X, Y=None, gamma=1.): def distance_metrics(): - """ Valid metrics for pairwise_distances + """Valid metrics for pairwise_distances. This function simply returns the valid pairwise distance metrics. - It exists, however, to allow for a verbose description of the mapping for + It exists to allow for a description of the mapping for each of the valid strings. The valid distance metrics, and the function they map to, are: @@ -537,11 +537,11 @@ def distance_metrics(): ============ ==================================== metric Function ============ ==================================== - 'cityblock' sklearn.pairwise.manhattan_distances - 'euclidean' sklearn.pairwise.euclidean_distances - 'l1' sklearn.pairwise.manhattan_distances - 'l2' sklearn.pairwise.euclidean_distances - 'manhattan' sklearn.pairwise.manhattan_distances + 'cityblock' metrics.pairwise.manhattan_distances + 'euclidean' metrics.pairwise.euclidean_distances + 'l1' metrics.pairwise.manhattan_distances + 'l2' metrics.pairwise.euclidean_distances + 'manhattan' metrics.pairwise.manhattan_distances ============ ==================================== """ diff --git a/sklearn/metrics/scorer.py b/sklearn/metrics/scorer.py new file mode 100644 index 0000000000000..0f01a0b8439d0 --- /dev/null +++ b/sklearn/metrics/scorer.py @@ -0,0 +1,136 @@ +""" +The :mod:`sklearn.metrics.scorer` submodule implements a flexible +interface for model selection and evaluation using +arbitrary score functions. + +A Scorer object is a callable that can be passed to +:class:`sklearn.grid_search.GridSearchCV` or +:func:`sklearn.cross_validation.cross_val_score` as the ``scoring`` parameter, +to specify how a model should be evaluated. + +The signature of the call is ``(estimator, X, y)`` where ``estimator`` +is the model to be evaluated, ``X`` is the test data and ``y`` is the +ground truth labeling (or ``None`` in the case of unsupervised models). +""" + +# Authors: Andreas Mueller +# Liscence: Simplified BSD + +import numpy as np + +from . import (r2_score, mean_squared_error, accuracy_score, f1_score, + auc_score, average_precision_score, precision_score, + recall_score) + +from .cluster import adjusted_rand_score + + +class Scorer(object): + """Flexible scores for any estimator. + + This class wraps estimator scoring functions for the use in GridSearchCV + and cross_val_score. It takes a score function, such as ``accuracy_score``, + ``mean_squared_error``, ``adjusted_rand_index`` or ``average_precision`` + and provides a call method. + + Parameters + ---------- + score_func : callable, + Score function (or loss function) with signature + ``score_func(y, y_pred, **kwargs)``. + + greater_is_better : boolean, default=True + Whether score_func is a score function (default), meaning high is good, + or a loss function, meaning low is good. + + needs_threshold : bool, default=False + Whether score_func takes a continuous decision certainty. + For example ``average_precision`` or the area under the roc curve + can not be computed using predictions alone, but need the output of + ``decision_function`` or ``predict_proba``. + + **kwargs : additional arguments + Additional parameters to be passed to score_func. + + Examples + -------- + >>> from sklearn.metrics import fbeta_score, Scorer + >>> ftwo_scorer = Scorer(fbeta_score, beta=2) + >>> from sklearn.grid_search import GridSearchCV + >>> from sklearn.svm import LinearSVC + >>> grid = GridSearchCV(LinearSVC(), param_grid={'C': [1, 10]}, + ... scoring=ftwo_scorer) + """ + def __init__(self, score_func, greater_is_better=True, + needs_threshold=False, **kwargs): + self.score_func = score_func + self.greater_is_better = greater_is_better + self.needs_threshold = needs_threshold + self.kwargs = kwargs + + def __repr__(self): + kwargs_string = "".join([", %s=%s" % (str(k), str(v)) + for k, v in self.kwargs.items()]) + return ("Scorer(score_func=%s, greater_is_better=%s, needs_thresholds=" + "%s%s)" % (self.score_func.__name__, self.greater_is_better, + self.needs_threshold, kwargs_string)) + + def __call__(self, estimator, X, y): + """Score X and y using the provided estimator. + + Parameters + ---------- + estimator : object + Trained estimator to use for scoring. + If ``needs_threshold`` is True, estimator needs + to provide ``decision_function`` or ``predict_proba``. + Otherwise, estimator needs to provide ``predict``. + + X : array-like or sparse matrix + Test data that will be scored by the estimator. + + y : array-like + True prediction for X. + + Returns + ------- + score : float + Score function applied to prediction of estimator on X. + """ + if self.needs_threshold: + if len(np.unique(y)) > 2: + raise ValueError("This classification score only " + "supports binary classification.") + try: + y_pred = estimator.decision_function(X).ravel() + except (NotImplementedError, AttributeError): + y_pred = estimator.predict_proba(X)[:, 1] + return self.score_func(y, y_pred, **self.kwargs) + else: + y_pred = estimator.predict(X) + return self.score_func(y, y_pred, **self.kwargs) + + +# Standard regression scores +r2_scorer = Scorer(r2_score) +mse_scorer = Scorer(mean_squared_error, greater_is_better=False) + +# Standard Classification Scores +accuracy_scorer = Scorer(accuracy_score) +f1_scorer = Scorer(f1_score) + +# Score functions that need decision values +auc_scorer = Scorer(auc_score, greater_is_better=True, needs_threshold=True) +average_precision_scorer = Scorer(average_precision_score, + needs_threshold=True) +precision_scorer = Scorer(precision_score) +recall_scorer = Scorer(recall_score) + +# Clustering scores +ari_scorer = Scorer(adjusted_rand_score) + +SCORERS = dict(r2=r2_scorer, mse=mse_scorer, accuracy=accuracy_scorer, + f1=f1_scorer, roc_auc=auc_scorer, + average_precision=average_precision_scorer, + precision=precision_scorer, recall=recall_scorer, + ari=ari_scorer) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py new file mode 100644 index 0000000000000..7777f15de1a7e --- /dev/null +++ b/sklearn/metrics/tests/test_score_objects.py @@ -0,0 +1,99 @@ +import pickle + +from sklearn.utils.testing import assert_almost_equal +from sklearn.utils.testing import assert_raises + +from sklearn.metrics import f1_score, r2_score, auc_score, fbeta_score +from sklearn.metrics.cluster import adjusted_rand_score +from sklearn.metrics import SCORERS, Scorer +from sklearn.svm import LinearSVC +from sklearn.cluster import KMeans +from sklearn.linear_model import Ridge, LogisticRegression +from sklearn.tree import DecisionTreeClassifier +from sklearn.datasets import make_blobs, load_diabetes +from sklearn.cross_validation import train_test_split, cross_val_score +from sklearn.grid_search import GridSearchCV + + +def test_classification_scores(): + X, y = make_blobs(random_state=0) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + clf = LinearSVC(random_state=0) + clf.fit(X_train, y_train) + score1 = SCORERS['f1'](clf, X_test, y_test) + score2 = f1_score(y_test, clf.predict(X_test)) + assert_almost_equal(score1, score2) + + # test fbeta score that takes an argument + scorer = Scorer(fbeta_score, beta=2) + score1 = scorer(clf, X_test, y_test) + score2 = fbeta_score(y_test, clf.predict(X_test), beta=2) + assert_almost_equal(score1, score2) + + # test that custom scorer can be pickled + unpickled_scorer = pickle.loads(pickle.dumps(scorer)) + score3 = unpickled_scorer(clf, X_test, y_test) + assert_almost_equal(score1, score3) + + # smoke test the repr: + repr(fbeta_score) + + +def test_regression_scores(): + diabetes = load_diabetes() + X, y = diabetes.data, diabetes.target + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + clf = Ridge() + clf.fit(X_train, y_train) + score1 = SCORERS['r2'](clf, X_test, y_test) + score2 = r2_score(y_test, clf.predict(X_test)) + assert_almost_equal(score1, score2) + + +def test_thresholded_scores(): + X, y = make_blobs(random_state=0, centers=2) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + clf = LogisticRegression(random_state=0) + clf.fit(X_train, y_train) + score1 = SCORERS['roc_auc'](clf, X_test, y_test) + score2 = auc_score(y_test, clf.decision_function(X_test)) + score3 = auc_score(y_test, clf.predict_proba(X_test)[:, 1]) + assert_almost_equal(score1, score2) + assert_almost_equal(score1, score3) + + # same for an estimator without decision_function + clf = DecisionTreeClassifier() + clf.fit(X_train, y_train) + score1 = SCORERS['roc_auc'](clf, X_test, y_test) + score2 = auc_score(y_test, clf.predict_proba(X_test)[:, 1]) + assert_almost_equal(score1, score2) + + # Test that an exception is raised on more than two classes + X, y = make_blobs(random_state=0, centers=3) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + clf.fit(X_train, y_train) + assert_raises(ValueError, SCORERS['roc_auc'], clf, X_test, y_test) + + +def test_unsupervised_scores(): + # test clustering where there is some true y. + # We don't have any real unsupervised SCORERS yet + X, y = make_blobs(random_state=0, centers=2) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + km = KMeans(n_clusters=3) + km.fit(X_train) + score1 = SCORERS['ari'](km, X_test, y_test) + score2 = adjusted_rand_score(y_test, km.predict(X_test)) + assert_almost_equal(score1, score2) + + +def test_raises_on_score_list(): + # test that when a list of scores is returned, we raise proper errors. + X, y = make_blobs(random_state=0) + f1_scorer_no_average = Scorer(f1_score, average=None) + clf = DecisionTreeClassifier() + assert_raises(ValueError, cross_val_score, clf, X, y, + scoring=f1_scorer_no_average) + grid_search = GridSearchCV(clf, scoring=f1_scorer_no_average, + param_grid={'max_depth': [1, 2]}) + assert_raises(ValueError, grid_search.fit, X, y) diff --git a/sklearn/tests/test_cross_validation.py b/sklearn/tests/test_cross_validation.py index 5928e06e8621f..97d63f1565bdd 100644 --- a/sklearn/tests/test_cross_validation.py +++ b/sklearn/tests/test_cross_validation.py @@ -1,11 +1,13 @@ """Test the cross_validation module""" -import numpy as np import warnings + +import numpy as np from scipy.sparse import coo_matrix from sklearn.utils.testing import assert_true from sklearn.utils.testing import assert_equal +from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_greater from sklearn.utils.testing import assert_less @@ -20,9 +22,10 @@ from sklearn.datasets import load_iris from sklearn.metrics import accuracy_score from sklearn.metrics import f1_score -from sklearn.metrics import mean_squared_error -from sklearn.metrics import r2_score from sklearn.metrics import explained_variance_score +from sklearn.metrics import fbeta_score +from sklearn.metrics import Scorer + from sklearn.svm import SVC from sklearn.linear_model import Ridge @@ -254,6 +257,11 @@ def test_cross_val_score_precomputed(): svm = SVC(kernel="precomputed") assert_raises(ValueError, cval.cross_val_score, svm, X, y) + # test error is raised when the precomputed kernel is not array-like + # or sparse + assert_raises(ValueError, cval.cross_val_score, svm, + linear_kernel.tolist(), y) + def test_cross_val_score_fit_params(): clf = MockClassifier() @@ -266,24 +274,16 @@ def test_cross_val_score_fit_params(): def test_cross_val_score_score_func(): clf = MockClassifier() - _score_func1_args = [] - _score_func2_args = [] - - def score_func1(data): - _score_func1_args.append(data) - return 1.0 + _score_func_args = [] - def score_func2(y_test, y_predict): - _score_func2_args.append((y_test, y_predict)) + def score_func(y_test, y_predict): + _score_func_args.append((y_test, y_predict)) return 1.0 - score1 = cval.cross_val_score(clf, X, score_func=score_func1) - assert_array_equal(score1, [1.0, 1.0, 1.0]) - assert len(_score_func1_args) == 3 - - score2 = cval.cross_val_score(clf, X, y, score_func=score_func2) - assert_array_equal(score2, [1.0, 1.0, 1.0]) - assert len(_score_func2_args) == 3 + with warnings.catch_warnings(True): + score = cval.cross_val_score(clf, X, y, score_func=score_func) + assert_array_equal(score, [1.0, 1.0, 1.0]) + assert len(_score_func_args) == 3 def test_cross_val_score_errors(): @@ -335,13 +335,18 @@ def test_cross_val_score_with_score_func_classification(): # Correct classification score (aka. zero / one score) - should be the # same as the default estimator score zo_scores = cval.cross_val_score(clf, iris.data, iris.target, - score_func=accuracy_score, cv=5) + scoring="accuracy", cv=5) assert_array_almost_equal(zo_scores, [1., 0.97, 0.90, 0.97, 1.], 2) # F1 score (class are balanced so f1_score should be equal to zero/one # score f1_scores = cval.cross_val_score(clf, iris.data, iris.target, - score_func=f1_score, cv=5) + scoring="f1", cv=5) + assert_array_almost_equal(f1_scores, [1., 0.97, 0.90, 0.97, 1.], 2) + # also test deprecated old way + with warnings.catch_warnings(record=True): + f1_scores = cval.cross_val_score(clf, iris.data, iris.target, + score_func=f1_score, cv=5) assert_array_almost_equal(f1_scores, [1., 0.97, 0.90, 0.97, 1.], 2) @@ -356,18 +361,18 @@ def test_cross_val_score_with_score_func_regression(): # R2 score (aka. determination coefficient) - should be the # same as the default estimator score - r2_scores = cval.cross_val_score(reg, X, y, score_func=r2_score, cv=5) + r2_scores = cval.cross_val_score(reg, X, y, scoring="r2", cv=5) assert_array_almost_equal(r2_scores, [0.94, 0.97, 0.97, 0.99, 0.92], 2) # Mean squared error - mse_scores = cval.cross_val_score(reg, X, y, cv=5, - score_func=mean_squared_error) + mse_scores = cval.cross_val_score(reg, X, y, cv=5, scoring="mse") expected_mse = np.array([763.07, 553.16, 274.38, 273.26, 1681.99]) assert_array_almost_equal(mse_scores, expected_mse, 2) # Explained variance - ev_scores = cval.cross_val_score(reg, X, y, cv=5, - score_func=explained_variance_score) + with warnings.catch_warnings(True): + ev_scores = cval.cross_val_score(reg, X, y, cv=5, + score_func=explained_variance_score) assert_array_almost_equal(ev_scores, [0.94, 0.97, 0.97, 0.99, 0.92], 2) @@ -380,22 +385,28 @@ def test_permutation_score(): cv = cval.StratifiedKFold(y, 2) score, scores, pvalue = cval.permutation_test_score( - svm, X, y, accuracy_score, cv) - + svm, X, y, "accuracy", cv) assert_greater(score, 0.9) - np.testing.assert_almost_equal(pvalue, 0.0, 1) + assert_almost_equal(pvalue, 0.0, 1) score_label, _, pvalue_label = cval.permutation_test_score( - svm, X, y, accuracy_score, cv, labels=np.ones(y.size), random_state=0) - + svm, X, y, "accuracy", cv, labels=np.ones(y.size), random_state=0) assert_true(score_label == score) assert_true(pvalue_label == pvalue) + # test with custom scoring object + scorer = Scorer(fbeta_score, beta=2) + score_label, _, pvalue_label = cval.permutation_test_score( + svm, X, y, scoring=scorer, cv=cv, labels=np.ones(y.size), + random_state=0) + assert_almost_equal(score_label, .95, 2) + assert_almost_equal(pvalue_label, 0.01, 3) + # check that we obtain the same results with a sparse representation svm_sparse = SVC(kernel='linear') cv_sparse = cval.StratifiedKFold(y, 2, indices=True) score_label, _, pvalue_label = cval.permutation_test_score( - svm_sparse, X_sparse, y, accuracy_score, cv_sparse, + svm_sparse, X_sparse, y, "accuracy", cv_sparse, labels=np.ones(y.size), random_state=0) assert_true(score_label == score) @@ -405,8 +416,15 @@ def test_permutation_score(): y = np.mod(np.arange(len(y)), 3) score, scores, pvalue = cval.permutation_test_score(svm, X, y, - accuracy_score, cv) + "accuracy", cv) + + assert_less(score, 0.5) + assert_greater(pvalue, 0.4) + # test with deprecated interface + with warnings.catch_warnings(record=True): + score, scores, pvalue = cval.permutation_test_score( + svm, X, y, score_func=accuracy_score, cv=cv) assert_less(score, 0.5) assert_greater(pvalue, 0.4) diff --git a/sklearn/tests/test_grid_search.py b/sklearn/tests/test_grid_search.py index 04cf50e790710..e7a3574e95fb0 100644 --- a/sklearn/tests/test_grid_search.py +++ b/sklearn/tests/test_grid_search.py @@ -3,6 +3,7 @@ """ +import warnings from cStringIO import StringIO import sys @@ -13,15 +14,16 @@ from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_true from sklearn.utils.testing import assert_array_equal +from sklearn.utils.testing import assert_almost_equal from sklearn.base import BaseEstimator from sklearn.grid_search import GridSearchCV from sklearn.datasets.samples_generator import make_classification, make_blobs from sklearn.svm import LinearSVC, SVC from sklearn.cluster import KMeans, MeanShift -from sklearn.metrics import f1_score, precision_score -from sklearn.metrics.cluster import adjusted_rand_score -from sklearn.cross_validation import KFold +from sklearn.metrics import f1_score +from sklearn.metrics import Scorer +from sklearn.cross_validation import KFold, StratifiedKFold class MockClassifier(BaseEstimator): @@ -156,18 +158,18 @@ def test_grid_search_sparse(): assert_equal(C, C2) -def test_grid_search_sparse_score_func(): +def test_grid_search_sparse_scoring(): X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0) clf = LinearSVC() - cv = GridSearchCV(clf, {'C': [0.1, 1.0]}, score_func=f1_score) + cv = GridSearchCV(clf, {'C': [0.1, 1.0]}, scoring="f1") cv.fit(X_[:180], y_[:180]) y_pred = cv.predict(X_[180:]) C = cv.best_estimator_.C X_ = sp.csr_matrix(X_) clf = LinearSVC() - cv = GridSearchCV(clf, {'C': [0.1, 1.0]}, score_func=f1_score) + cv = GridSearchCV(clf, {'C': [0.1, 1.0]}, scoring="f1") cv.fit(X_[:180], y_[:180]) y_pred2 = cv.predict(X_[180:]) C2 = cv.best_estimator_.C @@ -178,16 +180,54 @@ def test_grid_search_sparse_score_func(): #np.testing.assert_allclose(f1_score(cv.predict(X_[:180]), y[:180]), # cv.score(X_[:180], y[:180])) - # test loss_func + # test loss where greater is worse def f1_loss(y_true_, y_pred_): return -f1_score(y_true_, y_pred_) - cv = GridSearchCV(clf, {'C': [0.1, 1.0]}, loss_func=f1_loss) + F1Loss = Scorer(f1_loss, greater_is_better=False) + cv = GridSearchCV(clf, {'C': [0.1, 1.0]}, scoring=F1Loss) cv.fit(X_[:180], y_[:180]) y_pred3 = cv.predict(X_[180:]) C3 = cv.best_estimator_.C - assert_array_equal(y_pred, y_pred3) assert_equal(C, C3) + assert_array_equal(y_pred, y_pred3) + + +def test_deprecated_score_func(): + # test that old deprecated way of passing a score / loss function is still + # supported + X, y = make_classification(n_samples=200, n_features=100, random_state=0) + clf = LinearSVC(random_state=0) + cv = GridSearchCV(clf, {'C': [0.1, 1.0]}, scoring="f1") + cv.fit(X[:180], y[:180]) + y_pred = cv.predict(X[180:]) + C = cv.best_estimator_.C + + clf = LinearSVC(random_state=0) + cv = GridSearchCV(clf, {'C': [0.1, 1.0]}, score_func=f1_score) + with warnings.catch_warnings(record=True): + # catch deprecation warning + cv.fit(X[:180], y[:180]) + y_pred_func = cv.predict(X[180:]) + C_func = cv.best_estimator_.C + + assert_array_equal(y_pred, y_pred_func) + assert_equal(C, C_func) + + # test loss where greater is worse + def f1_loss(y_true_, y_pred_): + return -f1_score(y_true_, y_pred_) + + clf = LinearSVC(random_state=0) + cv = GridSearchCV(clf, {'C': [0.1, 1.0]}, loss_func=f1_loss) + with warnings.catch_warnings(record=True): + # catch deprecation warning + cv.fit(X[:180], y[:180]) + y_pred_loss = cv.predict(X[180:]) + C_loss = cv.best_estimator_.C + + assert_array_equal(y_pred, y_pred_loss) + assert_equal(C, C_loss) def test_grid_search_precomputed_kernel(): @@ -261,7 +301,7 @@ def test_refit(): y = np.array([0] * 5 + [1] * 5) clf = GridSearchCV(BrokenClassifier(), [{'parameter': [0, 1]}], - score_func=precision_score, refit=True) + scoring="precision", refit=True) clf.fit(X, y) @@ -282,9 +322,14 @@ def test_unsupervised_grid_search(): X, y = make_blobs(random_state=0) km = KMeans(random_state=0) grid_search = GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4]), - score_func=adjusted_rand_score) + scoring='ari') + grid_search.fit(X, y) + # ARI can find the right number :) + assert_equal(grid_search.best_params_["n_clusters"], 3) + + # Now without a score, and without y + grid_search = GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4])) grid_search.fit(X) - # most number of clusters should be best assert_equal(grid_search.best_params_["n_clusters"], 4) @@ -293,4 +338,29 @@ def test_bad_estimator(): ms = MeanShift() assert_raises(TypeError, GridSearchCV, ms, param_grid=dict(gamma=[.1, 1, 10]), - score_func=adjusted_rand_score) + scoring='ari') + + +def test_grid_search_score_consistency(): + # test that correct scores are used + from sklearn.metrics import auc_score + clf = LinearSVC(random_state=0) + X, y = make_blobs(random_state=0, centers=2) + Cs = [.1, 1, 10] + for score in ['f1', 'roc_auc']: + grid_search = GridSearchCV(clf, {'C': Cs}, scoring=score) + grid_search.fit(X, y) + cv = StratifiedKFold(n_folds=3, y=y) + for C, scores in zip(Cs, grid_search.grid_scores_): + clf.set_params(C=C) + scores = scores[2] # get the separate runs from grid scores + i = 0 + for train, test in cv: + clf.fit(X[train], y[train]) + if score == "f1": + correct_score = f1_score(y[test], clf.predict(X[test])) + elif score == "roc_auc": + correct_score = auc_score(y[test], + clf.decision_function(X[test])) + assert_almost_equal(correct_score, scores[i]) + i += 1