10000 FIX check_array dtype check for pandas series (#12625) · scikit-learn/scikit-learn@104f684 · GitHub
[go: up one dir, main page]

Skip to content

Commit 104f684

Browse files
amuellerjnothman
authored andcommitted
FIX check_array dtype check for pandas series (#12625)
1 parent 0b8650a commit 104f684

File tree

4 files changed

+14
-2
lines changed

4 files changed

+14
-2
lines changed

doc/whats_new/_contributors.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848

4949
.. _Bertrand Thirion: https://team.inria.fr/parietal/bertrand-thirions-page
5050

51-
.. _Andreas Müller: https://peekaboo-vision.blogspot.com/
51+
.. _Andreas Müller: https://amueller.github.io/
5252

5353
.. _Matthieu Perrot: http://brainvisa.info/biblio/lnao/en/Author/PERROT-M.html
5454

doc/whats_new/v0.20.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,10 @@ Changelog
180180
precision issues in :class:`preprocessing.StandardScaler` and
181181
:class:`decomposition.IncrementalPCA` when using float32 datasets.
182182
:issue:`12338` by :user:`bauks <bauks>`.
183+
184+
- |Fix| Calling :func:`utils.check_array` on `pandas.Series`, which
185+
raised an error in 0.20.0, now returns the expected output again.
186+
:issue:`12625` by `Andreas Müller`_
183187

184188
Miscellaneous
185189
.............

sklearn/utils/tests/test_validation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,14 @@ def test_suppress_validation():
694694
assert_raises(ValueError, assert_all_finite, X)
695695

696696

697+
def test_check_array_series():
698+
# regression test that check_array works on pandas Series
699+
pd = importorskip("pandas")
700+
res = check_array(pd.Series([1, 2, 3]), ensure_2d=False,
701+
warn_on_dtype=True)
702+
assert_array_equal(res, np.array([1, 2, 3]))
703+
704+
697705
def test_check_dataframe_warns_on_dtype():
698706
# Check that warn_on_dtype also works for DataFrames.
699707
# https://github.com/scikit-learn/scikit-learn/issues/10948

sklearn/utils/validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
477477
# check if the object contains several dtypes (typically a pandas
478478
# DataFrame), and store them. If not, store None.
479479
dtypes_orig = None
480-
if hasattr(array, "dtypes") and hasattr(array, "__array__"):
480+
if hasattr(array, "dtypes") and len(array.dtypes):
481481
dtypes_orig = np.array(array.dtypes)
482482

483483
if dtype_numeric:

0 commit comments

Comments
 (0)
0