@@ -101,8 +101,9 @@ def _mean_shift_single_seed(my_mean, X, nbrs, max_iter):
101
101
# If converged or at max_iter, adds the cluster
102
102
if (np .linalg .norm (my_mean - my_old_mean ) < stop_thresh or
103
103
completed_iterations == max_iter ):
104
- return tuple ( my_mean ), len ( points_within )
104
+ break
105
105
completed_iterations += 1
106
+ return tuple (my_mean ), len (points_within ), completed_iterations
106
107
107
108
108
109
def mean_shift (X , bandwidth = None , seeds = None , bin_seeding = False ,
@@ -178,72 +179,12 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False,
178
179
<sphx_glr_auto_examples_cluster_plot_mean_shift.py>`.
179
180
180
181
"""
181
-
182
- if bandwidth is None :
183
- bandwidth = estimate_bandwidth (X , n_jobs = n_jobs )
184
- elif bandwidth <= 0 :
185
- raise ValueError ("bandwidth needs to be greater than zero or None,"
186
- " got %f" % bandwidth )
187
- if seeds is None :
188
- if bin_seeding :
189
- seeds = get_bin_seeds (X , bandwidth , min_bin_freq )
190
- else :
191
- seeds = X
192
- n_samples , n_features = X .shape
193
- center_intensity_dict = {}
194
-
195
- # We use n_jobs=1 because this will be used in nested calls under
196
- # parallel calls to _mean_shift_single_seed so there is no need for
197
- # for further parallelism.
198
- nbrs = NearestNeighbors (radius = bandwidth , n_jobs = 1 ).fit (X )
199
-
200
- # execute iterations on all seeds in parallel
201
- all_res = Parallel (n_jobs = n_jobs )(
202
- delayed (_mean_shift_single_seed )
203
- (seed , X , nbrs , max_iter ) for seed in seeds )
204
- # copy results in a dictionary
205
- for i in range (len (seeds )):
206
- if all_res [i ] is not None :
207
- center_intensity_dict [all_res [i ][0 ]] = all_res [i ][1 ]
208
-
209
- if not center_intensity_dict :
210
- # nothing near seeds
211
- raise ValueError ("No point was within bandwidth=%f of any seed."
212
- " Try a different seeding strategy \
213
- or increase the bandwidth."
214
- % bandwidth )
215
-
216
- # POST PROCESSING: remove near duplicate points
217
- # If the distance between two kernels is less than the bandwidth,
218
- # then we have to remove one because it is a duplicate. Remove the
219
- # one with fewer points.
220
-
221
- sorted_by_intensity = sorted (center_intensity_dict .items (),
222
- key = lambda tup : (tup [1 ], tup [0 ]),
223
- reverse = True )
224
- sorted_centers = np .array ([tup [0 ] for tup in sorted_by_intensity ])
225
- unique = np .ones (len (sorted_centers ), dtype = np .bool )
226
- nbrs = NearestNeighbors (radius = bandwidth ,
227
- n_jobs = n_jobs ).fit (sorted_centers )
228
- for i , center in enumerate (sorted_centers ):
229
- if unique [i ]:
230
- neighbor_idxs = nbrs .radius_neighbors ([center ],
231
- return_distance = False )[0 ]
232
- unique [neighbor_idxs ] = 0
233
- unique [i ] = 1 # leave the current point as unique
234
- cluster_centers = sorted_centers [unique ]
235
-
236
- # ASSIGN LABELS: a point belongs to the cluster that it is closest to
237
- nbrs = NearestNeighbors (n_neighbors = 1 , n_jobs = n_jobs ).fit (cluster_centers )
238
- labels = np .zeros (n_samples , dtype = np .int )
239
- distances , idxs = nbrs .kneighbors (X )
240
- if cluster_all :
241
- labels = idxs .flatten ()
242
- else :
243
- labels .fill (- 1 )
244
- bool_selector = distances .flatten () <= bandwidth
245
- labels [bool_selector ] = idxs .flatten ()[bool_selector ]
246
- return cluster_centers , labels
182
+ model = MeanShift (bandwidth = bandwidth , seeds = seeds ,
183
+ min_bin_freq = min_bin_freq ,
184
+ bin_seeding = bin_seeding ,
185
+ cluster_all = cluster_all , n_jobs = n_jobs ,
186
+ max_iter = max_iter ).fit (X )
187
+ return model .cluster_centers_ , model .labels_
247
188
248
189
249
190
def get_bin_seeds (X , bin_size , min_bin_freq = 1 ):
@@ -347,6 +288,12 @@ class MeanShift(ClusterMixin, BaseEstimator):
347
288
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
348
289
for more details.
349
290
291
+ max_iter : int, default=300
292
+ Maximum number of iterations, per seed point before the clustering
293
+ operation terminates (for that seed point), if has not converged yet.
294
+
295
+ .. versionadded:: 0.22
296
+
350
297
Attributes
351
298
----------
352
299
cluster_centers_ : array, [n_clusters, n_features]
@@ -355,6 +302,11 @@ class MeanShift(ClusterMixin, BaseEstimator):
355
302
labels_ :
356
303
Labels of each point.
357
304
305
+ n_iter_ : int
306
+ Maximum number of iterations performed on each seed.
307
+
308
+ .. versionadded:: 0.22
309
+
358
310
Examples
359
311
--------
360
312
>>> from sklearn.cluster import MeanShift
@@ -395,13 +347,14 @@ class MeanShift(ClusterMixin, BaseEstimator):
395
347
396
348
"""
397
349
def __init__ (self , bandwidth = None , seeds = None , bin_seeding = False ,
398
- min_bin_freq = 1 , cluster_all = True , n_jobs = None ):
350
+ min_bin_freq = 1 , cluster_all = True , n_jobs = None , max_iter = 300 ):
399
351
self .bandwidth = bandwidth
400
352
self .seeds = seeds
401
353
self .bin_seeding = bin_seeding
402
354
self .cluster_all = cluster_all
403
355
self .min_bin_freq = min_bin_freq
404
356
self .n_jobs = n_jobs
357
+ self .max_iter = max_iter
405
358
406
359
def fit (self , X , y = None ):
407
360
"""Perform clustering.
@@ -415,11 +368,78 @@ def fit(self, X, y=None):
415
368
416
369
"""
417
370
X = check_array (X )
418
- self .cluster_centers_ , self .labels_ = \
419
- mean_shift (X , bandwidth = self .bandwidth , seeds = self .seeds ,
420
- min_bin_freq = self .min_bin_freq ,
421
- bin_seeding = self .bin_seeding ,
422
- cluster_all = self .cluster_all , n_jobs = self .n_jobs )
371
+ bandwidth = self .bandwidth
372
+ if bandwidth is None :
373
+ bandwidth = estimate_bandwidth (X , n_jobs = self .n_jobs )
374
+ elif bandwidth <= 0 :
375
+ raise ValueError ("bandwidth needs to be greater than zero or None,"
376
+ " got %f" % bandwidth )
377
+
378
+ seeds = self .seeds
379
+ if seeds is None :
380
+ if self .bin_seeding :
381
+ seeds = get_bin_seeds (X , bandwidth , self .min_bin_freq )
382
+ else :
383
+ seeds = X
384
+ n_samples , n_features = X .shape
385
+ center_intensity_dict = {}
386
+
387
+ # We use n_jobs=1 because this will be used in nested calls under
388
+ # parallel calls to _mean_shift_single_seed so there is no need for
389
+ # for further parallelism.
390
+ nbrs = NearestNeighbors (radius = bandwidth , n_jobs = 1 ).fit (X )
391
+
392
+ # execute iterations on all seeds in parallel
393
+ all_res = Parallel (n_jobs = self .n_jobs )(
394
+ delayed (_mean_shift_single_seed )
395
+ (seed , X , nbrs , self .max_iter ) for seed in seeds )
396
+ # copy results in a dictionary
397
+ for i in range (len (seeds )):
398
+ if all_res [i ][1 ]: # i.e. len(points_within) > 0
399
+ center_intensity_dict [all_res [i ][0 ]] = all_res [i ][1 ]
400
+
401
+ self .n_iter_ = max ([x [2 ] for x in all_res ])
402
+
403
+ if not center_intensity_dict :
404
+ # nothing near seeds
405
+ raise ValueError ("No point was within bandwidth=%f of any seed."
406
+ " Try a different seeding strategy \
407
+ or increase the bandwidth."
408
+ % bandwidth )
409
+
410
+ # POST PROCESSING: remove near duplicate points
411
+ # If the distance between two kernels is less than the bandwidth,
412
+ # then we have to remove one because it is a duplicate. Remove the
413
+ # one with fewer points.
414
+
415
+ sorted_by_intensity = sorted (center_intensity_dict .items (),
416
+ key = lambda tup : (tup [1 ], tup [0 ]),
417
+ reverse = True )
418
+ sorted_centers = np .array ([tup [0 ] for tup in sorted_by_intensity ])
419
+ unique = np .ones (len (sorted_centers ), dtype = np .bool )
420
+ nbrs = NearestNeighbors (radius = bandwidth ,
421
+ n_jobs = self .n_jobs ).fit (sorted_centers )
422
+ for i , center in enumerate (sorted_centers ):
423
+ if unique [i ]:
424
+ neighbor_idxs = nbrs .radius_neighbors ([center ],
425
+ return_distance = False )[0 ]
426
+ unique [neighbor_idxs ] = 0
427
+ unique [i ] = 1 # leave the current point as unique
428
+ cluster_centers = sorted_centers [unique ]
429
+
430
+ # ASSIGN LABELS: a point belongs to the cluster that it is closest to
431
+ nbrs = NearestNeighbors (n_neighbors = 1 ,
432
+ n_jobs = self .n_jobs ).fit (cluster_centers )
433
+ labels = np .zeros (n_samples , dtype = np .int )
434
+ distances , idxs = nbrs .kneighbors (X )
435
+ if self .cluster_all :
436
+ labels = idxs .flatten ()
437
+ else :
438
+ labels .fill (- 1 )
439
+ bool_selector = distances .flatten () <= bandwidth
440
+ labels [bool_selector ] = idxs .flatten ()[bool_selector ]
441
+
442
+ self .cluster_centers_ , self .labels_ = cluster_centers , labels
423
443
return self
424
444
425
445
def predict (self , X ):
0 commit comments