diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 8bc27194a63b5..57754988f4686 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -1344,30 +1344,30 @@ probability outputs (``predict_proba``) of a classifier instead of its discrete predictions. For binary classification with a true label :math:`y \in \{0,1\}` -and a probability estimate :math:`p = \operatorname{Pr}(y = 1)`, +and a probability estimate :math:`\hat{p} \approx \operatorname{Pr}(y = 1)`, the log loss per sample is the negative log-likelihood of the classifier given the true label: .. math:: - L_{\log}(y, p) = -\log \operatorname{Pr}(y|p) = -(y \log (p) + (1 - y) \log (1 - p)) + L_{\log}(y, \hat{p}) = -\log \operatorname{Pr}(y|\hat{p}) = -(y \log (\hat{p}) + (1 - y) \log (1 - \hat{p})) This extends to the multiclass case as follows. Let the true labels for a set of samples be encoded as a 1-of-K binary indicator matrix :math:`Y`, i.e., :math:`y_{i,k} = 1` if sample :math:`i` has label :math:`k` taken from a set of :math:`K` labels. -Let :math:`P` be a matrix of probability estimates, -with :math:`p_{i,k} = \operatorname{Pr}(y_{i,k} = 1)`. +Let :math:`\hat{P}` be a matrix of probability estimates, +with elements :math:`\hat{p}_{i,k} \approx \operatorname{Pr}(y_{i,k} = 1)`. Then the log loss of the whole set is .. math:: - L_{\log}(Y, P) = -\log \operatorname{Pr}(Y|P) = - \frac{1}{N} \sum_{i=0}^{N-1} \sum_{k=0}^{K-1} y_{i,k} \log p_{i,k} + L_{\log}(Y, \hat{P}) = -\log \operatorname{Pr}(Y|\hat{P}) = - \frac{1}{N} \sum_{i=0}^{N-1} \sum_{k=0}^{K-1} y_{i,k} \log \hat{p}_{i,k} To see how this generalizes the binary log loss given above, note that in the binary case, -:math:`p_{i,0} = 1 - p_{i,1}` and :math:`y_{i,0} = 1 - y_{i,1}`, +:math:`\hat{p}_{i,0} = 1 - \hat{p}_{i,1}` and :math:`y_{i,0} = 1 - y_{i,1}`, so expanding the inner sum over :math:`y_{i,k} \in \{0,1\}` gives the binary log loss. @@ -1923,41 +1923,64 @@ set [0,1] has an error:: Brier score loss ---------------- -The :func:`brier_score_loss` function computes the -`Brier score `_ -for binary classes [Brier1950]_. Quoting Wikipedia: +The :func:`brier_score_loss` function computes the `Brier score +`_ for binary and multiclass +probabilistic predictions and is equivalent to the mean squared error. +Quoting Wikipedia: - "The Brier score is a proper score function that measures the accuracy of - probabilistic predictions. It is applicable to tasks in which predictions - must assign probabilities to a set of mutually exclusive discrete outcomes." + "The Brier score is a strictly proper scoring rule that measures the accuracy of + probabilistic predictions. [...] [It] is applicable to tasks in which predictions + must assign probabilities to a set of mutually exclusive discrete outcomes or + classes." -This function returns the mean squared error of the actual outcome -:math:`y \in \{0,1\}` and the predicted probability estimate -:math:`p = \operatorname{Pr}(y = 1)` (:term:`predict_proba`) as outputted by: +Let the true labels for a set of :math:`N` data points be encoded as a 1-of-K binary +indicator matrix :math:`Y`, i.e., :math:`y_{i,k} = 1` if sample :math:`i` has +label :math:`k` taken from a set of :math:`K` labels. Let :math:`\hat{P}` be a matrix +of probability estimates with elements :math:`\hat{p}_{i,k} \approx \operatorname{Pr}(y_{i,k} = 1)`. +Following the original definition by [Brier1950]_, the Brier score is given by: .. math:: - BS = \frac{1}{n_{\text{samples}}} \sum_{i=0}^{n_{\text{samples}} - 1}(y_i - p_i)^2 + BS(Y, \hat{P}) = \frac{1}{N}\sum_{i=0}^{N-1}\sum_{k=0}^{K-1}(y_{i,k} - \hat{p}_{i,k})^{2} -The Brier score loss is also between 0 to 1 and the lower the value (the mean -square difference is smaller), the more accurate the prediction is. +The Brier score lies in the interval :math:`[0, 2]` and the lower the value the +better the probability estimates are (the mean squared difference is smaller). +Actually, the Brier score is a strictly proper scoring rule, meaning that it +achieves the best score only when the estimated probabilities equal the +true ones. + +Note that in the binary case, the Brier score is usually divided by two and +ranges between :math:`[0,1]`. For binary targets :math:`y_i \in {0, 1}` and +probability estimates :math:`\hat{p}_i \approx \operatorname{Pr}(y_i = 1)` +for the positive class, the Brier score is then equal to: + +.. math:: + + BS(y, \hat{p}) = \frac{1}{N} \sum_{i=0}^{N - 1}(y_i - \hat{p}_i)^2 + +The :func:`brier_score_loss` function computes the Brier score given the +ground-truth labels and predicted probabilities, as returned by an estimator's +``predict_proba`` method. The `scale_by_half` parameter controls which of the +two above definitions to follow. -Here is a small example of usage of this function:: >>> import numpy as np >>> from sklearn.metrics import brier_score_loss >>> y_true = np.array([0, 1, 1, 0]) >>> y_true_categorical = np.array(["spam", "ham", "ham", "spam"]) >>> y_prob = np.array([0.1, 0.9, 0.8, 0.4]) - >>> y_pred = np.array([0, 1, 1, 0]) >>> brier_score_loss(y_true, y_prob) 0.055 >>> brier_score_loss(y_true, 1 - y_prob, pos_label=0) 0.055 >>> brier_score_loss(y_true_categorical, y_prob, pos_label="ham") 0.055 - >>> brier_score_loss(y_true, y_prob > 0.5) - 0.0 + >>> brier_score_loss( + ... ["eggs", "ham", "spam"], + ... [[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.2, 0.6]], + ... labels=["eggs", "ham", "spam"], + ... ) + 0.146... The Brier score can be used to assess how well a classifier is calibrated. However, a lower Brier score loss does not always mean a better calibration. diff --git a/doc/whats_new/upcoming_changes/sklearn.metrics/22046.feature.rst b/doc/whats_new/upcoming_changes/sklearn.metrics/22046.feature.rst new file mode 100644 index 0000000000000..dbe9166aa1314 --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.metrics/22046.feature.rst @@ -0,0 +1,6 @@ +- :func:`metrics.brier_score_loss` implements the Brier score for multiclass + classification problems and adds a `scale_by_half` argument. This metric is + notably useful to assess both sharpness and calibration of probabilistic + classifiers. See the docstrings for more details. By + :user:`Varun Aggarwal `, :user:`Olivier Grisel ` and + :user:`Antoine Baker `. diff --git a/doc/whats_new/upcoming_changes/sklearn.metrics/22046.fix.rst b/doc/whats_new/upcoming_changes/sklearn.metrics/22046.fix.rst new file mode 100644 index 0000000000000..7ba041f2686cf --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.metrics/22046.fix.rst @@ -0,0 +1,3 @@ +- :func:`metrics.log_loss` now raises a `ValueError` if values of `y_true` + are missing in `labels`. By :user:`Varun Aggarwal `, + :user:`Olivier Grisel ` and :user:`Antoine Baker `. diff --git a/examples/calibration/plot_calibration_multiclass.py b/examples/calibration/plot_calibration_multiclass.py index 2208292d1ccc9..782a59133fcca 100644 --- a/examples/calibration/plot_calibration_multiclass.py +++ b/examples/calibration/plot_calibration_multiclass.py @@ -212,14 +212,30 @@ class of an instance (red: class 1, green: class 2, blue: class 3). from sklearn.metrics import log_loss -score = log_loss(y_test, clf_probs) -cal_score = log_loss(y_test, cal_clf_probs) +loss = log_loss(y_test, clf_probs) +cal_loss = log_loss(y_test, cal_clf_probs) -print("Log-loss of") -print(f" * uncalibrated classifier: {score:.3f}") -print(f" * calibrated classifier: {cal_score:.3f}") +print("Log-loss of:") +print(f" - uncalibrated classifier: {loss:.3f}") +print(f" - calibrated classifier: {cal_loss:.3f}") # %% +# We can also assess calibration with the Brier score for probabilistics predictions +# (lower is better, possible range is [0, 2]): + +from sklearn.metrics import brier_score_loss + +loss = brier_score_loss(y_test, clf_probs) +cal_loss = brier_score_loss(y_test, cal_clf_probs) + +print("Brier score of") +print(f" - uncalibrated classifier: {loss:.3f}") +print(f" - calibrated classifier: {cal_loss:.3f}") + +# %% +# According to the Brier score, the calibrated classifier is not better than +# the original model. +# # Finally we generate a grid of possible uncalibrated probabilities over # the 2-simplex, compute the corresponding calibrated probabilities and # plot arrows for each. The arrows are colored according the highest @@ -274,3 +290,15 @@ class of an instance (red: class 1, green: class 2, blue: class 3). plt.ylim(-0.05, 1.05) plt.show() + +# %% +# One can observe that, on average, the calibrator is pushing highly confident +# predictions away from the boundaries of the simplex while simultaneously +# moving uncertain predictions towards one of three modes, one for each class. +# We can also observe that the mapping is not symmetric. Furthermore some +# arrows seems to cross class assignment boundaries which is not necessarily +# what one would expect from a calibration map as it means that some predicted +# classes will change after calibration. +# +# All in all, the One-vs-Rest multiclass-calibration strategy implemented in +# `CalibratedClassifierCV` should not be trusted blindly. diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 2e23c251af58a..5d9987497ca28 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -152,6 +152,139 @@ def _check_targets(y_true, y_pred): return y_type, y_true, y_pred +def _validate_multiclass_probabilistic_prediction( + y_true, y_prob, sample_weight, labels +): + r"""Convert y_true and y_prob to shape (n_samples, n_classes) + + 1. Verify that y_true, y_prob, and sample_weights have the same first dim + 2. Ensure 2 or more classes in y_true i.e. valid classification task. The + classes are provided by the labels argument, or inferred using y_true. + When inferring y_true is assumed binary if it has shape (n_samples, ). + 3. Validate y_true, and y_prob have the same number of classes. Convert to + shape (n_samples, n_classes) + + Parameters + ---------- + y_true : array-like or label indicator matrix + Ground truth (correct) labels for n_samples samples. + + y_prob : array-like of float, shape=(n_samples, n_classes) or (n_samples,) + Predicted probabilities, as returned by a classifier's + predict_proba method. If `y_prob.shape = (n_samples,)` + the probabilities provided are assumed to be that of the + positive class. The labels in `y_prob` are assumed to be + ordered lexicographically, as done by + :class:`preprocessing.LabelBinarizer`. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + labels : array-like, default=None + If not provided, labels will be inferred from y_true. If `labels` + is `None` and `y_prob` has shape `(n_samples,)` the labels are + assumed to be binary and are inferred from `y_true`. + + Returns + ------- + transformed_labels : array of shape (n_samples, n_classes) + + y_prob : array of shape (n_samples, n_classes) + """ + y_prob = check_array( + y_prob, ensure_2d=False, dtype=[np.float64, np.float32, np.float16] + ) + + if y_prob.max() > 1: + raise ValueError(f"y_prob contains values greater than 1: {y_prob.max()}") + if y_prob.min() < 0: + raise ValueError(f"y_prob contains values lower than 0: {y_prob.min()}") + + check_consistent_length(y_prob, y_true, sample_weight) + lb = LabelBinarizer() + + if labels is not None: + lb = lb.fit(labels) + # LabelBinarizer does not respect the order implied by labels, which + # can be misleading. + if not np.all(lb.classes_ == labels): + warnings.warn( + f"Labels passed were {labels}. But this function " + "assumes labels are ordered lexicographically. " + f"Pass the ordered labels={lb.classes_.tolist()} and ensure that " + "the columns of y_prob correspond to this ordering.", + UserWarning, + ) + if not np.isin(y_true, labels).all(): + undeclared_labels = set(y_true) - set(labels) + raise ValueError( + f"y_true contains values {undeclared_labels} not belonging " + f"to the passed labels {labels}." + ) + + else: + lb = lb.fit(y_true) + + if len(lb.classes_) == 1: + if labels is None: + raise ValueError( + "y_true contains only one label ({0}). Please " + "provide the list of all expected class labels explicitly through the " + "labels argument.".format(lb.classes_[0]) + ) + else: + raise ValueError( + "The labels array needs to contain at least two " + "labels, got {0}.".format(lb.classes_) + ) + + transformed_labels = lb.transform(y_true) + + if transformed_labels.shape[1] == 1: + transformed_labels = np.append( + 1 - transformed_labels, transformed_labels, axis=1 + ) + + # If y_prob is of single dimension, assume y_true to be binary + # and then check. + if y_prob.ndim == 1: + y_prob = y_prob[:, np.newaxis] + if y_prob.shape[1] == 1: + y_prob = np.append(1 - y_prob, y_prob, axis=1) + + eps = np.finfo(y_prob.dtype).eps + + # Make sure y_prob is normalized + y_prob_sum = y_prob.sum(axis=1) + if not np.allclose(y_prob_sum, 1, rtol=np.sqrt(eps)): + warnings.warn( + "The y_prob values do not sum to one. Make sure to pass probabilities.", + UserWarning, + ) + + # Check if dimensions are consistent. + transformed_labels = check_array(transformed_labels) + if len(lb.classes_) != y_prob.shape[1]: + if labels is None: + raise ValueError( + "y_true and y_prob contain different number of " + "classes: {0} vs {1}. Please provide the true " + "labels explicitly through the labels argument. " + "Classes found in " + "y_true: {2}".format( + transformed_labels.shape[1], y_prob.shape[1], lb.classes_ + ) + ) + else: + raise ValueError( + "The number of classes in labels is different " + "from that in y_prob. Classes found in " + "labels: {0}".format(lb.classes_) + ) + + return transformed_labels, y_prob + + @validate_params( { "y_true": ["array-like", "sparse matrix"], @@ -3092,79 +3225,14 @@ def log_loss(y_true, y_pred, *, normalize=True, sample_weight=None, labels=None) ... [[.1, .9], [.9, .1], [.8, .2], [.35, .65]]) 0.21616... """ - y_pred = check_array( - y_pred, ensure_2d=False, dtype=[np.float64, np.float32, np.float16] + transformed_labels, y_pred = _validate_multiclass_probabilistic_prediction( + y_true, y_pred, sample_weight, labels ) - check_consistent_length(y_pred, y_true, sample_weight) - lb = LabelBinarizer() - - if labels is not None: - lb.fit(labels) - else: - lb.fit(y_true) - - if len(lb.classes_) == 1: - if labels is None: - raise ValueError( - "y_true contains only one label ({0}). Please " - "provide the true labels explicitly through the " - "labels argument.".format(lb.classes_[0]) - ) - else: - raise ValueError( - "The labels array needs to contain at least two " - "labels for log_loss, " - "got {0}.".format(lb.classes_) - ) - - transformed_labels = lb.transform(y_true) - - if transformed_labels.shape[1] == 1: - transformed_labels = np.append( - 1 - transformed_labels, transformed_labels, axis=1 - ) - - # If y_pred is of single dimension, assume y_true to be binary - # and then check. - if y_pred.ndim == 1: - y_pred = y_pred[:, np.newaxis] - if y_pred.shape[1] == 1: - y_pred = np.append(1 - y_pred, y_pred, axis=1) - - eps = np.finfo(y_pred.dtype).eps - - # Make sure y_pred is normalized - y_pred_sum = y_pred.sum(axis=1) - if not np.allclose(y_pred_sum, 1, rtol=np.sqrt(eps)): - warnings.warn( - "The y_pred values do not sum to one. Make sure to pass probabilities.", - UserWarning, - ) - # Clipping + eps = np.finfo(y_pred.dtype).eps y_pred = np.clip(y_pred, eps, 1 - eps) - # Check if dimensions are consistent. - transformed_labels = check_array(transformed_labels) - if len(lb.classes_) != y_pred.shape[1]: - if labels is None: - raise ValueError( - "y_true and y_pred contain different number of " - "classes {0}, {1}. Please provide the true " - "labels explicitly through the labels argument. " - "Classes found in " - "y_true: {2}".format( - transformed_labels.shape[1], y_pred.shape[1], lb.classes_ - ) - ) - else: - raise ValueError( - "The number of classes in labels is different " - "from that in y_pred. Classes found in " - "labels: {0}".format(lb.classes_) - ) - loss = -xlogy(transformed_labels, y_pred).sum(axis=1) return float(_average(loss, weights=sample_weight, normalize=normalize)) @@ -3322,38 +3390,105 @@ def hinge_loss(y_true, pred_decision, *, labels=None, sample_weight=None): return float(np.average(losses, weights=sample_weight)) +def _validate_binary_probabilistic_prediction(y_true, y_prob, sample_weight, pos_label): + r"""Convert y_true and y_prob in binary classification to shape (n_samples, 2) + + Parameters + ---------- + y_true : array-like of shape (n_samples,) + True labels. + + y_prob : array-like of shape (n_samples,) + Probabilities of the positive class. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + pos_label : int, float, bool or str, default=None + Label of the positive class. If None, `pos_label` will be inferred + in the following manner: + + * if `y_true` in {-1, 1} or {0, 1}, `pos_label` defaults to 1; + * else if `y_true` contains string, an error will be raised and + `pos_label` should be explicitly specified; + * otherwise, `pos_label` defaults to the greater label, + i.e. `np.unique(y_true)[-1]`. + + Returns + ------- + transformed_labels : array of shape (n_samples, 2) + + y_prob : array of shape (n_samples, 2) + """ + # sanity checks on y_true and y_prob + y_true = column_or_1d(y_true) + y_prob = column_or_1d(y_prob) + + assert_all_finite(y_true) + assert_all_finite(y_prob) + + check_consistent_length(y_prob, y_true, sample_weight) + + y_type = type_of_target(y_true, input_name="y_true") + if y_type != "binary": + raise ValueError( + f"The type of the target inferred from y_true is {y_type} but should be " + "binary according to the shape of y_prob." + ) + + if y_prob.max() > 1: + raise ValueError(f"y_prob contains values greater than 1: {y_prob.max()}") + if y_prob.min() < 0: + raise ValueError(f"y_prob contains values less than 0: {y_prob.min()}") + + # check that pos_label is consistent with y_true + try: + pos_label = _check_pos_label_consistency(pos_label, y_true) + except ValueError: + classes = np.unique(y_true) + if classes.dtype.kind not in ("O", "U", "S"): + # for backward compatibility, if classes are not string then + # `pos_label` will correspond to the greater label + pos_label = classes[-1] + else: + raise + + # convert (n_samples,) to (n_samples, 2) shape + y_true = np.array(y_true == pos_label, int) + transformed_labels = np.column_stack((1 - y_true, y_true)) + y_prob = np.column_stack((1 - y_prob, y_prob)) + + return transformed_labels, y_prob + + @validate_params( { "y_true": ["array-like"], "y_proba": ["array-like", Hidden(None)], "sample_weight": ["array-like", None], "pos_label": [Real, str, "boolean", None], + "labels": ["array-like", None], + "scale_by_half": ["boolean", StrOptions({"auto"})], "y_prob": ["array-like", Hidden(StrOptions({"deprecated"}))], }, prefer_skip_nested_validation=True, ) def brier_score_loss( - y_true, y_proba=None, *, sample_weight=None, pos_label=None, y_prob="deprecated" + y_true, + y_proba=None, + *, + sample_weight=None, + pos_label=None, + labels=None, + scale_by_half="auto", + y_prob="deprecated", ): - """Compute the Brier score loss. + r"""Compute the Brier score loss. The smaller the Brier score loss, the better, hence the naming with "loss". The Brier score measures the mean squared difference between the predicted - probability and the actual outcome. The Brier score always - takes on a value between zero and one, since this is the largest - possible difference between a predicted probability (which must be - between zero and one) and the actual outcome (which can take on values - of only 0 and 1). It can be decomposed as the sum of refinement loss and - calibration loss. - - The Brier score is appropriate for binary and categorical outcomes that - can be structured as true or false, but is inappropriate for ordinal - variables which can take on three or more values (this is because the - Brier score assumes that all possible outcomes are equivalently - "distant" from one another). Which label is considered to be the positive - label is controlled via the parameter `pos_label`, which defaults to - the greater label unless `y_true` is all 0 or all -1, in which case - `pos_label` defaults to 1. + probability and the actual outcome. The Brier score is a stricly proper scoring + rule. Read more in the :ref:`User Guide `. @@ -3362,14 +3497,20 @@ def brier_score_loss( y_true : array-like of shape (n_samples,) True targets. - y_proba : array-like of shape (n_samples,) - Probabilities of the positive class. + y_proba : array-like of shape (n_samples,) or (n_samples, n_classes) + Predicted probabilities. If `y_proba.shape = (n_samples,)` + the probabilities provided are assumed to be that of the + positive class. If `y_proba.shape = (n_samples, n_classes)` + the columns in `y_proba` are assumed to correspond to the + labels in alphabetical order, as done by + :class:`~sklearn.preprocessing.LabelBinarizer`. sample_weight : array-like of shape (n_samples,), default=None Sample weights. pos_label : int, float, bool or str, default=None - Label of the positive class. `pos_label` will be inferred in the + Label of the positive class when `y_proba.shape = (n_samples,)`. + If not provided, `pos_label` will be inferred in the following manner: * if `y_true` in {-1, 1} or {0, 1}, `pos_label` defaults to 1; @@ -3378,6 +3519,20 @@ def brier_score_loss( * otherwise, `pos_label` defaults to the greater label, i.e. `np.unique(y_true)[-1]`. + labels : array-like of shape (n_classes,), default=None + Class labels when `y_proba.shape = (n_samples, n_classes)`. + If not provided, labels will be inferred from `y_true`. + + .. versionadded:: 1.7 + + scale_by_half : bool or "auto", default="auto" + When True, scale the Brier score by 1/2 to lie in the [0, 1] range instead + of the [0, 2] range. The default "auto" option implements the rescaling to + [0, 1] only for binary classification (as customary) but keeps the + original [0, 2] range for multiclasss classification. + + .. versionadded:: 1.7 + y_prob : array-like of shape (n_samples,) Probabilities of the positive class. @@ -3390,6 +3545,30 @@ def brier_score_loss( score : float Brier score loss. + Notes + ----- + + For :math:`N` observations labeled from :math:`C` possible classes, the Brier + score is defined as: + + .. math:: + \frac{1}{N}\sum_{i=1}^{N}\sum_{c=1}^{C}(y_{ic} - \hat{p}_{ic})^{2} + + where :math:`y_{ic}` is 1 if observation `i` belongs to class `c`, + otherwise 0 and :math:`\hat{p}_{ic}` is the predicted probability for + observation `i` to belong to class `c`. + The Brier score then ranges between :math:`[0, 2]`. + + In binary classification tasks the Brier score is usually divided by + two and then ranges between :math:`[0, 1]`. It can be alternatively + written as: + + .. math:: + \frac{1}{N}\sum_{i=1}^{N}(y_{i} - \hat{p}_{i})^{2} + + where :math:`y_{i}` is the binary target and :math:`\hat{p}_{i}` + is the predicted probability of the positive class. + References ---------- .. [1] `Wikipedia entry for the Brier score @@ -3410,6 +3589,14 @@ def brier_score_loss( 0.037... >>> brier_score_loss(y_true, np.array(y_prob) > 0.5) 0.0 + >>> brier_score_loss(y_true, y_prob, scale_by_half=False) + 0.074... + >>> brier_score_loss( + ... ["eggs", "ham", "spam"], + ... [[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.2, 0.6]], + ... labels=["eggs", "ham", "spam"] + ... ) + 0.146... """ # TODO(1.7): remove in 1.7 and reset y_proba to be required # Note: validate params will raise an error if y_prob is not array-like, @@ -3429,36 +3616,29 @@ def brier_score_loss( ) y_proba = y_prob - y_true = column_or_1d(y_true) - y_proba = column_or_1d(y_proba) - assert_all_finite(y_true) - assert_all_finite(y_proba) - check_consistent_length(y_true, y_proba, sample_weight) + y_proba = check_array( + y_proba, ensure_2d=False, dtype=[np.float64, np.float32, np.float16] + ) - y_type = type_of_target(y_true, input_name="y_true") - if y_type != "binary": - raise ValueError( - "Only binary classification is supported. The type of the target " - f"is {y_type}." + if y_proba.ndim == 1 or y_proba.shape[1] == 1: + transformed_labels, y_proba = _validate_binary_probabilistic_prediction( + y_true, y_proba, sample_weight, pos_label + ) + else: + transformed_labels, y_proba = _validate_multiclass_probabilistic_prediction( + y_true, y_proba, sample_weight, labels ) - if y_proba.max() > 1: - raise ValueError("y_proba contains values greater than 1.") - if y_proba.min() < 0: - raise ValueError("y_proba contains values less than 0.") + brier_score = np.average( + np.sum((transformed_labels - y_proba) ** 2, axis=1), weights=sample_weight + ) - try: - pos_label = _check_pos_label_consistency(pos_label, y_true) - except ValueError: - classes = np.unique(y_true) - if classes.dtype.kind not in ("O", "U", "S"): - # for backward compatibility, if classes are not string then - # `pos_label` will correspond to the greater label - pos_label = classes[-1] - else: - raise - y_true = np.array(y_true == pos_label, int) - return float(np.average((y_true - y_proba) ** 2, weights=sample_weight)) + if scale_by_half == "auto": + scale_by_half = y_proba.ndim == 1 or y_proba.shape[1] < 3 + if scale_by_half: + brier_score *= 0.5 + + return float(brier_score) @validate_params( diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index b67c91737960c..0c79420e3cb6f 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -2777,6 +2777,17 @@ def test_log_loss(): with pytest.raises(ValueError): log_loss(y_true, y_pred) + # raise error if labels do not contain all values of y_true + y_true = ["a", "b", "c"] + y_pred = [[0.9, 0.1, 0.0], [0.1, 0.9, 0.0], [0.1, 0.1, 0.8]] + labels = ["a", "c", "d"] + error_str = ( + "y_true contains values {'b'} not belonging to the passed " + "labels ['a', 'c', 'd']." + ) + with pytest.raises(ValueError, match=re.escape(error_str)): + log_loss(y_true, y_pred, labels=labels) + # case when y_true is a string array object y_true = ["ham", "spam", "spam", "ham"] y_pred = [[0.3, 0.7], [0.6, 0.4], [0.4, 0.6], [0.7, 0.3]] @@ -2789,15 +2800,15 @@ def test_log_loss(): y_pred = [[0.2, 0.8], [0.6, 0.4]] y_score = np.array([[0.1, 0.9], [0.1, 0.9]]) error_str = ( - r"y_true contains only one label \(2\). Please provide " - r"the true labels explicitly through the labels argument." + "y_true contains only one label (2). Please provide the list of all " + "expected class labels explicitly through the labels argument." ) - with pytest.raises(ValueError, match=error_str): + with pytest.raises(ValueError, match=re.escape(error_str)): log_loss(y_true, y_pred) y_pred = [[0.2, 0.8], [0.6, 0.4], [0.7, 0.3]] - error_str = r"Found input variables with inconsistent numbers of samples: \[3, 2\]" - with pytest.raises(ValueError, match=error_str): + error_str = "Found input variables with inconsistent numbers of samples: [3, 2]" + with pytest.raises(ValueError, match=re.escape(error_str)): log_loss(y_true, y_pred) # works when the labels argument is used @@ -2833,7 +2844,7 @@ def test_log_loss_not_probabilities_warning(dtype): y_true = np.array([0, 1, 1, 0]) y_pred = np.array([[0.2, 0.7], [0.6, 0.3], [0.4, 0.7], [0.8, 0.3]], dtype=dtype) - with pytest.warns(UserWarning, match="The y_pred values do not sum to one."): + with pytest.warns(UserWarning, match="The y_prob values do not sum to one."): log_loss(y_true, y_pred) @@ -2869,39 +2880,188 @@ def test_log_loss_pandas_input(): assert_allclose(loss, 0.7469410) -def test_brier_score_loss(): +def test_log_loss_warnings(): + expected_message = re.escape( + "Labels passed were ['spam', 'eggs', 'ham']. But this function " + "assumes labels are ordered lexicographically. " + "Pass the ordered labels=['eggs', 'ham', 'spam'] and ensure that " + "the columns of y_prob correspond to this ordering." + ) + with pytest.warns(UserWarning, match=expected_message): + log_loss( + ["eggs", "spam", "ham"], + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + labels=["spam", "eggs", "ham"], + ) + + +def test_brier_score_loss_binary(): # Check brier_score_loss function y_true = np.array([0, 1, 1, 0, 1, 1]) - y_pred = np.array([0.1, 0.8, 0.9, 0.3, 1.0, 0.95]) - true_score = linalg.norm(y_true - y_pred) ** 2 / len(y_true) + y_prob = np.array([0.1, 0.8, 0.9, 0.3, 1.0, 0.95]) + true_score = linalg.norm(y_true - y_prob) ** 2 / len(y_true) assert_almost_equal(brier_score_loss(y_true, y_true), 0.0) - assert_almost_equal(brier_score_loss(y_true, y_pred), true_score) - assert_almost_equal(brier_score_loss(1.0 + y_true, y_pred), true_score) - assert_almost_equal(brier_score_loss(2 * y_true - 1, y_pred), true_score) + assert_almost_equal(brier_score_loss(y_true, y_prob), true_score) + assert_almost_equal(brier_score_loss(1.0 + y_true, y_prob), true_score) + assert_almost_equal(brier_score_loss(2 * y_true - 1, y_prob), true_score) + + # check that using (n_samples, 2) y_prob or y_true gives the same score + y_prob_reshaped = np.column_stack((1 - y_prob, y_prob)) + y_true_reshaped = np.column_stack((1 - y_true, y_true)) + assert_almost_equal(brier_score_loss(y_true, y_prob_reshaped), true_score) + assert_almost_equal(brier_score_loss(y_true_reshaped, y_prob_reshaped), true_score) + + # check scale_by_half argument + assert_almost_equal( + brier_score_loss(y_true, y_prob, scale_by_half="auto"), true_score + ) + assert_almost_equal( + brier_score_loss(y_true, y_prob, scale_by_half=True), true_score + ) + assert_almost_equal( + brier_score_loss(y_true, y_prob, scale_by_half=False), 2 * true_score + ) + + # calculate correctly when there's only one class in y_true + assert_almost_equal(brier_score_loss([-1], [0.4]), 0.4**2) + assert_almost_equal(brier_score_loss([0], [0.4]), 0.4**2) + assert_almost_equal(brier_score_loss([1], [0.4]), (1 - 0.4) ** 2) + assert_almost_equal(brier_score_loss(["foo"], [0.4], pos_label="bar"), 0.4**2) + assert_almost_equal( + brier_score_loss(["foo"], [0.4], pos_label="foo"), + (1 - 0.4) ** 2, + ) + + +def test_brier_score_loss_multiclass(): + # test cases for multi-class + assert_almost_equal( + brier_score_loss( + ["eggs", "spam", "ham"], + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0]], + labels=["eggs", "ham", "spam", "yams"], + ), + 2 / 3, + ) + + assert_almost_equal( + brier_score_loss( + [1, 0, 2], [[0.2, 0.7, 0.1], [0.6, 0.2, 0.2], [0.6, 0.1, 0.3]] + ), + 0.41333333, + ) + + # check perfect predictions for 3 classes + assert_almost_equal( + brier_score_loss( + [0, 1, 2], [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] + ), + 0, + ) + + # check perfectly incorrect predictions for 3 classes + assert_almost_equal( + brier_score_loss( + [0, 1, 2], [[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0]] + ), + 2, + ) + + +def test_brier_score_loss_invalid_inputs(): + # binary case + y_true = np.array([0, 1, 1, 0, 1, 1]) + y_prob = np.array([0.1, 0.8, 0.9, 0.3, 1.0, 0.95]) with pytest.raises(ValueError): - brier_score_loss(y_true, y_pred[1:]) + # bad length of y_prob + brier_score_loss(y_true, y_prob[1:]) with pytest.raises(ValueError): - brier_score_loss(y_true, y_pred + 1.0) + # y_pred has value greater than 1 + brier_score_loss(y_true, y_prob + 1.0) with pytest.raises(ValueError): - brier_score_loss(y_true, y_pred - 1.0) + # y_pred has value less than 0 + brier_score_loss(y_true, y_prob - 1.0) - # ensure to raise an error for multiclass y_true + # multiclass case + y_true = np.array([1, 0, 2]) + y_prob = np.array([[0.2, 0.7, 0.1], [0.6, 0.2, 0.2], [0.6, 0.1, 0.3]]) + with pytest.raises(ValueError): + # bad length of y_pred + brier_score_loss(y_true, y_prob[1:]) + with pytest.raises(ValueError): + # y_pred has value greater than 1 + brier_score_loss(y_true, y_prob + 1.0) + with pytest.raises(ValueError): + # y_pred has value less than 0 + brier_score_loss(y_true, y_prob - 1.0) + + # raise an error for multiclass y_true and binary y_prob y_true = np.array([0, 1, 2, 0]) - y_pred = np.array([0.8, 0.6, 0.4, 0.2]) + y_prob = np.array([0.8, 0.6, 0.4, 0.2]) + error_message = re.escape( + "The type of the target inferred from y_true is multiclass " + "but should be binary according to the shape of y_prob." + ) + with pytest.raises(ValueError, match=error_message): + brier_score_loss(y_true, y_prob) + + # raise an error for wrong number of classes + y_true = [0, 1, 2] + y_prob = [[1, 0], [0, 1], [0, 1]] error_message = ( - "Only binary classification is supported. The type of the target is multiclass" + "y_true and y_prob contain different number of " + "classes: 3 vs 2. Please provide the true " + "labels explicitly through the labels argument. " + "Classes found in " + "y_true: [0 1 2]" ) + with pytest.raises(ValueError, match=re.escape(error_message)): + brier_score_loss(y_true, y_prob) - with pytest.raises(ValueError, match=error_message): - brier_score_loss(y_true, y_pred) + y_true = ["eggs", "spam", "ham"] + y_prob = [[1, 0, 0], [0, 1, 0], [0, 1, 0]] + labels = ["eggs", "spam", "ham", "yams"] + error_message = ( + "The number of classes in labels is different " + "from that in y_prob. Classes found in " + "labels: ['eggs' 'ham' 'spam' 'yams']" + ) + with pytest.raises(ValueError, match=re.escape(error_message)): + brier_score_loss(y_true, y_prob, labels=labels) - # calculate correctly when there's only one class in y_true - assert_almost_equal(brier_score_loss([-1], [0.4]), 0.16) - assert_almost_equal(brier_score_loss([0], [0.4]), 0.16) - assert_almost_equal(brier_score_loss([1], [0.4]), 0.36) - assert_almost_equal(brier_score_loss(["foo"], [0.4], pos_label="bar"), 0.16) - assert_almost_equal(brier_score_loss(["foo"], [0.4], pos_label="foo"), 0.36) + # raise error message when there's only one class in y_true + y_true = ["eggs"] + y_prob = [[0.9, 0.1]] + error_message = ( + "y_true contains only one label (eggs). Please " + "provide the list of all expected class labels explicitly through the " + "labels argument." + ) + with pytest.raises(ValueError, match=re.escape(error_message)): + brier_score_loss(y_true, y_prob) + + # error is fixed when labels is specified + assert_almost_equal(brier_score_loss(y_true, y_prob, labels=["eggs", "ham"]), 0.01) + + +def test_brier_score_loss_warnings(): + expected_message = re.escape( + "Labels passed were ['spam', 'eggs', 'ham']. But this function " + "assumes labels are ordered lexicographically. " + "Pass the ordered labels=['eggs', 'ham', 'spam'] and ensure that " + "the columns of y_prob correspond to this ordering." + ) + with pytest.warns(UserWarning, match=expected_message): + brier_score_loss( + ["eggs", "spam", "ham"], + [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + ], + labels=["spam", "eggs", "ham"], + ) def test_balanced_accuracy_score_unseen(): @@ -3190,7 +3350,7 @@ def test_d2_log_loss_score_raises(): # check error if the number of classes in labels do not match the number # of classes in y_pred. - y_true = ["a", "b", "c"] + y_true = [0, 1, 2] y_pred = [[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]] labels = [0, 1, 2] err = "number of classes in labels is different" @@ -3213,7 +3373,7 @@ def test_d2_log_loss_score_raises(): # check error when y_true only has 1 label y_true = [1, 1, 1] - y_pred = [[0.5, 0.5], [0.5, 0.5], [0.5, 5]] + y_pred = [[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]] err = "y_true contains only one label" with pytest.raises(ValueError, match=err): d2_log_loss_score(y_true, y_pred) @@ -3222,7 +3382,7 @@ def test_d2_log_loss_score_raises(): # only 1 label y_true = [1, 1, 1] labels = [1] - y_pred = [[0.5, 0.5], [0.5, 0.5], [0.5, 5]] + y_pred = [[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]] err = "The labels array needs to contain at least two" with pytest.raises(ValueError, match=err): d2_log_loss_score(y_true, y_pred, labels=labels) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index e1c102670aec1..8f412133813d6 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -303,7 +303,6 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): # Those metrics don't support multiclass inputs METRIC_UNDEFINED_MULTICLASS = { - "brier_score_loss", "micro_roc_auc", "samples_roc_auc", "partial_roc_auc", @@ -398,6 +397,8 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "unnormalized_multilabel_confusion_matrix", "unnormalized_multilabel_confusion_matrix_sample", "cohen_kappa_score", + "log_loss", + "brier_score_loss", } # Metrics with a "normalize" option @@ -411,6 +412,7 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): THRESHOLDED_MULTILABEL_METRICS = { "log_loss", "unnormalized_log_loss", + "brier_score_loss", "roc_auc_score", "weighted_roc_auc", "samples_roc_auc", @@ -638,20 +640,46 @@ def test_symmetric_metric(name): @pytest.mark.parametrize("name", sorted(NOT_SYMMETRIC_METRICS)) def test_not_symmetric_metric(name): + # Test the symmetry of score and loss functions random_state = check_random_state(0) - y_true = random_state.randint(0, 2, size=(20,)) - y_pred = random_state.randint(0, 2, size=(20,)) - - if name in METRICS_REQUIRE_POSITIVE_Y: - y_true, y_pred = _require_positive_targets(y_true, y_pred) - metric = ALL_METRICS[name] - # use context manager to supply custom error message - with pytest.raises(AssertionError): - assert_array_equal(metric(y_true, y_pred), metric(y_pred, y_true)) - raise ValueError("%s seems to be symmetric" % name) + # The metric can be accidentally symmetric on a random draw. + # We run several random draws to check that at least of them + # gives an asymmetric result. + always_symmetric = True + for _ in range(5): + y_true = random_state.randint(0, 2, size=(20,)) + y_pred = random_state.randint(0, 2, size=(20,)) + + if name in METRICS_REQUIRE_POSITIVE_Y: + y_true, y_pred = _require_positive_targets(y_true, y_pred) + + nominal = metric(y_true, y_pred) + swapped = metric(y_pred, y_true) + if not np.allclose(nominal, swapped): + always_symmetric = False + break + + if always_symmetric: + raise ValueError(f"{name} seems to be symmetric") + + +def test_symmetry_tests(): + # check test_symmetric_metric and test_not_symmetric_metric + sym = "accuracy_score" + not_sym = "recall_score" + # test_symmetric_metric passes on a symmetric metric + # but fails on a not symmetric metric + test_symmetric_metric(sym) + with pytest.raises(AssertionError, match=f"{not_sym} is not symmetric"): + test_symmetric_metric(not_sym) + # test_not_symmetric_metric passes on a not symmetric metric + # but fails on a symmetric metric + test_not_symmetric_metric(not_sym) + with pytest.raises(ValueError, match=f"{sym} seems to be symmetric"): + test_not_symmetric_metric(sym) @pytest.mark.parametrize(