8000 ENH Improve warnings if func returns a dataframe in FunctionTransformer by thomasjpfan · Pull Request #26944 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

ENH Improve warnings if func returns a dataframe in FunctionTransformer #26944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,13 @@ Changelog
- |Enhancement| Added `neg_root_mean_squared_log_error_scorer` as scorer
:pr:`26734` by :user:`Alejandro Martin Gil <101AlexMartin>`.

:mod:`sklearn.preprocessing`
............................

- |Enhancement| Improves warnings in :class:`preprocessing.FunctionTransfomer` when
`func` returns a pandas dataframe and the output is configured to be pandas.
:pr:`26944` by `Thomas Fan`_.

:mod:`sklearn.model_selection`
..............................

Expand Down
28 changes: 19 additions & 9 deletions sklearn/preprocessing/_function_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

from ..base import BaseEstimator, TransformerMixin, _fit_context
from ..utils._param_validation import StrOptions
from ..utils._set_output import _get_output_config
from ..utils.metaestimators import available_if
from ..utils.validation import (
_allclose_dense_sparse,
_check_feature_names_in,
_is_pandas_df,
check_array,
)

Expand Down Expand Up @@ -237,7 +239,20 @@ def transform(self, X):
Transformed input.
"""
X = self._check_input(X, reset=False)
return self._transform(X, func=self.func, kw_args=self.kw_args)
out = self._transform(X, func=self.func, kw_args=self.kw_args)

output_config = _get_output_config("transform", self)["dense"]
if (
output_config == "pandas"
and self.feature_names_out is None
and not _is_pandas_df(out)
):
warnings.warn(
"When `set_output` is configured to be 'pandas', `func` should return "
"a DataFrame to follow the `set_output` API or `feature_names_out` "
"should be defined."
)
return out

def inverse_transform(self, X):
"""Transform X using the inverse function.
Expand Down Expand Up @@ -338,13 +353,8 @@ def set_output(self, *, transform=None):
self : estimator instance
Estimator instance.
"""
if hasattr(super(), "set_output"):
return super().set_output(transform=transform)

if transform == "pandas" and self.feature_names_out is None:
warnings.warn(
'With transform="pandas", `func` should return a DataFrame to follow'
" the set_output API."
)
if not hasattr(self, "_sklearn_output_config"):
self._sklearn_output_config = {}

self._sklearn_output_config["transform"] = transform
return self
25 changes: 19 additions & 6 deletions sklearn/preprocessing/tests/test_function_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,13 +451,26 @@ def test_set_output_func():
assert isinstance(X_trans, pd.DataFrame)
assert_array_equal(X_trans.columns, ["a", "b"])

# If feature_names_out is not defined, then a warning is raised in
# `set_output`
ft = FunctionTransformer(lambda x: 2 * x)
msg = "should return a DataFrame to follow the set_output API"
with pytest.warns(UserWarning, match=msg):
ft.set_output(transform="pandas")
ft.set_output(transform="pandas")

X_trans = ft.fit_transform(X)
# no warning is raised when func returns a panda dataframe
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
X_trans = ft.fit_transform(X)
assert isinstance(X_trans, pd.DataFrame)
assert_array_equal(X_trans.columns, ["a", "b"])

# Warning is raised when func returns a ndarray
ft_np = FunctionTransformer(lambda x: np.asarray(x))
ft_np.set_output(transform="pandas")

msg = "When `set_output` is configured to be 'pandas'"
with pytest.warns(UserWarning, match=msg):
ft_np.fit_transform(X)

# default transform does not warn
ft_np.set_output(transform="default")
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
ft_np.fit_transform(X)
0