8000 FEA Adds FeatureUnion.__getitem__ to access transformers (#25093) · scikit-learn/scikit-learn@f7e6977 · GitHub
[go: up one dir, main page]

Skip to content

Commit f7e6977

Browse files
FEA Adds FeatureUnion.__getitem__ to access transformers (#25093)
Co-authored-by: Meekail Zain <34613774+Micky774@users.noreply.github.com>
1 parent 1f9dc71 commit f7e6977

File tree

3 files changed

+41
-0
lines changed

3 files changed

+41
-0
lines changed

doc/whats_new/v1.3.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ Changelog
3636
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
3737
where 123456 is the *pull request* number, not the issue number.
3838
39+
:mod:`sklearn.pipeline`
40+
.......................
41+
- |Feature| :class:`pipeline.FeatureUnion` can now use indexing notation (e.g.
42+
`feature_union["scalar"]`) to access transformers by name. :pr:`25093` by
43+
`Thomas Fan`_.
44+
3945
Code and Documentation Contributors
4046
-----------------------------------
4147

sklearn/pipeline.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,6 +1286,12 @@ def _sk_visual_block_(self):
12861286
names, transformers = zip(*self.transformer_list)
12871287
return _VisualBlock("parallel", transformers, names=names)
12881288

1289+
def __getitem__(self, name):
1290+
"""Return transformer with name."""
1291+
if not isinstance(name, str):
1292+
raise KeyError("Only string keys are supported")
1293+
return self.named_transformers[name]
1294+
12891295

12901296
def make_union(*transformers, n_jobs=None, verbose=False):
12911297
"""Construct a FeatureUnion from the given transformers.

sklearn/tests/test_pipeline.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1658,3 +1658,32 @@ def test_feature_union_set_output():
16581658
assert isinstance(X_trans, pd.DataFrame)
16591659
assert_array_equal(X_trans.columns, union.get_feature_names_out())
16601660
assert_array_equal(X_trans.index, X_test.index)
1661+
1662+
1663+
def test_feature_union_getitem():
1664+
"""Check FeatureUnion.__getitem__ returns expected results."""
1665+
scalar = StandardScaler()
1666+
pca = PCA()
1667+
union = FeatureUnion(
1668+
[
1669+
("scalar", scalar),
1670+
("pca", pca),
1671+
("pass", "passthrough"),
1672+
("drop_me", "drop"),
1673+
]
1674+
)
1675+
assert union["scalar"] is scalar
1676+
assert union["pca"] is pca
1677+
assert union["pass"] == "passthrough"
1678+
assert union["drop_me"] == "drop"
1679+
1680+
1681+
@pytest.mark.parametrize("key", [0, slice(0, 2)])
1682+
def test_feature_union_getitem_error(key):
1683+
"""Raise error when __getitem__ gets a non-string input."""
1684+
1685+
union = FeatureUnion([("scalar", StandardScaler()), ("pca", PCA())])
1686+
1687+
msg = "Only string keys are supported"
1688+
with pytest.raises(KeyError, match=msg):
1689+
union[key]

0 commit comments

Comments
 (0)
0