8000 FIX convergence criterion of MeanShift (#28951) · scikit-learn/scikit-learn@e796d0a · GitHub
[go: up one dir, main page]

Skip to content

Commit e796d0a

Browse files
akikunoogriseljeremiedbb
authored
FIX convergence criterion of MeanShift (#28951)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
1 parent 87ceec2 commit e796d0a

File tree

3 files changed

+13
-1
lines changed

3 files changed

+13
-1
lines changed

doc/whats_new/v1.5.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@ Changelog
183183
:mod:`sklearn.cluster`
184184
......................
185185

186+
- |Fix| The :class:`cluster.MeanShift` class now properly converges for constant data.
187+
:pr:`28951` by :user:`Akihiro Kuno <akikuno>`.
188+
186189
- |FIX| Create copy of precomputed sparse matrix within the `fit` method of
187190
:class:`~cluster.OPTICS` to avoid in-place modification of the sparse matrix.
188191
:pr:`28491` by :user:`Thanh Lam Dang <lamdang2k>`.

sklearn/cluster/_mean_shift.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def _mean_shift_single_seed(my_mean, X, nbrs, max_iter):
122122
my_mean = np.mean(points_within, axis=0)
123123
# If converged or at max_iter, adds the cluster
124124
if (
125-
np.linalg.norm(my_mean - my_old_mean) < stop_thresh
125+
np.linalg.norm(my_mean - my_old_mean) <= stop_thresh
126126
or completed_iterations == max_iter
127127
):
128128
break

sklearn/cluster/tests/test_mean_shift.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@
2525
)
2626

2727

28+
def test_convergence_of_1d_constant_data():
29+
# Test convergence using 1D constant data
30+
# Non-regression test for:
31+
# https://github.com/scikit-learn/scikit-learn/issues/28926
32+
model = MeanShift()
33+
n_iter = model.fit(np.ones(10).reshape(-1, 1)).n_iter_
34+
assert n_iter < model.max_iter
35+
36+
2837
def test_estimate_bandwidth():
2938
# Test estimate_bandwidth
3039
bandwidth = estimate_bandwidth(X, n_samples=200)

0 commit comments

Comments
 (0)
0