10BC0 [MRG+1] break the tie in Meanshift in case cluster intensities are th… · scikit-learn/scikit-learn@a9c6ad9 · GitHub
[go: up one dir, main page]

Skip to content

Commit a9c6ad9

Browse files
adrinjalaliogrisel
authored andcommitted
[MRG+1] break the tie in Meanshift in case cluster intensities are the same (#11901)
1 parent 242410f commit a9c6ad9

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

doc/whats_new/v0.20.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ parameters, may produce different models from the previous version. This often
6363
occurs due to changes in the modelling logic (bug fixes or enhancements), or in
6464
random sampling procedures.
6565

66+
- :class:`cluster.MeanShift` (bug fix)
6667
- :class:`decomposition.IncrementalPCA` in Python 2 (bug fix)
6768
- :class:`decomposition.SparsePCA` (bug fix)
6869
- :class:`ensemble.GradientBoostingClassifier` (bug fix affecting feature importances)
@@ -151,6 +152,11 @@ Support for Python 3.3 has been officially dropped.
151152
``n_iter_`` attribute in the docstring of :class:`cluster.KMeans`.
152153
:issue:`11353` by :user:`Jeremie du Boisberranger <jeremiedbb>`.
153154

155+
- |Fix| Fixed a bug in :func:`cluster.mean_shift` where the assigned labels
156+
were not deterministic if there were multiple clusters with the same
157+
intensities.
158+
:issue:`11901` by :user:`Adrin Jalali <adrinjalali>`.
159+
154160
- |API| Deprecate ``pooling_func`` unused parameter in
155161
:class:`cluster.AgglomerativeClustering`.
156162
:issue:`9875` by :user:`Kumar Ashutosh <thechargedneutron>`.

sklearn/cluster/mean_shift_.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,10 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False,
215215
# If the distance between two kernels is less than the bandwidth,
216216
# then we have to remove one because it is a duplicate. Remove the
217217
# one with fewer points.
218+
218219
sorted_by_intensity = sorted(center_intensity_dict.items(),
219-
key=lambda tup: tup[1], reverse=True)
220+
key=lambda tup: (tup[1], tup[0]),
221+
reverse=True)
220222
sorted_centers = np.array([tup[0] for tup in sorted_by_intensity])
221223
unique = np.ones(len(sorted_centers), dtype=np.bool)
222224
nbrs = NearestNeighbors(radius=bandwidth,
@@ -359,9 +361,9 @@ class MeanShift(BaseEstimator, ClusterMixin):
359361
... [4, 7], [3, 5], [3, 6]])
360362
>>> clustering = MeanShift(bandwidth=2).fit(X)
361363
>>> clustering.labels_
362-
array([0, 0, 0, 1, 1, 1])
364+
array([1, 1, 1, 0, 0, 0])
363365
>>> clustering.predict([[0, 0], [5, 5]])
364-
array([0, 1])
366+
array([1, 0])
365367
>>> clustering # doctest: +NORMALIZE_WHITESPACE
366368
MeanShift(bandwidth=2, bin_seeding=False, cluster_all=True, min_bin_freq=1,
367369
n_jobs=None, seeds=None)

sklearn/cluster/tests/test_mean_shift.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,18 @@ def test_unfitted():
101101
assert_false(hasattr(ms, "labels_"))
102102

103103

104+
def test_cluster_intensity_tie():
105+
X = np.array([[1, 1], [2, 1], [1, 0],
106+
[4, 7], [3, 5], [3, 6]])
107+
c1 = MeanShift(bandwidth=2).fit(X)
108+
109+
X = np.array([[4, 7], [3, 5], [3, 6],
110+
[1, 1], [2, 1], [1, 0]])
111+
c2 = MeanShift(bandwidth=2).fit(X)
112+
assert_array_equal(c1.labels_, [1, 1, 1, 0, 0, 0])
113+
assert_array_equal(c2.labels_, [0, 0, 0, 1, 1, 1])
114+
115+
104116
def test_bin_seeds():
105117
# Test the bin seeding technique which can be used in the mean shift
106118
# algorithm

0 commit comments

Comments
 (0)
0