8000 reuse type_of_target · scikit-learn/scikit-learn@abea58f · GitHub
[go: up one dir, main page]

Skip to content

Commit abea58f

Browse files
committed
reuse type_of_target
1 parent a99dfce commit abea58f

File tree

6 files changed

+31
-38
lines changed

6 files changed

+31
-38
lines changed

sklearn/calibration.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from .svm import LinearSVC
5050
from .model_selection import check_cv, cross_val_predict
5151
from .metrics._base import _check_pos_label_consistency
52-
from .metrics._plot.base import _check_estimator_target
52+
from .metrics._plot.base import _check_estimator_and_target_is_binary
5353

5454

5555 8000
class CalibratedClassifierCV(ClassifierMixin, MetaEstimatorMixin, BaseEstimator):
@@ -1218,7 +1218,7 @@ def from_estimator(
12181218
method_name = f"{cls.__name__}.from_estimator"
12191219
check_matplotlib_support(method_name)
12201220

1221-
_check_estimator_target(estimator, y)
1221+
_check_estimator_and_target_is_binary(estimator, y)
12221222

12231223
y_prob, pos_label = _get_response_values(
12241224
estimator, X, y, response_method="predict_proba", pos_label=pos_label
@@ -1337,10 +1337,10 @@ def from_predictions(
13371337
method_name = f"{cls.__name__}.from_predictions"
13381338
check_matplotlib_support(method_name)
13391339

1340-
if type_of_target(y_true) != "binary":
1340+
target_type = type_of_target(y_true)
1341+
if target_type != "binary":
13411342
raise ValueError(
1342-
f"The target y is not binary. Got {type_of_target(y_true)} type of"
1343-
" target."
1343+
f"The target y is not binary. Got {target_type} type of target."
13441344
)
13451345

13461346
prob_true, prob_pred = calibration_curve(

sklearn/metrics/_plot/base.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,8 @@
44
from ...utils.validation import check_is_fitted
55

66

7-
def _check_estimator_target(estimator, y):
8-
"""Helper to check that estimator is a binary classifier and y is binary.
9-
10-
This function is aside from the class `BinaryClassifierCurveDisplayMixin`
11-
below because it allows to have consistent error messages between the
12-
displays and the plotting functions.
13-
14-
FIXME: Move into `BinaryClassifierCurveDisplayMixin.from_estimator` when
15-
the plotting functions will be removed in 1.2.
16-
"""
7+
def _check_estimator_and_target_is_binary(estimator, y):
8000
8+
"""Helper to check that estimator is a binary classifier and y is binary."""
179
try:
1810
check_is_fitted(estimator)
1911
except NotFittedError as e:
@@ -34,7 +26,8 @@ def _check_estimator_target(estimator, y):
3426
"classifier. It was fitted on multiclass problem with "
3527
f"{len(estimator.classes_)} classes."
3628
)
37-
elif type_of_target(y) != "binary":
29+
target_type = type_of_target(y)
30+
if target_type != "binary":
3831
raise ValueError(
39-
f"The target y is not binary. Got {type_of_target(y)} type of target."
32+
f"The target y is not binary. Got {target_type} type of target."
4033
)

sklearn/metrics/_plot/det_curve.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import scipy as sp
22

3-
from .base import _check_estimator_target
3+
from .base import _check_estimator_and_target_is_binary
44

55
from .. import det_curve
66
from .._base import _check_pos_label_consistency
@@ -172,7 +172,7 @@ def from_estimator(
172172
"""
173173
check_matplotlib_support(f"{cls.__name__}.from_estimator")
174174

175-
_check_estimator_target(estimator, y)
175+
_check_estimator_and_target_is_binary(estimator, y)
176176
if response_method == "auto":
177177
response_method = ["predict_proba", "decision_function"]
178178

@@ -275,10 +275,10 @@ def from_predictions(
275275
"""
276276
check_matplotlib_support(f"{cls.__name__}.from_predictions")
277277

278-
if type_of_target(y_true) != "binary":
278+
target_type = type_of_target(y_true)
279+
if target_type != "binary":
279280
raise ValueError(
280-
f"The target y is not binary. Got {type_of_target(y_true)} type of"
281-
" target."
281+
f"The target y is not binary. Got {target_type} type of target."
282282
)
283283

284284
fpr, fnr, _ = det_curve(
@@ -470,7 +470,7 @@ def plot_det_curve(
470470
"""
471471
check_matplotlib_support("plot_det_curve")
472472

473-
_check_estimator_target(estimator, y)
473+
_check_estimator_and_target_is_binary(estimator, y)
474474
if response_method == "auto":
475475
response_method = ["predict_proba", "decision_function"]
476476

sklearn/metrics/_plot/precision_recall_curve.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .base import _check_estimator_target
1+
from .base import _check_estimator_and_target_is_binary
22

33
from .. import average_precision_score
44
from .. import precision_recall_curve
@@ -240,7 +240,7 @@ def from_estimator(
240240
method_name = f"{cls.__name__}.from_estimator"
241241
check_matplotlib_support(method_name)
242242

243-
_check_estimator_target(estimator, y)
243+
_check_estimator_and_target_is_binary(estimator, y)
244244
if response_method == "auto":
245245
response_method = ["predict_proba", "decision_function"]
246246

@@ -333,10 +333,10 @@ def from_predictions(
333333
"""
334334
check_matplotlib_support(f"{cls.__name__}.from_predictions")
335335

336-
if type_of_target(y_true) != "binary":
336+
target_type = type_of_target(y_true)
337+
if target_type != "binary":
337338
raise ValueError(
338-
f"The target y is not binary. Got {type_of_target(y_true)} type of"
339-
" target."
339+
f"The target y is not binary. Got {target_type} type of target."
340340
)
341341

342342
check_consistent_length(y_true, y_pred, sample_weight)
@@ -444,7 +444,7 @@ def plot_precision_recall_curve(
444444
"""
445445
check_matplotlib_support("plot_precision_recall_curve")
446446

447-
_check_estimator_target(estimator, y)
447+
_check_estimator_and_target_is_binary(estimator, y)
448448

449449
if response_method == "auto":
450450
response_method = ["predict_proba", "decision_function"]

sklearn/metrics/_plot/roc_curve.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .base import _check_estimator_target
1+
from .base import _check_estimator_and_target_is_binary
22

33
from .. import auc
44
from .. import roc_curve
@@ -231,7 +231,7 @@ def from_estimator(
231231
"""
232232
check_matplotlib_support(f"{cls.__name__}.from_estimator")
233233

234-
_check_estimator_target(estimator, y)
234+
_check_estimator_and_target_is_binary(estimator, y)
235235
if response_method == "auto":
236236
response_method = ["predict_proba", "decision_function"]
237237

@@ -340,10 +340,10 @@ def from_predictions(
340340
"""
341341
check_matplotlib_support(f"{cls.__name__}.from_predictions")
342342

343-
if type_of_target(y_true) != "binary":
343+
target_type = type_of_target
344+
if target_type != "binary":
344345
raise ValueError(
345-
f"The target y is not binary. Got {type_of_target(y_true)} type of"
346-
" target."
346+
f"The target y is not binary. Got {target_type} type of target."
347347
)
348348

349349
fpr, tpr, _ = roc_curve(
@@ -467,7 +467,7 @@ def plot_roc_curve(
467467
"""
468468
check_matplotlib_support("plot_roc_curve")
469469

470-
_check_estimator_target(estimator, y)
470+
_check_estimator_and_target_is_binary(estimator, y)
471471
if response_method == "auto":
472472
response_method = ["predict_proba", "decision_function"]
473473

sklearn/metrics/_plot/tests/test_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from sklearn.exceptions import NotFittedError
55
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
66

7-
from sklearn.metrics._plot.base import _check_estimator_target
7+
from sklearn.metrics._plot.base import _check_estimator_and_target_is_binary
88

99
X, y = load_iris(return_X_y=True)
1010
X_binary, y_binary = X[:100], y[:100]
@@ -39,7 +39,7 @@
3939
),
4040
],
4141
)
42-
def test_check_estimator_target(estimator, target, err_type, err_msg):
42+
def test_check_estimator_and_target_is_binary(estimator, target, err_type, err_msg):
4343
"""Check that we raise the expected error when checking the estimator and target."""
4444
with pytest.raises(err_type, match=err_msg):
45-
_check_estimator_target(estimator, target)
45+
_check_estimator_and_target_is_binary(estimator, target)

0 commit comments

Comments
 (0)
0