8000 FIX Fixes pandas extension arrays in check_array by thomasjpfan · Pull Request #25813 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
Merged
< 8000 h1 class="d-flex text-bold f5"> FIX Fixes pandas extension arrays in check_array #25813
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,9 @@ Changelog
:pr:`25733` by :user:`Brigitta Sipőcz <bsipocz>` and
:user:`Jérémie du Boisberranger <jeremiedbb>`.

- |FIX| Fixes :func:`utils.validation.check_array` to properly convert pandas
extension arrays. :pr:`25813` by `Thomas Fan`_.

:mod:`sklearn.semi_supervised`
..............................

Expand Down
8 changes: 6 additions & 2 deletions sklearn/preprocessing/tests/test_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,19 @@ def test_label_binarizer_set_label_encoding():


@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"])
def test_label_binarizer_pandas_nullable(dtype):
@pytest.mark.parametrize("unique_first", [True, False])
def test_label_binarizer_pandas_nullable(dtype, unique_first):
"""Checks that LabelBinarizer works with pandas nullable dtypes.

Non-regression test for gh-25637.
"""
pd = pytest.importorskip("pandas")
from sklearn.preprocessing import LabelBinarizer

y_true = pd.Series([1, 0, 0, 1, 0, 1, 1, 0, 1], dtype=dtype)
if unique_first:
# Calling unique creates a pandas array which has a different interface
# compared to a pandas Series. Specifically, pandas arrays do not have "iloc".
y_true = y_true.unique()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment that explains what the returned object / type of unique is.

lb = LabelBinarizer().fit(y_true)
y_out = lb.transform([1, 0])

Expand Down
19 changes: 19 additions & 0 deletions sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,25 @@ def test_boolean_series_remains_boolean():
assert_array_equal(res, expected)


@pytest.mark.parametrize("input_values", [[0, 1, 0, 1, 0, np.nan], [0, 1, 0, 1, 0, 1]])
def test_pandas_array_returns_ndarray(input_values):
"""Check pandas array with extensions dtypes returns a numeric ndarray.

Non-regression test for gh-25637.
"""
pd = importorskip("pandas")
input_series = pd.array(input_values, dtype="Int32")
result = check_array(
input_series,
dtype=None,
ensure_2d=False,
allow_nd=False,
force_all_finite=False,
)
assert np.issubdtype(result.dtype.kind, np.floating)
assert_allclose(result, input_values)


@pytest.mark.parametrize("array_namespace", ["numpy.array_api", "cupy.array_api"])
def test_check_array_array_api_has_non_finite(array_namespace):
"""Checks that Array API arrays checks non-finite correctly."""
Expand Down
13 changes: 12 additions & 1 deletion sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,15 @@ def _pandas_dtype_needs_early_conversion(pd_dtype):
return False


def _is_extension_array_dtype(array):
try:
from pandas.api.types import is_extension_array_dtype

return is_extension_array_dtype(array)
except ImportError:
return False


def check_array(
array,
accept_sparse=False,
Expand Down 61F2 Expand Up @@ -777,7 +786,9 @@ def check_array(
if all(isinstance(dtype_iter, np.dtype) for dtype_iter in dtypes_orig):
dtype_orig = np.result_type(*dtypes_orig)

elif hasattr(array, "iloc") and hasattr(array, "dtype"):
elif (_is_extension_array_dtype(array) or hasattr(array, "iloc")) and hasattr(
array, "dtype"
):
# array is a pandas series
pandas_requires_conversion = _pandas_dtype_needs_early_conversion(array.dtype)
if isinstance(array.dtype, np.dtype):
Expand Down
0