diff --git a/sklearn/_config.py b/sklearn/_config.py index 05549c88a9ddc..af1342d69c5b4 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -8,18 +8,19 @@ from contextlib import contextmanager as contextmanager _global_config = { + "array_api_dispatch": False, "assume_finite": bool(os.environ.get("SKLEARN_ASSUME_FINITE", False)), - "working_memory": int(os.environ.get("SKLEARN_WORKING_MEMORY", 1024)), - "print_changed_only": True, "display": "diagram", + "enable_cython_pairwise_dist": True, + "enable_metadata_routing": False, + "inverse_transform_output": "default", "pairwise_dist_chunk_size": int( os.environ.get("SKLEARN_PAIRWISE_DIST_CHUNK_SIZE", 256) ), - "enable_cython_pairwise_dist": True, - "array_api_dispatch": False, - "transform_output": "default", - "enable_metadata_routing": False, + "print_changed_only": True, "skip_parameter_validation": False, + "transform_output": "default", + "working_memory": int(os.environ.get("SKLEARN_WORKING_MEMORY", 1024)), } _threadlocal = threading.local() @@ -66,6 +67,7 @@ def set_config( enable_cython_pairwise_dist=None, array_api_dispatch=None, transform_output=None, + inverse_transform_output=None, enable_metadata_routing=None, skip_parameter_validation=None, ): @@ -205,6 +207,8 @@ def set_config( local_config["array_api_dispatch"] = array_api_dispatch if transform_output is not None: local_config["transform_output"] = transform_output + if inverse_transform_output is not None: + local_config["inverse_transform_output"] = inverse_transform_output if enable_metadata_routing is not None: local_config["enable_metadata_routing"] = enable_metadata_routing if skip_parameter_validation is not None: @@ -222,6 +226,7 @@ def config_context( enable_cython_pairwise_dist=None, array_api_dispatch=None, transform_output=None, + inverse_transform_output=None, enable_metadata_routing=None, skip_parameter_validation=None, ): @@ -366,6 +371,7 @@ def config_context( enable_cython_pairwise_dist=enable_cython_pairwise_dist, array_api_dispatch=array_api_dispatch, transform_output=transform_output, + inverse_transform_output=inverse_transform_output, enable_metadata_routing=enable_metadata_routing, skip_parameter_validation=skip_parameter_validation, ) diff --git a/sklearn/cluster/_feature_agglomeration.py b/sklearn/cluster/_feature_agglomeration.py index 1983aae00ecbb..1a59b91d1a391 100644 --- a/sklearn/cluster/_feature_agglomeration.py +++ b/sklearn/cluster/_feature_agglomeration.py @@ -19,7 +19,7 @@ # Mixin class for feature agglomeration. -class AgglomerationTransform(TransformerMixin): +class AgglomerationTransform(TransformerMixin, auto_wrap_output_keys=("transform",)): """ A class for feature agglomeration via the transform interface. """ @@ -84,6 +84,9 @@ def inverse_transform(self, X=None, *, Xt=None): A vector of size `n_samples` with the values of `Xred` assigned to each of the cluster of samples. """ + # because this method takes X and Xt(deprecated), + # auto_wrap_output is not configured for this method. + X = _deprecate_Xt_in_inverse_transform(X, Xt) check_is_fitted(self) diff --git a/sklearn/decomposition/_nmf.py b/sklearn/decomposition/_nmf.py index 6be97f2223fb5..21f02e90fbaf9 100644 --- a/sklearn/decomposition/_nmf.py +++ b/sklearn/decomposition/_nmf.py @@ -1132,7 +1132,13 @@ def non_negative_factorization( return W, H, n_iter -class _BaseNMF(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator, ABC): +class _BaseNMF( + ClassNamePrefixFeaturesOutMixin, + TransformerMixin, + BaseEstimator, + ABC, + auto_wrap_output_keys=("transform",), +): """Base class for NMF and MiniBatchNMF.""" # This prevents ``set_split_inverse_transform`` to be generated for the @@ -1318,6 +1324,8 @@ def inverse_transform(self, X=None, *, Xt=None): Returns a data matrix of the original shape. """ + # because this method takes X and Xt(deprecated), + # auto_wrap_output is not configured for this method. X = _deprecate_Xt_in_inverse_transform(X, Xt) check_is_fitted(self) diff --git a/sklearn/preprocessing/_discretization.py b/sklearn/preprocessing/_discretization.py index 6a6a739c469fa..a6f9369e8ce23 100644 --- a/sklearn/preprocessing/_discretization.py +++ b/sklearn/preprocessing/_discretization.py @@ -22,7 +22,9 @@ from ._encoders import OneHotEncoder -class KBinsDiscretizer(TransformerMixin, BaseEstimator): +class KBinsDiscretizer( + TransformerMixin, BaseEstimator, auto_wrap_output_keys=("transform",) +): """ Bin continuous data into intervals. @@ -412,6 +414,9 @@ def inverse_transform(self, X=None, *, Xt=None): Xinv : ndarray, dtype={np.float32, np.float64} Data in the original feature space. """ + # because this method takes X and Xt(deprecated), + # auto_wrap_output is not configured for this method. + X = _deprecate_Xt_in_inverse_transform(X, Xt) check_is_fitted(self) diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index fbdb0e2884d32..7e859b0348664 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -20,6 +20,7 @@ def test_config_context(): "pairwise_dist_chunk_size": 256, "enable_cython_pairwise_dist": True, "transform_output": "default", + "inverse_transform_output": "default", "enable_metadata_routing": False, "skip_parameter_validation": False, } @@ -38,6 +39,7 @@ def test_config_context(): "pairwise_dist_chunk_size": 256, "enable_cython_pairwise_dist": True, "transform_output": "default", + "inverse_transform_output": "default", "enable_metadata_routing": False, "skip_parameter_validation": False, } @@ -73,6 +75,7 @@ def test_config_context(): "pairwise_dist_chunk_size": 256, "enable_cython_pairwise_dist": True, "transform_output": "default", + "inverse_transform_output": "default", "enable_metadata_routing": False, "skip_parameter_validation": False, } diff --git a/sklearn/utils/_set_output.py b/sklearn/utils/_set_output.py index 963e5e5bf6d77..18059b493a726 100644 --- a/sklearn/utils/_set_output.py +++ b/sklearn/utils/_set_output.py @@ -341,27 +341,30 @@ def _auto_wrap_is_configured(estimator): is manually disabled. """ auto_wrap_output_keys = getattr(estimator, "_sklearn_auto_wrap_output_keys", set()) - return ( - hasattr(estimator, "get_feature_names_out") - and "transform" in auto_wrap_output_keys + return hasattr(estimator, "get_feature_names_out") and ( + "transform" in auto_wrap_output_keys + or "inverse_transform" in auto_wrap_output_keys ) class _SetOutputMixin: """Mixin that dynamically wraps methods to return container based on config. - Currently `_SetOutputMixin` wraps `transform` and `fit_transform` and configures - it based on `set_output` of the global configuration. + Currently `_SetOutputMixin` wraps `transform`, `fit_transform` and + `inverse_transform` and configures it based on `set_output` of the global + configuration. `set_output` is only defined if `get_feature_names_out` is defined and `auto_wrap_output_keys` is the default value. """ - def __init_subclass__(cls, auto_wrap_output_keys=("transform",), **kwargs): + def __init_subclass__( + cls, auto_wrap_output_keys=("transform", "inverse_transform"), **kwargs + ): super().__init_subclass__(**kwargs) # Dynamically wraps `transform` and `fit_transform` and configure it's - # output based on `set_output`. + # `inverse_transform` output based on `set_output`. if not ( isinstance(auto_wrap_output_keys, tuple) or auto_wrap_output_keys is None ): @@ -375,6 +378,7 @@ def __init_subclass__(cls, auto_wrap_output_keys=("transform",), **kwargs): method_to_key = { "transform": "transform", "fit_transform": "transform", + "inverse_transform": "inverse_transform", } cls._sklearn_auto_wrap_output_keys = set() @@ -390,7 +394,7 @@ def __init_subclass__(cls, auto_wrap_output_keys=("transform",), **kwargs): setattr(cls, method, wrapped_method) @available_if(_auto_wrap_is_configured) - def set_output(self, *, transform=None): + def set_output(self, *, inverse_transform=None, transform=None): """Set output container. See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py` @@ -414,13 +418,14 @@ def set_output(self, *, transform=None): self : estimator instance Estimator instance. """ - if transform is None: - return self - - if not hasattr(self, "_sklearn_output_config"): - self._sklearn_output_config = {} - - self._sklearn_output_config["transform"] = transform + if transform is not None or inverse_transform is not None: + if not hasattr(self, "_sklearn_output_config"): + self._sklearn_output_config = {} + + if inverse_transform is not None: + self._sklearn_output_config["inverse_transform"] = inverse_transform + if transform is not None: + self._sklearn_output_config["transform"] = transform return self