8000 FIX handle outlier detector in _get_response_values (#27565) · jeremiedbb/scikit-learn@8912619 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8912619

Browse files
authored
FIX handle outlier detector in _get_response_values (scikit-learn#27565)
1 parent 5444030 commit 8912619

File tree

3 files changed

+71
-4
lines changed

3 files changed

+71
-4
lines changed

sklearn/inspection/_plot/tests/test_boundary_decision_display.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
make_classification,
1111
make_multilabel_classification,
1212
)
13+
from sklearn.ensemble import IsolationForest
1314
from sklearn.inspection import DecisionBoundaryDisplay
1415
from sklearn.inspection._plot.decision_boundary import _check_boundary_response_method
1516
from sklearn.linear_model import LogisticRegression
@@ -240,6 +241,39 @@ def test_decision_boundary_display_classifier(
240241
assert disp.figure_ == fig2
241242

242243

244+
@pytest.mark.parametrize("response_method", ["auto", "predict", "decision_function"])
245+
@pytest.mark.parametrize("plot_method", ["contourf", "contour"])
246+
def test_decision_boundary_display_outlier_detector(
247+
pyplot, response_method, plot_method
248+
):
249+
"""Check that decision boundary is correct for outlier detector."""
250+
fig, ax = pyplot.subplots()
251+
eps = 2.0
252+
outlier_detector = IsolationForest(random_state=0).fit(X, y)
253+
disp = DecisionBoundaryDisplay.from_estimator(
254+
outlier_detector,
255+
X,
256+
grid_resolution=5,
257+
response_method=response_method,
258+
plot_method=plot_method,
259+
eps=eps,
260+
ax=ax,
261+
)
262+
assert isinstance(disp.surface_, pyplot.matplotlib.contour.QuadContourSet)
263+
assert disp.ax_ == ax
264+
assert disp.figure_ == fig
265+
266+
x0, x1 = X[:, 0], X[:, 1]
267+
268+
x0_min, x0_max = x0.min() - eps, x0.max() + eps
269+
x1_min, x1_max = x1.min() - eps, x1.max() + eps
270+
271+
assert disp.xx0.min() == pytest.approx(x0_min)
272+
assert disp.xx0.max() == pytest.approx(x0_max)
273+
assert disp.xx1.min() == pytest.approx(x1_min)
274+
assert disp.xx1.max() == pytest.approx(x1_max)
275+
276+
243277
@pytest.mark.parametrize("response_method", ["auto", "predict"])
244278
@pytest.mark.parametrize("plot_method", ["contourf", "contour"])
245279
def test_decision_boundary_display_regressor(pyplot, response_method, plot_method):

sklearn/utils/_response.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,14 @@ def _get_response_values(
116116
pos_label=None,
117117
return_response_method_used=False,
118118
):
119-
"""Compute the response values of a classifier or a regressor.
119+
"""Compute the response values of a classifier, an outlier detector, or a regressor.
120120
121121
The response values are predictions such that it follows the following shape:
122122
123123
- for binary classification, it is a 1d array of shape `(n_samples,)`;
124124
- for multiclass classification, it is a 2d array of shape `(n_samples, n_classes)`;
125125
- for multilabel classification, it is a 2d array of shape `(n_samples, n_outputs)`;
126+
- for outlier detection, it is a 1d array of shape `(n_samples,)`;
126127
- for regression, it is a 1d array of shape `(n_samples,)`.
127128
128129
If `estimator` is a binary classifier, also return the label for the
@@ -135,8 +136,9 @@ def _get_response_values(
135136
Parameters
136137
----------
137138
estimator : estimator instance
138-
Fitted classifier or regressor or a fitted :class:`~sklearn.pipeline.Pipeline`
139-
in which the last estimator is a classifier or a regressor.
139+
Fitted classifier, outlier detector, or regressor or a
140+
fitted :class:`~sklearn.pipeline.Pipeline` in which the last estimator is a
141+
classifier, an outlier detector, or a regressor.
140142
141143
X : {array-like, sparse matrix} of shape (n_samples, n_features)
142144
Input values.
@@ -188,7 +190,7 @@ def _get_response_values(
188190
If the response method can be applied to a classifier only and
189191
`estimator` is a regressor.
190192
"""
191-
from sklearn.base import is_classifier # noqa
193+
from sklearn.base import is_classifier, is_outlier_detector # noqa
192194

193195
if is_classifier(estimator):
194196
prediction_method = _check_response_method(estimator, response_method)
@@ -220,6 +222,9 @@ def _get_response_values(
220222
classes=classes,
221223
pos_label=pos_label,
222224
)
225+
elif is_outlier_detector(estimator):
226+
prediction_method = _check_response_method(estimator, response_method)
227+
y_pred, pos_label = prediction_method(X), None
223228
else: # estimator is a regressor
224229
if response_method != "predict":
225230
raise ValueError(

sklearn/utils/tests/test_response.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
make_multilabel_classification,
88
make_regression,
99
)
10+
from sklearn.ensemble import IsolationForest
1011
from sklearn.linear_model import (
1112
LinearRegression,
1213
LogisticRegression,
@@ -52,6 +53,33 @@ def test_get_response_values_regressor(return_response_method_used):
5253
assert results[2] == "predict"
5354

5455

56+
@pytest.mark.parametrize(
57+
"response_method",
58+
["predict", "decision_function", ["decision_function", "predict"]],
59+
)
60+
@pytest.mark.parametrize("return_response_method_used", [True, False])
61+
def test_get_response_values_outlier_detection(
62+
response_method, return_response_method_used
63+
):
64+
"""Check the behaviour of `_get_response_values` with outlier detector."""
65+
X, y = make_classification(n_samples=50, random_state=0)
66+
outlier_detector = IsolationForest(random_state=0).fit(X, y)
67+
results = _get_response_values(
68+
outlier_detector,
69+
X,
70+
response_method=response_method,
71+
return_response_method_used=return_response_method_used,
72+
)
73+
chosen_response_method = (
74+
response_method[0] if isinstance(response_method, list) else response_method
75+
)
76+
prediction_method = getattr(outlier_detector, chosen_response_method)
77+
assert_array_equal(results[0], prediction_method(X))
78+
assert results[1] is None
79+
if return_response_method_used:
80+
assert results[2] == chosen_response_method
81+
82+
5583
@pytest.mark.parametrize(
5684
"response_method",
5785
["predict_proba", "decision_function", "predict"],

0 commit comments

Comments
 (0)
0