-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
FIX Ignore zero sample weights in precision recall curve #18328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2374879
3bca885
89d1858
7d3e151
6974713
c336a40
db507b7
3f0eb2d
31418e8
b1f7339
9f89984
f42cbef
e4aabb5
c1cddd6
3d49013
1a6c65f
b06c816
ed12766
cc71416
35866c9
303b375
1d31472
44264e6
d41aba3
b3075a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
|
||
from ..utils import assert_all_finite | ||
from ..utils import check_consistent_length | ||
from ..utils.validation import _check_sample_weight | ||
from ..utils import column_or_1d, check_array | ||
from ..utils.multiclass import type_of_target | ||
from ..utils.extmath import stable_cumsum | ||
|
@@ -291,14 +292,14 @@ def det_curve(y_true, y_score, pos_label=None, sample_weight=None): | |
>>> thresholds | ||
array([0.35, 0.4 , 0.8 ]) | ||
""" | ||
if len(np.unique(y_true)) != 2: | ||
raise ValueError("Only one class present in y_true. Detection error " | ||
"tradeoff curve is not defined in that case.") | ||
|
||
fps, tps, thresholds = _binary_clf_curve( | ||
y_true, y_score, pos_label=pos_label, sample_weight=sample_weight | ||
) | ||
|
||
if len(np.unique(y_true)) != 2: | ||
raise ValueError("Only one class present in y_true. Detection error " | ||
"tradeoff curve is not defined in that case.") | ||
|
||
fns = tps[-1] - tps | ||
p_count = tps[-1] | ||
n_count = fps[-1] | ||
|
@@ -696,8 +697,14 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): | |
assert_all_finite(y_true) | ||
assert_all_finite(y_score) | ||
|
||
# Filter out zero-weighted samples, as they should not impact the result | ||
if sample_weight is not None: | ||
sample_weight = column_or_1d(sample_weight) | ||
sample_weight = _check_sample_weight(sample_weight, y_true) | ||
nonzero_weight_mask = sample_weight != 0 | ||
y_true = y_true[nonzero_weight_mask] | ||
y_score = y_score[nonzero_weight_mask] | ||
sample_weight = sample_weight[nonzero_weight_mask] | ||
Comment on lines
+703
to
+707
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it not be better to instead of filtering out zero sample weights, doing the computation of scores in a way that zero sample weights are not taken into account? This proposed change results in a difference between There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For whatever reasons you have some zeros in
Could you elaborate? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry @adrinjalali I do not understand your point either... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let say we want a weighted average of a bunch of numbers, one way is first to filter out data with zero weight, the other is to have sth like: Now in this case, there's an issue with having samples with zero weight in the data. My question is if we just filter out the zero weights, then what happens to the sample with weight equal to 1e-32 for instance? That's practically zero, and should be treated [almost] the same as the ones with zero sample weight. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AFAIK, Suggestion: Should we just add a test with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we reach a consensus? I would like to get this one closed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I played around a little bit, I'm happy with this solution as is :) |
||
|
||
pos_label = _check_pos_la 9E81 bel_consistency(pos_label, y_true) | ||
|
||
|
@@ -759,7 +766,9 @@ def precision_recall_curve(y_true, probas_pred, *, pos_label=None, | |
pos_label should be explicitly given. | ||
|
||
probas_pred : ndarray of shape (n_samples,) | ||
Estimated probabilities or output of a decision function. | ||
Target scores, can either be probability estimates of the positive | ||
lorentzenchr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class, or non-thresholded measure of decisions (as returned by | ||
`decision_function` on some classifiers). | ||
rth marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
pos_label : int or str, default=None | ||
The label of the positive class. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this change? I would imagine it's better to check for it before computing the the
_binary_clf_curve
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As it was, if you pass a multiclass
y_true
, it would rise "Only one class present..." ValueError.With the reorder, a multiclass ValueError is raised by
_binary_clf_curve
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this needs to be in whats_new since it's changing a behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@adrinjalali are you sure?
When passing a multiclass
y_true
, before and now a ValueError is raised; that does not change. The only change is the error message: