8000 ENH warn_on_dtype for DataFrames (#10949) · meetnaren/scikit-learn@42e6d4e · GitHub
[go: up one dir, main page]

Skip to content

Commit 42e6d4e

Browse files
wdevazelhesjnothman
authored andcommitted
ENH warn_on_dtype for DataFrames (scikit-learn#10949)
1 parent ee264ce commit 42e6d4e

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

sklearn/utils/tests/test_validation.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from itertools import product
88

99
import pytest
10+
from pytest import importorskip
1011
import numpy as np
1112
import scipy.sparse as sp
1213
from scipy import __version__ as scipy_version
@@ -713,6 +714,38 @@ def test_suppress_validation():
713714
assert_raises(ValueError, assert_all_finite, X)
714715

715716

717+
def test_check_dataframe_warns_on_dtype():
718+
# Check that warn_on_dtype also works for DataFrames.
719+
# https://github.com/scikit-learn/scikit-learn/issues/10948
720+
pd = importorskip("pandas")
721+
722+
df = pd.DataFrame([[1, 2, 3], [4, 5, 6]], dtype=object)
723+
assert_warns_message(DataConversionWarning,
724+
"Data with input dtype object were all converted to "
725+
"float64.",
726+
check_array, df, dtype=np.float64, warn_on_dtype=True)
727+
assert_warns(DataConversionWarning, check_array, df,
728+
dtype='numeric', warn_on_dtype=True)
729+
assert_no_warnings(check_array, df, dtype='object', warn_on_dtype=True)
730+
731+
# Also check that it raises a warning for mixed dtypes in a DataFrame.
732+
df_mixed = pd.DataFrame([['1', 2, 3], ['4', 5, 6]])
733+
assert_warns(DataConversionWarning, check_array, df_mixed,
734+
dtype=np.float64, warn_on_dtype=True)
735+
assert_warns(DataConversionWarning, check_array, df_mixed,
736+
dtype='numeric', warn_on_dtype=True)
737+
assert_warns(DataConversionWarning, check_array, df_mixed,
738+
dtype=object, warn_on_dtype=True)
739+
740+
# Even with numerical dtypes, a conversion can be made because dtypes are
741+
# uniformized throughout the array.
742+
df_mixed_numeric = pd.DataFrame([[1., 2, 3], [4., 5, 6]])
743+
assert_warns(DataConversionWarning, check_array, df_mixed_numeric,
744+
dtype='numeric', warn_on_dtype=True)
745+
assert_no_warnings(check_array, df_mixed_numeric.astype(int),
746+
dtype='numeric', warn_on_dtype=True)
747+
748+
716749
class DummyMemory(object):
717750
def cache(self, func):
718751
return func

sklearn/utils/validation.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,12 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
466466
# not a data type (e.g. a column named dtype in a pandas DataFrame)
467467
dtype_orig = None
468468

469+
# check if the object contains several dtypes (typically a pandas
470+
# DataFrame), and store them. If not, store None.
471+
dtypes_orig = None
472+
if hasattr(array, "dtypes") and hasattr(array, "__array__"):
473+
dtypes_orig = np.array(array.dtypes)
474+
469475
if dtype_numeric:
470476
if dtype_orig is not None and dtype_orig.kind == "O":
471477
# if input is object, convert to float.
@@ -581,6 +587,16 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
581587
if copy and np.may_share_memory(array, array_orig):
582588
array = np.array(array, dtype=dtype, order=order)
583589

590+
if (warn_on_dtype and dtypes_orig is not None and
591+
{array.dtype} != set(dtypes_orig)):
592+
# if there was at the beginning some other types than the final one
593+
# (for instance in a DataFrame that can contain several dtypes) then
594+
# some data must have been converted
595+
msg = ("Data with input dtype %s were all converted to %s%s."
596+
% (', '.join(map(str, sorted(set(dtypes_orig)))), array.dtype,
597+
context))
598+
warnings.warn(msg, DataConversionWarning, stacklevel=3)
599+
584600
return array
585601

586602

0 commit comments

Comments
 (0)
0