8000 FIX TunedThresholdClassifierCV error or warn with informative message on invalid metrics by ogrisel · Pull Request #29082 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FIX TunedThresholdClassifierCV error or warn with informative message on invalid metrics #29082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
5817ee9
FIX TunedThresholdClassifierCV error or warn with informative message…
ogrisel May 22, 2024
0acb511
Update sklearn/model_selection/tests/test_classification_threshold.py
ogrisel May 23, 2024
a1b5518
Update sklearn/model_selection/tests/test_classification_threshold.py
ogrisel May 23, 2024
a0d62c5
Use 0.5 threshold in case of constant scores.
ogrisel May 23, 2024
a874855
Typo
ogrisel May 23, 2024
4f82762
Apply suggestions from code review
ogrisel May 23, 2024
505baff
Improve & fix tests
ogrisel May 23, 2024
b6b13d3
Improve test docstring.
ogrisel May 23, 2024
942d58a
Merge branch 'main' into fix-tuned-threshold-on-invalid-metrics
ogrisel May 23, 2024
964a079
Add changelog entry for 1.5.1
ogrisel May 23, 2024
c8e811d
Linter fix
ogrisel May 23, 2024
5a464e4
Add TODO comment to make it clear that the code would be clearer with…
ogrisel May 24, 2024
942221d
Add a dedicated warning to guide the user into crafting non-trivial s…
ogrisel May 24, 2024
397ee8d
Improve test
ogrisel May 24, 2024
553a1b8
Update changelog
ogrisel May 24, 2024
2b4927a
DOC improve the docstring for the scoring parameter of TunedThreshold…
ogrisel May 24, 2024
a37edcd
DOC more precise phrasing in scoring docstring
lorentzenchr May 24, 2024
aadae00
Merge branch 'main' into fix-tuned-threshold-on-invalid-metrics
ogrisel May 24, 2024
856d075
Grammar fix in comment
ogrisel May 28, 2024
f4c43e3
Trim last sentence of paragraph to make docstring a bit shorter.
ogrisel Jun 4, 2024
351a523
Avoid failing if the private `_response_method` attribute does not exist
ogrisel Jun 4, 2024
0af3d32
Merge branch 'main' into fix-tuned-threshold-on-invalid-metrics
ogrisel Jun 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ Changelog
grids that have heterogeneous parameter values.
:pr:`29078` by :user:`Loïc Estève <lesteve>`.

- |Fix| Fix :class:`model_selection.TunedThresholdClassifierCV` to raise
`ValueError` when passed a `scoring` argument intended for unthresholded
predictions. It now also raises warnings for different choice of `scoring`
that lead to degenerate choice of thresholds.
:pr:`29082` by :user:`Olivier Grisel <ogrisel>`.

.. _changes_1_5:

Expand Down
62 changes: 59 additions & 3 deletions sklearn/model_selection/_classification_threshold.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import MutableMapping
from numbers import Integral, Real
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -656,7 +657,17 @@ class TunedThresholdClassifierCV(BaseThresholdClassifier):

* a string associated to a scoring function for binary classification
(see :ref:`scoring_parameter`);
* a scorer callable object created with :func:`~sklearn.metrics.make_scorer`;
* a scorer callable object created with :func:`~sklearn.metrics.make_scorer`.

Note that scoring objective should introduce a trade-off between false
negatives and false positives, otherwise the tuned threshold would be
trivial and the resulting classifier would be equivalent to constantly
classifiying one of the two possible classes. This would be the case
when passing scoring="precision" or scoring="recall" for instance.
Furthermore, the scoring objective should evaluate thresholded
classifier predictions: as a result, metrics such as ROC AUC, Average
Precision, log loss or the Brier score are not valid scoring metrics in
this context.

response_method : {"auto", "decision_function", "predict_proba"}, default="auto"
Methods by the classifier `estimator` corresponding to the
Expand Down Expand Up @@ -947,6 +958,35 @@ def _fit(self, X, y, **params):
best_idx = objective_scores.argmax()
self.best_score_ = objective_scores[best_idx]
self.best_threshold_ = decision_thresholds[best_idx]

if self.best_threshold_ == min_threshold:
trivial_kind = "positive"
elif self.best_threshold_ == max_threshold:
trivial_kind = "negative"
else:
trivial_kind = None

if (
objective_scores.max() - objective_scores.min()
<= np.finfo(objective_scores.dtype).eps
Comment on lines +970 to +971
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this is something we should be doing everywhere where we "search" over a few models using a scorer. Kind of arbitrary to have it here and not in other places.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also be in favor of raising a similar warning for *SearchCV meta-estimators when mean_test_score is constant for all hyper-parameters. Not sure how frequent this is though.

But it's true that it's actually not that frequent for TunedThresholdClassifierCV either.

The original problem I encountered that triggered the selection of an extreme threshold is actually of a different nature. Let me update this PR accordingly to discuss that further (maybe tomorrow).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in #29082 (comment).

Copy link
Member Author
@ogrisel ogrisel May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that checking for the constant case is not really important but I do not expect this to happen often in practice. The other cases (extreme thresholds due to lack of trade-off) are more useful to warn against (and they do easily happen in practice as shown in the tests).

But I would rather keep the constant warning to make the warning message more precise for this particular edge case but also to fallback to the neutral 0.5 and keep the estimator behavior symmetric.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this to a utility function which takes a bunch of scores and warns, like _warn_on_constant_metrics and call it in a few places where it's relevant? (A single PR for this would be nice, which would include this usecase)

):
warn(
f"The objective metric {self.scoring!r} is constant at "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good idea. Might help some users. Also I like that it is a warning not an error.

Copy link
Member Author
@ogrisel ogrisel May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the PR in a0d62c5 to keep the 0.5 threshold in that case. Using a extreme near-zero threshold would introduce a very weird / unexpectedly biased behavior. Better keep a more neutral behavior in such a pathological situation.

Comment on lines +973 to +974
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still like this warning.

f"{self.best_score_} across all thresholds. Falling back "
"to the default 0.5 threshold. Please instead pass a scoring "
"metric that varies with the decision threshold.",
UserWarning,
)
self.best_threshold_ = 0.5
elif trivial_kind is not None:
warn(
f"Tuning the decision threshold on {self.scoring} "
"leads to a trivial classifier that classifies all samples as "
f"the {trivial_kind} class. Consider revising the scoring parameter "
"to include a trade-off between false positives and false negatives.",
UserWarning,
)
Comment on lines +981 to +988
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have a similar warning when the found solution is using a parameter form the edges of the bound. Could they all be in the same utility function? It would make the messages more consistent. Could also be a separate PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we would need a similar features for hparam search but I think both the code and the messages deserves to be specialized.

For the case of hparam tuning, we need to take the bounds of parameter validation into account (e.g. some hparams are naturally bounded, e.g. alpha=0 or l1_ratio=1.0 so we should not raise a warning if we reach those bounds).

For the case of threshold tuning I think it's important to mention the positive / negative classifications as I do in the 2 custom warnings of this PR to get a more explicit and actionable error message. The message for hparam tuning would be more generic.


if self.store_cv_results:
self.cv_results_ = {
"thresholds": decision_thresholds,
Expand Down Expand Up @@ -1012,8 +1052,24 @@ def get_metadata_routing(self):

def _get_curve_scorer(self):
"""Get the curve scorer based on the objective metric used."""
scoring = check_scoring(self.estimator, scoring=self.scoring)
scorer = check_scoring(self.estimator, scoring=self.scoring)
# XXX: at the time of writing, there is no very explicit way to check
# if a scorer expects thresholded binary classification predictions.
# TODO: update this condition when a better way is available.
scorer_response_methods = getattr(scorer, "_response_method", "predict")
if isinstance(scorer_response_methods, str):
scorer_response_methods = {scorer_response_methods}
else:
scorer_response_methods = set(scorer_response_methods)

if scorer_response_methods.issubset({"predict_proba", "decision_function"}):
raise ValueError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I’m not so sure whether we overly constrain possible scorers from the user. But I also do not 100% follow how the curve scorer works.

Copy link
Member Author
@ogrisel ogrisel May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't do that (as is the case in main) and the users pass a non thresholded metric like roc auc/average precision/log loss/brier, the metric is evaluated on the thresholded (binary) predictions which is really misleading. You still get a 'tuned' threshold but its meaning is really confusing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we can first restrain and be conservative. If a real use case is reported then we can then rework and make sure that the API is right and not just working as a side effect.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like that here we're accessing private _response_method as a proxy to see if it's the right kind of scorer or not. Scorers should have a public API for this (cc @StefanieSenger )

Copy link
Contributor
@StefanieSenger StefanieSenger May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that since response_method="predict" is the default in _BaseScorer, this error might not reach the users that might be most vulnerable to be confused by wrong outputs. Is there a more computed way to check if the scorer fits the purpose that does not depend on user input?

Edit: That wasn't very helpful. I think I got something wrong.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if the following helps address those concerns:

https://github.com/scikit-learn/scikit-learn/pull/29082/files#r1613325997

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still the most controversial part. Without it, this PR would be almost merged, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of #29082 (comment) ?

f"{self.__class__.__name__} expects a scoring metric that evaluates "
f"the thresholded predictions of a binary classifier, got: "
f"{self.scoring!r} which expects unthresholded predictions computed by "
f"the {scorer._response_method!r} method(s) of the classifier."
)
curve_scorer = _CurveScorer.from_scorer(
scoring, self._get_response_method(), self.thresholds
scorer, self._get_response_method(), self.thresholds
)
return curve_scorer
93 changes: 93 additions & 0 deletions sklearn/model_selection/tests/test_classification_threshold.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import numpy as np
import pytest

Expand All @@ -14,6 +16,7 @@
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
balanced_accuracy_score,
confusion_matrix,
f1_score,
fbeta_score,
make_scorer,
Expand Down Expand Up @@ -684,6 +687,96 @@ def test_fixed_threshold_classifier_metadata_routing():
assert_allclose(classifier_default_threshold.estimator_.coef_, classifier.coef_)


@pytest.mark.parametrize(
"scoring_name, expected_method_names",
[
("roc_auc", "('decision_function', 'predict_proba')"),
("average_precision", "('decision_function', 'predict_proba')"),
("neg_log_loss", "'predict_proba'"),
],
)
def test_error_on_unthresholded_classification_metrics(
scoring_name, expected_method_names
):
"""Check error raised with metrics meant for unthresholded predictions."""
X, y = make_classification(random_state=0)
estimator = LogisticRegression()
err_msg = re.escape(
"TunedThresholdClassifierCV expects a scoring metric that evaluates the "
f"thresholded predictions of a binary classifier, got: '{scoring_name}' "
"which expects unthresholded predictions computed by the "
f"{expected_method_names} method(s) of the classifier."
)
with pytest.raises(ValueError, match=err_msg):
TunedThresholdClassifierCV(estimator, scoring=scoring_name).fit(X, y)


def test_warn_on_constant_scores():
"""Check that a warning is raised when the score is constant."""
X, y = make_classification(random_state=0)
estimator = LogisticRegression()

def constant_score_func(y_true, y_pred):
return 1.0

scorer = make_scorer(constant_score_func, response_method="predict")

warn_msg = re.escape(
"The objective metric make_scorer(constant_score_func, "
"response_method='predict') is constant at 1.0 across all thresholds. Falling "
"back to the default 0.5 threshold. Please instead pass a scoring metric that "
"varies with the decision threshold."
)
with pytest.warns(UserWarning, match=warn_msg):
tuned_clf = TunedThresholdClassifierCV(
estimator, scoring=scorer, store_cv_results=True
).fit(X, y)
assert_allclose(tuned_clf.cv_results_["scores"], np.ones(shape=100))
assert tuned_clf.best_threshold_ == pytest.approx(0.5)


def always_prefer_positive_class(y_observed, y_pred):
tn, fp, fn, tp = confusion_matrix(y_observed, y_pred, normalize="all").ravel()
return tp - 2 * fn


def always_prefer_negative_class(y_observed, y_pred):
tn, fp, fn, tp = confusion_matrix(y_observed, y_pred, normalize="all").ravel()
return tn - 2 * fp


@pytest.mark.parametrize(
"scoring, kind",
[
(make_scorer(always_prefer_positive_class), "positive"),
(make_scorer(always_prefer_negative_class), "negative"),
("precision", "negative"),
("recall", "positive"),
],
)
def test_warn_on_trivial_thresholds(scoring, kind):
"""Check that a warning is raised when the score is constant."""
X, y = make_classification(random_state=0)
estimator = LogisticRegression()

warn_msg = re.escape(
f"Tuning the decision threshold on {scoring} leads to a trivial classifier "
f"that classifies all samples as the {kind} class. Consider revising the "
"scoring parameter to include a trade-off between false positives and false "
"negatives."
)
with pytest.warns(UserWarning, match=warn_msg):
tuned_clf = TunedThresholdClassifierCV(
estimator, scoring=scoring, store_cv_results=True
).fit(X, y)

threshods = tuned_clf.cv_results_["thresholds"]
if kind == "positive":
assert tuned_clf.best_threshold_ == threshods[0] == threshods.min()
else:
assert tuned_clf.best_threshold_ == threshods[-1] == threshods.max()


class ClassifierLoggingFit(ClassifierMixin, BaseEstimator):
"""Classifier that logs the number of `fit` calls."""

Expand Down
Loading
0