8000 FIX add max_iter, and func uses class (#15120) · crankycoder/scikit-learn@ac72a48 · GitHub
[go: up one dir, main page]

Skip to content

Commit ac72a48

Browse files
adrinjalalirth
authored andcommitted
FIX add max_iter, and func uses class (scikit-learn#15120)
1 parent 32d5a76 commit ac72a48

File tree

3 files changed

+111
-73
lines changed

3 files changed

+111
-73
lines changed

doc/whats_new/v0.22.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ Changelog
8989
producing Segmentation Fault on large arrays due to integer index overflow.
9090
:pr:`15057` by :user:`Vladimir Korolev <balodja>`.
9191

92+
- |Fix| :class:`~cluster.MeanShift` now accepts a :term:`max_iter` with a
93+
default value of 300 instead of always using the default 300. It also now
94+
exposes an ``n_iter_`` indicating the maximum number of iterations performed
95+
on each seed. :pr:`15120` by `Adrin Jalali`_.
96+
9297
:mod:`sklearn.compose`
9398
......................
9499

sklearn/cluster/mean_shift_.py

Lines changed: 93 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,9 @@ def _mean_shift_single_seed(my_mean, X, nbrs, max_iter):
101101
# If converged or at max_iter, adds the cluster
102102
if (np.linalg.norm(my_mean - my_old_mean) < stop_thresh or
103103
completed_iterations == max_iter):
104-
return tuple(my_mean), len(points_within)
104+
break
105105
completed_iterations += 1
106+
return tuple(my_mean), len(points_within), completed_iterations
106107

107108

108109
def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False,
@@ -178,72 +179,12 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False,
178179
<sphx_glr_auto_examples_cluster_plot_mean_shift.py>`.
179180
180181
"""
181-
182-
if bandwidth is None:
183-
bandwidth = estimate_bandwidth(X, n_jobs=n_jobs)
184-
elif bandwidth <= 0:
185-
raise ValueError("bandwidth needs to be greater than zero or None,"
186-
" got %f" % bandwidth)
187-
if seeds is None:
188-
if bin_seeding:
189-
seeds = get_bin_seeds(X, bandwidth, min_bin_freq)
190-
else:
191-
seeds = X
192-
n_samples, n_features = X.shape
193-
center_intensity_dict = {}
194-
195-
# We use n_jobs=1 because this will be used in nested calls under
196-
# parallel calls to _mean_shift_single_seed so there is no need for
197-
# for further parallelism.
198-
nbrs = NearestNeighbors(radius=bandwidth, n_jobs=1).fit(X)
199-
200-
# execute iterations on all seeds in parallel
201-
all_res = Parallel(n_jobs=n_jobs)(
202-
delayed(_mean_shift_single_seed)
203-
(seed, X, nbrs, max_iter) for seed in seeds)
204-
# copy results in a dictionary
205-
for i in range(len(seeds)):
206-
if all_res[i] is not None:
207-
center_intensity_dict[all_res[i][0]] = all_res[i][1]
208-
209-
if not center_intensity_dict:
210-
# nothing near seeds
211-
raise ValueError("No point was within bandwidth=%f of any seed."
212-
" Try a different seeding strategy \
213-
or increase the bandwidth."
214-
% bandwidth)
215-
216-
# POST PROCESSING: remove near duplicate points
217-
# If the distance between two kernels is less than the bandwidth,
218-
# then we have to remove one because it is a duplicate. Remove the
219-
# one with fewer points.
220-
221-
sorted_by_intensity = sorted(center_intensity_dict.items(),
222-
key=lambda tup: (tup[1], tup[0]),
223-
reverse=True)
224-
sorted_centers = np.array([tup[0] for tup in sorted_by_intensity])
225-
unique = np.ones(len(sorted_centers), dtype=np.bool)
226-
nbrs = NearestNeighbors(radius=bandwidth,
227-
n_jobs=n_jobs).fit(sorted_centers)
228-
for i, center in enumerate(sorted_centers):
229-
if unique[i]:
230-
neighbor_idxs = nbrs.radius_neighbors([center],
231-
return_distance=False)[0]
232-
unique[neighbor_idxs] = 0
233-
unique[i] = 1 # leave the current point as unique
234-
cluster_centers = sorted_centers[unique]
235-
236-
# ASSIGN LABELS: a point belongs to the cluster that it is closest to
237-
nbrs = NearestNeighbors(n_neighbors=1, n_jobs=n_jobs).fit(cluster_centers)
238-
labels = np.zeros(n_samples, dtype=np.int)
239-
distances, idxs = nbrs.kneighbors(X)
240-
if cluster_all:
241-
labels = idxs.flatten()
242-
else:
243-
labels.fill(-1)
244-
bool_selector = distances.flatten() <= bandwidth
245-
labels[bool_selector] = idxs.flatten()[bool_selector]
246-
return cluster_centers, labels
182+
model = MeanShift(bandwidth=bandwidth, seeds=seeds,
183+
min_bin_freq=min_bin_freq,
184+
bin_seeding=bin_seeding,
185+
cluster_all=cluster_all, n_jobs=n_jobs,
186+
max_iter=max_iter).fit(X)
187+
return model.cluster_centers_, model.labels_
247188

248189

249190
def get_bin_seeds(X, bin_size, min_bin_freq=1):
@@ -347,6 +288,12 @@ class MeanShift(ClusterMixin, BaseEstimator):
347288
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
348289
for more details.
349290
291+
max_iter : int, default=300
292+
Maximum number of iterations, per seed point before the clustering
293+
operation terminates (for that seed point), if has not converged yet.
294+
295+
.. versionadded:: 0.22
296+
350297
Attributes
351298
----------
352299
cluster_centers_ : array, [n_clusters, n_features]
@@ -355,6 +302,11 @@ class MeanShift(ClusterMixin, BaseEstimator):
355302
labels_ :
356303
Labels of each point.
357304
305+
n_iter_ : int
306+
Maximum number of iterations performed on each seed.
307+
308+
.. versionadded:: 0.22
309+
358310
Examples
359311
--------
360312
>>> from sklearn.cluster import MeanShift
@@ -395,13 +347,14 @@ class MeanShift(ClusterMixin, BaseEstimator):
395347
396348
"""
397349
def __init__(self, bandwidth=None, seeds=None, bin_seeding=False,
398-
min_bin_freq=1, cluster_all=True, n_jobs=None):
350+
min_bin_freq=1, cluster_all=True, n_jobs=None, max_iter=300):
399351
self.bandwidth = bandwidth
400352
self.seeds = seeds
401353
self.bin_seeding = bin_seeding
402354
self.cluster_all = cluster_all
403355
self.min_bin_freq = min_bin_freq
404356
self.n_jobs = n_jobs
357+
self.max_iter = max_iter
405358

406359
def fit(self, X, y=None):
407360
"""Perform clustering.
@@ -415,11 +368,78 @@ def fit(self, X, y=None):
415368
416369
"""
417370
X = check_array(X)
418-
self.cluster_centers_, self.labels_ = \
419-
mean_shift(X, bandwidth=self.bandwidth, seeds=self.seeds,
420-
min_bin_freq=self.min_bin_freq,
421-
bin_seeding=self.bin_seeding,
422-
cluster_all=self.cluster_all, n_jobs=self.n_jobs)
371+
bandwidth = self.bandwidth
372+
if bandwidth is None:
373+
bandwidth = estimate_bandwidth(X, n_jobs=self.n_jobs)
374+
elif bandwidth <= 0:
375+
raise ValueError("bandwidth needs to be greater than zero or None,"
376+
" got %f" % bandwidth)
377+
378+
seeds = self.seeds
379+
if seeds is None:
380+
if self.bin_seeding:
381+
seeds = get_bin_seeds(X, bandwidth, self.min_bin_freq)
382+
else:
383+
seeds = X
384+
n_samples, n_features = X.shape
385+
center_intensity_dict = {}
386+
387+
# We use n_jobs=1 because this will be used in nested calls under
388+
# parallel calls to _mean_shift_single_seed so there is no need for
389+
# for further parallelism.
390+
nbrs = NearestNeighbors(radius=bandwidth, n_jobs=1).fit(X)
391+
392+
# execute iterations on all seeds in parallel
393+
all_res = Parallel(n_jobs=self.n_jobs)(
394+
delayed(_mean_shift_single_seed)
395+
(seed, X, nbrs, self.max_iter) for seed in seeds)
396+
# copy results in a dictionary
397+
for i in range(len(seeds)):
398+
if all_res[i][1]: # i.e. len(points_within) > 0
399+
center_intensity_dict[all_res[i][0]] = all_res[i][1]
400+
401+
self.n_iter_ = max([x[2] for x in all_res])
402+
403+
if not center_intensity_dict:
404+
# nothing near seeds
405+
raise ValueError("No point was within bandwidth=%f of any seed."
406+
" Try a different seeding strategy \
407+
or increase the bandwidth."
408+
% bandwidth)
409+
410+
# POST PROCESSING: remove near duplicate points
411+
# If the distance between two kernels is less than the bandwidth,
412+
# then we have to remove one because it is a duplicate. Remove the
413+
# one with fewer points.
414+
415+
sorted_by_intensity = sorted(center_intensity_dict.items(),
416+
key=lambda tup: (tup[1], tup[0]),
417+
reverse=True)
418+
sorted_centers = np.array([tup[0] for tup in sorted_by_intensity])
419+
unique = np.ones(len(sorted_centers), dtype=np.bool)
420+
nbrs = NearestNeighbors(radius=bandwidth,
421+
n_jobs=self.n_jobs).fit(sorted_centers)
422+
for i, center in enumerate(sorted_centers):
423+
if unique[i]:
424+
neighbor_idxs = nbrs.radius_neighbors([center],
425+
return_distance=False)[0]
426+
unique[neighbor_idxs] = 0
427+
unique[i] = 1 # leave the current point as unique
428+
cluster_centers = sorted_centers[unique]
429+
430+
# ASSIGN LABELS: a point belongs to the cluster that it is closest to
431+
nbrs = NearestNeighbors(n_neighbors=1,
432+
n_jobs=self.n_jobs).fit(cluster_centers)
433+
labels = np.zeros(n_samples, dtype=np.int)
434+
distances, idxs = nbrs.kneighbors(X)
435+
if self.cluster_all:
436+
labels = idxs.flatten()
437+
else:
438+
labels.fill(-1)
439+
bool_selector = distances.flatten() <= bandwidth
440+
labels[bool_selector] = idxs.flatten()[bool_selector]
441+
442+
self.cluster_centers_, self.labels_ = cluster_centers, labels
423443
return self
424444

425445
def predict(self, X):

sklearn/cluster/tests/test_mean_shift.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,16 @@ def test_bin_seeds():
155155
cluster_std=0.1, random_state=0)
156156
test_bins = get_bin_seeds(X, 1)
157157
assert_array_equal(test_bins, [[0, 0], [1, 1]])
158+
159+
160+
@pytest.mark.parametrize('max_iter', [1, 100])
161+
def test_max_iter(max_iter):
162+
clusters1, _ = mean_shift(X, max_iter=max_iter)
163+
ms = MeanShift(max_iter=max_iter).fit(X)
164+
clusters2 = ms.cluster_centers_
165+
166+
assert ms.n_iter_ <= ms.max_iter
167+
assert len(clusters1) == len(clusters2)
168+
169+
for c1, c2 in zip(clusters1, clusters2):
170+
assert np.allclose(c1, c2)

0 commit comments

Comments
 (0)
0