-
-
Notifications
You must be signed in to change notification settings - Fork 26k
FIX TunedThresholdClassifierCV error or warn with informative message on invalid metrics #29082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5817ee9
0acb511
a1b5518
a0d62c5
a874855
4f82762
505baff
b6b13d3
942d58a
964a079
c8e811d
5a464e4
942221d
397ee8d
553a1b8
2b4927a
a37edcd
aadae00
856d075
f4c43e3
351a523
0af3d32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from collections.abc import MutableMapping | ||
from numbers import Integral, Real | ||
from warnings import warn | ||
|
||
import numpy as np | ||
|
||
|
@@ -656,7 +657,17 @@ class TunedThresholdClassifierCV(BaseThresholdClassifier): | |
|
||
* a string associated to a scoring function for binary classification | ||
(see :ref:`scoring_parameter`); | ||
* a scorer callable object created with :func:`~sklearn.metrics.make_scorer`; | ||
* a scorer callable object created with :func:`~sklearn.metrics.make_scorer`. | ||
|
||
Note that scoring objective should introduce a trade-off between false | ||
negatives and false positives, otherwise the tuned threshold would be | ||
trivial and the resulting classifier would be equivalent to constantly | ||
classifiying one of the two possible classes. This would be the case | ||
when passing scoring="precision" or scoring="recall" for instance. | ||
Furthermore, the scoring objective should evaluate thresholded | ||
classifier predictions: as a result, metrics such as ROC AUC, Average | ||
Precision, log loss or the Brier score are not valid scoring metrics in | ||
this context. | ||
|
||
response_method : {"auto", "decision_function", "predict_proba"}, default="auto" | ||
Methods by the classifier `estimator` corresponding to the | ||
|
@@ -947,6 +958,35 @@ def _fit(self, X, y, **params): | |
best_idx = objective_scores.argmax() | ||
self.best_score_ = objective_scores[best_idx] | ||
self.best_threshold_ = decision_thresholds[best_idx] | ||
|
||
if self.best_threshold_ == min_threshold: | ||
trivial_kind = "positive" | ||
elif self.best_threshold_ == max_threshold: | ||
trivial_kind = "negative" | ||
else: | ||
trivial_kind = None | ||
|
||
if ( | ||
objective_scores.max() - objective_scores.min() | ||
<= np.finfo(objective_scores.dtype).eps | ||
): | ||
warn( | ||
f"The objective metric {self.scoring!r} is constant at " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good idea. Might help some users. Also I like that it is a warning not an error. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I updated the PR in a0d62c5 to keep the 0.5 threshold in that case. Using a extreme near-zero threshold would introduce a very weird / unexpectedly biased behavior. Better keep a more neutral behavior in such a pathological situation.
Comment on lines
+973
to
+974
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still like this warning. |
||
f"{self.best_score_} across all thresholds. Falling back " | ||
"to the default 0.5 threshold. Please instead pass a scoring " | ||
"metric that varies with the decision threshold.", | ||
UserWarning, | ||
) | ||
self.best_threshold_ = 0.5 | ||
elif trivial_kind is not None: | ||
warn( | ||
f"Tuning the decision threshold on {self.scoring} " | ||
"leads to a trivial classifier that classifies all samples as " | ||
f"the {trivial_kind} class. Consider revising the scoring parameter " | ||
"to include a trade-off between false positives and false negatives.", | ||
UserWarning, | ||
) | ||
Comment on lines
+981
to
+988
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we have a similar warning when the found solution is using a parameter form the edges of the bound. Could they all be in the same utility function? It would make the messages more consistent. Could also be a separate PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree we would need a similar features for hparam search but I think both the code and the messages deserves to be specialized. For the case of hparam tuning, we need to take the bounds of parameter validation into account (e.g. some hparams are naturally bounded, e.g. For the case of threshold tuning I think it's important to mention the positive / negative classifications as I do in the 2 custom warnings of this PR to get a more explicit and actionable error message. The message for hparam tuning would be more generic. |
||
|
||
if self.store_cv_results: | ||
self.cv_results_ = { | ||
"thresholds": decision_thresholds, | ||
|
@@ -1012,8 +1052,24 @@ def get_metadata_routing(self): | |
|
||
def _get_curve_scorer(self): | ||
"""Get the curve scorer based on the objective metric used.""" | ||
scoring = check_scoring(self.estimator, scoring=self.scoring) | ||
scorer = check_scoring(self.estimator, scoring=self.scoring) | ||
# XXX: at the time of writing, there is no very explicit way to check | ||
# if a scorer expects thresholded binary classification predictions. | ||
# TODO: update this condition when a better way is available. | ||
scorer_response_methods = getattr(scorer, "_response_method", "predict") | ||
if isinstance(scorer_response_methods, str): | ||
scorer_response_methods = {scorer_response_methods} | ||
else: | ||
scorer_response_methods = set(scorer_response_methods) | ||
|
||
if scorer_response_methods.issubset({"predict_proba", "decision_function"}): | ||
raise ValueError( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, I’m not so sure whether we overly constrain possible scorers from the user. But I also do not 100% follow how the curve scorer works. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we don't do that (as is the case in main) and the users pass a non thresholded metric like roc auc/average precision/log loss/brier, the metric is evaluated on the thresholded (binary) predictions which is really misleading. You still get a 'tuned' threshold but its meaning is really confusing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that we can first restrain and be conservative. If a real use case is reported then we can then rework and make sure that the API is right and not just working as a side effect. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like that here we're accessing private There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I Edit: That wasn't very helpful. I think I got something wrong. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me know if the following helps address those concerns: https://github.com/scikit-learn/scikit-learn/pull/29082/files#r1613325997 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is still the most controversial part. Without it, this PR would be almost merged, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you think of #29082 (comment) ? |
||
f"{self.__class__.__name__} expects a scoring metric that evaluates " | ||
f"the thresholded predictions of a binary classifier, got: " | ||
f"{self.scoring!r} which expects unthresholded predictions computed by " | ||
f"the {scorer._response_method!r} method(s) of the classifier." | ||
) | ||
curve_scorer = _CurveScorer.from_scorer( | ||
scoring, self._get_response_method(), self.thresholds | ||
scorer, self._get_response_method(), self.thresholds | ||
) | ||
return curve_scorer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this is something we should be doing everywhere where we "search" over a few models using a scorer. Kind of arbitrary to have it here and not in other places.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also be in favor of raising a similar warning for
*SearchCV
meta-estimators whenmean_test_score
is constant for all hyper-parameters. Not sure how frequent this is though.But it's true that it's actually not that frequent for
TunedThresholdClassifierCV
either.The original problem I encountered that triggered the selection of an extreme threshold is actually of a different nature. Let me update this PR accordingly to discuss that further (maybe tomorrow).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in #29082 (comment).
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that checking for the constant case is not really important but I do not expect this to happen often in practice. The other cases (extreme thresholds due to lack of trade-off) are more useful to warn against (and they do easily happen in practice as shown in the tests).
But I would rather keep the constant warning to make the warning message more precise for this particular edge case but also to fallback to the neutral 0.5 and keep the estimator behavior symmetric.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this to a utility function which takes a bunch of scores and warns, like
_warn_on_constant_metrics
and call it in a few places where it's relevant? (A single PR for this would be nice, which would include this usecase)