8000 FIX Remove validation from __init__ and set_params for ColumnTransfor… · scikit-learn/scikit-learn@ff1c6f3 · GitHub
[go: up one dir, main page]

Skip to content

Commit ff1c6f3

Browse files
iofallarisayosh
andauthored
FIX Remove validation from __init__ and set_params for ColumnTransformer (#22537)
Co-authored-by: iofall <50991099+iofall@users.noreply.github.com> Co-authored-by: arisayosh <15692997+arisayosh@users.noreply.github.com>
1 parent 6ab950e commit ff1c6f3

File tree

4 files changed

+27
-11
lines changed

4 files changed

+27
-11
lines changed

doc/whats_new/v1.1.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,13 @@ Changelog
171171
`-1` and the original warning message is shown.
172172
:pr:`22217` by :user:`Meekail Zain <micky774>`.
173173

174+
:mod:`sklearn.compose`
175+
......................
176+
177+
- |Fix| :class:`compose.ColumnTransformer` now removes validation errors from
178+
`__init__` and `set_params` methods.
179+
:pr:`22537` by :user:`iofall <iofall>` and :user:`Arisa Y. <arisayosh>`.
180+
174181
:mod:`sklearn.cross_decomposition`
175182
..................................
176183

sklearn/compose/_column_transformer.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,14 +222,20 @@ def _transformers(self):
222222
of get_params via BaseComposition._get_params which expects lists
223223
of tuples of len 2.
224224
"""
225-
return [(name, trans) for name, trans, _ in self.transformers]
225+
try:
226+
return [(name, trans) for name, trans, _ in self.transformers]
227+
except (TypeError, ValueError):
228+
return self.transformers
226229

227230
@_transformers.setter
228231
def _transformers(self, value):
229-
self.transformers = [
230-
(name, trans, col)
231-
for ((name, trans), (_, _, col)) in zip(value, self.transformers)
232-
]
232+
try:
233+
self.transformers = [
234+
(name, trans, col)
235+
for ((name, trans), (_, _, col)) in zip(value, self.transformers)
236+
]
237+
except (TypeError, ValueError):
238+
self.transformers = value
233239

234240
def get_params(self, deep=True):
235241
"""Get parameters for this estimator.

sklearn/tests/test_common.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,6 @@ def test_transformers_get_feature_names_out(transformer):
413413

414414

415415
VALIDATE_ESTIMATOR_INIT = [
416-
"ColumnTransformer",
417416
"SGDOneClassSVM",
418417
"TheilSenRegressor",
419418
"TweedieRegressor",
@@ -436,7 +435,7 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
436435
if param.kind != Parameter.VAR_KEYWORD
437436
]
438437

439-
smoke_test_values = [-1, 3.0, "helloworld", np.array([1.0, 4.0]), {}, []]
438+
smoke_test_values = [-1, 3.0, "helloworld", np.array([1.0, 4.0]), [1], {}, []]
440439
for value in smoke_test_values:
441440
new_params = {key: value for key in params}
442441

sklearn/utils/metaestimators.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from operator import attrgetter
99
from functools import update_wrapper
1010
import numpy as np
11+
from contextlib import suppress
1112

1213
from ..utils import _safe_indexing
1314
from ..utils._tags import _safe_tags
@@ -56,10 +57,13 @@ def _set_params(self, attr, **params):
5657
items = getattr(self, attr)
5758
if isinstance(items, list) and items:
5859
# Get item names used to identify valid names in params
59-
item_names, _ = zip(*items)
60-
for name in list(params.keys()):
61-
if "__" not in name and name in item_names:
62-
self._replace_estimator(attr, name, params.pop(name))
60+
# `zip` raises a TypeError when `items` does not contains
61+
# elements of length 2
62+
with suppress(TypeError):
63+
item_names, _ = zip(*items)
64+
for name in list(params.keys()):
65+
if "__" not in name and name in item_names:
66+
self._replace_estimator(attr, name, params.pop(name))
6367

6468
# 3. Step parameters and other initialisation arguments
6569
super().set_params(**params)

0 commit comments

Comments
 (0)
0