diff --git a/doc/modules/calibration.rst b/doc/modules/calibration.rst index a462ff322932c..c23b67a489e4a 100644 --- a/doc/modules/calibration.rst +++ b/doc/modules/calibration.rst @@ -1,4 +1,4 @@ -.. _calibration: +.. _probability_calibration: ======================= Probability calibration @@ -208,3 +208,124 @@ a similar decrease in log-loss. .. [5] On the combination of forecast probabilities for consecutive precipitation periods. Wea. Forecasting, 5, 640–650., Wilks, D. S., 1990a + + +.. _decision_threshold_calibration: + +============================== +Decision Threshold calibration +============================== + +.. currentmodule:: sklearn.calibration + +Often Machine Learning classifiers base their +predictions on real-valued decision functions or probability estimates that +carry the inherited biases of their models. Additionally when using a machine +learning model the evaluation criteria can differ from the optimisation +objectives used by the model during training. + +When predicting between two classes it is commonly advised that an appropriate +decision threshold is estimated based on some cutoff criteria rather than +arbitrarily using the midpoint of the space of possible values. Estimating a +decision threshold for a specific use case can help to increase the overall +accuracy of the model and provide better handling for sensitive classes. + +.. currentmodule:: sklearn.calibration + +:class:`CutoffClassifier` can be used as a wrapper around a model for binary +classification to help obtain a more appropriate decision threshold and use it +for predicting new samples. + +Usage +----- + +To use the :class:`CutoffClassifier` you need to provide an estimator that has +a ``decision_function`` or a ``predict_proba`` method. The ``method`` +parameter controls whether the first will be preferred over the second if both +are available. + +The wrapped estimator can be pre-trained, in which case ``cv = 'prefit'``, or +not. If the classifier is not trained then a cross-validation loop specified by +the parameter ``cv`` can be used to obtain a decision threshold by averaging +all decision thresholds calculated on the hold-out parts of each cross +validation iteration. Finally the model is trained on all the provided data. +When using ``cv = 'prefit'`` you need to make sure to use a hold-out part of +your data for calibration. + +The strategies, controlled by the parameter ``strategy``, for finding +appropriate decision thresholds are based either on precision recall estimates +or true positive and true negative rates. Specifically: + +.. currentmodule:: sklearn.metrics + +* ``f_beta`` + selects a decision threshold that maximizes the :func:`fbeta_score`. The + value of beta is specified by the parameter ``beta``. The ``beta`` parameter + determines the weight of precision. When ``beta = 1`` both precision recall + get the same weight therefore the maximization target in this case is the + :func:`f1_score`. if ``beta < 1`` more weight is given to precision whereas + if ``beta > 1`` more weight is given to recall. + +* ``roc`` + selects the decision threshold for the point on the :func:`roc_curve` that + is closest to the ideal corner (0, 1) + +* ``max_tpr`` + selects the decision threshold for the point that yields the highest true + positive rate while maintaining a minimum true negative rate, specified by + the parameter ``threshold`` + +* ``max_tnr`` + selects the decision threshold for the point that yields the highest true + negative rate while maintaining a minimum true positive rate, specified by + the parameter ``threshold`` + +Here is a simple usage example:: + + >>> from sklearn.calibration import CutoffClassifier + >>> from sklearn.datasets import load_breast_cancer + >>> from sklearn.naive_bayes import GaussianNB + >>> from sklearn.metrics import precision_score + >>> from sklearn.model_selection import train_test_split + + >>> X, y = load_breast_cancer(return_X_y=True) + >>> X_train, X_test, y_train, y_test = train_test_split( + ... X, y, train_size=0.6, random_state=42) + >>> clf = CutoffClassifier(GaussianNB(), strategy='f_beta', beta=0.6, + ... cv=3).fit(X_train, y_train) + >>> y_pred = clf.predict(X_test) + >>> precision_score(y_test, y_pred) # doctest: +ELLIPSIS + 0.959... + +.. topic:: Examples: + + * :ref:`sphx_glr_auto_examples_calibration_plot_decision_threshold_calibration.py`: Decision + threshold calibration on the breast cancer dataset + +.. currentmodule:: sklearn.calibration + +The following image shows the results of using the :class:`CutoffClassifier` +for finding a decision threshold for a :class:`LogisticRegression` classifier +and an :class:`AdaBoostClassifier` for two use cases. + +.. figure:: ../auto_examples/calibration/images/sphx_glr_plot_decision_threshold_calibration_001.png + :target: ../auto_examples/calibration/plot_decision_threshold_calibration.html + :align: center + +In the first case we want to increase the overall accuracy of the classifier on +the breast cancer dataset. In the second case we want to find a decision +threshold that yields maximum true positive rate while maintaining a minimum +value for the true negative rate. + +.. topic:: References: + + * Receiver-operating characteristic (ROC) plots: a fundamental + evaluation tool in clinical medicine, MH Zweig, G Campbell - + Clinical chemistry, 1993 + +Notes +----- + +Calibrating the decision threshold of a classifier does not guarantee increased +performance. The generalisation ability of the obtained decision threshold has +to be evaluated. diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 57ccfb5cff704..6e44071aa23e6 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -50,7 +50,7 @@ Functions set_config show_versions -.. _calibration_ref: +.. _probability_calibration_ref: :mod:`sklearn.calibration`: Probability Calibration =================================================== @@ -59,7 +59,7 @@ Functions :no-members: :no-inherited-members: -**User guide:** See the :ref:`calibration` section for further details. +**User guide:** See the :ref:`probability_calibration` section for further details. .. currentmodule:: sklearn @@ -76,6 +76,25 @@ Functions calibration.calibration_curve +.. _decision_threshold_calibration_ref: + +:mod:`sklearn.calibration`: Decision Threshold Calibration +========================================================== + +.. automodule:: sklearn.calibration + :no-members: + :no-inherited-members: + +**User guide:** See the :ref:`decision_threshold_calibration` section for further details. + +.. currentmodule:: sklearn + +.. autosummary:: + :toctree: generated/ + :template: class.rst + + calibration.CutoffClassifier + .. _cluster_ref: :mod:`sklearn.cluster`: Clustering diff --git a/examples/calibration/README.txt b/examples/calibration/README.txt index 5e4a31b966b50..a820b63654f98 100644 --- a/examples/calibration/README.txt +++ b/examples/calibration/README.txt @@ -3,4 +3,4 @@ Calibration ----------------------- -Examples illustrating the calibration of predicted probabilities of classifiers. +Examples concerning the :mod:`sklearn.calibration` module. diff --git a/examples/calibration/plot_decision_threshold_calibration.py b/examples/calibration/plot_decision_threshold_calibration.py new file mode 100644 index 0000000000000..e14e680380e17 --- /dev/null +++ b/examples/calibration/plot_decision_threshold_calibration.py @@ -0,0 +1,167 @@ +""" +====================================================================== +Decision threshold (cutoff point) calibration on breast cancer dataset +====================================================================== + +Machine learning classifiers often base their predictions on real-valued +decision functions that don't always have accuracy as their objective. Moreover +the learning objective of a model can differ from the user's needs hence using +an arbitrary decision threshold as defined by the model can be not ideal. + +The CutoffClassifier can be used to calibrate the decision threshold of a model +in order to increase the classifier's trustworthiness. Optimization objectives +during the decision threshold calibration can be the true positive and / or +the true negative rate as well as the f beta score. + +In this example the decision threshold calibration is applied on two +classifiers trained on the breast cancer dataset. The goal in the first case is +to maximize the f1 score of the classifiers whereas in the second the goal is +to maximize the true positive rate while maintaining a minimum true negative +rate. + +As you can see after calibration the f1 score of the LogisticRegression +classifiers has increased slightly whereas the accuracy of the +AdaBoostClassifier classifier has stayed the same. + +For the second goal as seen after calibration both classifiers achieve better +true positive rate while their respective true negative rates have decreased +slightly or remained stable. +""" + +# Author: Prokopios Gryllos +# +# License: BSD 3 clause + +from __future__ import division + +import numpy as np + +from sklearn.ensemble import AdaBoostClassifier +from sklearn.metrics import confusion_matrix, f1_score +from sklearn.calibration import CutoffClassifier +from sklearn.linear_model import LogisticRegression +from sklearn.datasets import load_breast_cancer +import matplotlib.pyplot as plt +from sklearn.model_selection import train_test_split + + +print(__doc__) + +# percentage of the training set that will be used for calibration +calibration_samples_percentage = 0.2 + +X, y = load_breast_cancer(return_X_y=True) + +X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.6, + random_state=42) + +calibration_samples = int(len(X_train) * calibration_samples_percentage) + +lr = LogisticRegression().fit( + X_train[:-calibration_samples], y_train[:-calibration_samples]) + +y_pred_lr = lr.predict(X_test) +tn_lr, fp_lr, fn_lr, tp_lr = confusion_matrix(y_test, y_pred_lr).ravel() +tpr_lr = tp_lr / (tp_lr + fn_lr) +tnr_lr = tn_lr / (tn_lr + fp_lr) +f_one_lr = f1_score(y_test, y_pred_lr) + +ada = AdaBoostClassifier().fit( + X_train[:-calibration_samples], y_train[:-calibration_samples]) + +y_pred_ada = ada.predict(X_test) +tn_ada, fp_ada, fn_ada, tp_ada = confusion_matrix(y_test, y_pred_ada).ravel() +tpr_ada = tp_ada / (tp_ada + fn_ada) +tnr_ada = tn_ada / (tn_ada + fp_ada) +f_one_ada = f1_score(y_test, y_pred_ada) + +# objective 1: we want to calibrate the decision threshold in order to achieve +# better f1 score +lr_f_beta = CutoffClassifier( + lr, strategy='f_beta', method='predict_proba', beta=1, cv='prefit').fit( + X_train[calibration_samples:], y_train[calibration_samples:]) + +y_pred_lr_f_beta = lr_f_beta.predict(X_test) +f_one_lr_f_beta = f1_score(y_test, y_pred_lr_f_beta) + +ada_f_beta = CutoffClassifier( + ada, strategy='f_beta', method='predict_proba', beta=1, cv='prefit' +).fit(X_train[calibration_samples:], y_train[calibration_samples:]) + +y_pred_ada_f_beta = ada_f_beta.predict(X_test) +f_one_ada_f_beta = f1_score(y_test, y_pred_ada_f_beta) + +# objective 2: we want to maximize the true positive rate while the true +# negative rate is at least 0.7 +lr_max_tpr = CutoffClassifier( + lr, strategy='max_tpr', method='predict_proba', threshold=0.7, cv='prefit' +).fit(X_train[calibration_samples:], y_train[calibration_samples:]) + +y_pred_lr_max_tpr = lr_max_tpr.predict(X_test) +tn_lr_max_tpr, fp_lr_max_tpr, fn_lr_max_tpr, tp_lr_max_tpr = \ + confusion_matrix(y_test, y_pred_lr_max_tpr).ravel() +tpr_lr_max_tpr = tp_lr_max_tpr / (tp_lr_max_tpr + fn_lr_max_tpr) +tnr_lr_max_tpr = tn_lr_max_tpr / (tn_lr_max_tpr + fp_lr_max_tpr) + +ada_max_tpr = CutoffClassifier( + ada, strategy='max_tpr', method='predict_proba', threshold=0.7, cv='prefit' +).fit(X_train[calibration_samples:], y_train[calibration_samples:]) + +y_pred_ada_max_tpr = ada_max_tpr.predict(X_test) +tn_ada_max_tpr, fp_ada_max_tpr, fn_ada_max_tpr, tp_ada_max_tpr = \ + confusion_matrix(y_test, y_pred_ada_max_tpr).ravel() +tpr_ada_max_tpr = tp_ada_max_tpr / (tp_ada_max_tpr + fn_ada_max_tpr) +tnr_ada_max_tpr = tn_ada_max_tpr / (tn_ada_max_tpr + fp_ada_max_tpr) + +print('Calibrated threshold') +print('Logistic Regression classifier: {}'.format( + lr_max_tpr.decision_threshold_)) +print('AdaBoost classifier: {}'.format(ada_max_tpr.decision_threshold_)) +print('before calibration') +print('Logistic Regression classifier: tpr = {}, tnr = {}, f1 = {}'.format( + tpr_lr, tnr_lr, f_one_lr)) +print('AdaBoost classifier: tpr = {}, tpn = {}, f1 = {}'.format( + tpr_ada, tnr_ada, f_one_ada)) + +print('true positive and true negative rates after calibration') +print('Logistic Regression classifier: tpr = {}, tnr = {}, f1 = {}'.format( + tpr_lr_max_tpr, tnr_lr_max_tpr, f_one_lr_f_beta)) +print('AdaBoost classifier: tpr = {}, tnr = {}, f1 = {}'.format( + tpr_ada_max_tpr, tnr_ada_max_tpr, f_one_ada_f_beta)) + +######### +# plots # +######### +bar_width = 0.2 + +plt.subplot(2, 1, 1) +index = np.asarray([1, 2]) +plt.bar(index, [f_one_lr, f_one_ada], bar_width, color='r', + label='Before calibration') + +plt.bar(index + bar_width, [f_one_lr_f_beta, f_one_ada_f_beta], bar_width, + color='b', label='After calibration') + +plt.xticks(index + bar_width / 2, ('f1 logistic', 'f1 adaboost')) + +plt.ylabel('scores') +plt.title('f1 score') +plt.legend(bbox_to_anchor=(.5, -.2), loc='center', borderaxespad=0.) + +plt.subplot(2, 1, 2) +index = np.asarray([1, 2, 3, 4]) +plt.bar(index, [tpr_lr, tnr_lr, tpr_ada, tnr_ada], + bar_width, color='r', label='Before calibration') + +plt.bar(index + bar_width, + [tpr_lr_max_tpr, tnr_lr_max_tpr, tpr_ada_max_tpr, tnr_ada_max_tpr], + bar_width, color='b', label='After calibration') + +plt.xticks( + index + bar_width / 2, + ('tpr logistic', 'tnr logistic', 'tpr adaboost', 'tnr adaboost')) +plt.ylabel('scores') +plt.title('true positive & true negative rate') + +plt.subplots_adjust(hspace=0.6) +plt.show() diff --git a/sklearn/calibration.py b/sklearn/calibration.py index ed80523880cfd..46120a647d535 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -4,6 +4,7 @@ # Balazs Kegl # Jan Hendrik Metzen # Mathieu Blondel +# Prokopios Gryllos # # License: BSD 3 clause @@ -16,7 +17,8 @@ from scipy.optimize import fmin_bfgs from sklearn.preprocessing import LabelEncoder -from .base import BaseEstimator, ClassifierMixin, RegressorMixin, clone +from .base import (BaseEstimator, ClassifierMixin, RegressorMixin, + MetaEstimatorMixin, clone) from .preprocessing import label_binarize, LabelBinarizer from .utils import check_X_y, check_array, indexable, column_or_1d from .utils.validation import check_is_fitted, check_consistent_length @@ -25,6 +27,368 @@ from .svm import LinearSVC from .model_selection import check_cv from .metrics.classification import _check_binary_probabilistic_predictions +from .metrics.ranking import precision_recall_curve, roc_curve +from .utils.multiclass import type_of_target + + +class CutoffClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin): + """Decision threshold calibration for binary classification + + Meta estimator that calibrates the decision threshold (cutoff point) + that is used for prediction. The methods for picking cutoff points make use + of traditional binary classification evaluation statistics such as the + true positive and true negative rates and F-scores. + + If cv="prefit" the base estimator is assumed to be fitted and all data will + be used for the selection of the cutoff point. Otherwise the decision + threshold is calculated as the average of the thresholds resulting from the + cross-validation loop. + + Parameters + ---------- + base_estimator : obj + The binary classifier whose decision threshold will be adapted + according to the acquired cutoff point. The estimator must have a + decision_function or a predict_proba + + strategy : str, optional (default='roc') + The strategy to use for choosing the cutoff point + + 'roc' + selects the point on the roc curve that is closest to the ideal + corner (0, 1) + + 'f_beta' + selects a decision threshold that maximizes the f_beta score + + 'max_tpr' + selects the point that yields the highest true positive rate with + true negative rate at least equal to the value of the parameter + threshold + + 'max_tnr' + selects the point that yields the highest true negative rate with + true positive rate at least equal to the value of the parameter + threshold + + method : str or None, optional (default=None) + The method to be used for acquiring the score + + 'decision_function' + base_estimator.decision_function will be used for scoring + + 'predict_proba' + base_estimator.predict_proba will be used for scoring + + None + base_estimator.decision_function will be used first and if not + available base_estimator.predict_proba + + beta : float in [0, 1], optional (default=None) + beta value to be used in case strategy == 'f_beta' + + threshold : float in [0, 1] or None, (default=None) + In case strategy is 'max_tpr' or 'max_tnr' this parameter must be set + to specify the threshold for the true negative rate or true positive + rate respectively that needs to be achieved + + pos_label : object, optional (default=1) + Object representing the positive label + + cv : int, cross-validation generator, iterable or 'prefit', optional + (default=3). Determines the cross-validation splitting strategy. + If cv='prefit' the base estimator is assumed to be fitted and all data + will be used for the calibration of the probability threshold + + Attributes + ---------- + decision_threshold_ : float + Decision threshold for the positive class. Determines the output of + predict + + std_ : float + Standard deviation of the obtained decision thresholds for when the + provided base estimator is not pre-trained and the decision_threshold_ + is computed as the mean of the decision threshold of each + cross-validation iteration. If the base estimator is pre-trained then + std_ = None + + classes_ : array, shape (n_classes) + The class labels. + + References + ---------- + .. [1] Receiver-operating characteristic (ROC) plots: a fundamental + evaluation tool in clinical medicine, MH Zweig, G Campbell - + Clinical chemistry, 1993 + + """ + def __init__(self, base_estimator, strategy='roc', method=None, beta=None, + threshold=None, pos_label=1, cv=3): + self.base_estimator = base_estimator + self.strategy = strategy + self.method = method + self.beta = beta + self.threshold = threshold + self.pos_label = pos_label + self.cv = cv + + def fit(self, X, y): + """Fit model + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Training data + + y : array-like, shape (n_samples,) + Target values. There must be two 2 distinct values + + Returns + ------- + self : object + Instance of self + """ + if (not hasattr(self.base_estimator, 'decision_function') and + not hasattr(self.base_estimator, 'predict_proba')): + raise TypeError('The base_estimator needs to implement either a ' + 'decision_function or a predict_proba method') + + if self.strategy not in ('roc', 'f_beta', 'max_tpr', 'max_tnr'): + raise ValueError('strategy can either be "roc" or "max_tpr" or ' + '"max_tnr. Got {} instead'.format(self.strategy)) + + if self.method not in (None, 'decision_function', 'predict_proba'): + raise ValueError('method param can either be "decision_function" ' + 'or "predict_proba" or None. ' + 'Got {} instead'.format(self.method)) + + if self.strategy == 'max_tpr' or self.strategy == 'max_tnr': + if (not self.threshold or not + isinstance(self.threshold, (int, float)) + or not self.threshold >= 0 or not self.threshold <= 1): + raise ValueError('parameter threshold must be a number in' + '[0, 1]. ' + 'Got {} instead'.format(self.threshold)) + + if self.strategy == 'f_beta': + if not self.beta or not isinstance(self.beta, (int, float)): + raise ValueError('parameter beta must be a real number.' + 'Got {} instead'.format(type(self.beta))) + + X, y = check_X_y(X, y) + + y_type = type_of_target(y) + if y_type != 'binary': + raise ValueError('Expected target of binary type. Got {}'.format( + y_type)) + + self.label_encoder_ = LabelEncoder().fit(y) + self.classes_ = self.label_encoder_.classes_ + + y = self.label_encoder_.transform(y) + self.pos_label = self.label_encoder_.transform([self.pos_label])[0] + + if self.cv == 'prefit': + self.decision_threshold_ = _CutoffClassifier( + self.base_estimator, self.strategy, self.method, self.beta, + self.threshold, self.pos_label + ).fit(X, y).decision_threshold_ + self.std_ = None + else: + cv = check_cv(self.cv, y, classifier=True) + decision_thresholds = [] + + for train, test in cv.split(X, y): + estimator = clone(self.base_estimator).fit(X[train], y[train]) + decision_thresholds.append( + _CutoffClassifier(estimator, self.strategy, self.method, + self.beta, self.threshold, + self.pos_label).fit( + X[test], y[test] + ).decision_threshold_ + ) + self.decision_threshold_ = np.mean(decision_thresholds) + self.std_ = np.std(decision_thresholds) + + self.base_estimator.fit(X, y) + return self + + def predict(self, X): + """Predict using the calibrated decision threshold + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + The samples + + Returns + ------- + C : array, shape (n_samples,) + The predicted class + """ + X = check_array(X) + check_is_fitted( + self, ["label_encoder_", "decision_threshold_", "std_", "classes_"] + ) + + y_score = _get_binary_score(self.base_estimator, X, self.method, + self.pos_label) + return self.label_encoder_.inverse_transform( + (y_score > self.decision_threshold_).astype(int) + ) + + +class _CutoffClassifier(object): + """Cutoff point selection. + + It assumes that base_estimator has already been fit, and uses the input set + of the fit function to select a cutoff point. Note that this class should + not be used as an estimator directly. Use the CutoffClassifier with + cv="prefit" instead. + + Parameters + ---------- + base_estimator : obj + The binary classifier whose decision threshold will be adapted + according to the acquired cutoff point. The estimator must have a + decision_function or a predict_proba + + strategy : 'roc' or 'f_beta' or 'max_tpr' or 'max_tnr' + The method to use for choosing the cutoff point + + method : str or None, optional (default=None) + The method to be used for acquiring the score. Can either be + "decision_function" or "predict_proba" or None. If None then + decision_function will be used first and if not available + predict_proba + + beta : float in [0, 1] + beta value to be used in case strategy == 'f_beta' + + threshold : float in [0, 1] + minimum required value for the true negative rate (specificity) in case + strategy 'max_tpr' is used or for the true positive rate (sensitivity) + in case method 'max_tnr' is used + + pos_label : object + Label considered as positive during the roc_curve construction + + Attributes + ---------- + decision_threshold_ : float + Acquired decision threshold for the positive class + """ + def __init__(self, base_estimator, strategy, method, beta, threshold, + pos_label): + self.base_estimator = base_estimator + self.strategy = strategy + self.method = method + self.beta = beta + self.threshold = threshold + self.pos_label = pos_label + + def fit(self, X, y): + """Select a decision threshold for the fitted model's positive class + using one of the available methods + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Training data + + y : array-like, shape (n_samples,) + Target values + + Returns + ------- + self : object + Instance of self + """ + y_score = _get_binary_score(self.base_estimator, X, self.method, + self.pos_label) + if self.strategy == 'f_beta': + precision, recall, thresholds = precision_recall_curve( + y, y_score, pos_label=self.pos_label + ) + f_beta = ((1 + self.beta**2) * (precision * recall) / + (self.beta**2 * precision + recall)) + self.decision_threshold_ = thresholds[np.argmax(f_beta)] + return self + + fpr, tpr, thresholds = roc_curve(y, y_score, pos_label=self.pos_label) + + if self.strategy == 'roc': + # we find the threshold of the point (fpr, tpr) with the smallest + # euclidean distance from the "ideal" corner (0, 1) + self.decision_threshold_ = thresholds[ + np.argmin(fpr ** 2 + (tpr - 1) ** 2) + ] + elif self.strategy == 'max_tpr': + indices = np.where(1 - fpr >= self.threshold)[0] + max_tpr_index = np.argmax(tpr[indices]) + self.decision_threshold_ = thresholds[indices[max_tpr_index]] + else: + indices = np.where(tpr >= self.threshold)[0] + max_tnr_index = np.argmax(1 - fpr[indices]) + self.decision_threshold_ = thresholds[indices[max_tnr_index]] + return self + + +def _get_binary_score(clf, X, method=None, pos_label=1): + """Binary classification score for the positive label (0 or 1) + + Returns the score that a binary classifier outputs for the positive label + acquired either from decision_function or predict_proba + + Parameters + ---------- + clf : object + Classifier object to be used for acquiring the scores. Needs to have + a decision_function or a predict_proba method + + X : array-like, shape (n_samples, n_features) + The samples + + pos_label : int, optional (default=1) + The positive label. Can either be 0 or 1 + + method : str or None, optional (default=None) + The method to be used for acquiring the score. Can either be + "decision_function" or "predict_proba" or None. If None then + decision_function will be used first and if not available + predict_proba + + Returns + ------- + y_score : array-like, shape (n_samples,) + The return value of the provided classifier's decision_function or + predict_proba depending on the method used. + """ + if len(clf.classes_) != 2: + raise ValueError('Expected binary classifier. Found {} classes'.format( + len(clf.classes_) + )) + + if method not in (None, 'decision_function', 'predict_proba'): + raise ValueError('scoring param can either be "decision_function" ' + 'or "predict_proba" or None. ' + 'Got {} instead'.format(method)) + + if not method: + try: + y_score = clf.decision_function(X) + if pos_label == clf.classes_[0]: + y_score = -y_score + except (NotImplementedError, AttributeError): + y_score = clf.predict_proba(X)[:, pos_label] + elif method == 'decision_function': + y_score = clf.decision_function(X) + if pos_label == clf.classes_[0]: + y_score = - y_score + else: + y_score = clf.predict_proba(X)[:, pos_label] + return y_score class CalibratedClassifierCV(BaseEstimator, ClassifierMixin): diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py index e454633a3a294..dfea4fbc76c7f 100644 --- a/sklearn/tests/test_calibration.py +++ b/sklearn/tests/test_calibration.py @@ -1,30 +1,236 @@ # Authors: Alexandre Gramfort +# Prokopios Gryllos # License: BSD 3 clause from __future__ import division import pytest import numpy as np from scipy import sparse -from sklearn.model_selection import LeaveOneOut +from sklearn.model_selection import LeaveOneOut, train_test_split from sklearn.utils.testing import (assert_array_almost_equal, assert_equal, assert_greater, assert_almost_equal, assert_greater_equal, assert_array_equal, - assert_raises, - ignore_warnings) + assert_raises) from sklearn.datasets import make_classification, make_blobs +from sklearn.linear_model import LogisticRegression from sklearn.naive_bayes import MultinomialNB from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor from sklearn.svm import LinearSVC from sklearn.pipeline import Pipeline from sklearn.impute import SimpleImputer -from sklearn.metrics import brier_score_loss, log_loss -from sklearn.calibration import CalibratedClassifierCV +from sklearn.metrics import (brier_score_loss, log_loss, confusion_matrix, + f1_score, recall_score) +from sklearn.calibration import CalibratedClassifierCV, CutoffClassifier +from sklearn.calibration import _get_binary_score from sklearn.calibration import _sigmoid_calibration, _SigmoidCalibration from sklearn.calibration import calibration_curve +def test_cutoff_prefit(): + calibration_samples = 200 + X, y = make_classification(n_samples=1000, n_features=6, random_state=42, + n_classes=2) + + X_train, X_test, y_train, y_test = train_test_split(X, y, + test_size=0.4, + random_state=42) + lr = LogisticRegression(solver='liblinear').fit(X_train, y_train) + + clf_roc = CutoffClassifier(lr, strategy='roc', cv='prefit').fit( + X_test[:calibration_samples], y_test[:calibration_samples] + ) + + y_pred = lr.predict(X_test[calibration_samples:]) + y_pred_roc = clf_roc.predict(X_test[calibration_samples:]) + + tn, fp, fn, tp = confusion_matrix( + y_test[calibration_samples:], y_pred).ravel() + tn_roc, fp_roc, fn_roc, tp_roc = confusion_matrix( + y_test[calibration_samples:], y_pred_roc).ravel() + + tpr = tp / (tp + fn) + tnr = tn / (tn + fp) + + tpr_roc = tp_roc / (tp_roc + fn_roc) + tnr_roc = tn_roc / (tn_roc + fp_roc) + + # check that the sum of tpr and tnr has improved + assert tpr_roc + tnr_roc > tpr + tnr + + clf_f1 = CutoffClassifier( + lr, strategy='f_beta', method='predict_proba', beta=1, + cv='prefit').fit( + X_test[:calibration_samples], y_test[:calibration_samples] + ) + + y_pred_f1 = clf_f1.predict(X_test[calibration_samples:]) + assert (f1_score(y_test[calibration_samples:], y_pred_f1) > + f1_score(y_test[calibration_samples:], y_pred)) + + clf_fbeta = CutoffClassifier( + lr, strategy='f_beta', method='predict_proba', beta=2, + cv='prefit').fit( + X_test[:calibration_samples], y_test[:calibration_samples] + ) + + y_pred_fbeta = clf_fbeta.predict(X_test[calibration_samples:]) + assert (recall_score(y_test[calibration_samples:], y_pred_fbeta) > + recall_score(y_test[calibration_samples:], y_pred)) + + clf_max_tpr = CutoffClassifier( + lr, strategy='max_tpr', threshold=0.7, cv='prefit' + ).fit(X_test[:calibration_samples], y_test[:calibration_samples]) + + y_pred_max_tpr = clf_max_tpr.predict(X_test[calibration_samples:]) + + tn_max_tpr, fp_max_tpr, fn_max_tpr, tp_max_tpr = confusion_matrix( + y_test[calibration_samples:], y_pred_max_tpr).ravel() + + tpr_max_tpr = tp_max_tpr / (tp_max_tpr + fn_max_tpr) + tnr_max_tpr = tn_max_tpr / (tn_max_tpr + fp_max_tpr) + + # check that the tpr increases with tnr >= min_val_tnr + assert tpr_max_tpr > tpr + assert tpr_max_tpr > tpr_roc + assert tnr_max_tpr >= 0.7 + + clf_max_tnr = CutoffClassifier( + lr, strategy='max_tnr', threshold=0.7, cv='prefit' + ).fit(X_test[:calibration_samples], y_test[:calibration_samples]) + + y_pred_clf = clf_max_tnr.predict(X_test[calibration_samples:]) + + tn_clf, fp_clf, fn_clf, tp_clf = confusion_matrix( + y_test[calibration_samples:], y_pred_clf).ravel() + + tnr_clf_max_tnr = tn_clf / (tn_clf + fp_clf) + tpr_clf_max_tnr = tp_clf / (tp_clf + fn_clf) + + # check that the tnr increases with tpr >= min_val_tpr + assert tnr_clf_max_tnr > tnr + assert tnr_clf_max_tnr > tnr_roc + assert tpr_clf_max_tnr >= 0.7 + + # check error cases + clf_bad_base_estimator = CutoffClassifier([]) + with pytest.raises(TypeError): + clf_bad_base_estimator.fit(X_train, y_train) + + X_non_binary, y_non_binary = make_classification( + n_samples=20, n_features=6, random_state=42, n_classes=4, + n_informative=4 + ) + with pytest.raises(ValueError): + clf_roc.fit(X_non_binary, y_non_binary) + + clf_foo = CutoffClassifier(lr, strategy='f_beta', beta='foo') + with pytest.raises(ValueError): + clf_foo.fit(X_train, y_train) + + clf_foo = CutoffClassifier(lr, strategy='foo') + with pytest.raises(ValueError): + clf_foo.fit(X_train, y_train) + + for method in ['max_tpr', 'max_tnr']: + clf_missing_info = CutoffClassifier(lr, strategy=method) + with pytest.raises(ValueError): + clf_missing_info.fit(X_train, y_train) + + +def test_cutoff_cv(): + X, y = make_classification(n_samples=1000, n_features=6, random_state=42, + n_classes=2) + + X_train, X_test, y_train, y_test = train_test_split(X, y, + test_size=0.4, + random_state=42) + lr = LogisticRegression(solver='liblinear').fit(X_train, y_train) + clf_roc = CutoffClassifier(LogisticRegression(solver='liblinear'), + strategy='roc', + cv=3).fit( + X_train, y_train + ) + + assert clf_roc.decision_threshold_ != 0 + assert clf_roc.std_ is not None and clf_roc.std_ != 0 + + y_pred = lr.predict(X_test) + y_pred_roc = clf_roc.predict(X_test) + + tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel() + tn_roc, fp_roc, fn_roc, tp_roc = confusion_matrix( + y_test, y_pred_roc + ).ravel() + + tpr = tp / (tp + fn) + tnr = tn / (tn + fp) + + tpr_roc = tp_roc / (tp_roc + fn_roc) + tnr_roc = tn_roc / (tn_roc + fp_roc) + + # check that the sum of tpr + tnr has improved + assert tpr_roc + tnr_roc > tpr + tnr + + +def test_get_binary_score(): + X, y = make_classification(n_samples=200, n_features=6, random_state=42, + n_classes=2) + + X_train, X_test, y_train, _ = train_test_split(X, y, test_size=0.4, + random_state=42) + lr = LogisticRegression(solver='liblinear').fit(X_train, y_train) + y_pred_proba = lr.predict_proba(X_test) + y_pred_score = lr.decision_function(X_test) + + assert_array_equal( + y_pred_score, _get_binary_score( + lr, X_test, method='decision_function', pos_label=1) + ) + + assert_array_equal( + - y_pred_score, _get_binary_score( + lr, X_test, method='decision_function', pos_label=0) + ) + + assert_array_equal( + y_pred_proba[:, 1], _get_binary_score( + lr, X_test, method='predict_proba', pos_label=1) + ) + + assert_array_equal( + y_pred_proba[:, 0], _get_binary_score( + lr, X_test, method='predict_proba', pos_label=0) + ) + + assert_array_equal( + y_pred_score, + _get_binary_score(lr, X_test, method=None, pos_label=1) + ) + + with pytest.raises(ValueError): + _get_binary_score(lr, X_test, method='foo') + + # classifier that does not have a decision_function + rf = RandomForestClassifier(n_estimators=10).fit(X_train, y_train) + y_pred_proba_rf = rf.predict_proba(X_test) + assert_array_equal( + y_pred_proba_rf[:, 1], + _get_binary_score(rf, X_test, method=None, pos_label=1) + ) + + X_non_binary, y_non_binary = make_classification( + n_samples=20, n_features=6, random_state=42, n_classes=4, + n_informative=4 + ) + + rf_non_bin = RandomForestClassifier(n_estimators=10).fit(X_non_binary, + y_non_binary) + with pytest.raises(ValueError): + _get_binary_score(rf_non_bin, X_non_binary) + + @pytest.mark.filterwarnings('ignore:The default value of n_estimators') @pytest.mark.filterwarnings('ignore: You should specify a value') # 0.22 def test_calibration(): diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index 75b3789619dd3..3f5d5dd3aa13b 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -587,7 +587,7 @@ def uninstall_mldata_mock(): "MultiOutputRegressor", "MultiOutputClassifier", "OutputCodeClassifier", "OneVsRestClassifier", "RFE", "RFECV", "BaseEnsemble", "ClassifierChain", - "RegressorChain"] + "RegressorChain", "CutoffClassifier"] # estimators that there is no way to default-construct sensibly OTHER = ["Pipeline", "FeatureUnion", "GridSearchCV", "RandomizedSearchCV",