8000 [MRG] fixes #4577 adds interpolation to PR curve by chiragnagpal · Pull Request #4936 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] fixes #4577 adds interpolation to PR curve #4936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions examples/model_selection/plot_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@
a precision-recall curve by considering each element of the label indicator
matrix as a binary prediction (micro-averaging).

Increasing the value of threshold over a small range reduces both the recall
and the precision, causing large jitters. Over small ranges, while the
number of True Positives :math:`T_p` decrease, the sum of True Positives and
False Positives :math:`T_p+F_n` may not decrease in the same proportion,
reducing precision. Interpolation is used, to remove this discrepancy and make
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sentence is too long. Cut in pieces to make it clearer.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, you got that! will comply.

the plot smoother. This is done by ensuring, that increasing threshold value,
does not let precision drop with respect to lower thresholds.


.. note::

See also :func:`sklearn.metrics.average_precision_score`,
Expand Down Expand Up @@ -148,3 +157,49 @@
plt.title('Extension of Precision-Recall curve to multi-class')
plt.legend(loc="lower right")
plt.show()
# Interpolated Precision Recall Curve
precision = dict()
recall = dict()
average_precision = dict()
for i in range(n_classes):
precision[i], recall[i], _ = precision_recall_curve(y_test[:, i],
y_score[:, i],
interpolate=True)
average_precision[i] = average_precision_score(y_test[:, i], y_score[:, i])

# Compute micro-average ROC curve and ROC area
precision["micro"], recall["micro"], _ = precision_recall_curve(y_test.ravel(),
y_score.ravel(),
interpolate=True)

average_precision["micro"] = average_precision_score(y_test, y_score,
average="micro")

# Plot Precision-Recall curve
plt.clf()
plt.plot(recall[0], precision[0], label='Precision-Recall curve')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.0])
plt.title('Precision-Recall example: AUC={0:0.2f}'.format(average_precision[0]))
plt.legend(loc="lower left")
plt.show()

# Plot Precision-Recall curve for each class
plt.clf()
plt.plot(recall["micro"], precision["micro"],
label='micro-average Precision-recall curve (area = {0:0.2f})'
''.format(average_precision["micro"]))
for i in range(n_classes):
plt.plot(recall[i], precision[i],
label='Precision-recall curve of class {0} (area = {1:0.2f})'
''.format(i, average_precision[i]))

plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Interpolated multi-class Precision-Recall curve')
plt.legend(loc="lower right")
plt.show()
32 changes: 29 additions & 3 deletions sklearn/metrics/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):


def precision_recall_curve(y_true, probas_pred, pos_label=None,
sample_weight=None):
sample_weight=None, interpolate=False):
"""Compute precision-recall pairs for different probability thresholds

Note: this implementation is restricted to the binary classification task.
Expand Down Expand Up @@ -364,11 +364,14 @@ def precision_recall_curve(y_true, probas_pred, pos_label=None,
Estimated probabilities or decision function.

pos_label : int, optional (default=None)
The label of the positive class
The label of the positive class.

sample_weight : array-like of shape = [n_samples], optional
Sample weights.

interpolate : boolean, optional (default=False)
Interpolates precision score, to de-noise PR curve. Based on [1]

Returns
-------
precision : array, shape = [n_thresholds + 1]
Expand All @@ -383,6 +386,12 @@ def precision_recall_curve(y_true, probas_pred, pos_label=None,
Increasing thresholds on the decision function used to compute
precision and recall.

References
----------
.. [1] Manning, C. D., Raghavan, P., & Schutze, H. (2008).
Introduction to information retrieval (Vol. 1, p. 159).
Cambridge: Cambridge university press.

Examples
--------
>>> import numpy as np
Expand All @@ -399,6 +408,10 @@ def precision_recall_curve(y_true, probas_pred, pos_label=None,
array([ 0.35, 0.4 , 0.8 ])

"""
warnings.warn("The default behaviour of no interpolation is deprecated."
"Interpolation would be default behaviour in 0.18",
DeprecationWarning)

fps, tps, thresholds = _binary_clf_curve(y_true, probas_pred,
pos_label=pos_label,
sample_weight=sample_weight)
Expand All @@ -410,7 +423,20 @@ def precision_recall_curve(y_true, probas_pred, pos_label=None,
# and reverse the outputs so recall is decreasing
last_ind = tps.searchsorted(tps[-1])
sl = slice(last_ind, None, -1)
return np.r_[precision[sl], 1], np.r_[recall[sl], 0], thresholds[sl]

if interpolate:
prec = np.r_[precision[sl], 1]
p_temp = prec[0]
n = len(prec)
for i in range(n):
if prec[i] < p_temp:
prec[i] = p_temp
else:
p_temp = prec[i]
return prec, np.r_[recall[sl], 0], thresholds[sl]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't you use the interpolate module from scipy?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This interpolation refers to the issue #4577. The interpolation logic has been taken from http://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-ranked-retrieval-results-1.html

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because that does something differently, right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scipy.interpolate function is used to perform curve fitting on a set of data points. Here, we don't need to perform curve fitting. all we need is to ensure that the value of precision does not fall, for decreasing values of recall.


else:
return np.r_[precision[sl], 1], np.r_[recall[sl], 0], thresholds[sl]


def roc_curve(y_true, y_score, pos_label=None, sample_weight=None):
Expand Down
11 changes: 11 additions & 0 deletions sklearn/metrics/tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,17 @@ def test_precision_recall_curve():
assert_equal(p.size, t.size + 1)


def test_precision_recall_interpolate():
labels = [1, 0, 0, 1]
predict_probas = [1, 2, 3, 4]
p, r, t = precision_recall_curve(labels, predict_probas, interpolate=True)
assert_array_almost_equal(p, np.array([0.5, 0.5, 0.5, 1., 1.]))
assert_array_almost_equal(r, np.array([1., 0.5, 0.5, 0.5, 0.]))
assert_array_almost_equal(t, np.array([1, 2, 3, 4]))
assert_equal(p.size, r.size)
assert_equal(p.size, t.size + 1)


def test_precision_recall_curve_pos_label():
y_true, _, probas_pred = make_prediction(binary=False)
pos_label = 2
Expand Down
0