-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG] fixes #4577 adds interpolation to PR curve #4936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f9af288
8e281b2
b12ccfb
7b47627
8c8c1cd
a6227b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -335,7 +335,7 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): | |
|
||
|
||
def precision_recall_curve(y_true, probas_pred, pos_label=None, | ||
sample_weight=None): | ||
sample_weight=None, interpolate=False): | ||
"""Compute precision-recall pairs for different probability thresholds | ||
|
||
Note: this implementation is restricted to the binary classification task. | ||
|
@@ -364,11 +364,14 @@ def precision_recall_curve(y_true, probas_pred, pos_label=None, | |
Estimated probabilities or decision function. | ||
|
||
pos_label : int, optional (default=None) | ||
The label of the positive class | ||
The label of the positive class. | ||
|
||
sample_weight : array-like of shape = [n_samples], optional | ||
Sample weights. | ||
|
||
interpolate : boolean, optional (default=False) | ||
Interpolates precision score, to de-noise PR curve. Based on [1] | ||
|
||
Returns | ||
------- | ||
precision : array, shape = [n_thresholds + 1] | ||
|
@@ -383,6 +386,12 @@ def precision_recall_curve(y_true, probas_pred, pos_label=None, | |
Increasing thresholds on the decision function used to compute | ||
precision and recall. | ||
|
||
References | ||
---------- | ||
.. [1] Manning, C. D., Raghavan, P., & Schutze, H. (2008). | ||
Introduction to information retrieval (Vol. 1, p. 159). | ||
Cambridge: Cambridge university press. | ||
|
||
Examples | ||
-------- | ||
>>> import numpy as np | ||
|
@@ -399,6 +408,10 @@ def precision_recall_curve(y_true, probas_pred, pos_label=None, | |
array([ 0.35, 0.4 , 0.8 ]) | ||
|
||
""" | ||
warnings.warn("The default behaviour of no interpolation is deprecated." | ||
"Interpolation would be default behaviour in 0.18", | ||
DeprecationWarning) | ||
|
||
fps, tps, thresholds = _binary_clf_curve(y_true, probas_pred, | ||
pos_label=pos_label, | ||
sample_weight=sample_weight) | ||
|
@@ -410,7 +423,20 @@ def precision_recall_curve(y_true, probas_pred, pos_label=None, | |
# and reverse the outputs so recall is decreasing | ||
last_ind = tps.searchsorted(tps[-1]) | ||
sl = slice(last_ind, None, -1) | ||
return np.r_[precision[sl], 1], np.r_[recall[sl], 0], thresholds[sl] | ||
|
||
if interpolate: | ||
prec = np.r_[precision[sl], 1] | ||
p_temp = prec[0] | ||
n = len(prec) | ||
for i in range(n): | ||
if prec[i] < p_temp: | ||
prec[i] = p_temp | ||
else: | ||
p_temp = prec[i] | ||
return prec, np.r_[recall[sl], 0], thresholds[sl] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why don't you use the interpolate module from scipy? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This interpolation refers to the issue #4577. The interpolation logic has been taken from http://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-ranked-retrieval-results-1.html There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure. This is ok. I am just asking why not using: ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. because that does something differently, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. scipy.interpolate function is used to perform curve fitting on a set of data points. Here, we don't need to perform curve fitting. all we need is to ensure that the value of precision does not fall, for decreasing values of recall. |
||
|
||
else: | ||
return np.r_[precision[sl], 1], np.r_[recall[sl], 0], thresholds[sl] | ||
|
||
|
||
def roc_curve(y_true, y_score, pos_label=None, sample_weight=None): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sentence is too long. Cut in pieces to make it clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup, you got that! will comply.