From 5cc8032487f9535dfec56ed145186dc37b080d2a Mon Sep 17 00:00:00 2001 From: Arnaud Joly Date: Thu, 7 Mar 2013 10:01:56 +0100 Subject: [PATCH] ENH add normalize option to accuracy_score + FIX bug with 1d array --- doc/modules/model_evaluation.rst | 4 +- doc/whats_new.rst | 4 + sklearn/metrics/metrics.py | 203 ++++++++++++++++++++++---- sklearn/metrics/tests/test_metrics.py | 144 ++++++++++++++++-- 4 files changed, 309 insertions(+), 46 deletions(-) diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 4386d501fb14e..e89df88d26c0e 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -72,7 +72,7 @@ Accuracy score --------------- The :func:`accuracy_score` function computes the `accuracy `_, the fraction -of correct predictions. In multilabel classification, +(default) or the number of correct predictions. In multilabel classification, the function returns the subset accuracy: the entire set of labels for a sample must be entirely correct or the sample has an accuracy of zero. @@ -96,6 +96,8 @@ where :math:`1(x)` is the `indicator function >>> y_true = [0, 1, 2, 3] >>> accuracy_score(y_true, y_pred) 0.5 + >>> accuracy_score(y_true, y_pred, normalize=False) + 2 In the multilabel case with binary indicator format: diff --git a/doc/whats_new.rst b/doc/whats_new.rst index bca47266bc0e3..15238ceb1fb8c 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -65,6 +65,10 @@ Changelog - Performance improvements in :class:`isotonic.IsotonicRegression` by Nelle Varoquaux. + - :func:`metrics.accuracy_score` has an option normalize to return + the fraction or the number of correctly classified sample + by `Arnaud Joly`_. + API changes summary ------------------- diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py index 2ea9d0d600319..b06b1c27d7bba 100644 --- a/sklearn/metrics/metrics.py +++ b/sklearn/metrics/metrics.py @@ -24,7 +24,8 @@ from ..externals.six.moves import zip from ..preprocessing import LabelBinarizer -from ..utils import check_arrays, deprecated +from ..utils import check_arrays +from ..utils import deprecated from ..utils.multiclass import is_label_indicator_matrix from ..utils.multiclass import is_multilabel from ..utils.multiclass import unique_labels @@ -33,6 +34,121 @@ ############################################################################### # General utilities ############################################################################### +def _is_1d(x): + """Return True if x can be considered as a 1d vector. + + This function allows to distinguish between a 1d vector, e.g. : + - ``np.array([1, 2])`` + - ``np.array([[1, 2]])`` + - ``np.array([[1], [2]])`` + + and 2d matrix, e.g.: + - ``np.array([[1, 2], [3, 4]])`` + + + Parameters + ---------- + x : numpy array. + + Return + ------ + is_1d : boolean, + Return True if x can be considered as a 1d vector. + + Examples + -------- + >>> import numpy as np + >>> from sklearn.metrics.metrics import _is_1d + >>> _is_1d([1, 2, 3]) + True + >>> _is_1d(np.array([1, 2, 3])) + True + >>> _is_1d([[1, 2, 3]]) + True + >>> _is_1d(np.array([[1, 2, 3]])) + True + >>> _is_1d([[1], [2], [3]]) + True + >>> _is_1d(np.array([[1], [2], [3]])) + True + >>> _is_1d([[1, 2], [3, 4]]) + False + >>> _is_1d(np.array([[1, 2], [3, 4]])) + False + + See also + -------- + _check_1d_array + + """ + return np.size(x) == np.max(np.shape(x)) + + +def _check_1d_array(y1, y2, ravel=False): + """Check that y1 and y2 are vectors of the same shape. + + It convert 1d arrays (y1 and y2) of various shape to a common shape + representation. Note that ``y1`` and ``y2`` should have the same number of + element. + + Parameters + ---------- + y1 : array-like, + y1 must be a "vector". + + y2 : array-like + y2 must be a "vector". + + ravel : boolean, optional (default=False), + If ``ravel``` is set to ``True``, then ``y1`` and ``y2`` are raveled. + + Returns + ------- + y1 : numpy array, + If ``ravel`` is set to ``True``, return np.ravel(y1), else + return y1. + + y2 : numpy array, + Return y2 reshaped to have the shape of y1. + + Examples + -------- + >>> from numpy import array + >>> from sklearn.metrics.metrics import _check_1d_array + >>> _check_1d_array([1, 2], [[3, 4]]) + (array([1, 2]), array([3, 4])) + >>> _check_1d_array([[1, 2]], [[3], [4]]) + (array([[1, 2]]), array([[3, 4]])) + >>> _check_1d_array([[1], [2]], [[3, 4]]) + (array([[1], + [2]]), array([[3], + [4]])) + >>> _check_1d_array([[1], [2]], [[3, 4]], ravel=True) + (array([1, 2]), array([3, 4])) + + See also + -------- + _is_1d + + """ + y1 = np.asarray(y1) + y2 = np.asarray(y2) + + if not _is_1d(y1): + raise ValueError("y1 can't be considered as a vector") + + if not _is_1d(y2): + raise ValueError("y2 can't be considered as a vector") + + if ravel: + return np.ravel(y1), np.ravel(y2) + else: + if np.shape(y1) != np.shape(y2): + y2 = np.reshape(y2, np.shape(y1)) + + return y1, y2 + + def auc(x, y, reorder=False): """Compute Area Under the Curve (AUC) using the trapezoidal rule @@ -47,7 +163,7 @@ def auc(x, y, reorder=False): y : array, shape = [n] y coordinates. - reorder : boolean, optional + reorder : boolean, optional (default=False) If True, assume that the curve is ascending in the case of ties, as for an ROC curve. If the curve is non-ascending, the result will be wrong. @@ -299,6 +415,9 @@ def matthews_corrcoef(y_true, y_pred): -0.33... """ + y_true, y_pred = check_arrays(y_true, y_pred) + y_true, y_pred = _check_1d_array(y_true, y_pred, ravel=True) + mcc = np.corrcoef(y_true, y_pred)[0, 1] if np.isnan(mcc): return 0. @@ -655,8 +774,8 @@ def zero_one_loss(y_true, y_pred, normalize=True): y_pred : array-like or list of labels or label indicator matrix Predicted labels, as returned by a classifier. - normalize : bool, optional - If ``False`` (default), return the number of misclassifications. + normalize : bool, optional (default=True) + If ``False``, return the number of misclassifications. Otherwise, return the fraction of misclassifications. Returns @@ -696,34 +815,19 @@ def zero_one_loss(y_true, y_pred, normalize=True): """ y_true, y_pred = check_arrays(y_true, y_pred, allow_lists=True) + score = accuracy_score(y_true, y_pred, normalize=normalize) - if is_multilabel(y_true): - # Handle mix representation - if type(y_true) != type(y_pred): - labels = unique_labels(y_true, y_pred) - lb = LabelBinarizer() - lb.fit([labels.tolist()]) - y_true = lb.transform(y_true) - y_pred = lb.transform(y_pred) + if normalize: + return 1 - score + else: + if hasattr(y_true, "shape"): + n_samples = (np.max(y_true.shape) if _is_1d(y_true) + else y_true.shape[0]) - if is_label_indicator_matrix(y_true): - loss = (y_pred != y_true).sum(axis=1) > 0 else: - # numpy 1.3 : it is required to perform a unique before setxor1d - # to get unique label in numpy 1.3. - # This is needed in order to handle redundant labels. - # FIXME : check if this can be simplified when 1.3 is removed - loss = np.array([np.size(np.setxor1d(np.unique(pred), - np.unique(true))) > 0 - for pred, true in zip(y_pred, y_true)]) - else: - y_true, y_pred = check_arrays(y_true, y_pred) - loss = y_true != y_pred + n_samples = len(y_true) - if normalize: - return np.mean(loss) - else: - return np.sum(loss) + return n_samples - score @deprecated("Function 'zero_one' has been renamed to " @@ -743,7 +847,7 @@ def zero_one(y_true, y_pred, normalize=False): y_pred : array-like - normalize : bool, optional + normalize : bool, optional (default=False) If ``False`` (default), return the number of misclassifications. Otherwise, return the fraction of misclassifications. @@ -771,7 +875,7 @@ def zero_one(y_true, y_pred, normalize=False): ############################################################################### # Multiclass score functions ############################################################################### -def accuracy_score(y_true, y_pred): +def accuracy_score(y_true, y_pred, normalize=True): """Accuracy classification score. Parameters @@ -782,6 +886,10 @@ def accuracy_score(y_true, y_pred): y_pred : array-like or list of labels or label indicator matrix Predicted labels, as returned by a classifier. + normalize : bool, optional (default=True) + If ``False``, return the number of correctly classified samples. + Otherwise, return the fraction of correctly classified samples. + Returns ------- score : float @@ -806,6 +914,8 @@ def accuracy_score(y_true, y_pred): >>> y_true = [0, 1, 2, 3] >>> accuracy_score(y_true, y_pred) 0.5 + >>> accuracy_score(y_true, y_pred, normalize=False) + 2 In the multilabel case with binary indicator format: @@ -841,9 +951,15 @@ def accuracy_score(y_true, y_pred): for pred, true in zip(y_pred, y_true)]) else: y_true, y_pred = check_arrays(y_true, y_pred) + + # Handle mix shape + y_true, y_pred = _check_1d_array(y_true, y_pred, ravel=True) score = y_true == y_pred - return np.mean(score) + if normalize: + return np.mean(score) + else: + return np.sum(score) def f1_score(y_true, y_pred, labels=None, pos_label=1, average='weighted'): @@ -1146,6 +1262,8 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None, raise ValueError("beta should be >0 in the F-beta score") y_true, y_pred = check_arrays(y_true, y_pred) + y_true, y_pred = _check_1d_array(y_true, y_pred) + if labels is None: labels = unique_labels(y_true, y_pred) else: @@ -1589,6 +1707,9 @@ def hamming_loss(y_true, y_pred, classes=None): return np.mean(loss) / np.size(classes) else: + y_true, y_pred = check_arrays(y_true, y_pred) + y_true, y_pred = _check_1d_array(y_true, y_pred) + return sp_hamming(y_true, y_pred) @@ -1625,6 +1746,11 @@ def mean_absolute_error(y_true, y_pred): """ y_true, y_pred = check_arrays(y_true, y_pred) + + # Handle mix 1d representation + if _is_1d(y_true): + y_true, y_pred = _check_1d_array(y_true, y_pred) + return np.mean(np.abs(y_pred - y_true)) @@ -1658,6 +1784,11 @@ def mean_squared_error(y_true, y_pred): """ y_true, y_pred = check_arrays(y_true, y_pred) + + # Handle mix 1d representation + if _is_1d(y_true): + y_true, y_pred = _check_1d_array(y_true, y_pred) + return np.mean((y_pred - y_true) ** 2) @@ -1696,6 +1827,11 @@ def explained_variance_score(y_true, y_pred): """ y_true, y_pred = check_arrays(y_true, y_pred) + + # Handle mix 1d representation + if _is_1d(y_true): + y_true, y_pred = _check_1d_array(y_true, y_pred) + numerator = np.var(y_true - y_pred) denominator = np.var(y_true) if denominator == 0.0: @@ -1752,6 +1888,11 @@ def r2_score(y_true, y_pred): """ y_true, y_pred = check_arrays(y_true, y_pred) + + # Handle mix 1d representation + if _is_1d(y_true): + y_true, y_pred = _check_1d_array(y_true, y_pred, ravel=True) + if len(y_true) == 1: raise ValueError("r2_score can only be computed given more than one" " sample.") diff --git a/sklearn/metrics/tests/test_metrics.py b/sklearn/metrics/tests/test_metrics.py index 59e857e19aad7..f00b47f02bbec 100644 --- a/sklearn/metrics/tests/test_metrics.py +++ b/sklearn/metrics/tests/test_metrics.py @@ -46,6 +46,22 @@ zero_one_score, zero_one_loss) +ALL_METRICS = [accuracy_score, + lambda y1, y2: accuracy_score(y1, y2, normalize=False), + hamming_loss, + zero_one_loss, + lambda y1, y2: zero_one_loss(y1, y2, normalize=False), + precision_score, + recall_score, + f1_score, + lambda y1, y2: fbeta_score(y1, y2, beta=2), + lambda y1, y2: fbeta_score(y1, y2, beta=0.5), + matthews_corrcoef, + mean_absolute_error, + mean_squared_error, + explained_variance_score, + r2_score] + def make_prediction(dataset=None, binary=False): """Make some classification predictions on a toy dataset using a SVC @@ -569,6 +585,9 @@ def test_losses(): assert_equal(accuracy_score(y_true, y_pred), 1 - zero_one_loss(y_true, y_pred)) + assert_equal(accuracy_score(y_true, y_pred, normalize=False), + n_samples - zero_one_loss(y_true, y_pred, normalize=False)) + with warnings.catch_warnings(True): # Throw deprecated warning assert_equal(zero_one_score(y_true, y_pred), @@ -616,6 +635,7 @@ def test_symmetry(): # Symmetric metric for metric in [accuracy_score, + lambda y1, y2: accuracy_score(y1, y2, normalize=False), zero_one_loss, lambda y1, y2: zero_one_loss(y1, y2, normalize=False), hamming_loss, @@ -658,20 +678,7 @@ def test_sample_order_invariance(): y_true_shuffle, y_pred_shuffle = shuffle(y_true, y_pred, random_state=0) - for metric in [accuracy_score, - hamming_loss, - zero_one_loss, - lambda y1, y2: zero_one_loss(y1, y2, normalize=False), - precision_score, - recall_score, - f1_score, - lambda y1, y2: fbeta_score(y1, y2, beta=2), - lambda y1, y2: fbeta_score(y1, y2, beta=0.5), - matthews_corrcoef, - mean_absolute_error, - mean_squared_error, - explained_variance_score, - r2_score]: + for metric in ALL_METRICS: assert_almost_equal(metric(y_true, y_pred), metric(y_true_shuffle, y_pred_shuffle), @@ -679,6 +686,88 @@ def test_sample_order_invariance(): % metric) +def test_format_invariance_with_1d_vectors(): + y1, y2, _ = make_prediction(binary=True) + + y1_list = list(y1) + y2_list = list(y2) + + y1_1d, y2_1d = np.array(y1), np.array(y2) + assert_equal(y1_1d.ndim, 1) + assert_equal(y2_1d.ndim, 1) + y1_column = np.reshape(y1_1d, (-1, 1)) + y2_column = np.reshape(y2_1d, (-1, 1)) + y1_row = np.reshape(y1_1d, (1, -1)) + y2_row = np.reshape(y2_1d, (1, -1)) + + for metric in ALL_METRICS: + + measure = metric(y1, y2) + + assert_almost_equal(measure, + metric(y1_list, y2_list), + err_msg="%s is not representation invariant" + "with list" % metric) + + assert_almost_equal(measure, + metric(y1_1d, y2_1d), + err_msg="%s is not representation invariant" + "with np-array-1d" % metric) + + assert_almost_equal(measure, + metric(y1_column, y2_column), + err_msg="%s is not representation invariant " + "with np-array-column" % metric) + + assert_almost_equal(measure, + metric(y1_row, y2_row), + err_msg="%s is not representation invariant " + "with np-array-row" % metric) + + # Mix format support + assert_almost_equal(measure, + metric(y1_1d, y2_list), + err_msg="%s is not representation invariant " + "with mix np-array-1d and list" % metric) + + assert_almost_equal(measure, + metric(y1_list, y2_1d), + err_msg="%s is not representation invariant " + "with mix np-array-1d and list" % metric) + + assert_almost_equal(measure, + metric(y1_1d, y2_column), + err_msg="%s is not representation invariant " + "with mix np-array-1d and np-array-column" + % metric) + + assert_almost_equal(measure, + metric(y1_column, y2_1d), + err_msg="%s is not representation invariant " + "with mix np-array-1d and np-array-column" + % metric) + + assert_almost_equal(measure, + metric(y1_list, y2_column), + err_msg="%s is not representation invariant" + "with mix list and np-array-column" + % metric) + + assert_almost_equal(measure, + metric(y1_column, y2_list), + err_msg="%s is not representation invariant" + "with mix list and np-array-column" + % metric) + + # At the moment, these mix representations aren't allowed + assert_raises(ValueError, metric, y1_1d, y2_row) + assert_raises(ValueError, metric, y1_row, y2_1d) + assert_raises(ValueError, metric, y1_list, y2_row) + assert_raises(ValueError, metric, y1_row, y2_list) + assert_raises(ValueError, metric, y1_column, y2_row) + assert_raises(ValueError, metric, y1_row, y2_column) + + def test_hinge_loss_binary(): y_true = np.array([-1, 1, 1, -1]) pred_decision = np.array([-8.5, 0.5, 1.5, -0.3]) @@ -859,6 +948,14 @@ def test_multilabel_zero_one_loss(): assert_equal(1.0, zero_one_loss(y1, np.zeros(y1.shape))) assert_equal(1.0, zero_one_loss(y2, np.zeros(y1.shape))) + assert_equal(1, zero_one_loss(y1, y2, normalize=False)) + assert_equal(0, zero_one_loss(y1, y1, normalize=False)) + assert_equal(0, zero_one_loss(y2, y2, normalize=False)) + assert_equal(2, zero_one_loss(y2, np.logical_not(y2), normalize=False)) + assert_equal(2, zero_one_loss(y1, np.logical_not(y1), normalize=False)) + assert_equal(2, zero_one_loss(y1, np.zeros(y1.shape), normalize=False)) + assert_equal(2, zero_one_loss(y2, np.zeros(y1.shape), normalize=False)) + # List of tuple of label y1 = [(1, 2,), (0, 2,)] @@ -872,6 +969,12 @@ def test_multilabel_zero_one_loss(): assert_equal(1.0, zero_one_loss(y2, [(), ()])) assert_equal(1.0, zero_one_loss(y2, [tuple(), (10, )])) + assert_equal(1, zero_one_loss(y1, y2, normalize=False)) + assert_equal(0, zero_one_loss(y1, y1, normalize=False)) + assert_equal(0, zero_one_loss(y2, y2, normalize=False)) + assert_equal(2, zero_one_loss(y2, [(), ()], normalize=False)) + assert_equal(2, zero_one_loss(y2, [tuple(), (10, )], normalize=False)) + def test_multilabel_hamming_loss(): # Dense label indicator matrix format @@ -919,6 +1022,14 @@ def test_multilabel_accuracy_score(): assert_equal(0.0, accuracy_score(y1, np.zeros(y1.shape))) assert_equal(0.0, accuracy_score(y2, np.zeros(y1.shape))) + assert_equal(1, accuracy_score(y1, y2, normalize=False)) + assert_equal(2, accuracy_score(y1, y1, normalize=False)) + assert_equal(2, accuracy_score(y2, y2, normalize=False)) + assert_equal(0, accuracy_score(y2, np.logical_not(y2), normalize=False)) + assert_equal(0, accuracy_score(y1, np.logical_not(y1), normalize=False)) + assert_equal(0, accuracy_score(y1, np.zeros(y1.shape), normalize=False)) + assert_equal(0, accuracy_score(y2, np.zeros(y1.shape), normalize=False)) + # List of tuple of label y1 = [(1, 2,), (0, 2,)] @@ -930,3 +1041,8 @@ def test_multilabel_accuracy_score(): assert_equal(1.0, accuracy_score(y1, y1)) assert_equal(1.0, accuracy_score(y2, y2)) assert_equal(0.0, accuracy_score(y2, [(), ()])) + + assert_equal(1, accuracy_score(y1, y2, normalize=False)) + assert_equal(2, accuracy_score(y1, y1, normalize=False)) + assert_equal(2, accuracy_score(y2, y2, normalize=False)) + assert_equal(0, accuracy_score(y2, [(), ()], normalize=False))