8000 FEAT add inverse_transform parameter to `_SetOutputMixin.set_output` by SuccessMoses · Pull Request #30376 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FEAT add inverse_transform parameter to _SetOutputMixin.set_output #30376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions sklearn/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,19 @@
from contextlib import contextmanager as contextmanager

_global_config = {
"array_api_dispatch": False,
"assume_finite": bool(os.environ.get("SKLEARN_ASSUME_FINITE", False)),
"working_memory": int(os.environ.get("SKLEARN_WORKING_MEMORY", 1024)),
"print_changed_only": True,
"display": "diagram",
"enable_cython_pairwise_dist": True,
"enable_metadata_routing": False,
"inverse_transform_output": "default",
"pairwise_dist_chunk_size": int(
os.environ.get("SKLEARN_PAIRWISE_DIST_CHUNK_SIZE", 256)
),
"enable_cython_pairwise_dist": True,
"array_api_dispatch": False,
"transform_output": "default",
"enable_metadata_routing": False,
"print_changed_only": True,
"skip_parameter_validation": False,
"transform_output": "default",
"working_memory": int(os.environ.get("SKLEARN_WORKING_MEMORY", 1024)),
}
_threadlocal = threading.local()

Expand Down Expand Up @@ -66,6 +67,7 @@ def set_config(
enable_cython_pairwise_dist=None,
array_api_dispatch=None,
transform_output=None,
inverse_transform_output=None,
enable_metadata_routing=None,
skip_parameter_validation=None,
):
Expand Down Expand Up @@ -205,6 +207,8 @@ def set_config(
local_config["array_api_dispatch"] = array_api_dispatch
if transform_output is not None:
local_config["transform_output"] = transform_output
if inverse_transform_output is not None:
local_config["inverse_transform_output"] = inverse_transform_output
if enable_metadata_routing is not None:
local_config["enable_metadata_routing"] = enable_metadata_routing
if skip_parameter_validation is not None:
Expand All @@ -222,6 +226,7 @@ def config_context(
enable_cython_pairwise_dist=None,
array_api_dispatch=None,
transform_output=None,
inverse_transform_output=None,
enable_metadata_routing=None,
skip_parameter_validation=None,
):
Expand Down Expand Up @@ -366,6 +371,7 @@ def config_context(
enable_cython_pairwise_dist=enable_cython_pairwise_dist,
array_api_dispatch=array_api_dispatch,
transform_output=transform_output,
inverse_transform_output=inverse_transform_output,
enable_metadata_routing=enable_metadata_routing,
skip_parameter_validation=skip_parameter_validation,
)
Expand Down
5 changes: 4 additions & 1 deletion sklearn/cluster/_feature_agglomeration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# Mixin class for feature agglomeration.


class AgglomerationTransform(TransformerMixin):
class AgglomerationTransform(TransformerMixin, auto_wrap_output_keys=("transform",)):
"""
A class for feature agglomeration via the transform interface.
"""
Expand Down Expand Up @@ -84,6 +84,9 @@ def inverse_transform(self, X=None, *, Xt=None):
A vector of size `n_samples` with the values of `Xred` assigned to
each of the cluster of samples.
"""
# because this method takes X and Xt(deprecated),
# auto_wrap_output is not configured for this method.

X = _deprecate_Xt_in_inverse_transform(X, Xt)

check_is_fitted(self)
Expand Down
10 changes: 9 additions & 1 deletion sklearn/decomposition/_nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,13 @@ def non_negative_factorization(
return W, H, n_iter


class _BaseNMF(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator, ABC):
class _BaseNMF(
ClassNamePrefixFeaturesOutMixin,
TransformerMixin,
BaseEstimator,
ABC,
auto_wrap_output_keys=("transform",),
):
"""Base class for NMF and MiniBatchNMF."""

# This prevents ``set_split_inverse_transform`` to be generated for the
Expand Down Expand Up @@ -1318,6 +1324,8 @@ def inverse_transform(self, X=None, *, Xt=None):
Returns a data matrix of the original shape.
"""

# because this method takes X and Xt(deprecated),
# auto_wrap_output is not configured for this method.
X = _deprecate_Xt_in_inverse_transform(X, Xt)

check_is_fitted(self)
Expand Down
7 changes: 6 additions & 1 deletion sklearn/preprocessing/_discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from ._encoders import OneHotEncoder


class KBinsDiscretizer(TransformerMixin, BaseEstimator):
class KBinsDiscretizer(
TransformerMixin, BaseEstimator, auto_wrap_output_keys=("transform",)
):
"""
Bin continuous data into intervals.

Expand Down Expand Up @@ -412,6 +414,9 @@ def inverse_transform(self, X=None, *, Xt=None):
Xinv : ndarray, dtype={np.float32, np.float64}
Data in the original feature space.
"""
# because this method takes X and Xt(deprecated),
# auto_wrap_output is not configured for this method.

X = _deprecate_Xt_in_inverse_transform(X, Xt)

check_is_fitted(self)
Expand Down
3 changes: 3 additions & 0 deletions sklearn/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def test_config_context():
"pairwise_dist_chunk_size": 256,
"enable_cython_pairwise_dist": True,
"transform_output": "default",
"inverse_transform_output": "default",
"enable_metadata_routing": False,
"skip_parameter_validation": False,
}
Expand All @@ -38,6 +39,7 @@ def test_config_context():
"pairwise_dist_chunk_size": 256,
"enable_cython_pairwise_dist": True,
"transform_output": "default",
"inverse_transform_output": "default",
"enable_metadata_routing": False,
"skip_parameter_validation": False,
}
Expand Down Expand Up @@ -73,6 +75,7 @@ def test_config_context():
"pairwise_dist_chunk_size": 256,
"enable_cython_pairwise_dist": True,
"transform_output": "default",
"inverse_transform_output": "default",
"enable_metadata_routing": False,
"skip_parameter_validation": False,
}
Expand Down
35 changes: 20 additions & 15 deletions sklearn/utils/_set_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,27 +341,30 @@ def _auto_wrap_is_configured(estimator):
is manually disabled.
"""
auto_wrap_output_keys = getattr(estimator, "_sklearn_auto_wrap_output_keys", set())
return (
hasattr(estimator, "get_feature_names_out")
and "transform" in auto_wrap_output_keys
return hasattr(estimator, "get_feature_names_out") and (
"transform" in auto_wrap_output_keys
or "inverse_transform" in auto_wrap_output_keys
)


class _SetOutputMixin:
"""Mixin that dynamically wraps methods to return container based on config.

Currently `_SetOutputMixin` wraps `transform` and `fit_transform` and configures
it based on `set_output` of the global configuration.
Currently `_SetOutputMixin` wraps `transform`, `fit_transform` and
`inverse_transform` and configures it based on `set_output` of the global
configuration.

`set_output` is only defined if `get_feature_names_out` is defined and
`auto_wrap_output_keys` is the default value.
"""

def __init_subclass__(cls, auto_wrap_output_keys=("transform",), **kwargs):
def __init_subclass__(
cls, auto_wrap_output_keys=("transform", "inverse_transform"), **kwargs
):
super().__init_subclass__(**kwargs)

# Dynamically wraps `transform` and `fit_transform` and configure it's
# output based on `set_output`.
# `inverse_transform` output based on `set_output`.
if not (
isinstance(auto_wrap_output_keys, tuple) or auto_wrap_output_keys is None
):
Expand All @@ -375,6 +378,7 @@ def __init_subclass__(cls, auto_wrap_output_keys=("transform",), **kwargs):
method_to_key = {
"transform": "transform",
"fit_transform": "transform",
"inverse_transform": "inverse_transform",
}
cls._sklearn_auto_wrap_output_keys = set()

Expand All @@ -390,7 +394,7 @@ def __init_subclass__(cls, auto_wrap_output_keys=("transform",), **kwargs):
setattr(cls, method, wrapped_method)

@available_if(_auto_wrap_is_configured)
def set_output(self, *, transform=None):
def set_output(self, *, inverse_transform=None, transform=None):
"""Set output container.

See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py`
Expand All @@ -414,13 +418,14 @@ def set_output(self, *, transform=None):
self : estimator instance
Estimator instance.
"""
if transform is None:
return self

if not hasattr(self, "_sklearn_output_config"):
self._sklearn_output_config = {}

self._sklearn_output_config["transform"] = transform
if transform is not None or inverse_transform is not None:
if not hasattr(self, "_sklearn_output_config"):
self._sklearn_output_config = {}

if inverse_transform is not None:
self._sklearn_output_config["inverse_transform"] = inverse_transform
if transform is not None:
self._sklearn_output_config["transform"] = transform
return self


Expand Down
Loading
0