8000 ENH respect dtypes in pandas dataframes if homogeneous (#15094) · crankycoder/scikit-learn@b906078 · GitHub
[go: up one dir, main page]

Skip to content
8000

Commit b906078

Browse files
amuellerjnothman
authored andcommitted
ENH respect dtypes in pandas dataframes if homogeneous (scikit-learn#15094)
Only handles the case that all dtypes are numpy dtypes
1 parent cd3d502 commit b906078

File tree

3 files changed

+47
-1
lines changed

3 files changed

+47
-1
lines changed

doc/whats_new/v0.22.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,10 @@ Changelog
645645
NaN to integer.
646646
:pr:`14872` by `Roman Yurchak`_.
647647

648+
- |Fix| :func:`utils.check_array` will now correctly detect numeric dtypes in
649+
pandas dataframes, fixing a bug where ``float32`` was upcast to ``float64``
650+
unnecessarily. :pre:`15094` by `Andreas Müller`_.
651+
648652
- |API| The following utils have been deprecated and are now private:
649653
- ``choose_check_classifiers_labels``
650654
- ``enforce_estimator_tags_y``

sklearn/utils/tests/test_validation.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@
4343
check_scalar,
4444
_deprecate_positional_args,
4545
_check_sample_weight,
46-
_allclose_dense_sparse)
46+
_allclose_dense_sparse,
47+
FLOAT_DTYPES)
4748
import sklearn
4849

4950
from sklearn.exceptions import NotFittedError
@@ -352,6 +353,45 @@ def test_check_array_pandas_dtype_object_conversion():
352353
assert check_array(X_df, ensure_2d=False).dtype.kind == "f"
353354

354355

356+
def test_check_array_pandas_dtype_casting():
357+
# test that data-frames with homogeneous dtype are not upcast
358+
pd = pytest.importorskip('pandas')
359+
X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
360+
X_df = pd.DataFrame(X)
361+
assert check_array(X_df).dtype == np.float32
362+
assert check_array(X_df, dtype=FLOAT_DTYPES).dtype == np.float32
363+
364+
X_df.iloc[:, 0] = X_df.iloc[:, 0].astype(np.float16)
365+
assert_array_equal(X_df.dtypes,
366+
(np.float16, np.float32, np.float32))
367+
assert check_array(X_df).dtype == np.float32
368+
assert check_array(X_df, dtype=FLOAT_DTYPES).dtype == np.float32
369+
370+
X_df.iloc[:, 1] = X_df.iloc[:, 1].astype(np.int16)
371+
# float16, int16, float32 casts to float32
372+
assert check_array(X_df).dtype == np.float32
373+
assert check_array(X_df, dtype=FLOAT_DTYPES).dtype == np.float32
374+
375+
X_df.iloc[:, 2] = X_df.iloc[:, 2].astype(np.float16)
376+
# float16, int16, float16 casts to float32
377+
assert check_array(X_df).dtype == np.float32
378+
assert check_array(X_df, dtype=FLOAT_DTYPES).dtype == np.float32
379+
380+
X_df = X_df.astype(np.int16)
381+
assert check_array(X_df).dtype == np.int16
382+
# we're not using upcasting rules for determining
383+
# the target type yet, so we cast to the default of float64
384+
assert check_array(X_df, dtype=FLOAT_DTYPES).dtype == np.float64
385+
386+
# check that we handle pandas dtypes in a semi-reasonable way
387+
# this is actually tricky because we can't really know that this
388+
# should be integer ahead of converting it.
389+
cat_df = pd.DataFrame([pd.Categorical([1, 2, 3])])
390+
assert (check_array(cat_df).dtype == np.int64)
391+
assert (check_array(cat_df, dtype=FLOAT_DTYPES).dtype
392+
== np.float64)
393+
394+
355395
def test_check_array_on_mock_dataframe():
356396
arr = np.array([[0.2, 0.7], [0.6, 0.5], [0.4, 0.1], [0.7, 0.2]])
357397
mock_df = MockDataFrame(arr)

sklearn/utils/validation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,8 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
454454
dtypes_orig = None
455455
if hasattr(array, "dtypes") and hasattr(array.dtypes, '__array__'):
456456
dtypes_orig = np.array(array.dtypes)
457+
if all(isinstance(dtype, np.dtype) for dtype in dtypes_orig):
458+
dtype_orig = np.result_type(*array.dtypes)
457459

458460
if dtype_numeric:
459461
if dtype_orig is not None and dtype_orig.kind == "O":

0 commit comments

Comments
 (0)
0