8000 Distances for agglomerativeclustering (#17984) · thomasjpfan/scikit-learn@5a804fb · GitHub
[go: up one dir, main page]

Skip to content

Commit 5a804fb

Browse files
FrancescoCasalegnomriedmannEmilieDel
authored
Distances for agglomerativeclustering (scikit-learn#17984)
Co-authored-by: Michael Riedmann <michael_riedmann@live.com> Co-authored-by: Emilie Delattre <emilie.delattre@epfl.ch> Co-authored-by: EmilieDel <47669575+EmilieDel@users.noreply.github.com>
1 parent 3cb3d41 commit 5a804fb

File tree

3 files changed

+74
-8
lines changed

3 files changed

+74
-8
lines changed

doc/whats_new/v0.24.rst

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,20 @@ Changelog
7070
`init_size_`, are deprecated and will be removed in 0.26. :pr:`17864` by
7171
:user:`Jérémie du Boisberranger <jeremiedbb>`.
7272

73+
- |Fix| :class:`cluster.AgglomerativeClustering` has a new parameter
74+
`compute_distances`. When set to `True`, distances between clusters are
75+
computed and stored in the `distances_` attribute even when the parameter
76+
`distance_threshold` is not used. This new parameter is useful to produce
77+
dendrogram visualizations, but introduces a computational and memory
78+
overhead. :pr:`17984` by :user:`Michael Riedmann <mriedmann>`,
79+
:user:`Emilie Delattre <EmilieDel>`, and
80+
:user:`Francesco Casalegno <FrancescoCasalegno>`.
81+
7382
- |Fix| Fixed a bug in :class:`cluster.AffinityPropagation`, that
7483
gives incorrect clusters when the array dtype is float32.
75-
:pr:`17995` by :user:`Thomaz Santana <Wikilicious>` and :user:`Amanda Dsouza <amy12xx>`.
84+
:pr:`17995` by :user:`Thomaz Santana <Wikilicious>` and
85+
:user:`Amanda Dsouza <amy12xx>`.
86+
7687

7788
:mod:`sklearn.covariance`
7889
.........................

sklearn/cluster/_agglomerative.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,13 @@ class AgglomerativeClustering(ClusterMixin, BaseEstimator):
747747
748748
.. versionadded:: 0.21
749749
750+
compute_distances : bool, default=False
751+
Computes distances between clusters even if `distance_threshold` is not
752+
used. This can be used to make dendrogram visualization, but introduces
753+
a computational and memory overhead.
754+
755+
.. versionadded:: 0.24
756+
750757
Attributes
751758
----------
752759
n_clusters_ : int
@@ -776,7 +783,8 @@ class AgglomerativeClustering(ClusterMixin, BaseEstimator):
776783
777784
distances_ : array-like of shape (n_nodes-1,)
778785
Distances between nodes in the corresponding place in `children_`.
779-
Only computed if distance_threshold is not None.
786+
Only computed if `distance_threshold` is used or `compute_distances`
787+
is set to `True`.
780788
781789
Examples
782790
--------
@@ -795,14 +803,16 @@ class AgglomerativeClustering(ClusterMixin, BaseEstimator):
795803
def __init__(self, n_clusters=2, *, affinity="euclidean",
796804
memory=None,
797805
connectivity=None, compute_full_tree='auto',
798-
linkage='ward', distance_threshold=None):
806+
linkage='ward', distance_threshold=None,
807+
compute_distances=False):
799808
self.n_clusters = n_clusters
800809
self.distance_threshold = distance_threshold
801810
self.memory = memory
802811
self.connectivity = connectivity
803812
self.compute_full_tree = compute_full_tree
804813
self.linkage = linkage
805814
self.affinity = affinity
815+
self.compute_distances = compute_distances
806816

807817
def fit(self, X, y=None):
808818
"""Fit the hierarchical clustering from features, or distance matrix.
@@ -879,7 +889,10 @@ def fit(self, X, y=None):
879889

880890
distance_threshold = self.distance_threshold
881891

882-
return_distance = distance_threshold is not None
892+
return_distance = (
893+
(distance_threshold is not None) or self.compute_distances
894+
)
895+
883896
out = memory.cache(tree_builder)(X, connectivity=connectivity,
884897
n_clusters=n_clusters,
885898
return_distance=return_distance,
@@ -891,9 +904,11 @@ def fit(self, X, y=None):
891904

892905
if return_distance:
893906
self.distances_ = out[-1]
907+
908+
if self.distance_threshold is not None: # distance_threshold is used
894909
self.n_clusters_ = np.count_nonzero(
895910
self.distances_ >= distance_threshold) + 1
896-
else:
911+
else: # n_clusters is used
897912
self.n_clusters_ = self.n_clusters
898913

899914
# Cut the tree
@@ -999,6 +1014,13 @@ class FeatureAgglomeration(AgglomerativeClustering, AgglomerationTransform):
9991014
10001015
.. versionadded:: 0.21
10011016
1017+
compute_distances : bool, default=False
1018+
Computes distances between clusters even if `distance_threshold` is not
1019+
used. This can be used to make dendrogram visualization, but introduces
1020+
a computational and memory overhead.
1021+
1022+
.. versionadded:: 0.24
1023+
10021024
Attributes
10031025
----------
10041026
n_clusters_ : int
@@ -1028,7 +1050,8 @@ class FeatureAgglomeration(AgglomerativeClustering, AgglomerationTransform):
10281050
10291051
distances_ : array-like of shape (n_nodes-1,)
10301052
Distances between nodes in the corresponding place in `children_`.
1031-
Only computed if distance_threshold is not None.
1053+
Only computed if `distance_threshold` is used or `compute_distances`
1054+
is set to `True`.
10321055
10331056
Examples
10341057
--------
@@ -1049,11 +1072,12 @@ def __init__(self, n_clusters=2, *, affinity="euclidean",
10491072
memory=None,
10501073
connectivity=None, compute_full_tree='auto',
10511074
linkage='ward', pooling_func=np.mean,
1052-
distance_threshold=None):
1075+
distance_threshold=None, compute_distances=False):
10531076
super().__init__(
10541077
n_clusters=n_clusters, memory=memory, connectivity=connectivity,
10551078
compute_full_tree=compute_full_tree, linkage=linkage,
1056-
affinity=affinity, distance_threshold=distance_threshold)
1079+
affinity=affinity, distance_threshold=distance_threshold,
1080+
compute_distances=compute_distances)
10571081
self.pooling_func = pooling_func
10581082

10591083
def fit(self, X, y=None, **params):

sklearn/cluster/tests/test_hierarchical.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,37 @@ def test_zero_cosine_linkage_tree():
143143
assert_raise_message(ValueError, msg, linkage_tree, X, affinity='cosine')
144144

145145

146+
@pytest.mark.parametrize('n_clusters, distance_threshold',
147+
[(None, 0.5), (10, None)])
148+
@pytest.mark.parametrize('compute_distances', [True, False])
149+
@pytest.mark.parametrize('linkage', ["ward", "complete", "average", "single"])
150+
def test_agglomerative_clustering_distances(n_clusters,
151+
compute_distances,
152+
distance_threshold,
153+
linkage):
154+
# Check that when `compute_distances` is True or `distance_threshold` is
155+
# given, the fitted model has an attribute `distances_`.
156+
rng = np.random.RandomState(0)
157+
mask = np.ones([10, 10], dtype=bool)
158+
n_samples = 100
159+
X = rng.randn(n_samples, 50)
160+
connectivity = grid_to_graph(*mask.shape)
161+
162+
clustering = AgglomerativeClustering(n_clusters=n_clusters,
163+
connectivity=connectivity,
164+
linkage=linkage,
165+
distance_threshold=distance_threshold,
166+
compute_distances=compute_distances)
167+
clustering.fit(X)
168+
if compute_distances or (distance_threshold is not None):
169+
assert hasattr(clustering, 'distances_')
170+
n_children = clustering.children_.shape[0]
171+
n_nodes = n_children + 1
172+
assert clustering.distances_.shape == (n_nodes-1, )
173+
else:
174+
assert not hasattr(clustering, 'distances_')
175+
176+
146177
def test_agglomerative_clustering():
147178
# Check that we obtain the correct number of clusters with
148179
# agglomerative clustering.

0 commit comments

Comments
 (0)
0