8000 FIX Fixes pandas extension arrays in check_array by thomasjpfan · Pull Request #25813 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FIX Fixes pandas extension arrays in check_array #25813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 22, 2023
3 changes: 3 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,9 @@ Changelog
:pr:`25733` by :user:`Brigitta Sipőcz <bsipocz>` and
:user:`Jérémie du Boisberranger <jeremiedbb>`.

- |FIX| Fixes :func:`utils.validation.check_array` to properly convert pandas
extension arrays. :pr:`25813` by `Thomas Fan`_.

:mod:`sklearn.semi_supervised`
..............................

Expand Down
8 changes: 6 additions & 2 deletions sklearn/preprocessing/tests/test_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,19 @@ def test_label_binarizer_set_label_encoding():


@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"])
def test_label_binarizer_pandas_nullable(dtype):
@pytest.mark.parametrize("unique_first", [True, False])
def test_label_binarizer_pandas_nullable(dtype, unique_first):
"""Checks that LabelBinarizer works with pandas nullable dtypes.

Non-regression test for gh-25637.
"""
pd = pytest.importorskip("pandas")
from sklearn.preprocessing import LabelBinarizer

y_true = pd.Series([1, 0, 0, 1, 0, 1, 1, 0, 1], dtype=dtype)
if unique_first:
# Calling unique creates a pandas array which has a different interface
# compared to a pandas Series. Specifically, pandas arrays do not have "iloc".
y_true = y_true.unique()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment that explains what the returned object / type of unique is.

lb = LabelBinarizer().fit(y_true)
y_out = lb.transform([1, 0])

Expand Down
19 changes: 19 additions & 0 deletions sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,25 @@ def test_boolean_series_remains_boolean():
assert_array_equal(res, expected)


@pytest.mark.parametrize("input_values", [[0, 1, 0, 1, 0, np.nan], [0, 1, 0, 1, 0, 1]])
def test_pandas_array_returns_ndarray(input_values):
"""Check pandas array with extensions dtypes returns a numeric ndarray.

Non-regression test for gh-25637.
"""
pd = importorskip("pandas")
input_series = pd.array(input_values, dtype="Int32")
result = check_array(
input_series,
dtype=None,
ensure_2d=False,
allow_nd=False,
force_all_finite=False,
)
assert np.issubdtype(result.dtype.kind, np.floating)
assert_allclose(result, input_values)


@pytest.mark.parametrize("array_namespace", ["numpy.array_api", "cupy.array_api"])
def test_check_array_array_api_has_non_finite(array_namespace):
"""Checks that Array API arrays checks non-finite correctly."""
Expand Down
13 changes: 12 additions & 1 deletion sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,15 @@ def _pandas_dtype_needs_early_conversion(pd_dtype):
return False


def _is_extension_array_dtype(array):
try:
from pandas.api.types import is_extension_array_dtype

return is_extension_array_dtype(array)
except ImportError:
return False


def check_array(
array,
accept_sparse=False,
Expand Down Expand Up @@ -777,7 +786,9 @@ def check_array(
if all(isinstance(dtype_iter, np.dtype) for dtype_iter in dtypes_orig):
dtype_orig = np.result_type(*dtypes_orig)

elif hasattr(array, "iloc") and hasattr(array, "dtype"):
elif (_is_extension_array_dtype(array) or hasattr(array, "iloc")) and hasattr(
array, "dtype"
):
# array is a pandas series
pandas_requires_conversion = _pandas_dtype_needs_early_conversion(array.dtype)
if isinstance(array.dtype, np.dtype):
Expand Down
0