8000 ENH Allows target to be pandas nullable dtypes (#25638) · jeremiedbb/scikit-learn@6671d60 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6671d60

Browse files
authored
ENH Allows target to be pandas nullable dtypes (scikit-learn#25638)
1 parent 6adb209 commit 6671d60

File tree

5 files changed

+102
-6
lines changed

5 files changed

+102
-6
lines changed

doc/whats_new/v1.3.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,9 @@ Changelog
244244
during `transform` with no prior call to `fit` or `fit_transform`.
245245
:pr:`25190` by :user:`Vincent Maladière <Vincent-Maladiere>`.
246246

247+
- |Enhancement| :func:`utils.multiclass.type_of_target` can identify pandas
248+
nullable data types as classification targets. :pr:`25638` by `Thomas Fan`_.
249+
247250
:mod:`sklearn.semi_supervised`
248251
..............................
249252

sklearn/metrics/tests/test_classification.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,24 @@ def test_confusion_matrix_dtype():
10791079
assert cm[1, 1] == -2
10801080

10811081

1082+
@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"])
1083+
def test_confusion_matrix_pandas_nullable(dtype):
1084+
"""Checks that confusion_matrix works with pandas nullable dtypes.
1085+
1086+
Non-regression test for gh-25635.
1087+
"""
1088+
pd = pytest.importorskip("pandas")
1089+
1090+
y_ndarray = np.array([1, 0, 0, 1, 0, 1, 1, 0, 1])
1091+
y_true = pd.Series(y_ndarray, dtype=dtype)
1092+
y_predicted = pd.Series([0, 0, 1, 1, 0, 1, 1, 1, 1], dtype="int64")
1093+
1094+
output = confusion_matrix(y_true, y_predicted)
1095+
expected_output = confusion_matrix(y_ndarray, y_predicted)
1096+
1097+
assert_array_equal(output, expected_output)
1098+
1099+
10821100
def test_classification_report_multiclass():
10831101
# Test performance report
10841102
iris = datasets.load_iris()

sklearn/preprocessing/tests/test_label.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,22 @@ def test_label_binarizer_set_label_encoding():
117117
assert_array_equal(lb.inverse_transform(got), inp)
118118

119119

120+
@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"])
121+
def test_label_binarizer_pandas_nullable(dtype):
122+
"""Checks that LabelBinarizer works with pandas nullable dtypes.
123+
124+
Non-regression test for gh-25637.
125+
"""
126+
pd = pytest.importorskip("pandas")
127+
from sklearn.preprocessing import LabelBinarizer
128+
129+
y_true = pd.Series([1, 0, 0, 1, 0, 1, 1, 0, 1], dtype=dtype)
130+
lb = LabelBinarizer().fit(y_true)
131+
y_out = lb.transform([1, 0])
132+
133+
assert_array_equal(y_out, [[1], [0]])
134+
135+
120136
@ignore_warnings
121137
def test_label_binarizer_errors():
122138
# Check that invalid arguments yield ValueError

sklearn/utils/multiclass.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,25 @@ def is_multilabel(y):
155155
if hasattr(y, "__array__") or isinstance(y, Sequence) or is_array_api:
156156
# DeprecationWarning will be replaced by ValueError, see NEP 34
157157
# https://numpy.org/neps/nep-0034-infer-dtype-is-object.html
158+
check_y_kwargs = dict(
159+
accept_sparse=True,
160+
allow_nd=True,
161+
force_all_finite=False,
162+
ensure_2d=False,
163+
ensure_min_samples=0,
164+
ensure_min_features=0,
165+
)
158166
with warnings.catch_warnings():
159167
warnings.simplefilter("error", np.VisibleDeprecationWarning)
160168
try:
161-
y = xp.asarray(y)
162-
except (np.VisibleDeprecationWarning, ValueError):
169+
y = check_array(y, dtype=None, **check_y_kwargs)
170+
except (np.VisibleDeprecationWarning, ValueError) as e:
171+
if str(e).startswith("Complex data not supported"):
172+
raise
173+
163174
# dtype=object should be provided explicitly for ragged arrays,
164175
# see NEP 34
165-
y = xp.asarray(y, dtype=object)
176+
y = check_array(y, dtype=object, **check_y_kwargs)
166177

167178
if not (hasattr(y, "shape") and y.ndim == 2 and y.shape[1] > 1):
168179
return False
@@ -302,15 +313,27 @@ def type_of_target(y, input_name=""):
302313
# https://numpy.org/neps/nep-0034-infer-dtype-is-object.html
303314
# We therefore catch 8000 both deprecation (NumPy < 1.24) warning and
304315
# value error (NumPy >= 1.24).
316+
check_y_kwargs = dict(
317+
accept_sparse=True,
318+
allow_nd=True,
319+
force_all_finite=False,
320+
ensure_2d=False,
321+
ensure_min_samples=0,
322+
ensure_min_features=0,
323+
)
324+
305325
with warnings.catch_warnings():
306326
warnings.simplefilter("error", np.VisibleDeprecationWarning)
307327
if not issparse(y):
308328
try:
309-
y = xp.asarray(y)
310-
except (np.VisibleDeprecationWarning, ValueError):
329+
y = check_array(y, dtype=None, **check_y_kwargs)
330+
except (np.VisibleDeprecationWarning, ValueError) as e:
331+
if str(e).startswith("Complex data not supported"):
332+
raise
333+
311334
# dtype=object should be provided explicitly for ragged arrays,
312335
# see NEP 34
313-
y = xp.asarray(y, dtype=object)
336+
y = check_array(y, dtype=object, **check_y_kwargs)
314337

315338
# The old sequence of sequences format
316339
try:

sklearn/utils/tests/test_multiclass.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,42 @@ def test_type_of_target_pandas_sparse():
346346
type_of_target(y)
347347

348348

349+
def test_type_of_target_pandas_nullable():
350+
"""Check that type_of_target works with pandas nullable dtypes."""
351+
pd = pytest.importorskip("pandas")
352+
353+
for dtype in ["Int32", "Float32"]:
354+
y_true = pd.Series([1, 0, 2, 3, 4], dtype=dtype)
355+
assert type_of_target(y_true) == "multiclass"
356+
357+
y_true = pd.Series([1, 0, 1, 0], dtype=dtype)
358+
assert type_of_target(y_true) == "binary"
359+
360+
y_true = pd.DataFrame([[1.4, 3.1], [3.1, 1.4]], dtype="Float32")
361+
assert type_of_target(y_true) == "continuous-multioutput"
362+
363+
y_true = pd.DataFrame([[0, 1], [1, 1]], dtype="Int32")
364+
assert type_of_target(y_true) == "multilabel-indicator"
365+
366+
y_true = pd.DataFrame([[1, 2], [3, 1]], dtype="Int32")
367+
assert type_of_target(y_true) == "multiclass-multioutput"
368+
A2C0 369+
370+
@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"])
371+
def test_unique_labels_pandas_nullable(dtype):
372+
"""Checks that unique_labels work with pandas nullable dtypes.
373+
374+
Non-regression test for gh-25634.
375+
"""
376+
pd = pytest.importorskip("pandas")
377+
378+
y_true = pd.Series([1, 0, 0, 1, 0, 1, 1, 0, 1], dtype=dtype)
379+
y_predicted = pd.Series([0, 0, 1, 1, 0, 1, 1, 1, 1], dtype="int64")
380+
381+
labels = unique_labels(y_true, y_predicted)
382+
assert_array_equal(labels, [0, 1])
383+
384+
349385
def test_class_distribution():
350386
y = np.array(
351387
[

0 commit comments

Comments
 (0)
0