-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
API Deprecate using labels in bytes format #18555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hi @cozek sorry for the late reply and thanks for your pull request. |
@@ -950,8 +953,7 @@ def test_score_scale_invariance(): | |||
([1, 0, 1], [0.5, 0.75, 1], [1, 1, 0], [0, 0.5, 0.5]), | |||
([1, 0, 1], [0.25, 0.5, 0.75], [1, 1, 0], [0, 0.5, 0.5]), | |||
]) | |||
def test_det_curve_toydata(y_true, y_score, | |||
expected_fpr, expected_fnr): | |||
def test_det_curve_toydata(y_true, y_score, expected_fpr, expected_fnr): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some function definitions were weirdly formatted. So, I fixed these.
"value in {0, 1} or {-1, 1} or pass pos_label " | ||
"explicitly.") | ||
with pytest.raises(ValueError, match=msg): | ||
msg = (("y_true takes value in {b'a', b'b'} and pos_label is " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought a lot about it, but ultimately it seemed like the best way was to check for both the new error message and the old one together.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that we can pass the error message in the parameterization linked to the metric tested.
If needed, we can isolate the test just for this usecase.
@@ -437,3 +437,13 @@ def test_ovr_decision_function(): | |||
n_classes)[0] for i in range(4)] | |||
|
|||
assert_allclose(dec_values, dec_values_one, atol=1e-6) | |||
|
|||
|
|||
def test_labels_in_bytes_format(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hope this test is enough.
@cmarmo I have fixed the conflicts with the other tests. Please let me know if I need to make any other changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A first pass on the PR.
Please add an entry to the change log at doc/whats_new/v*.rst
. Like the other entries there, please reference this pull request with :pr:
and credit yourself (and other contributors if applicable) with :user:
.
In general it looks good.
sklearn/utils/multiclass.py
Outdated
if isinstance(y[0], bytes): | ||
raise ValueError('Labels represented as bytes is not supported.' | ||
' Convert the labels to a supported format.' | ||
' For example, y = y.astype(str)') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure that y=y.astype(str)
will be the right way to solve the problem:
- it expects that we have a NumPy array;
- it will convert into a NumPy string array that we don't always support properly if I am not wrong.
sklearn/utils/multiclass.py
Outdated
raise ValueError('Labels represented as bytes is not supported.' | ||
' Convert the labels to a supported format.' | ||
' For example, y = y.astype(str)') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise ValueError('Labels represented as bytes is not supported.' | |
' Convert the labels to a supported format.' | |
' For example, y = y.astype(str)') | |
raise ValueError( | |
'Labels are represented as bytes and are not supported. ' | |
'Convert the labels to Python string or integral format.' | |
) |
# test whether labels are represented in bytes format | ||
# and display a helpful message |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# test whether labels are represented in bytes format | |
# and display a helpful message | |
# check that we raise an error with bytes encoded labels | |
# non-regression test for: | |
# https://github.com/scikit-learn/scikit-learn/issues/16980 |
' Convert the labels to a supported format.' | ||
' For example, y = y.astype'r'\(str\)') | ||
with pytest.raises(ValueError, match=msg): | ||
type_of_target(np.array([b'a', b'b'], dtype='<S1')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can pass a list of bytes and then the array using pytest parameterization to test both case.
"value in {0, 1} or {-1, 1} or pass pos_label " | ||
"explicitly.") | ||
with pytest.raises(ValueError, match=msg): | ||
msg = (("y_true takes value in {b'a', b'b'} and pos_label is " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that we can pass the error message in the parameterization linked to the metric tested.
If needed, we can isolate the test just for this usecase.
fixed broken tests and added new tests fixed pep8 violations
@glemaitre I implemented the changes as suggested. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some extra changes
doc/whats_new/v0.24.rst
Outdated
:mod:`sklearn.metrics` | ||
...................... | ||
|
||
- |Enhancement| :func:`test_ranking.test_binary_clf_curve_implicit_bytes_pos_label` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need this entry here. This is not intended to scikit-learn user. We restrain our changelog to public changes.
doc/whats_new/v0.24.rst
Outdated
- |Enhancement| Raise informative error message in :func:`type_of_target` when | ||
labels encoded as bytes are used. For testing the same, | ||
:func:`test_multiclass.test_labels_in_bytes_format` is added. | ||
:pr:`18555` by :user:`Kaushik Amar Das <cozek>`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- |Enhancement| Raise informative error message in :func:`type_of_target` when | |
labels encoded as bytes are used. For testing the same, | |
:func:`test_multiclass.test_labels_in_bytes_format` is added. | |
:pr:`18555` by :user:`Kaushik Amar Das <cozek>`. | |
- |Enhancement| Raise informative error message in :func:`type_of_target` when | |
labels encoded as bytes are used. | |
:pr:`18555` by :user:`Kaushik Amar Das <cozek>`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to mention the test.
@pytest.mark.parametrize("labels", [ | ||
np.array([b"a", b"b"], dtype='<S1'), | ||
[b'a', b'b'] | ||
]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pytest.mark.parametrize("labels", [ | |
np.array([b"a", b"b"], dtype='<S1'), | |
[b'a', b'b'] | |
]) | |
@pytest.mark.parametrize("labels_type", [list, array]) |
np.array([b"a", b"b"], dtype='<S1'), | ||
[b'a', b'b'] | ||
]) | ||
def test_binary_clf_curve_implicit_bytes_pos_label(curve_func, labels): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def test_binary_clf_curve_implicit_bytes_pos_label(curve_func, labels): | |
def test_binary_clf_curve_implicit_bytes_pos_label(curve_func, labels_type): |
msg = ('Labels are represented as bytes and are not supported. ' | ||
'Convert the labels to Python string or integral format.') | ||
with pytest.raises(ValueError, match=msg): | ||
roc_curve(labels, [0., 1.]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
roc_curve(labels, [0., 1.]) | |
curv_func(labels, [0., 1.]) |
@pytest.mark.parametrize("test_input,msg", [ | ||
(np.array([b'a', b'b'], dtype='<S1'), | ||
('Labels are represented as bytes and are not supported. ' | ||
'Convert the labels to Python string or integral format.')), | ||
([b'a', b'b'], | ||
('Labels are represented as bytes and are not supported. ' | ||
'Convert the labels to Python string or integral format.')), | ||
]) | ||
def test_labels_in_bytes_format(test_input, msg): | ||
# check that we raise an error with bytes encoded labels | ||
# non-regression test for: | ||
# https://github.com/scikit-learn/scikit-learn/issues/16980 | ||
with pytest.raises(ValueError, match=msg): | ||
type_of_target(test_input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pytest.mark.parametrize("test_input,msg", [ | |
(np.array([b'a', b'b'], dtype='<S1'), | |
('Labels are represented as bytes and are not supported. ' | |
'Convert the labels to Python string or integral format.')), | |
([b'a', b'b'], | |
('Labels are represented as bytes and are not supported. ' | |
'Convert the labels to Python string or integral format.')), | |
]) | |
def test_labels_in_bytes_format(test_input, msg): | |
# check that we raise an error with bytes encoded labels | |
# non-regression test for: | |
# https://github.com/scikit-learn/scikit-learn/issues/16980 | |
with pytest.raises(ValueError, match=msg): | |
type_of_target(test_input) | |
@pytest.mark.parametrize("input_type", [list, array]) | |
def test_labels_in_bytes_format(input_type): | |
# check that we raise an error with bytes encoded labels | |
# non-regression test for: | |
# https://github.com/scikit-learn/scikit-learn/issues/16980 | |
target = _convert_container([b'a', b'b'], input_type) | |
err_msg = ( | |
"Labels are represented as bytes and are not supported. " | |
"Convert the labels to Python string or integral format." | |
) | |
with pytest.raises(ValueError, match=err_msg): | |
type_of_target(target) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arff I did not see that it was a merge conflict. @cozek Could you solve it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the updates!
@thomasjpfan ping 😄 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, this PR dropped out of my radar during the 1.3 release. After we move the changes to 1.4, I am happy with merging this PR.
No worries. I have done the requested changes. :) |
@glemaitre This PR has changed a bit since you approved it. Now the PR deprecates the bytes format as the target instead of raising an error. Are you okay with this change? |
@glemaitre Ping! Hope you don't mind. :) |
So in #27274 we have (https://github.com/scikit-learn/scikit-learn/pull/27274/files#diff-c9490289833bdc44eb4aca4a0ba1c1bdef48decb9136081e9563e7bcc06e88a8R344): Which is very related here. I've changed that to |
I'm okay with the new name :) I was too focused on sparse matrices at that time and forgot about 1D arrays. |
Reference Issues/PRs
Fixes #16980
What does this implement/fix? Explain your changes.
Labels in bytes format are no supported for StratifiedKFold. But at present, the error message for the same is not helpful.
For instance, in the example below, scikit-learn is unable to catch the error properly and display an appropriate error message.
Current error message:
In the above code, the error message does not make sense since the labels are not multi-label at all.
So I added a check that displays the following error message instead.
Any other comments?
Perhaps we should add some tests for this as well in test_multilabel.py ?
If there is a better way to implement this, I am listening.