From 45d305b5e73db53a0d06a7dc9cf752bffc3d069c Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 7 Feb 2022 23:45:15 -0500 Subject: [PATCH 1/4] FIX Only raise feature name warning with mixed types and strings --- doc/whats_new/v1.1.rst | 4 ++++ sklearn/utils/tests/test_validation.py | 13 ++++++++----- sklearn/utils/validation.py | 7 +++---- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 56d5fe010fc9a..4c602294e1117 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -45,6 +45,10 @@ random sampling procedures. `force_finite=False` if you really want to get non-finite values and keep the old behavior. +- |Fix| Panda's DataFrames with all non-string columns such as a MultiIndex + will no longer warn when passed into an Estimator. :pr:`xxxxx` by + `Thomas Fan`_. + Changelog --------- diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index a4f587a68bd2e..3fc6ed2c6efc3 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -1613,11 +1613,14 @@ def test_check_array_deprecated_matrix(): @pytest.mark.parametrize( "names", - [list(range(2)), range(2), None], - ids=["list-int", "range", "default"], + [list(range(2)), range(2), None, [["a", "b"], ["c", "d"]]], + ids=["list-int", "range", "default", "MultiIndex"], ) def test_get_feature_names_pandas_with_ints_no_warning(names): - """Get feature names with pandas dataframes with ints without warning""" + """Get feature names with pandas dataframes without warning. + + Column names with consistent dtypes will not warn, such as int or MultiIndex. + """ pd = pytest.importorskip("pandas") X = pd.DataFrame([[1, 2], [4, 5], [5, 6]], columns=names) @@ -1648,10 +1651,10 @@ def test_get_feature_names_numpy(): @pytest.mark.parametrize( "names, dtypes", [ - ([["a", "b"], ["c", "d"]], "['tuple']"), (["a", 1], "['int', 'str']"), + (["pizza", ["a", "b"]], "['list', 'str']"), ], - ids=["multi-index", "mixed"], + ids=["int-str", "list-str"], ) def test_get_feature_names_invalid_dtypes_warns(names, dtypes): """Get feature names warns when the feature names have mixed dtypes""" diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index a6459059ba2f6..7fffa09da137e 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1810,9 +1810,8 @@ def _get_feature_names(X): types = sorted(t.__qualname__ for t in set(type(v) for v in feature_names)) - # Warn when types are mixed. - # ints and strings do not warn - if len(types) > 1 or not (types[0].startswith("int") or types[0] == "str"): + # Warn when types are mixed and string is one of the types + if len(types) > 1 and "str" in types: # TODO: Convert to an error in 1.2 warnings.warn( "Feature names only support names that are all strings. " @@ -1823,7 +1822,7 @@ def _get_feature_names(X): return # Only feature names of all strings are supported - if types[0] == "str": + if len(types) == 1 and types[0] == "str": return feature_names From a82352efc0ad129c42133059d26dbe79350513e3 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 7 Feb 2022 23:58:55 -0500 Subject: [PATCH 2/4] DOC Add pr number --- doc/whats_new/v1.1.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 4c602294e1117..8542c08dde020 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -46,7 +46,7 @@ random sampling procedures. the old behavior. - |Fix| Panda's DataFrames with all non-string columns such as a MultiIndex - will no longer warn when passed into an Estimator. :pr:`xxxxx` by + will no longer warn when passed into an Estimator. :pr:`22410` by `Thomas Fan`_. Changelog From 28abce885366f82aeb710e329641296c84a9209a Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 8 Feb 2022 00:02:51 -0500 Subject: [PATCH 3/4] DOC Adds more details --- doc/whats_new/v1.1.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 8542c08dde020..44457746f4393 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -46,8 +46,10 @@ random sampling procedures. the old behavior. - |Fix| Panda's DataFrames with all non-string columns such as a MultiIndex - will no longer warn when passed into an Estimator. :pr:`22410` by - `Thomas Fan`_. + no longer warns when passed into an Estimator. These non-string columns + will be supported in 1.3. Note that `feature_names_in_` will continue to + **not** be defined. For `feature_names_in_` to be defined, columns must be + all strings. :pr:`22410` by `Thomas Fan`_. Changelog --------- From cf77c1f97310d98a5feeb2606602d24b52119ee2 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 8 Feb 2022 10:22:01 -0500 Subject: [PATCH 4/4] DOC Improve wording --- doc/whats_new/v1.1.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 44457746f4393..39f8e405ebf7c 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -45,11 +45,11 @@ random sampling procedures. `force_finite=False` if you really want to get non-finite values and keep the old behavior. -- |Fix| Panda's DataFrames with all non-string columns such as a MultiIndex - no longer warns when passed into an Estimator. These non-string columns - will be supported in 1.3. Note that `feature_names_in_` will continue to - **not** be defined. For `feature_names_in_` to be defined, columns must be - all strings. :pr:`22410` by `Thomas Fan`_. +- |Fix| Panda's DataFrames with all non-string columns such as a MultiIndex no + longer warns when passed into an Estimator. Estimators will continue to + ignore the column names in DataFrames with non-string columns. For + `feature_names_in_` to be defined, columns must be all strings. :pr:`22410` by + `Thomas Fan`_. Changelog ---------