8000 ENH Adds get_feature_names to cluster module (#22255) · scikit-learn/scikit-learn@5219b6f · GitHub
[go: up one dir, main page]

Skip to content

Commit 5219b6f

Browse files
authored
ENH Adds get_feature_names to cluster module (#22255)
1 parent de3373b commit 5219b6f

File tree

8 files changed

+73
-7
lines changed

8 files changed

+73
-7
lines changed

doc/whats_new/v1.1.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ Changelog
9595
See :func:`cluster.spectral_clustering` for more details.
9696
:pr:`21148` by :user:`Andrew Knyazev <lobpcg>`
9797

98+
- |Enhancement| Adds :term:`get_feature_names_out` to :class:`cluster.Birch`,
99+
:class:`cluster.FeatureAgglomeration`, :class:`cluster.KMeans`,
100+
:class:`cluster.MiniBatchKMeans`. :pr:`22255` by `Thomas Fan`_.
101+
98102
- |Efficiency| In :class:`cluster.KMeans`, the default ``algorithm`` is now
99103
``"lloyd"`` which is the full classical EM-style algorithm. Both ``"auto"``
100104
and ``"full"`` are deprecated and will be removed in version 1.3. They are

sklearn/cluster/_agglomerative.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from scipy import sparse
1515
from scipy.sparse.csgraph import connected_components
1616

17-
from ..base import BaseEstimator, ClusterMixin
17+
from ..base import BaseEstimator, ClusterMixin, _ClassNamePrefixFeaturesOutMixin
1818
from ..metrics.pairwise import paired_distances
1919
from ..metrics import DistanceMetric
2020
from ..metrics._dist_metrics import METRIC_MAPPING
@@ -1054,7 +1054,9 @@ def fit_predict(self, X, y=None):
10541054
return super().fit_predict(X, y)
10551055

10561056

1057-
class FeatureAgglomeration(AgglomerativeClustering, AgglomerationTransform):
1057+
class FeatureAgglomeration(
1058+
_ClassNamePrefixFeaturesOutMixin, AgglomerativeClustering, AgglomerationTransform
1059+
):
10581060
"""Agglomerate features.
10591061
10601062
Recursively merges pair of clusters of features.
@@ -1236,6 +1238,7 @@ def fit(self, X, y=None):
12361238
"""
12371239
X = self._validate_data(X, ensure_min_features=2)
12381240
super()._fit(X.T)
1241+
self._n_features_out = self.n_clusters_
12391242
return self
12401243

12411244
@property

sklearn/cluster/_birch.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111

1212
from ..metrics import pairwise_distances_argmin
1313
from ..metrics.pairwise import euclidean_distances
14-
from ..base import TransformerMixin, ClusterMixin, BaseEstimator
14+
from ..base import (
15+
TransformerMixin,
16+
ClusterMixin,
17+
BaseEstimator,
18+
_ClassNamePrefixFeaturesOutMixin,
19+
)
1520
from ..utils.extmath import row_norms
1621
from ..utils import check_scalar, deprecated
1722
from ..utils.validation import check_is_fitted
@@ -342,7 +347,9 @@ def radius(self):
342347
return sqrt(max(0, sq_radius))
343348

344349

345-
class Birch(ClusterMixin, TransformerMixin, BaseEstimator):
350+
class Birch(
351+
_ClassNamePrefixFeaturesOutMixin, ClusterMixin, TransformerMixin, BaseEstimator
352+
):
346353
"""Implements the BIRCH clustering algorithm.
347354
348355
It is a memory-efficient, online-learning algorithm provided as an
@@ -599,6 +606,7 @@ def _fit(self, X, partial):
599606

600607
centroids = np.concatenate([leaf.centroids_ for leaf in self._get_leaves()])
601608
self.subcluster_centers_ = centroids
609+
self._n_features_out = self.subcluster_centers_.shape[0]
602610

603611
self._global_clustering(X)
604612
return self

sklearn/cluster/_kmeans.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
import numpy as np
1717
import scipy.sparse as sp
1818

19-
from ..base import BaseEstimator, ClusterMixin, TransformerMixin
19+
from ..base import (
20+
BaseEstimator,
21+
ClusterMixin,
22+
TransformerMixin,
23+
_ClassNamePrefixFeaturesOutMixin,
24+
)
2025
from ..metrics.pairwise import euclidean_distances
2126
from ..metrics.pairwise import _euclidean_distances
2227
from ..utils.extmath import row_norms, stable_cumsum
@@ -767,7 +772,9 @@ def _labels_inertia_threadpool_limit(
767772
return labels, inertia
768773

769774

770-
class KMeans(TransformerMixin, ClusterMixin, BaseEstimator):
775+
class KMeans(
776+
_ClassNamePrefixFeaturesOutMixin, TransformerMixin, ClusterMixin, BaseEstimator
777+
):
771778
"""K-Means clustering.
772779
773780
Read more in the :ref:`User Guide <k_means>`.
@@ -1240,6 +1247,7 @@ def fit(self, X, y=None, sample_weight=None):
12401247
)
12411248

12421249
self.cluster_centers_ = best_centers
1250+
self._n_features_out = self.cluster_centers_.shape[0]
12431251
self.labels_ = best_labels
12441252
self.inertia_ = best_inertia
12451253
self.n_iter_ = best_n_iter
@@ -2020,6 +2028,7 @@ def fit(self, X, y=None, sample_weight=None):
20202028
break
20212029

20222030
self.cluster_centers_ = centers
2031+
self._n_features_out = self.cluster_centers_.shape[0]
20232032

20242033
self.n_steps_ = i + 1
20252034
self.n_iter_ = int(np.ceil(((i + 1) * self._batch_size) / n_samples))
@@ -2134,6 +2143,7 @@ def partial_fit(self, X, y=None, sample_weight=None):
21342143
)
21352144

21362145
self.n_steps_ += 1
2146+
self._n_features_out = self.cluster_centers_.shape[0]
21372147

21382148
return self
21392149

sklearn/cluster/tests/test_birch.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,14 @@ def test_birch_params_validation(params, err_type, err_msg):
219219
X, _ = make_blobs(n_samples=80, centers=4)
220220
with pytest.raises(err_type, match=err_msg):
221221
Birch(**params).fit(X)
222+
223+
224+
def test_feature_names_out():
225+
"""Check `get_feature_names_out` for `Birch`."""
226+
X, _ = make_blobs(n_samples=80, n_features=4, random_state=0)
227+
brc = Birch(n_clusters=4)
228+
brc.fit(X)
229+
n_clusters = brc.subcluster_centers_.shape[0]
230+
231+
names_out = brc.get_feature_names_out()
232+
assert_array_equal([f"birch{i}" for i in range(n_clusters)], names_out)

sklearn/cluster/tests/test_feature_agglomeration.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
# Authors: Sergul Aydore 2017
55
import numpy as np
66
import pytest
7+
8+
from numpy.testing import assert_array_equal
79
from sklearn.cluster import FeatureAgglomeration
810
from sklearn.utils._testing import assert_array_almost_equal
11+
from sklearn.datasets import make_blobs
912

1013

1114
def test_feature_agglomeration():
@@ -41,3 +44,16 @@ def test_feature_agglomeration():
4144

4245
assert_array_almost_equal(agglo_mean.transform(X_full_mean), Xt_mean)
4346
assert_array_almost_equal(agglo_median.transform(X_full_median), Xt_median)
47+
48+
49+
def test_feature_agglomeration_feature_names_out():
50+
"""Check `get_feature_names_out` for `FeatureAgglomeration`."""
51+
X, _ = make_blobs(n_features=6, random_state=0)
52+
agglo = FeatureAgglomeration(n_clusters=3)
53+
agglo.fit(X)
54+
n_clusters = agglo.n_clusters_
55+
56+
names_out = agglo.get_feature_names_out()
57+
assert_array_equal(
58+
[f"featureagglomeration{i}" for i in range(n_clusters)], names_out
59+
)

sklearn/cluster/tests/test_k_means.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,3 +1205,18 @@ def test_is_same_clustering():
12051205
# mapped to a same value
12061206
labels3 = np.array([1, 0, 0, 2, 2, 0, 2, 1], dtype=np.int32)
12071207
assert not _is_same_clustering(labels1, labels3, 3)
1208+
1209+
1210+
@pytest.mark.parametrize(
1211+
"Klass, method",
1212+
[(KMeans, "fit"), (MiniBatchKMeans, "fit"), (MiniBatchKMeans, "partial_fit")],
1213+
)
1214+
def test_feature_names_out(Klass, method):
1215+
"""Check `feature_names_out` for `KMeans` and `MiniBatchKMeans`."""
1216+
class_name = Klass.__name__.lower()
1217+
kmeans = Klass()
1218+
getattr(kmeans, method)(X)
1219+
n_clusters = kmeans.cluster_centers_.shape[0]
1220+
1221+
names_out = kmeans.get_feature_names_out()
1222+
assert_array_equal([f"{class_name}{i}" for i in range(n_clusters)], names_out)

sklearn/tests/test_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,6 @@ def test_pandas_column_name_consistency(estimator):
380380
# TODO: As more modules support get_feature_names_out they should be removed
381381
# from this list to be tested
382382
GET_FEATURES_OUT_MODULES_TO_IGNORE = [
383-
"cluster",
384383
"ensemble",
385384
"isotonic",
386385
"kernel_approximation",

0 commit comments

Comments
 (0)
0