diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 8dbd867b0c9ba..c56e67024a890 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -250,6 +250,13 @@ Changelog - |Enhancement| Added `neg_root_mean_squared_log_error_scorer` as scorer :pr:`26734` by :user:`Alejandro Martin Gil <101AlexMartin>`. +:mod:`sklearn.preprocessing` +............................ + +- |Enhancement| Improves warnings in :class:`preprocessing.FunctionTransfomer` when + `func` returns a pandas dataframe and the output is configured to be pandas. + :pr:`26944` by `Thomas Fan`_. + :mod:`sklearn.model_selection` .............................. diff --git a/sklearn/preprocessing/_function_transformer.py b/sklearn/preprocessing/_function_transformer.py index f1df0f43dc96e..fa755265d7bc2 100644 --- a/sklearn/preprocessing/_function_transformer.py +++ b/sklearn/preprocessing/_function_transformer.py @@ -4,10 +4,12 @@ from ..base import BaseEstimator, TransformerMixin, _fit_context from ..utils._param_validation import StrOptions +from ..utils._set_output import _get_output_config from ..utils.metaestimators import available_if from ..utils.validation import ( _allclose_dense_sparse, _check_feature_names_in, + _is_pandas_df, check_array, ) @@ -237,7 +239,20 @@ def transform(self, X): Transformed input. """ X = self._check_input(X, reset=False) - return self._transform(X, func=self.func, kw_args=self.kw_args) + out = self._transform(X, func=self.func, kw_args=self.kw_args) + + output_config = _get_output_config("transform", self)["dense"] + if ( + output_config == "pandas" + and self.feature_names_out is None + and not _is_pandas_df(out) + ): + warnings.warn( + "When `set_output` is configured to be 'pandas', `func` should return " + "a DataFrame to follow the `set_output` API or `feature_names_out` " + "should be defined." + ) + return out def inverse_transform(self, X): """Transform X using the inverse function. @@ -338,13 +353,8 @@ def set_output(self, *, transform=None): self : estimator instance Estimator instance. """ - if hasattr(super(), "set_output"): - return super().set_output(transform=transform) - - if transform == "pandas" and self.feature_names_out is None: - warnings.warn( - 'With transform="pandas", `func` should return a DataFrame to follow' - " the set_output API." - ) + if not hasattr(self, "_sklearn_output_config"): + self._sklearn_output_config = {} + self._sklearn_output_config["transform"] = transform return self diff --git a/sklearn/preprocessing/tests/test_function_transformer.py b/sklearn/preprocessing/tests/test_function_transformer.py index d843f56002619..c4b2f79f288f0 100644 --- a/sklearn/preprocessing/tests/test_function_transformer.py +++ b/sklearn/preprocessing/tests/test_function_transformer.py @@ -451,13 +451,26 @@ def test_set_output_func(): assert isinstance(X_trans, pd.DataFrame) assert_array_equal(X_trans.columns, ["a", "b"]) - # If feature_names_out is not defined, then a warning is raised in - # `set_output` ft = FunctionTransformer(lambda x: 2 * x) - msg = "should return a DataFrame to follow the set_output API" - with pytest.warns(UserWarning, match=msg): - ft.set_output(transform="pandas") + ft.set_output(transform="pandas") - X_trans = ft.fit_transform(X) + # no warning is raised when func returns a panda dataframe + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + X_trans = ft.fit_transform(X) assert isinstance(X_trans, pd.DataFrame) assert_array_equal(X_trans.columns, ["a", "b"]) + + # Warning is raised when func returns a ndarray + ft_np = FunctionTransformer(lambda x: np.asarray(x)) + ft_np.set_output(transform="pandas") + + msg = "When `set_output` is configured to be 'pandas'" + with pytest.warns(UserWarning, match=msg): + ft_np.fit_transform(X) + + # default transform does not warn + ft_np.set_output(transform="default") + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + ft_np.fit_transform(X)