|
1 | | -from .base import _check_estimator_target |
| 1 | +from .base import _check_estimator_and_target_is_binary |
2 | 2 |
|
3 | 3 | from .. import average_precision_score |
4 | 4 | from .. import precision_recall_curve |
@@ -240,7 +240,7 @@ def from_estimator( |
240 | 240 | method_name = f"{cls.__name__}.from_estimator" |
241 | 241 | check_matplotlib_support(method_name) |
242 | 242 |
|
243 | | - _check_estimator_target(estimator, y) |
| 243 | + _check_estimator_and_target_is_binary(estimator, y) |
244 | 244 | if response_method == "auto": |
245 | 245 | response_method = ["predict_proba", "decision_function"] |
246 | 246 |
|
@@ -333,10 +333,10 @@ def from_predictions( |
333 | 333 | """ |
334 | 334 | check_matplotlib_support(f"{cls.__name__}.from_predictions") |
335 | 335 |
|
336 | | - if type_of_target(y_true) != "binary": |
| 336 | + target_type = type_of_target(y_true) |
| 337 | + if target_type != "binary": |
337 | 338 | raise ValueError( |
338 | | - f"The target y is not binary. Got {type_of_target(y_true)} type of" |
339 | | - " target." |
| 339 | + f"The target y is not binary. Got {target_type} type of target." |
340 | 340 | ) |
341 | 341 |
|
342 | 342 | check_consistent_length(y_true, y_pred, sample_weight) |
@@ -444,7 +444,7 @@ def plot_precision_recall_curve( |
444 | 444 | """ |
445 | 445 | check_matplotlib_support("plot_precision_recall_curve") |
446 | 446 |
|
447 | | - _check_estimator_target(estimator, y) |
| 447 | + _check_estimator_and_target_is_binary(estimator, y) |
448 | 448 |
|
449 | 449 | if response_method == "auto": |
450 | 450 | response_method = ["predict_proba", "decision_function"] |
|
0 commit comments