diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 91aa2a2859fd8..2c1b23dbc41ad 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -158,6 +158,13 @@ Changelog `-1` and the original warning message is shown. :pr:`22217` by :user:`Meekail Zain `. +:mod:`sklearn.compose` +...................... + +- |Fix| :class:`compose.ColumnTransformer` now removes validation errors from + `__init__` and `set_params` methods. + :pr:`22537` by :user:`iofall ` and :user:`Arisa Y. `. + :mod:`sklearn.cross_decomposition` .................................. diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index 68fc2086c3699..0520840328951 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -222,14 +222,20 @@ def _transformers(self): of get_params via BaseComposition._get_params which expects lists of tuples of len 2. """ - return [(name, trans) for name, trans, _ in self.transformers] + try: + return [(name, trans) for name, trans, _ in self.transformers] + except (TypeError, ValueError): + return self.transformers @_transformers.setter def _transformers(self, value): - self.transformers = [ - (name, trans, col) - for ((name, trans), (_, _, col)) in zip(value, self.transformers) - ] + try: + self.transformers = [ + (name, trans, col) + for ((name, trans), (_, _, col)) in zip(value, self.transformers) + ] + except (TypeError, ValueError): + self.transformers = value def get_params(self, deep=True): """Get parameters for this estimator. diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 350e1e95d9882..44a2938775981 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -415,7 +415,6 @@ def test_transformers_get_feature_names_out(transformer): VALIDATE_ESTIMATOR_INIT = [ - "ColumnTransformer", "SGDOneClassSVM", "TheilSenRegressor", "TweedieRegressor", @@ -438,7 +437,7 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator): if param.kind != Parameter.VAR_KEYWORD ] - smoke_test_values = [-1, 3.0, "helloworld", np.array([1.0, 4.0]), {}, []] + smoke_test_values = [-1, 3.0, "helloworld", np.array([1.0, 4.0]), [1], {}, []] for value in smoke_test_values: new_params = {key: value for key in params} diff --git a/sklearn/utils/metaestimators.py b/sklearn/utils/metaestimators.py index 944b54f062c55..63e010e970039 100644 --- a/sklearn/utils/metaestimators.py +++ b/sklearn/utils/metaestimators.py @@ -8,6 +8,7 @@ from operator import attrgetter from functools import update_wrapper import numpy as np +from contextlib import suppress from ..utils import _safe_indexing from ..utils._tags import _safe_tags @@ -56,10 +57,13 @@ def _set_params(self, attr, **params): items = getattr(self, attr) if isinstance(items, list) and items: # Get item names used to identify valid names in params - item_names, _ = zip(*items) - for name in list(params.keys()): - if "__" not in name and name in item_names: - self._replace_estimator(attr, name, params.pop(name)) + # `zip` raises a TypeError when `items` does not contains + # elements of length 2 + with suppress(TypeError): + item_names, _ = zip(*items) + for name in list(params.keys()): + if "__" not in name and name in item_names: + self._replace_estimator(attr, name, params.pop(name)) # 3. Step parameters and other initialisation arguments super().set_params(**params)