8000 API Deprecate using labels in bytes format by cozek · Pull Request #18555 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

API Deprecate using labels in bytes format #18555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
4ebc75d
added error message for using bytes as labels
cozek Oct 7, 2020
6e8304d
Merge remote-tracking branch 'upstream/master' into add-bytes-error
cozek Jan 16, 2021
5bbabf3
wip
cozek Jan 19, 2021
4f44ec4
Merge branch 'main' into add-bytes-error
cozek Mar 3, 2021
0a94e1f
implemented reviewer recommendations
cozek Mar 3, 2021
5ecb970
added changes in changelog
cozek Mar 3, 2021
f13c949
Merge branch 'main' into add-bytes-error
cozek Apr 10, 2021
2b4b735
precomit fix
cozek Apr 10, 2021
cf32454
Merge branch 'main' into add-bytes-error
cozek Apr 10, 2021
a34de68
removing bytes test
cozek Apr 11, 2021
7deab3c
Merge branch 'main' into add-bytes-error
cozek May 28, 2021
2531335
Merge remote-tracking branch 'origin/main' into pr/cozek/18555
glemaitre Jul 23, 2021
3dcf861
move whats new entry
glemaitre Jul 23, 2021
3f1dbe4
mergin main
cozek Apr 16, 2023
f3e47cb
precommit
cozek Apr 16, 2023
c1b56e3
Merge branch 'main' into add-bytes-error
cozek May 7, 2023
fe350a7
reviewer changes
cozek May 7, 2023
66d71c3
fixing linting issue
cozek May 8, 2023
f6cb141
fix flake8 errors
cozek May 8, 2023
847dde4
Merge branch 'main' into add-bytes-error
cozek May 19, 2023
b7fb09c
updates
cozek May 19, 2023
0add79d
resolving conflicts
cozek Jun 4, 2023
9cd1b53
changes
cozek Jun 4, 2023
5572db4
Merge branch 'main' into add-bytes-error
cozek Jun 8, 2023
f438106
Merge branch 'main' into add-bytes-error
cozek Jun 17, 2023
11ad744
Merge branch 'main' into add-bytes-error
cozek Jul 28, 2023
8f2c4a2
updates
cozek Jul 28, 2023
93736e9
Merge remote-tracking branch 'upstream/main' into add-bytes-error
cozek Jul 29, 2023
9d4200c
Merge branch 'main' into add-bytes-error
cozek Jul 31, 2023
5a4de53
Merge branch 'main' into add-bytes-error
cozek Aug 4, 2023
ad65d47
Merge branch 'main' into add-bytes-error
cozek Aug 8, 2023
64fddfe
Merge branch 'main' into add-bytes-error
cozek Aug 23, 2023
3e684f4
Merge remote-tracking branch 'upstream/main' into add-bytes-error
adrinjalali Apr 15, 2024
1f4e936
fix error message version
adrinjalali Apr 15, 2024
d6f8022
fix row or val
adrinjalali Apr 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,10 @@ Changelog
in favor of `y_proba`. `y_prob` will be removed in version 1.7.
:pr:`28092` by :user:`Adam Li <adam2392>`.

- |API| For classifiers and classification metrics, labels encoded as bytes
is deprecated and will raise an error in v1.6.
:pr:`18555` by :user:`Kaushik Amar Das <cozek>`.

:mod:`sklearn.mixture`
......................

Expand Down Expand Up @@ -418,6 +422,11 @@ Changelog
- |API| :func:`utils.tosequence` is deprecated and will be removed in version 1.7.
:pr:`28763` by :user:`Jérémie du Boisberranger <jeremiedbb>`.

- |API| Raise informative warning message in :func:`type_of_target` when
represented as bytes. For classifiers and classification metrics, labels encoded
as bytes is deprecated and will raise an error in v1.6.
:pr:`18555` by :user:`Kaushik Amar Das <cozek>`.

- |Fix| :func:`~utils._safe_indexing` now works correctly for polars DataFrame when
`axis=0` and supports indexing polars Series.
:pr:`28521` by :user:`Yao Xiao <Charlie-XIAO>`.
Expand Down
30 changes: 19 additions & 11 deletions sklearn/metrics/tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
from sklearn.preprocessing import label_binarize
from sklearn.random_projection import _sparse_random_matrix
from sklearn.utils._testing import (
_convert_container,
assert_allclose,
assert_almost_equal,
assert_array_almost_equal,
assert_array_equal,
ignore_warnings,
)
from sklearn.utils.extmath import softmax
from sklearn.utils.fixes import CSR_CONTAINERS
Expand Down Expand Up @@ -864,17 +866,6 @@ def test_binary_clf_curve_implicit_pos_label(curve_func):
with pytest.raises(ValueError, match=msg):
curve_func(np.array(["a", "b"], dtype=object), [0.0, 1.0])

# The error message is slightly different for bytes-encoded
# class labels, but otherwise the behavior is the same:
msg = (
"y_true takes value in {b'a', b'b'} and pos_label is "
"not specified: either make y_true take "
"value in {0, 1} or {-1, 1} or pass pos_label "
"explicitly."
)
with pytest.raises(ValueError, match=msg):
curve_func(np.array([b"a", b"b"], dtype="<S1"), [0.0, 1.0])

# Check that it is possible to use floating point class labels
# that are interpreted similarly to integer class labels:
y_pred = [0.0, 1.0, 0.2, 0.42]
Expand All @@ -884,6 +875,23 @@ def test_binary_clf_curve_implicit_pos_label(curve_func):
np.testing.assert_allclose(int_curve_part, float_curve_part)


# TODO(1.5): Update test to check for error when bytes support is removed.
@ignore_warnings(category=FutureWarning)
@pytest.mark.parametrize("curve_func", [precision_recall_curve, roc_curve])
@pytest.mark.parametrize("labels_type", ["list", "array"])
def test_binary_clf_curve_implicit_bytes_pos_label(curve_func, labels_type):
# Check that using bytes class labels raises an informative
# error for any supported string dtype:
labels = _convert_container([b"a", b"b"], labels_type)
msg = (
"y_true takes value in {b'a', b'b'} and pos_label is not "
"specified: either make y_true take value in {0, 1} or "
"{-1, 1} or pass pos_label explicitly."
)
with pytest.raises(ValueError, match=msg):
curve_func(labels, [0.0, 1.0])


@pytest.mark.parametrize("curve_func", CURVE_FUNCS)
def test_binary_clf_curve_zero_sample_weight(curve_func):
y_true = [0, 0, 1, 1, 1]
Expand Down
27 changes: 19 additions & 8 deletions sklearn/utils/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,13 +342,24 @@ def type_of_target(y, input_name=""):
# see NEP 34
y = check_array(y, dtype=object, **check_y_kwargs)

# The old sequence of sequences format
try:
first_row = y[[0], :] if issparse(y) else y[0]
# TODO(1.7): Change to ValueError when byte labels is deprecated.
# labels in bytes format
first_row_or_val = y[[0], :] if issparse(y) else y[0]
if isinstance(first_row_or_val, bytes):
warnings.warn(
(
"Support for labels represented as bytes is deprecated in v1.5 and"
" will error in v1.7. Convert the labels to a string or integer"
" format."
),
FutureWarning,
)
# The old sequence of sequences format
if (
not hasattr(first_row, "__array__")
and isinstance(first_row, Sequence)
and not isinstance(first_row, str)
not hasattr(first_row_or_val, "__array__")
and isinstance(first_row_or_val, Sequence)
and not isinstance(first_row_or_val, str)
):
raise ValueError(
"You appear to be using a legacy multi-label data"
Expand Down Expand Up @@ -390,9 +401,9 @@ def type_of_target(y, input_name=""):
return "continuous" + suffix

# Check multiclass
if issparse(first_row):
first_row = first_row.data
if xp.unique_values(y).shape[0] > 2 or (y.ndim == 2 and len(first_row) > 1):
if issparse(first_row_or_val):
first_row_or_val = first_row_or_val.data
if xp.unique_values(y).shape[0] > 2 or (y.ndim == 2 and len(first_row_or_val) > 1):
# [1, 2, 3] or [[1., 2., 3]] or [[1, 2]]
return "multiclass" + suffix
else:
Expand Down
16 changes: 16 additions & 0 deletions sklearn/utils/tests/test_multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sklearn.utils._array_api import yield_namespace_device_dtype_combinations
from sklearn.utils._testing import (
_array_api_for_tests,
_convert_container,
assert_allclose,
assert_array_almost_equal,
assert_array_equal,
Expand Down Expand Up @@ -595,3 +596,18 @@ def test_ovr_decision_function():
]

assert_allclose(dec_values, dec_values_one, atol=1e-6)


# TODO(1.6): Change to ValueError when byte labels is deprecated.
@pytest.mark.parametrize("input_type", ["list", "array"])
def test_labels_in_bytes_format(input_type):
# check that we raise an error with bytes encoded labels
# non-regression test for:
# https://github.com/scikit-learn/scikit-learn/issues/16980
target = _convert_container([b"a", b"b"], input_type)
err_msg = (
"Support for labels represented as bytes is deprecated in v1.5 and will"
" error in v1.7. Convert the labels to a string or integer format."
)
with pytest.warns(FutureWarning, match=err_msg):
type_of_target(target)
0