8000 FEAT Adds __getitem__ to ColumnTransformer (#27990) · punndcoder28/scikit-learn@678e399 · GitHub
[go: up one dir, main page]

Skip to content

Commit 678e399

Browse files
FEAT Adds __getitem__ to ColumnTransformer (scikit-learn#27990)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 0d4a88e commit 678e399

File tree

3 files changed

+34
-0
lines changed

3 files changed

+34
-0
lines changed

doc/whats_new/v1.5.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,9 @@ Thanks to everyone who has contributed to the maintenance and improvement of
3232
the project since version 1.4, including:
3333

3434
TODO: update at the time of the release.
35+
36+
:mod:`sklearn.compose`
37+
......................
38+
39+
- |Feature| A fitted :class:`compose.ColumnTransformer` now implements `__getitem__`
40+
which returns the fitted transformers by name. :pr:`27990` by `Thomas Fan`_.

sklearn/compose/_column_transformer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,16 @@ def _sk_visual_block_(self):
10821082
"parallel", transformers, names=names, name_details=name_details
10831083
)
10841084

1085+
def __getitem__(self, key):
1086+
try:
1087+
return self.named_transformers_[key]
1088+
except AttributeError as e:
1089+
raise TypeError(
1090+
"ColumnTransformer is subscriptable after it is fitted"
1091+
) from e
1092+
except KeyError as e:
1093+
raise KeyError(f"'{key}' is not a valid transformer name") from e
1094+
10851095
def _get_empty_routing(self):
10861096
"""Return empty routing.
10871097

sklearn/compose/tests/test_column_transformer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2301,6 +2301,24 @@ def test_dataframe_different_dataframe_libraries():
23012301
assert_array_equal(out_pd_in, X_test_np)
23022302

23032303

2304+
def test_column_transformer__getitem__():
2305+
"""Check __getitem__ for ColumnTransformer."""
2306+
X = np.array([[0, 1, 2], [3, 4, 5]])
2307+
ct = ColumnTransformer([("t1", Trans(), [0, 1]), ("t2", Trans(), [1, 2])])
2308+
2309+
msg = "ColumnTransformer is subscriptable after it is fitted"
2310+
with pytest.raises(TypeError, match=msg):
2311+
ct["t1"]
2312+
2313+
ct.fit(X)
2314+
assert ct["t1"] is ct.named_transformers_["t1"]
2315+
assert ct["t2"] is ct.named_transformers_["t2"]
2316+
2317+
msg = "'does_not_exist' is not a valid transformer name"
2318+
with pytest.raises(KeyError, match=msg):
2319+
ct["does_not_exist"]
2320+
2321+
23042322
# Metadata Routing Tests
23052323
# ======================
23062324

0 commit comments

Comments
 (0)
0