From 606f9f9baffc444b44265055a3877925b12dc0f8 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Tue, 10 Apr 2018 17:02:44 +0200 Subject: [PATCH 1/8] [WIP] fixes #10948 --- sklearn/utils/validation.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 70e968ee6d36b..2b064eee419a9 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -445,6 +445,12 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, # not a data type (e.g. a column named dtype in a pandas DataFrame) dtype_orig = None + # check if the object contains several dtypes (typically a pandas + # DataFrame), and store them. If not, store None. + dtypes_orig = None + if hasattr(array, "dtypes"): + dtypes_orig = array.dtypes.get_values() + if dtype_numeric: if dtype_orig is not None and dtype_orig.kind == "O": # if input is object, convert to float. @@ -556,6 +562,16 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, msg = ("Data with input dtype %s was converted to %s%s." % (dtype_orig, array.dtype, context)) warnings.warn(msg, DataConversionWarning) + + if warn_on_dtype and dtypes_orig is not None and {array.dtype} != \ + set(dtypes_orig): + # if there was at the beginning some other types than the final one + # (for instance in a DataFrame that can contain several dtypes) then + # some data must have been converted + msg = ("Data with input dtype %s were all converted to %s%s." + % (', '.join(map(str, set(dtypes_orig))), array.dtype, context)) + warnings.warn(msg, DataConversionWarning) + return array From 35bb7a130d646bb6353a8a567b249ab2640f838e Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Wed, 11 Apr 2018 08:20:35 +0200 Subject: [PATCH 2/8] fix travis test with pandas 0.20 --- sklearn/utils/validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 2b064eee419a9..fe388a80bfd44 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -449,7 +449,7 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, # DataFrame), and store them. If not, store None. dtypes_orig = None if hasattr(array, "dtypes"): - dtypes_orig = array.dtypes.get_values() + dtypes_orig = np.array(array.dtypes) if dtype_numeric: if dtype_orig is not None and dtype_orig.kind == "O": From 17de6aa5cb53a09d1a9362491af8a61959458fa9 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Wed, 11 Apr 2018 11:27:12 +0200 Subject: [PATCH 3/8] Add test. --- sklearn/utils/tests/test_validation.py | 33 ++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index a3a4175d7eff4..2d2953ca9128b 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -9,6 +9,7 @@ import pytest import numpy as np import scipy.sparse as sp +from pytest import importorskip from sklearn.utils.testing import assert_true, assert_false, assert_equal from sklearn.utils.testing import assert_raises @@ -664,6 +665,38 @@ def test_suppress_validation(): assert_raises(ValueError, assert_all_finite, X) +def test_check_dataframe_warns_on_dtype(): + # Check that warn_on_dtype also works for DataFrames. + # https://github.com/scikit-learn/scikit-learn/issues/10948 + pd = importorskip("pandas") + + df = pd.DataFrame([[1, 2, 3], [4, 5, 6]], dtype=object) + assert_warns_message(DataConversionWarning, + "Data with input dtype object were all converted to " + "float64.", + check_array, df, dtype=np.float64, warn_on_dtype=True) + assert_warns(DataConversionWarning, check_array, df, + dtype='numeric', warn_on_dtype=True) + assert_no_warnings(check_array, df, dtype='object', warn_on_dtype=True) + + # Also check that it raises a warning for mixed dtypes in a DataFrame. + df_mixed = pd.DataFrame([['1', 2, 3], ['4', 5, 6]]) + assert_warns(DataConversionWarning, check_array, df_mixed, + dtype=np.float64, warn_on_dtype=True) + assert_warns(DataConversionWarning, check_array, df_mixed, + dtype='numeric', warn_on_dtype=True) + assert_warns(DataConversionWarning, check_array, df_mixed, + dtype=object, warn_on_dtype=True) + + # Even with numerical dtypes, a conversion can be made because dtypes are + # uniformized throughout the array. + df_mixed_numeric = pd.DataFrame([[1., 2, 3], [4., 5, 6]]) + assert_warns(DataConversionWarning, check_array, df_mixed_numeric, + dtype='numeric', warn_on_dtype=True) + assert_no_warnings(check_array, df_mixed_numeric.astype(int), + dtype='numeric', warn_on_dtype=True) + + class DummyMemory(object): def cache(self, func): return func From 5a3aa9e61c85260c59ddcffa7bc592fba7ff522b Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Fri, 25 May 2018 13:53:21 +0200 Subject: [PATCH 4/8] ENH: add stacklevel=2 for better warning log --- sklearn/utils/validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index c705a806f518a..db1468853818a 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -576,7 +576,7 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, # some data must have been converted msg = ("Data with input dtype %s were all converted to %s%s." % (', '.join(map(str, set(dtypes_orig))), array.dtype, context)) - warnings.warn(msg, DataConversionWarning) + warnings.warn(msg, DataConversionWarning, stacklevel=2) return array From f6843c795b14d7c80abcd977a6f0f4c16a43d133 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Fri, 25 May 2018 13:54:20 +0200 Subject: [PATCH 5/8] ENH: add more precise duck typing (see comment https://github.com/scikit-learn/scikit-learn/pull/10949#discussion_r182351141) --- sklearn/utils/validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index db1468853818a..8c9f29aa04ecc 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -452,7 +452,7 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, # check if the object contains several dtypes (typically a pandas # DataFrame), and store them. If not, store None. dtypes_orig = None - if hasattr(array, "dtypes"): + if hasattr(array, "dtypes") and hasattr(array, "__array__"): dtypes_orig = np.array(array.dtypes) if dtype_numeric: From b8e0d10234adcadb3bd0b69ca5d52dc83786275c Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Fri, 25 May 2018 13:56:33 +0200 Subject: [PATCH 6/8] FIX: flake8 remove blank line --- sklearn/utils/validation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 8c9f29aa04ecc..4f7a75b5010df 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -568,7 +568,6 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, if copy and np.may_share_memory(array, array_orig): array = np.array(array, dtype=dtype, order=order) - if warn_on_dtype and dtypes_orig is not None and {array.dtype} != \ set(dtypes_orig): # if there was at the beginning some other types than the final one From 2b906b0bb956c806c4bd588b5cfe976b22d15beb Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Thu, 28 Jun 2018 09:32:49 +0200 Subject: [PATCH 7/8] FIX take into account review comments: - sort dtypes to make output deterministic - put parenthesis instead of backslash for continuation line - put stacklevel=3 for more targeted output --- sklearn/utils/validation.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 4f7a75b5010df..37fee790fc2f0 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -568,14 +568,15 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, if copy and np.may_share_memory(array, array_orig): array = np.array(array, dtype=dtype, order=order) - if warn_on_dtype and dtypes_orig is not None and {array.dtype} != \ - set(dtypes_orig): + if (warn_on_dtype and dtypes_orig is not None and {array.dtype} != + set(dtypes_orig)): # if there was at the beginning some other types than the final one # (for instance in a DataFrame that can contain several dtypes) then # some data must have been converted msg = ("Data with input dtype %s were all converted to %s%s." - % (', '.join(map(str, set(dtypes_orig))), array.dtype, context)) - warnings.warn(msg, DataConversionWarning, stacklevel=2) + % (', '.join(map(str, sorted(set(dtypes_orig)))), array.dtype, + context)) + warnings.warn(msg, DataConversionWarning, stacklevel=3) return array From c37e7ba70d87dc12b06858d1c7f202a3f45972b1 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Thu, 28 Jun 2018 13:56:34 +0200 Subject: [PATCH 8/8] REF move line break after logical operator --- sklearn/utils/validation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 0cd74a8ba45d2..a000d935624c6 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -587,8 +587,8 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True, if copy and np.may_share_memory(array, array_orig): array = np.array(array, dtype=dtype, order=order) - if (warn_on_dtype and dtypes_orig is not None and {array.dtype} != - set(dtypes_orig)): + if (warn_on_dtype and dtypes_orig is not None and + {array.dtype} != set(dtypes_orig)): # if there was at the beginning some other types than the final one # (for instance in a DataFrame that can contain several dtypes) then # some data must have been converted