8000 BUG: fix check_array on pandas Series with custom dtype (eg categoric… · scikit-learn/scikit-learn@4f5dd77 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4f5dd77

Browse files
jorisvandenbosschejnothman
authored andcommitted
BUG: fix check_array on pandas Series with custom dtype (eg categorical) (#12706)
Closes #12699. Related to #12625
1 parent aae4e33 commit 4f5dd77

File tree

3 files changed

+13
-1
lines changed

3 files changed

+13
-1
lines changed

doc/whats_new/v0.20.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ Changelog
4646
function to return 0 when two all-zero vectors are compared.
4747
:issue:`12685` by :user:`Thomas Fan <thomasjpfan>`.
4848

49+
:mod:`sklearn.utils`
50+
....................
51+
52+
- |Fix| Calling :func:`utils.check_array` on `pandas.Series` with categorical
53+
data, which raised an error in 0.20.0, now returns the expected output again.
54+
:issue:`12699` by `Joris Van den Bossche`_.
55+
4956
.. _changes_0_20_1:
5057

5158
Version 0.20.1

sklearn/utils/tests/test_validation.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,11 @@ def test_check_array_series():
701701
warn_on_dtype=True)
702702
assert_array_equal(res, np.array([1, 2, 3]))
703703

704+
# with categorical dtype (not a numpy dtype) (GH12699)
705+
s = pd.Series(['a', 'b', 'c']).astype('category')
706+
res = check_array(s, dtype=None, ensure_2d=False)
707+
assert_array_equal(res, np.array(['a', 'b', 'c'], dtype=object))
708+
704709

705710
def test_check_dataframe_warns_on_dtype():
706711
# Check that warn_on_dtype also works for DataFrames.

sklearn/utils/validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
477477
# check if the object contains several dtypes (typically a pandas
478478
# DataFrame), and store them. If not, store None.
479479
dtypes_orig = None
480-
if hasattr(array, "dtypes") and len(array.dtypes):
480+
if hasattr(array, "dtypes") and hasattr(array.dtypes, '__array__'):
481481
dtypes_orig = np.array(array.dtypes)
482482

483483
if dtype_numeric:

0 commit comments

Comments
 (0)
0