From d388ffaad39293b65d2e1368ac0baa8d303eef5d Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 30 Jul 2023 18:48:40 -0400 Subject: [PATCH 1/4] ENH Improves warnings when func returns a dataframe in FunctionTransformer --- doc/whats_new/v1.4.rst | 7 +++++ .../preprocessing/_function_transformer.py | 28 +++++++++++++------ .../tests/test_function_transformer.py | 19 +++++++++---- 3 files changed, 39 insertions(+), 15 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 1f66619ae2219..1d5d187afabb1 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -134,6 +134,13 @@ Changelog to :ref:`metadata routing user guide `. :pr:`26789` by `Adrin Jalali`_. +:mod:`sklearn.preprocessing` +............................ + +- |Enhancement| Improves warnings in :class:`preprocessing.FunctionTransfomer` when + `func` returns a pandas dataframe and the output is configured to be pandas. + :pr:`xxxxx` by `Thomas Fan`_. + :mod:`sklearn.model_selection` .............................. diff --git a/sklearn/preprocessing/_function_transformer.py b/sklearn/preprocessing/_function_transformer.py index f1df0f43dc96e..9f53c1f71d1f3 100644 --- a/sklearn/preprocessing/_function_transformer.py +++ b/sklearn/preprocessing/_function_transformer.py @@ -4,10 +4,12 @@ from ..base import BaseEstimator, TransformerMixin, _fit_context from ..utils._param_validation import StrOptions +from ..utils._set_output import _get_output_config from ..utils.metaestimators import available_if from ..utils.validation import ( _allclose_dense_sparse, _check_feature_names_in, + _is_pandas_df, check_array, ) @@ -237,7 +239,20 @@ def transform(self, X): Transformed input. """ X = self._check_input(X, reset=False) - return self._transform(X, func=self.func, kw_args=self.kw_args) + out = self._transform(X, func=self.func, kw_args=self.kw_args) + + output_config = _get_output_config("transform", self)["dense"] + if ( + output_config == "pandas" + and self.feature_names_out is None + and not _is_pandas_df(out) + ): + warnings.warn( + "When set_output is configured to be 'pandas', either `func` returns a " + "DataFrame to follow the set_output API or `feature_names_out` is " + "defined" + ) + return out def inverse_transform(self, X): """Transform X using the inverse function. @@ -338,13 +353,8 @@ def set_output(self, *, transform=None): self : estimator instance Estimator instance. """ - if hasattr(super(), "set_output"): - return super().set_output(transform=transform) - - if transform == "pandas" and self.feature_names_out is None: - warnings.warn( - 'With transform="pandas", `func` should return a DataFrame to follow' - " the set_output API." - ) + if not hasattr(self, "_sklearn_output_config"): + self._sklearn_output_config = {} + self._sklearn_output_config["transform"] = transform return self diff --git a/sklearn/preprocessing/tests/test_function_transformer.py b/sklearn/preprocessing/tests/test_function_transformer.py index fa19171503a1d..9ceace7c29506 100644 --- a/sklearn/preprocessing/tests/test_function_transformer.py +++ b/sklearn/preprocessing/tests/test_function_transformer.py @@ -454,13 +454,20 @@ def test_set_output_func(): assert isinstance(X_trans, pd.DataFrame) assert_array_equal(X_trans.columns, ["a", "b"]) - # If feature_names_out is not defined, then a warning is raised in - # `set_output` ft = FunctionTransformer(lambda x: 2 * x) - msg = "should return a DataFrame to follow the set_output API" - with pytest.warns(UserWarning, match=msg): - ft.set_output(transform="pandas") + ft.set_output(transform="pandas") - X_trans = ft.fit_transform(X) + # no warning is raised when func returns a panda dataframe + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + X_trans = ft.fit_transform(X) assert isinstance(X_trans, pd.DataFrame) assert_array_equal(X_trans.columns, ["a", "b"]) + + # Warning is raised when func returns a ndarray + ft_np = FunctionTransformer(lambda x: np.asarray(x)) + ft_np.set_output(transform="pandas") + + msg = "When set_output is configured to be 'pandas'" + with pytest.warns(UserWarning, match=msg): + ft_np.fit_transform(X) From 554e389e527043e3f71af54f264c9b9c6ec3becf Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 30 Jul 2023 18:50:06 -0400 Subject: [PATCH 2/4] DOC Adds PR number --- doc/whats_new/v1.4.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 1d5d187afabb1..40d3f88836d2b 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -139,7 +139,7 @@ Changelog - |Enhancement| Improves warnings in :class:`preprocessing.FunctionTransfomer` when `func` returns a pandas dataframe and the output is configured to be pandas. - :pr:`xxxxx` by `Thomas Fan`_. + :pr:`26944` by `Thomas Fan`_. :mod:`sklearn.model_selection` .............................. From 5b08530254ae2134153ae401bdc97f7235ea3a4c Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 30 Jul 2023 22:41:42 -0400 Subject: [PATCH 3/4] TST Improve coverage --- sklearn/preprocessing/tests/test_function_transformer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sklearn/preprocessing/tests/test_function_transformer.py b/sklearn/preprocessing/tests/test_function_transformer.py index 9ceace7c29506..cdad6826fe21d 100644 --- a/sklearn/preprocessing/tests/test_function_transformer.py +++ b/sklearn/preprocessing/tests/test_function_transformer.py @@ -471,3 +471,9 @@ def test_set_output_func(): msg = "When set_output is configured to be 'pandas'" with pytest.warns(UserWarning, match=msg): ft_np.fit_transform(X) + + # default transform does not warn + ft_np.set_output(transform="default") + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + ft_np.fit_transform(X) From 51961fa1fcfd64f99a2b64f5393430d3526bff49 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 30 Aug 2023 15:00:10 -0400 Subject: [PATCH 4/4] Apply suggestions from code review Co-authored-by: Olivier Grisel --- sklearn/preprocessing/_function_transformer.py | 6 +++--- sklearn/preprocessing/tests/test_function_transformer.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/preprocessing/_function_transformer.py b/sklearn/preprocessing/_function_transformer.py index 9f53c1f71d1f3..fa755265d7bc2 100644 --- a/sklearn/preprocessing/_function_transformer.py +++ b/sklearn/preprocessing/_function_transformer.py @@ -248,9 +248,9 @@ def transform(self, X): and not _is_pandas_df(out) ): warnings.warn( - "When set_output is configured to be 'pandas', either `func` returns a " - "DataFrame to follow the set_output API or `feature_names_out` is " - "defined" + "When `set_output` is configured to be 'pandas', `func` should return " + "a DataFrame to follow the `set_output` API or `feature_names_out` " + "should be defined." ) return out diff --git a/sklearn/preprocessing/tests/test_function_transformer.py b/sklearn/preprocessing/tests/test_function_transformer.py index cdad6826fe21d..36081a8ce6380 100644 --- a/sklearn/preprocessing/tests/test_function_transformer.py +++ b/sklearn/preprocessing/tests/test_function_transformer.py @@ -468,7 +468,7 @@ def test_set_output_func(): ft_np = FunctionTransformer(lambda x: np.asarray(x)) ft_np.set_output(transform="pandas") - msg = "When set_output is configured to be 'pandas'" + msg = "When `set_output` is configured to be 'pandas'" with pytest.warns(UserWarning, match=msg): ft_np.fit_transform(X)