From 89f8831acf6a605e2b819b9634f8ad0738a4298b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jochen=20Wersd=C3=B6rfer?= Date: Thu, 6 Sep 2012 23:47:11 +0200 Subject: [PATCH 1/4] ENH added multiclass_log_loss metric --- sklearn/metrics/__init__.py | 2 ++ sklearn/metrics/metrics.py | 28 +++++++++++++++++++++++++++ sklearn/metrics/tests/test_metrics.py | 9 +++++++++ 3 files changed, 39 insertions(+) diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index 403c4ff195a10..aed51201c7f13 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -18,6 +18,7 @@ matthews_corrcoef, mean_squared_error, mean_absolute_error, + multiclass_log_loss, precision_recall_curve, precision_recall_fscore_support, precision_score, @@ -71,6 +72,7 @@ 'matthews_corrcoef', 'mean_squared_error', 'mean_absolute_error', + 'multiclass_log_loss', 'mutual_info_score', 'normalized_mutual_info_score', 'pairwise_distances', diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 52757a4c65922..66f019dd85d82 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -2111,3 +2111,31 @@ def r2_score(y_true, y_pred): return 0.0 return 1 - numerator / denominator + + +def multiclass_log_loss(y_true, y_pred, eps=1e-15): + """Multi class version of Logarithmic Loss metric. + https://www.kaggle.com/wiki/MultiClassLogLoss + + idea from this post: + http://www.kaggle.com/c/emc-data-science/forums/t/2149/is-anyone-noticing-difference-betwen-validation-and-leaderboard-error/12209#post12209 + + Parameters + ---------- + y_true : array, shape = [n_samples] + y_pred : array, shape = [n_samples, n_classes] + + Returns + ------- + loss : float + """ + predictions = np.clip(y_pred, eps, 1 - eps) + + # normalize row sums to 1 + predictions /= predictions.sum(axis=1)[:, np.newaxis] + + actual = np.zeros(y_pred.shape) + rows = actual.shape[0] + actual[np.arange(rows), y_true.astype(int)] = 1 + vsota = np.sum(actual * np.log(predictions)) + return -1.0 / rows * vsota diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index 71a0af183d46b..b57bc6a6f1d68 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -39,6 +39,7 @@ matthews_corrcoef, mean_squared_error, mean_absolute_error, + multiclass_log_loss, precision_recall_curve, precision_recall_fscore_support, precision_score, @@ -1801,3 +1802,11 @@ def test__column_or_1d(): assert_array_equal(_column_or_1d(y), np.ravel(y)) else: assert_raises(ValueError, _column_or_1d, y) + + +def test_multiclass_log_loss(): + y_true = np.array([0, 0, 0, 1, 1, 1]) + y_pred = np.array([[0.5, 0.5], [0.1, 0.9], [0.01, 0.99], + [0.9, 0.1], [0.75, 0.25], [0.001, 0.999]]) + loss = multiclass_log_loss(y_true, y_pred) + assert_equal(loss, 1.8817970689982668) From f2c50883fa03809ae07c1032ed2aa3a1a13bd496 Mon Sep 17 00:00:00 2001 From: Lars Buitinck Date: Tue, 28 May 2013 13:01:22 +0200 Subject: [PATCH 2/4] ENH rewrite multiclass_log_loss, rename log_loss, document it --- doc/modules/classes.rst | 1 + doc/modules/model_evaluation.rst | 46 ++++++++++++++++++++ doc/whats_new.rst | 3 ++ sklearn/metrics/__init__.py | 4 +- sklearn/metrics/metrics.py | 62 ++++++++++++++++++++------- sklearn/metrics/tests/test_metrics.py | 31 +++++++++++--- 6 files changed, 124 insertions(+), 23 deletions(-) diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 32f4dbae3c0e9..bfa1455d02432 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -721,6 +721,7 @@ Classification metrics metrics.hamming_loss metrics.hinge_loss metrics.jaccard_similarity_score + metrics.log_loss metrics.matthews_corrcoef metrics.precision_recall_curve metrics.precision_recall_fscore_support diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index e03881c2ad7da..6f3d1352b1e41 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -712,6 +712,52 @@ with a svm classifier:: 0.3... +Log loss +-------- +The log loss, also called logistic regression loss or cross-entropy loss, +is a loss function defined on probability estimates. +It is commonly used in (multinomial) logistic regression and neural networks, +as well as some variants of expectation-maximization, +and can be used to evaluate the probability outputs (``predict_proba``) +of a classifier, rather than its discrete predictions. + +For binary classification with a true label :math:`y_t \in \{0,1\}` +and a probability estimate :math:`y_p = P(y_t = 1)`, +the log loss per sample is the negative log-likelihood +of the true label given the prediction: + +.. math:: + + L_{\log}(y_t, y_p) = -\log P(y_t|y_p) = -(y_t \log y_p + (1 - y_t) \log (1 - y_p)) + +This extends to the multiclass case as follows. +Let the true labels for a set of samples +be encoded as a 1-of-K binary indicator matrix :math:`T`, +i.e. :math:`t_{i,k} = 1` if sample :math:`i` has label :math:`k` +taken from a set of :math:`K` labels. +Let :math:`Y` be a matrix of probability estimates, +with :math:`y_{i,k} = P(t_{i,k} = 1)`. +Then the total log loss of the whole set is + +.. math:: + + L_{\log}(T, Y) = -\log P(T|Y) = - \sum_i \sum_j t_{i,k} \log y_{i,k} + +The function :func:`log_loss` computes either total or mean log loss +given a list of ground-truth labels and a probability matrix, +as returned by an estimator's ``predict_proba`` method. + + >>> from sklearn.metrics import log_loss + >>> y_true = [0, 0, 1, 1] + >>> y_pred = [[.9, .1], [.8, .2], [.3, .7], [.01, .99]] + >>> log_loss(y_true, y_pred) # doctest: +ELLIPSIS + 0.1738... + +The first ``[.9, .1]`` in ``y_pred`` +denotes 90% probability that the first sample has label 0. +The log loss is non-negative. + + Matthews correlation coefficient ................................. diff --git a/doc/whats_new.rst b/doc/whats_new.rst index ade65dadb7624..5361e915a73c4 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -99,6 +99,9 @@ Changelog the fraction or the number of correctly classified sample by `Arnaud Joly`_. + - Added :func:`metrics.log_loss` that computes log loss, aka cross-entropy + loss. By Jochen Wersdörfer and `Lars Buitinck`_. + - A bug that caused :class:`ensemble.AdaBoostClassifier`'s to output incorrect probabilities has been fixed. diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index aed51201c7f13..f3aca8f62b115 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -15,10 +15,10 @@ hamming_loss, hinge_loss, jaccard_similarity_score, + log_loss, matthews_corrcoef, mean_squared_error, mean_absolute_error, - multiclass_log_loss, precision_recall_curve, precision_recall_fscore_support, precision_score, @@ -69,10 +69,10 @@ 'homogeneity_completeness_v_measure', 'homogeneity_score', 'jaccard_similarity_score', + 'log_loss', 'matthews_corrcoef', 'mean_squared_error', 'mean_absolute_error', - 'multiclass_log_loss', 'mutual_info_score', 'normalized_mutual_info_score', 'pairwise_distances', diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 66f019dd85d82..05b5ae4fe71e1 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -2113,29 +2113,61 @@ def r2_score(y_true, y_pred): return 1 - numerator / denominator -def multiclass_log_loss(y_true, y_pred, eps=1e-15): - """Multi class version of Logarithmic Loss metric. - https://www.kaggle.com/wiki/MultiClassLogLoss +def log_loss(y_true, y_pred, eps=1e-15, normalize=True): + """Log loss, aka logistic loss or cross-entropy loss. - idea from this post: - http://www.kaggle.com/c/emc-data-science/forums/t/2149/is-anyone-noticing-difference-betwen-validation-and-leaderboard-error/12209#post12209 + This is the loss function used in (multinomial) logistic regression + and extensions of it such as neural networks, defined as the negative + log-likelihood of the true labels given a probabilistic classifier's + predictions. For a single sample with true label yt in {0,1} and + estimated probability yp that yt = 1, the log loss is + + -log P(yt|yp) = -(yt log(yp) + (1 - yt) log(1 - yp)) Parameters ---------- - y_true : array, shape = [n_samples] - y_pred : array, shape = [n_samples, n_classes] + y_true : array-like or list of labels or label indicator matrix + Ground truth (correct) labels for n_samples samples. + + y_pred : array-like of float, shape = (n_samples, n_classes) + Predicted probabilities, as returned by a classifier's + predict_proba method. + + eps : float + Log loss is undefined for p=0 or p=1, so probabilities are + clipped to max(eps, min(1 - eps, p)). + + normalize : bool, optional (default=True) + If true, return the mean loss per sample. + Otherwise, return the total loss. Returns ------- loss : float + + Examples + -------- + >>> log_loss(["spam", "ham", "ham", "spam"], # doctest: +ELLIPSIS + ... [[.1, .9], [.9, .1], [.8, .2], [.35, .65]]) + 0.21616... + + References + ---------- + C.M. Bishop (2006). Pattern Recognition and Machine Learning. Springer, + p. 209. + + Notes + ----- + The logarithm used is the natural logarithm (base-e). """ - predictions = np.clip(y_pred, eps, 1 - eps) + lb = LabelBinarizer() + T = lb.fit_transform(y_true) + if T.shape[1] == 1: + T = np.append(1 - T, T, axis=1) - # normalize row sums to 1 - predictions /= predictions.sum(axis=1)[:, np.newaxis] + # Clip and renormalize + Y = np.clip(y_pred, eps, 1 - eps) + Y /= Y.sum(axis=1)[:, np.newaxis] - actual = np.zeros(y_pred.shape) - rows = actual.shape[0] - actual[np.arange(rows), y_true.astype(int)] = 1 - vsota = np.sum(actual * np.log(predictions)) - return -1.0 / rows * vsota + loss = -(T * np.log(Y)).sum() + return loss / T.shape[0] if normalize else loss diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index b57bc6a6f1d68..6f4a10378b398 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -36,10 +36,10 @@ hamming_loss, hinge_loss, jaccard_similarity_score, + log_loss, matthews_corrcoef, mean_squared_error, mean_absolute_error, - multiclass_log_loss, precision_recall_curve, precision_recall_fscore_support, precision_score, @@ -1804,9 +1804,28 @@ def test__column_or_1d(): assert_raises(ValueError, _column_or_1d, y) -def test_multiclass_log_loss(): - y_true = np.array([0, 0, 0, 1, 1, 1]) +def test_log_loss(): + # binary case with symbolic labels ("no" < "yes") + y_true = ["no", "no", "no", "yes", "yes", "yes"] y_pred = np.array([[0.5, 0.5], [0.1, 0.9], [0.01, 0.99], - [0.9, 0.1], [0.75, 0.25], [0.001, 0.999]]) - loss = multiclass_log_loss(y_true, y_pred) - assert_equal(loss, 1.8817970689982668) + [0.9, 0.1], [0.75, 0.25], [0.001, 0.999]]) + loss = log_loss(y_true, y_pred) + assert_almost_equal(loss, 1.8817971) + + # multiclass case; adapted from http://bit.ly/RJJHWA + y_true = [1, 0, 2] + y_pred = [[0.2, 0.7, 0.1], [0.6, 0.2, 0.2], [0.6, 0.1, 0.3]] + loss = log_loss(y_true, y_pred, normalize=True) + assert_almost_equal(loss, 0.6904911) + + # check that we got all the shapes and axes right + # by doubling the length of y_true and y_pred + y_true *= 2 + y_pred *= 2 + loss = log_loss(y_true, y_pred, normalize=False) + assert_almost_equal(loss, 0.6904911 * 6, decimal=6) + + # check eps and handling of absolute zero and one probabilities + y_pred = np.asarray(y_pred) > .5 + loss = log_loss(y_true, y_pred, normalize=True, eps=.1) + assert_almost_equal(loss, log_loss(y_true, np.clip(y_pred, .1, .9))) From d0cf3a683e5e2d836b4501e57324b34e6c810748 Mon Sep 17 00:00:00 2001 From: Lars Buitinck Date: Mon, 1 Jul 2013 13:33:00 +0200 Subject: [PATCH 3/4] ENH Scorer object for log loss --- doc/modules/model_evaluation.rst | 3 +- sklearn/metrics/scorer.py | 7 ++++- sklearn/metrics/tests/test_score_objects.py | 33 ++++----------------- 3 files changed, 14 insertions(+), 29 deletions(-) diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 6f3d1352b1e41..e46f7a7ff980c 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -75,7 +75,7 @@ of acceptable values:: >>> model = svm.SVC() >>> cross_validation.cross_val_score(model, X, y, scoring='wrong_choice') Traceback (most recent call last): - ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_rand_score', 'average_precision', 'f1', 'mean_squared_error', 'precision', 'r2', 'recall', 'roc_auc'] + ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_rand_score', 'average_precision', 'f1', 'log_likelihood', 'mean_squared_error', 'precision', 'r2', 'recall', 'roc_auc'] .. note:: @@ -1077,6 +1077,7 @@ Scoring Function 'accuracy' :func:`sklearn.metrics.accuracy_score` 'average_precision' :func:`sklearn.metrics.average_precision_score` 'f1' :func:`sklearn.metrics.f1_score` +'log_likelihood' :func:`sklearn.metric.log_loss` 'precision' :func:`sklearn.metrics.precision_score` 'recall' :func:`sklearn.metrics.recall_score` 'roc_auc' :func:`sklearn.metrics.auc_score` diff --git a/sklearn/metrics/scorer.py b/sklearn/metrics/scorer.py index 6f4d37776dedc..7f0f0fba6b1aa 100644 --- a/sklearn/metrics/scorer.py +++ b/sklearn/metrics/scorer.py @@ -23,7 +23,7 @@ from . import (r2_score, mean_squared_error, accuracy_score, f1_score, auc_score, average_precision_score, precision_score, - recall_score) + recall_score, log_loss) from .cluster import adjusted_rand_score from ..externals import six @@ -224,6 +224,10 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False, precision_scorer = make_scorer(precision_score) recall_scorer = make_scorer(recall_score) +# Score function for probabilistic classification +log_likelihood_scorer = make_scorer(log_loss, greater_is_better=False, + needs_proba=True) + # Clustering scores adjusted_rand_scorer = make_scorer(adjusted_rand_score) @@ -232,4 +236,5 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False, accuracy=accuracy_scorer, f1=f1_scorer, roc_auc=auc_scorer, average_precision=average_precision_scorer, precision=precision_scorer, recall=recall_scorer, + log_likelihood=log_likelihood_scorer, adjusted_rand_score=adjusted_rand_scorer) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 5c2278ceaceb8..299fe56f5062c 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -4,7 +4,8 @@ from sklearn.utils.testing import assert_almost_equal, assert_array_equal from sklearn.utils.testing import assert_raises -from sklearn.metrics import f1_score, r2_score, auc_score, fbeta_score +from sklearn.metrics import (f1_score, r2_score, auc_score, fbeta_score, + log_loss) from sklearn.metrics.cluster import adjusted_rand_score from sklearn.metrics import make_scorer, SCORERS from sklearn.svm import LinearSVC @@ -59,32 +60,6 @@ def test_regression_scorers(): assert_almost_equal(score1, score2) -def test_proba_scorer(): - """Non-regression test for _ProbaScorer (which did not have __call__).""" - # This test can be removed once we have an actual scorer that uses - # predict_proba, e.g. by merging #2013. - def log_loss(y, p): - """Binary log loss with labels in {0, 1}.""" - return -(y * np.log(p) + (1 - y) * np.log(1 - p)) - - log_loss_scorer = make_scorer(log_loss, greater_is_better=False, - needs_proba=True) - - class DiscreteUniform(object): - def __init__(self, n_classes): - self.n_classes = n_classes - - def predict_proba(self, X): - n = self.n_classes - return np.repeat(1. / n, n) - - estimator = DiscreteUniform(5) - y_true = np.array([0, 0, 1, 1, 1]) - y_pred = estimator.predict_proba(None) - assert_array_equal(log_loss(y_true, y_pred), - -log_loss_scorer(estimator, None, y_true)) - - def test_thresholded_scorers(): """Test scorers that take thresholds.""" X, y = make_blobs(random_state=0, centers=2) @@ -97,6 +72,10 @@ def test_thresholded_scorers(): assert_almost_equal(score1, score2) assert_almost_equal(score1, score3) + logscore = SCORERS['log_likelihood'](clf, X_test, y_test) + logloss = log_loss(y_test, clf.predict_proba(X_test)) + assert_almost_equal(-logscore, logloss) + # same for an estimator without decision_function clf = DecisionTreeClassifier() clf.fit(X_train, y_train) From dd602cdc03e6bd4707bb132c723e68eba0751688 Mon Sep 17 00:00:00 2001 From: Lars Buitinck Date: Fri, 26 Jul 2013 19:37:34 +0200 Subject: [PATCH 4/4] ENH add log_likelihood_score as -log_loss ... or rather, the other way around. --- doc/modules/model_evaluation.rst | 37 +++++++++++---------- sklearn/metrics/__init__.py | 1 + sklearn/metrics/metrics.py | 57 ++++++++++++++++++++++++-------- sklearn/metrics/scorer.py | 5 ++- 4 files changed, 66 insertions(+), 34 deletions(-) diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index e46f7a7ff980c..8e912f1c814b3 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -712,23 +712,23 @@ with a svm classifier:: 0.3... -Log loss --------- -The log loss, also called logistic regression loss or cross-entropy loss, -is a loss function defined on probability estimates. -It is commonly used in (multinomial) logistic regression and neural networks, +Log-likelihood and log loss +--------------------------- +Log-likelihood is a score to evaluate probabilistic classifiers +by their probability outputs (``predict_proba``) +rather than their discrete predictions. + +Log loss is negative log-likelihood and is used as the loss function +in logistic regression and neural networks, as well as some variants of expectation-maximization, -and can be used to evaluate the probability outputs (``predict_proba``) -of a classifier, rather than its discrete predictions. For binary classification with a true label :math:`y_t \in \{0,1\}` and a probability estimate :math:`y_p = P(y_t = 1)`, -the log loss per sample is the negative log-likelihood -of the true label given the prediction: +the log-likelihood of the model that predicted :math:`y_p` is: .. math:: - L_{\log}(y_t, y_p) = -\log P(y_t|y_p) = -(y_t \log y_p + (1 - y_t) \log (1 - y_p)) + L(y_t, y_p) = \log P(y_t|y_p) = (y_t \log y_p + (1 - y_t) \log (1 - y_p)) This extends to the multiclass case as follows. Let the true labels for a set of samples @@ -737,25 +737,28 @@ i.e. :math:`t_{i,k} = 1` if sample :math:`i` has label :math:`k` taken from a set of :math:`K` labels. Let :math:`Y` be a matrix of probability estimates, with :math:`y_{i,k} = P(t_{i,k} = 1)`. -Then the total log loss of the whole set is +Then the total log-likelihood of the whole set is .. math:: - L_{\log}(T, Y) = -\log P(T|Y) = - \sum_i \sum_j t_{i,k} \log y_{i,k} + L(T, Y) = \log P(T|Y) = \sum_i \sum_j t_{i,k} \log y_{i,k} -The function :func:`log_loss` computes either total or mean log loss +The functions :func:`log_likelihood_score` and :func:`log_loss` +compute either total or mean log-likelihood/loss given a list of ground-truth labels and a probability matrix, as returned by an estimator's ``predict_proba`` method. - >>> from sklearn.metrics import log_loss + >>> from sklearn.metrics import log_likelihood_score, log_loss >>> y_true = [0, 0, 1, 1] >>> y_pred = [[.9, .1], [.8, .2], [.3, .7], [.01, .99]] - >>> log_loss(y_true, y_pred) # doctest: +ELLIPSIS + >>> log_likelihood_score(y_true, y_pred) # doctest: +ELLIPSIS + -0.1738... + >>> log_loss(y_true, y_pred) # doctest: +ELLIPSIS 0.1738... The first ``[.9, .1]`` in ``y_pred`` denotes 90% probability that the first sample has label 0. -The log loss is non-negative. +Log-likelihood is negative or zero (with zero meaning perfect predictions). Matthews correlation coefficient @@ -1077,7 +1080,7 @@ Scoring Function 'accuracy' :func:`sklearn.metrics.accuracy_score` 'average_precision' :func:`sklearn.metrics.average_precision_score` 'f1' :func:`sklearn.metrics.f1_score` -'log_likelihood' :func:`sklearn.metric.log_loss` +'log_likelihood' :func:`sklearn.metric.log_likelihood_score` 'precision' :func:`sklearn.metrics.precision_score` 'recall' :func:`sklearn.metrics.recall_score` 'roc_auc' :func:`sklearn.metrics.auc_score` diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index f3aca8f62b115..64a3e2c96717f 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -15,6 +15,7 @@ hamming_loss, hinge_loss, jaccard_similarity_score, + log_likelihood_score, log_loss, matthews_corrcoef, mean_squared_error, diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 05b5ae4fe71e1..77006aac5b716 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -23,7 +23,7 @@ from scipy.spatial.distance import hamming as sp_hamming from ..externals.six.moves import zip -from ..preprocessing import LabelBinarizer +from ..preprocessing import LabelBinarizer, label_binarize from ..utils import check_arrays from ..utils import deprecated from ..utils.fixes import divide @@ -2113,16 +2113,20 @@ def r2_score(y_true, y_pred): return 1 - numerator / denominator -def log_loss(y_true, y_pred, eps=1e-15, normalize=True): - """Log loss, aka logistic loss or cross-entropy loss. +def log_likelihood_score(y_true, y_pred, eps=1e-15, normalize=True): + """Log-likelihood of the model that generated y_pred. - This is the loss function used in (multinomial) logistic regression - and extensions of it such as neural networks, defined as the negative - log-likelihood of the true labels given a probabilistic classifier's - predictions. For a single sample with true label yt in {0,1} and - estimated probability yp that yt = 1, the log loss is + This function returns the probability of y_true being the true labels + when a probability model has returned y_pred, also known as the likelihood + of the model. Log-likelihood is the most common optimization objective for + probability models such as logistic regression. - -log P(yt|yp) = -(yt log(yp) + (1 - yt) log(1 - yp)) + For a single sample with true label yt in {0,1} and estimated probability + yp that yt = 1, the log-likelihood is + + log P(yt|yp) = (yt log(yp) + (1 - yt) log(1 - yp)) + + Note that log-probabilities are <= 0 with 0 meaning perfect predictions. Parameters ---------- @@ -2143,19 +2147,23 @@ def log_loss(y_true, y_pred, eps=1e-15, normalize=True): Returns ------- - loss : float + score : float Examples -------- - >>> log_loss(["spam", "ham", "ham", "spam"], # doctest: +ELLIPSIS - ... [[.1, .9], [.9, .1], [.8, .2], [.35, .65]]) - 0.21616... + >>> log_likelihood_score(["spam", "ham", "ham", "spam"], # doctest: +ELLIPSIS + ... [[.1, .9], [.9, .1], [.8, .2], [.35, .65]]) + -0.21616... References ---------- C.M. Bishop (2006). Pattern Recognition and Machine Learning. Springer, p. 209. + See also + -------- + log_loss + Notes ----- The logarithm used is the natural logarithm (base-e). @@ -2169,5 +2177,26 @@ def log_loss(y_true, y_pred, eps=1e-15, normalize=True): Y = np.clip(y_pred, eps, 1 - eps) Y /= Y.sum(axis=1)[:, np.newaxis] - loss = -(T * np.log(Y)).sum() + loss = (T * np.log(Y)).sum() return loss / T.shape[0] if normalize else loss + + +def log_loss(y_true, y_pred, eps=1e-15, normalize=True): + """Log loss, aka logistic loss or cross-entropy loss. + + This is the loss function used in logistic regression and other + probability models, defined as the negative log-likelihood of a models + prediction given ground truth labels. For a single sample with true label + yt in {0,1} and estimated probability yp that yt = 1, the log loss is + + -log P(yt|yp) = -(yt log(yp) + (1 - yt) log(1 - yp)) + + See log_likelihood_score for the parameters. + + Examples + -------- + >>> log_loss(["spam", "ham", "ham", "spam"], # doctest: +ELLIPSIS + ... [[.1, .9], [.9, .1], [.8, .2], [.35, .65]]) + 0.21616... + """ + return -log_likelihood_score(y_true, y_pred, eps, normalize) diff --git a/sklearn/metrics/scorer.py b/sklearn/metrics/scorer.py index 7f0f0fba6b1aa..9e4108499ee3b 100644 --- a/sklearn/metrics/scorer.py +++ b/sklearn/metrics/scorer.py @@ -23,7 +23,7 @@ from . import (r2_score, mean_squared_error, accuracy_score, f1_score, auc_score, average_precision_score, precision_score, - recall_score, log_loss) + recall_score, log_likelihood_score) from .cluster import adjusted_rand_score from ..externals import six @@ -225,8 +225,7 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False, recall_scorer = make_scorer(recall_score) # Score function for probabilistic classification -log_likelihood_scorer = make_scorer(log_loss, greater_is_better=False, - needs_proba=True) +log_likelihood_scorer = make_scorer(log_likelihood_score, needs_proba=True) # Clustering scores adjusted_rand_scorer = make_scorer(adjusted_rand_score)