8000 TST add unit tests for current _get_response (#21041) · scikit-learn/scikit-learn@eda0473 · GitHub
[go: up one dir, main page]

Skip to content

Commit eda0473

Browse files
glemaitreogrisel
andcommitted
TST add unit tests for current _get_response (#21041)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 8f903ca commit eda0473

File tree

2 files changed

+89
-20
lines changed

2 files changed

+89
-20
lines changed

sklearn/metrics/_plot/base.py

+14-20
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import numpy as np
2-
31
from ...base import is_classifier
42

53

@@ -91,14 +89,18 @@ def _get_response(X, estimator, response_method, pos_label=None):
9189
raise ValueError(classification_error)
9290

9391
prediction_method = _check_classifier_response_method(estimator, response_method)
94-
9592
y_pred = prediction_method(X)
96-
97-
if pos_label is not None and pos_label not in estimator.classes_:
98-
raise ValueError(
99-
"The class provided by 'pos_label' is unknown. Got "
100-
f"{pos_label} instead of one of {estimator.classes_}"
101-
)
93+
if pos_label is not None:
94+
try:
95+
class_idx = estimator.classes_.tolist().index(pos_label)
96+
except ValueError as e:
97+
raise ValueError(
98+
"The class provided by 'pos_label' is unknown. Got "
99+
f"{pos_label} instead of one of {set(estimator.classes_)}"
100+
) from e
101+
else:
102+
class_idx = 1
103+
pos_label = estimator.classes_[class_idx]
102104

103105
if y_pred.ndim != 1: # `predict_proba`
104106
y_pred_shape = y_pred.shape[1]
@@ -107,16 +109,8 @@ def _get_response(X, estimator, response_method, pos_label=None):
107109
f"{classification_error} fit on multiclass ({y_pred_shape} classes)"
108110
" data"
109111
)
110-
if pos_label is None:
111-
pos_label = estimator.classes_[1]
112-
y_pred = y_pred[:, 1]
113-
else:
114-
class_idx = np.flatnonzero(estimator.classes_ == pos_label)
115-
y_pred = y_pred[:, class_idx]
116-
else:
117-
if pos_label is None:
118-
pos_label = estimator.classes_[1]
119-
elif pos_label == estimator.classes_[0]:
120-
y_pred *= -1
112+
y_pred = y_pred[:, class_idx]
113+
elif pos_label == estimator.classes_[0]: # `decision_function`
114+
y_pred *= -1
121115

122116
return y_pred, pos_label
+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import numpy as np
2+
import pytest
3+
4+
from sklearn.datasets import load_iris
5+
from sklearn.linear_model import LogisticRegression
6+
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
7+
8+
from sklearn.metrics._plot.base import _get_response
9+
10+
11+
@pytest.mark.parametrize(
12+
"estimator, err_msg, params",
13+
[
14+
(
15+
DecisionTreeRegressor(),
16+
"Expected 'estimator' to be a binary classifier",
17+
{"response_method": "auto"},
18+
),
19+
(
20+
DecisionTreeClassifier(),
21+
"The class provided by 'pos_label' is unknown.",
22+
{"response_method": "auto", "pos_label": "unknown"},
23+
),
24+
(
25+
DecisionTreeClassifier(),
26+
"fit on multiclass",
27+
{"response_method": "predict_proba"},
28+
),
29+
],
30+
)
31+
def test_get_response_error(estimator, err_msg, params):
32+
"""Check that we raise the proper error messages in `_get_response`."""
33+
X, y = load_iris(return_X_y=True)
34+
35+
estimator.fit(X, y)
36+
with pytest.raises(ValueError, match=err_msg):
37+
_get_response(X, estimator, **params)
38+
39+
40+
def test_get_response_predict_proba():
41+
"""Check the behaviour of `_get_response` using `predict_proba`."""
42+
X, y = load_iris(return_X_y=True)
43+
X_binary, y_binary = X[:100], y[:100]
44+
45+
classifier = DecisionTreeClassifier().fit(X_binary, y_binary)
46+
y_proba, pos_label = _get_response(
47+
X_binary, classifier, response_method="predict_proba"
48+
)
49+
np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 1])
50+
assert pos_label == 1
51+
52+
y_proba, pos_label = _get_response(
53+
X_binary, classifier, response_method="predict_proba", pos_label=0
54+
)
55+
np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 0])
56+
assert pos_label == 0
57+
58+
59+
def test_get_response_decision_function():
60+
"""Check the behaviour of `get_response` using `decision_function`."""
61+
X, y = load_iris(return_X_y=True)
62+
X_binary, y_binary = X[:100], y[:100]
63+
64+
classifier = LogisticRegression().fit(X_binary, y_binary)
65+
y_score, pos_label = _get_response(
66+
X_binary, classifier, response_method="decision_function"
67+
)
68+
np.testing.assert_allclose(y_score, classifier.decision_function(X_binary))
69+
assert pos_label == 1
70+
71+
y_score, pos_label = _get_response(
72+
X_binary, classifier, response_method="decision_function", pos_label=0
73+
)
74+
np.testing.assert_allclose(y_score, classifier.decision_function(X_binary) * -1)
75+
assert pos_label == 0

0 commit comments

Comments
 (0)
0