MNT refactor _get_response_values #21538
…nto is/18212_again_again
response_method="predict", | ||
target_type=target_type, | ||
) | ||
assert_allclose(y_pred, regressor.predict(X)) |
There was a problem hiding this comment.
This one could be checking strict equality right?
|
||
Raises | ||
------ | ||
ValueError |
There was a problem hiding this comment.
ValueError | |
AttributeError |
y_pred = y_pred[:, col_idx] | ||
else: | ||
err_msg = ( | ||
f"Got predict_proba of shape {y_pred.shape}, but need " |
There was a problem hiding this comment.
I find the error message not super explicit, I would find it easier to understand if it was saying the classifier was fitted with only one class (rather that complaining about the shape of the predict_proba). Maybe there are some edge cases I haven't thought of though.
f"one of {classes}" | ||
) | ||
elif pos_label is None and target_type == "binary": | ||
pos_label = pos_label if pos_label is not None else classes[-1] |
There was a problem hiding this comment.
Although -1 and 1 are equivalent I would use 1 since classes[1]
is used elsewhere e.g. in _get_response_values
docstring:
pos_label = pos_label if pos_label is not None else classes[-1] | |
pos_label = pos_label if pos_label is not None else classes[1] |
method = ( | ||
["predict_proba", "decision_function", "predict"] | ||
if method == "auto" | ||
else method | ||
) |
There was a problem hiding this comment.
A bit lighter to parse IMO (and also what is used in other places in this PR):
method = ( | |
["predict_proba", "decision_function", "predict"] | |
if method == "auto" | |
else method | |
) | |
if method == "auto": | |
method = ["predict_proba", "decision_function", "predict"] |
"'fit' with appropriate arguments before intending to use it to plotting " | ||
"functionalities." |
There was a problem hiding this comment.
Maybe a bit lighter to read (better suggestion welcome):
"'fit' with appropriate arguments before intending to use it to plotting " | |
"functionalities." | |
"'fit' with appropriate arguments before using it for plotting " | |
"functionalities." |
The CI is red at the moment, maybe I got something wrong in my merge ... |
There was a problem hiding this comment.
Comment that still needs to be addressed, with the same concern:
- https://github.com/scikit-learn/scikit-learn/pull/21538/files#r769872637
- https://github.com/scikit-learn/scikit-learn/pull/21538/files#r769873989
The metric checks if the input is binary. If the *Display
object checks too, then there is double validation and np.unqiue
is called twice.
I guess this is okay.
@@ -330,6 +342,12 @@ def from_predictions( | |||
""" | |||
check_matplotlib_support(f"{cls.__name__}.from_predictions") | |||
|
|||
target_type = type_of_target(y_true) | |||
if target_type != "binary": |
There was a problem hiding this comment.
I do not think we need this check. roc_curve
calls _binary_clf_curve
which ends up doing the binary check itself:
scikit-learn/sklearn/metrics/_ranking.py
Lines 736 to 738 in 22ca942
Let's move that to 1.2 |
Closing this one since we merged #23073 |
Partially address #18212
This is a simplification towards merging #20999