8000 MNT Refactor scorer using _get_response · scikit-learn/scikit-learn@cc27a27 · GitHub
[go: up one dir, main page]

Skip to content

Commit cc27a27

Browse files
committed
MNT Refactor scorer using _get_response
1 parent 193670c commit cc27a27

File tree

7 files changed

+228
-213
lines changed

7 files changed

+228
-213
lines changed

sklearn/metrics/_base.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import numpy as np
1818

19+
from ..base import is_classifier
1920
from ..utils import check_array, check_consistent_length
2021
from ..utils.multiclass import type_of_target
2122

@@ -249,3 +250,138 @@ def _check_pos_label_consistency(pos_label, y_true):
249250
pos_label = 1.0
250251

251252
return pos_label
253+
254+
255+
def _check_classifier_response_method(estimator, response_method):
256+
"""Return prediction method from the `response_method`.
257+
258+
Parameters
259+
----------
260+
estimator : estimator instance
261+
Classifier to check.
262+
8000
263+
response_method : {'auto', 'predict_proba', 'decision_function', 'predict'}
264+
Specifies whether to use :term:`predict_proba` or
265+
:term:`decision_function` as the target response. If set to 'auto',
266+
:term:`predict_proba` is tried first and if it does not exist
267+
:term:`decision_function` is tried next and :term:`predict` last.
268+
269+
Returns
270+
-------
271+
prediction_method : callable
272+
Prediction method of estimator.
273+
"""
274+
275+
possible_response_methods = (
276+
"predict", "predict_proba", "decision_function", "auto"
277+
)
278+
if response_method not in possible_response_methods:
279+
raise ValueError(
280+
f"response_method must be one of "
281+
f"{','.join(possible_response_methods)}."
282+
)
283+
284+
error_msg = "response method {} is not defined in {}"
285+
if response_method != "auto":
286+
prediction_method = getattr(estimator, response_method, None)
287+
if prediction_method is None:
288+
raise ValueError(
289+
error_msg.format(response_method, estimator.__class__.__name__)
290+
)
291+
else:
292+
predict_proba = getattr(estimator, 'predict_proba', None)
293+
decision_function = getattr(estimator, 'decision_function', None)
294+
predict = getattr(estimator, 'predict', None)
295+
prediction_method = predict_proba or decision_function or predict
296+
if prediction_method is None:
297+
raise ValueError(
298+
error_msg.format(
299+
"decision_function, predict_proba or predict",
300+
estimator.__class__.__name__
301+
)
302+
)
303+
304+
return prediction_method
305+
306+
307+
def _get_response(
308+
estimator,
309+
X,
310+
y_true,
311+
response_method,
312+
pos_label=None,
313+
):
314+
"""Return response and positive label.
315+
316+
Parameters
317+
----------
318+
estimator : estimator instance
319+
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
320+
in which the last estimator is a classifier.
321+
322+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
323+
Input values.
324+
325+
y_true : array-like of shape (n_samples,)
326+
The true label.
327+
328+
response_method: {'auto', 'predict_proba', 'decision_function', 'predict'}
329+
Specifies whether to use :term:`predict_proba` or
330+
:term:`decision_function` as the target response. If set to 'auto',
331+
:term:`predict_proba` is tried first and if it does not exist
332+
:term:`decision_function` is tried next and :term:`predict` last.
333+
334+
pos_label : str or int, default=None
335+
The class considered as the positive class when computing
336+
the metrics. By default, `estimators.classes_[1]` is
337+
considered as the positive class.
338+
339+
Returns
340+
-------
341+
y_pred : ndarray of shape (n_samples,)
342+
Target scores calculated from the provided response_method
343+
and pos_label.
344+
345+
pos_label : str or int
346+
The class considered as the positive class when computing
347+
the metrics.
348+
"""
349+
if is_classifier(estimator):
350+
y_type = type_of_target(y_true)
351+
classes = estimator.classes_
352+
prediction_method = _check_classifier_response_method(
353+
estimator, response_method
354+
)
355+
y_pred = prediction_method(X)
356+
357+
if pos_label is not None and pos_label not in classes:
358+
raise ValueError(
359+
f"pos_label={pos_label} is not a valid label: It should be "
360+
f"one of {classes}"
361+
)
362+
elif pos_label is None and y_type == "binary":
363+
pos_label = pos_label if pos_label is not None else classes[-1]
364+
365+
if prediction_method.__name__ == "predict_proba":
366+
if y_type == "binary" and y_pred.shape[1] <= 2:
367+
if y_pred.shape[1] == 2:
368+
col_idx = np.flatnonzero(classes == pos_label)[0]
369+
y_pred = y_pred[:, col_idx]
370+
else:
371+
err_msg = (
372+
f"Got predict_proba of shape {y_pred.shape}, but need "
373+
f"classifier with two classes."
374+
)
375+
raise ValueError(err_msg)
376+
elif prediction_method.__name__ == "decision_function":
377+
if y_type == "binary":
378+
if pos_label == classes[0]:
379+
y_pred *= -1
380+
else:
381+
if response_method not in ("predict", "auto"):
382+
raise ValueError(
383+
f"{estimator.__class__.__name__} should be a classifier"
384+
)
385+
y_pred, pos_label = estimator.predict(X), None
386+
387+
return y_pred, pos_label

sklearn/metrics/_plot/base.py

Lines changed: 0 additions & 114 deletions
This file was deleted.

sklearn/metrics/_plot/det_curve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import scipy as sp
22

3-
from .base import _get_response
3+
from .._base import _get_response
44

55
from .. import det_curve
66

sklearn/metrics/_plot/precision_recall_curve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .base import _get_response
1+
from .._base import _get_response
22

33
from .. import average_precision_score
44
from .. import precision_recall_curve

sklearn/metrics/_plot/roc_curve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .base import _get_response
1+
from .._base import _get_response
22

33
from .. import auc
44
from .. import roc_curve

0 commit comments

Comments
 (0)
0