From 0048b51c3895e07f73976a3b3bbeff595e339936 Mon Sep 17 00:00:00 2001 From: iofall <50991099+iofall@users.noreply.github.com> Date: Fri, 18 Feb 2022 20:42:33 +0530 Subject: [PATCH 1/4] Remove validation from __init__ and set_params for ColumnTransformer Co-authored-by: iofall <50991099+iofall@users.noreply.github.com> Co-authored-by: arisayosh <15692997+arisayosh@users.noreply.github.com> --- sklearn/compose/_column_transformer.py | 16 +++++++++++----- sklearn/tests/test_common.py | 1 - sklearn/utils/metaestimators.py | 11 +++++++---- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index 68fc2086c3699..0520840328951 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -222,14 +222,20 @@ def _transformers(self): of get_params via BaseComposition._get_params which expects lists of tuples of len 2. """ - return [(name, trans) for name, trans, _ in self.transformers] + try: + return [(name, trans) for name, trans, _ in self.transformers] + except (TypeError, ValueError): + return self.transformers @_transformers.setter def _transformers(self, value): - self.transformers = [ - (name, trans, col) - for ((name, trans), (_, _, col)) in zip(value, self.transformers) - ] + try: + self.transformers = [ + (name, trans, col) + for ((name, trans), (_, _, col)) in zip(value, self.transformers) + ] + except (TypeError, ValueError): + self.transformers = value def get_params(self, deep=True): """Get parameters for this estimator. diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 350e1e95d9882..9bbd1ac8c57a8 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -415,7 +415,6 @@ def test_transformers_get_feature_names_out(transformer): VALIDATE_ESTIMATOR_INIT = [ - "ColumnTransformer", "SGDOneClassSVM", "TheilSenRegressor", "TweedieRegressor", diff --git a/sklearn/utils/metaestimators.py b/sklearn/utils/metaestimators.py index 944b54f062c55..3ed520730b01f 100644 --- a/sklearn/utils/metaestimators.py +++ b/sklearn/utils/metaestimators.py @@ -56,10 +56,13 @@ def _set_params(self, attr, **params): items = getattr(self, attr) if isinstance(items, list) and items: # Get item names used to identify valid names in params - item_names, _ = zip(*items) - for name in list(params.keys()): - if "__" not in name and name in item_names: - self._replace_estimator(attr, name, params.pop(name)) + try: + item_names, _ = zip(*items) + for name in list(params.keys()): + if "__" not in name and name in item_names: + self._replace_estimator(attr, name, params.pop(name)) + except TypeError: + pass # 3. Step parameters and other initialisation arguments super().set_params(**params) From 1c02aef161e03d1b36ae627e3f60b8b643c3f475 Mon Sep 17 00:00:00 2001 From: iofall <50991099+iofall@users.noreply.github.com> Date: Fri, 18 Feb 2022 21:24:23 +0530 Subject: [PATCH 2/4] Add changelog for PR 22537 Co-authored-by: iofall <50991099+iofall@users.noreply.github.com> Co-authored-by: arisayosh <15692997+arisayosh@users.noreply.github.com> --- doc/whats_new/v1.1.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 91aa2a2859fd8..2c1b23dbc41ad 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -158,6 +158,13 @@ Changelog `-1` and the original warning message is shown. :pr:`22217` by :user:`Meekail Zain `. +:mod:`sklearn.compose` +...................... + +- |Fix| :class:`compose.ColumnTransformer` now removes validation errors from + `__init__` and `set_params` methods. + :pr:`22537` by :user:`iofall ` and :user:`Arisa Y. `. + :mod:`sklearn.cross_decomposition` .................................. From a647c9cc21676ac60ee55635c9553dc90e913db4 Mon Sep 17 00:00:00 2001 From: iofall <50991099+iofall@users.noreply.github.com> Date: Sat, 19 Feb 2022 03:34:08 +0530 Subject: [PATCH 3/4] Add suppress instead of try-except Co-authored-by: iofall <50991099+iofall@users.noreply.github.com> Co-authored-by: arisayosh <15692997+arisayosh@users.noreply.github.com> --- sklearn/utils/metaestimators.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sklearn/utils/metaestimators.py b/sklearn/utils/metaestimators.py index 3ed520730b01f..63e010e970039 100644 --- a/sklearn/utils/metaestimators.py +++ b/sklearn/utils/metaestimators.py @@ -8,6 +8,7 @@ from operator import attrgetter from functools import update_wrapper import numpy as np +from contextlib import suppress from ..utils import _safe_indexing from ..utils._tags import _safe_tags @@ -56,13 +57,13 @@ def _set_params(self, attr, **params): items = getattr(self, attr) if isinstance(items, list) and items: # Get item names used to identify valid names in params - try: + # `zip` raises a TypeError when `items` does not contains + # elements of length 2 + with suppress(TypeError): item_names, _ = zip(*items) for name in list(params.keys()): if "__" not in name and name in item_names: self._replace_estimator(attr, name, params.pop(name)) - except TypeError: - pass # 3. Step parameters and other initialisation arguments super().set_params(**params) From b0697e4abaa11069910b01d489ae3b58f2a725f6 Mon Sep 17 00:00:00 2001 From: iofall <50991099+iofall@users.noreply.github.com> Date: Sat, 19 Feb 2022 03:36:14 +0530 Subject: [PATCH 4/4] Add smoke test value Co-authored-by: iofall <50991099+iofall@users.noreply.github.com> Co-authored-by: arisayosh <15692997+arisayosh@users.noreply.github.com> --- sklearn/tests/test_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 9bbd1ac8c57a8..44a2938775981 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -437,7 +437,7 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator): if param.kind != Parameter.VAR_KEYWORD ] - smoke_test_values = [-1, 3.0, "helloworld", np.array([1.0, 4.0]), {}, []] + smoke_test_values = [-1, 3.0, "helloworld", np.array([1.0, 4.0]), [1], {}, []] for value in smoke_test_values: new_params = {key: value for key in params}