|
7 | 7 | from itertools import product
|
8 | 8 |
|
9 | 9 | import pytest
|
| 10 | +from pytest import importorskip |
10 | 11 | import numpy as np
|
11 | 12 | import scipy.sparse as sp
|
12 | 13 | from scipy import __version__ as scipy_version
|
@@ -713,6 +714,38 @@ def test_suppress_validation():
|
713 | 714 | assert_raises(ValueError, assert_all_finite, X)
|
714 | 715 |
|
715 | 716 |
|
| 717 | +def test_check_dataframe_warns_on_dtype(): |
| 718 | + # Check that warn_on_dtype also works for DataFrames. |
| 719 | + # https://github.com/scikit-learn/scikit-learn/issues/10948 |
| 720 | + pd = importorskip("pandas") |
| 721 | + |
| 722 | + df = pd.DataFrame([[1, 2, 3], [4, 5, 6]], dtype=object) |
| 723 | + assert_warns_message(DataConversionWarning, |
| 724 | + "Data with input dtype object were all converted to " |
| 725 | + "float64.", |
| 726 | + check_array, df, dtype=np.float64, warn_on_dtype=True) |
| 727 | + assert_warns(DataConversionWarning, check_array, df, |
| 728 | + dtype='numeric', warn_on_dtype=True) |
| 729 | + assert_no_warnings(check_array, df, dtype='object', warn_on_dtype=True) |
| 730 | + |
| 731 | + # Also check that it raises a warning for mixed dtypes in a DataFrame. |
| 732 | + df_mixed = pd.DataFrame([['1', 2, 3], ['4', 5, 6]]) |
| 733 | + assert_warns(DataConversionWarning, check_array, df_mixed, |
| 734 | + dtype=np.float64, warn_on_dtype=True) |
| 735 | + assert_warns(DataConversionWarning, check_array, df_mixed, |
| 736 | + dtype='numeric', warn_on_dtype=True) |
| 737 | + assert_warns(DataConversionWarning, check_array, df_mixed, |
| 738 | + dtype=object, warn_on_dtype=True) |
| 739 | + |
| 740 | + # Even with numerical dtypes, a conversion can be made because dtypes are |
| 741 | + # uniformized throughout the array. |
| 742 | + df_mixed_numeric = pd.DataFrame([[1., 2, 3], [4., 5, 6]]) |
| 743 | + assert_warns(DataConversionWarning, check_array, df_mixed_numeric, |
| 744 | + dtype='numeric', warn_on_dtype=True) |
| 745 | + assert_no_warnings(check_array, df_mixed_numeric.astype(int), |
| 746 | + dtype='numeric', warn_on_dtype=True) |
| 747 | + |
| 748 | + |
716 | 749 | class DummyMemory(object):
|
717 | 750 | def cache(self, func):
|
718 | 751 | return func
|
|
0 commit comments