8000 BUG Fixes pandas dataframe bug with boolean dtypes (#15797) · scikit-learn/scikit-learn@1b1c869 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1b1c869

Browse files
thomasjpfanglemaitre
authored andcommitted
BUG Fixes pandas dataframe bug with boolean dtypes (#15797)
1 parent 3e26ea3 commit 1b1c869

File tree

3 files changed

+47
-2
lines changed

3 files changed

+47
-2
lines changed

doc/whats_new/v0.22.rst

+19
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,25 @@
22

33
.. currentmodule:: sklearn
44

5+
.. _changes_0_22_1:
6+
7+
Version 0.22.1
8+
==============
9+
10+
**In Development**
11+
12+
This is a bug-fix release to primarily resolve some packaging issues in version
13+
0.22.0. It also includes minor documentation improvements and some bug fixes.
14+
15+
Changelog
16+
---------
17+
18+
:mod:`sklearn.utils`
19+
....................
20+
21+
- |Fix| :func:`utils.check_array` now correctly converts pandas DataFrame with
22+
boolean columns to floats. :pr:`15797` by `Thomas Fan`_.
23+
524
.. _changes_0_22:
625

726
Version 0.22.0

sklearn/utils/tests/test_validation.py

+21
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,27 @@ def test_check_dataframe_warns_on_dtype():
826826
assert len(record) == 0
827827

828828

829+
def test_check_dataframe_mixed_float_dtypes():
830+
# pandas dataframe will coerce a boolean into a object, this is a mismatch
831+
# with np.result_type which will return a float
832+
# check_array needs to explicitly check for bool dtype in a dataframe for
833+
# this situation
834+
# https://github.com/scikit-learn/scikit-learn/issues/15787
835+
836+
pd = importorskip("pandas")
837+
df = pd.DataFrame({
838+
'int': [1, 2, 3],
839+
'float': [0, 0.1, 2.1],
840+
'bool': [True, False, True]}, columns=['int', 'float', 'bool'])
841+
842+
array = check_array(df, dtype=(np.float64, np.float32, np.float16))
843+
expected_array = np.array(
844+
[[1.0, 0.0, 1.0],
845+
[2.0, 0.1, 0.0],
846+
[3.0, 2.1, 1.0]], dtype=np.float)
847+
assert_allclose_dense_sparse(array, expected_array)
848+
849+
829850
class DummyMemory:
830851
def cache(self, func):
831852
return func

sklearn/utils/validation.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -454,9 +454,14 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
454454
# DataFrame), and store them. If not, store None.
455455
dtypes_orig = None
456456
if hasattr(array, "dtypes") and hasattr(array.dtypes, '__array__'):
457-
dtypes_orig = np.array(array.dtypes)
457+
dtypes_orig = list(array.dtypes)
458+
# pandas boolean dtype __array__ interface coerces bools to objects
459+
for i, dtype_iter in enumerate(dtypes_orig):
460+
if dtype_iter.kind == 'b':
461+
dtypes_orig[i] = np.object
462+
458463
if all(isinstance(dtype, np.dtype) for dtype in dtypes_orig):
459-
dtype_orig = np.result_type(*array.dtypes)
464+
dtype_orig = np.result_type(*dtypes_orig)
460465

461466
if dtype_numeric:
462467
if dtype_orig is not None and dtype_orig.kind == "O":

0 commit comments

Comments
 (0)
0