@@ -101,8 +101,9 @@ def _mean_shift_single_seed(my_mean, X, nbrs, max_iter):
101101 # If converged or at max_iter, adds the cluster
102102 if (np .linalg .norm (my_mean - my_old_mean ) < stop_thresh or
103103 completed_iterations == max_iter ):
104- return tuple ( my_mean ), len ( points_within )
104+ break
105105 completed_iterations += 1
106+ return tuple (my_mean ), len (points_within ), completed_iterations
106107
107108
108109def mean_shift (X , bandwidth = None , seeds = None , bin_seeding = False ,
@@ -178,72 +179,12 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False,
178179 <sphx_glr_auto_examples_cluster_plot_mean_shift.py>`.
179180
180181 """
181-
182- if bandwidth is None :
183- bandwidth = estimate_bandwidth (X , n_jobs = n_jobs )
184- elif bandwidth <= 0 :
185- raise ValueError ("bandwidth needs to be greater than zero or None,"
186- " got %f" % bandwidth )
187- if seeds is None :
188- if bin_seeding :
189- seeds = get_bin_seeds (X , bandwidth , min_bin_freq )
190- else :
191- seeds = X
192- n_samples , n_features = X .shape
193- center_intensity_dict = {}
194-
195- # We use n_jobs=1 because this will be used in nested calls under
196- # parallel calls to _mean_shift_single_seed so there is no need for
197- # for further parallelism.
198- nbrs = NearestNeighbors (radius = bandwidth , n_jobs = 1 ).fit (X )
199-
200- # execute iterations on all seeds in parallel
201- all_res = Parallel (n_jobs = n_jobs )(
202- delayed (_mean_shift_single_seed )
203- (seed , X , nbrs , max_iter ) for seed in seeds )
204- # copy results in a dictionary
205- for i in range (len (seeds )):
206- if all_res [i ] is not None :
207- center_intensity_dict [all_res [i ][0 ]] = all_res [i ][1 ]
208-
209- if not center_intensity_dict :
210- # nothing near seeds
211- raise ValueError ("No point was within bandwidth=%f of any seed."
212- " Try a different seeding strategy \
213- or increase the bandwidth."
214- % bandwidth )
215-
216- # POST PROCESSING: remove near duplicate points
217- # If the distance between two kernels is less than the bandwidth,
218- # then we have to remove one because it is a duplicate. Remove the
219- # one with fewer points.
220-
221- sorted_by_intensity = sorted (center_intensity_dict .items (),
222- key = lambda tup : (tup [1 ], tup [0 ]),
223- reverse = True )
224- sorted_centers = np .array ([tup [0 ] for tup in sorted_by_intensity ])
225- unique = np .ones (len (sorted_centers ), dtype = np .bool )
226- nbrs = NearestNeighbors (radius = bandwidth ,
227- n_jobs = n_jobs ).fit (sorted_centers )
228- for i , center in enumerate (sorted_centers ):
229- if unique [i ]:
230- neighbor_idxs = nbrs .radius_neighbors ([center ],
231- return_distance = False )[0 ]
232- unique [neighbor_idxs ] = 0
233- unique [i ] = 1 # leave the current point as unique
234- cluster_centers = sorted_centers [unique ]
235-
236- # ASSIGN LABELS: a point belongs to the cluster that it is closest to
237- nbrs = NearestNeighbors (n_neighbors = 1 , n_jobs = n_jobs ).fit (cluster_centers )
238- labels = np .zeros (n_samples , dtype = np .int )
239- distances , idxs = nbrs .kneighbors (X )
240- if cluster_all :
241- labels = idxs .flatten ()
242- else :
243- labels .fill (- 1 )
244- bool_selector = distances .flatten () <= bandwidth
245- labels [bool_selector ] = idxs .flatten ()[bool_selector ]
246- return cluster_centers , labels
182+ model = MeanShift (bandwidth = bandwidth , seeds = seeds ,
183+ min_bin_freq = min_bin_freq ,
184+ bin_seeding = bin_seeding ,
185+ cluster_all = cluster_all , n_jobs = n_jobs ,
186+ max_iter = max_iter ).fit (X )
187+ return model .cluster_centers_ , model .labels_
247188
248189
249190def get_bin_seeds (X , bin_size , min_bin_freq = 1 ):
@@ -347,6 +288,12 @@ class MeanShift(ClusterMixin, BaseEstimator):
347288 ``-1`` means using all processors. See :term:`Glossary <n_jobs>`
348289 for more details.
349290
291+ max_iter : int, default=300
292+ Maximum number of iterations, per seed point before the clustering
293+ operation terminates (for that seed point), if has not converged yet.
294+
295+ .. versionadded:: 0.22
296+
350297 Attributes
351298 ----------
352299 cluster_centers_ : array, [n_clusters, n_features]
@@ -355,6 +302,11 @@ class MeanShift(ClusterMixin, BaseEstimator):
355302 labels_ :
356303 Labels of each point.
357304
305+ n_iter_ : int
306+ Maximum number of iterations performed on each seed.
307+
308+ .. versionadded:: 0.22
309+
358310 Examples
359311 --------
360312 >>> from sklearn.cluster import MeanShift
@@ -395,13 +347,14 @@ class MeanShift(ClusterMixin, BaseEstimator):
395347
396348 """
397349 def __init__ (self , bandwidth = None , seeds = None , bin_seeding = False ,
398- min_bin_freq = 1 , cluster_all = True , n_jobs = None ):
350+ min_bin_freq = 1 , cluster_all = True , n_jobs = None , max_iter = 300 ):
399351 self .bandwidth = bandwidth
400352 self .seeds = seeds
401353 self .bin_seeding = bin_seeding
402354 self .cluster_all = cluster_all
403355 self .min_bin_freq = min_bin_freq
404356 self .n_jobs = n_jobs
357+ self .max_iter = max_iter
405358
406359 def fit (self , X , y = None ):
407360 """Perform clustering.
@@ -415,11 +368,78 @@ def fit(self, X, y=None):
415368
416369 """
417370 X = check_array (X )
418- self .cluster_centers_ , self .labels_ = \
419- mean_shift (X , bandwidth = self .bandwidth , seeds = self .seeds ,
420- min_bin_freq = self .min_bin_freq ,
421- bin_seeding = self .bin_seeding ,
422- cluster_all = self .cluster_all , n_jobs = self .n_jobs )
371+ bandwidth = self .bandwidth
372+ if bandwidth is None :
373+ bandwidth = estimate_bandwidth (X , n_jobs = self .n_jobs )
374+ elif bandwidth <= 0 :
375+ raise ValueError ("bandwidth needs to be greater than zero or None,"
376+ " got %f" % bandwidth )
377+
378+ seeds = self .seeds
379+ if seeds is None :
380+ if self .bin_seeding :
381+ seeds = get_bin_seeds (X , bandwidth , self .min_bin_freq )
382+ else :
383+ seeds = X
384+ n_samples , n_features = X .shape
385+ center_intensity_dict = {}
386+
387+ # We use n_jobs=1 because this will be used in nested calls under
388+ # parallel calls to _mean_shift_single_seed so there is no need for
389+ # for further parallelism.
390+ nbrs = NearestNeighbors (radius = bandwidth , n_jobs = 1 ).fit (X )
391+
392+ # execute iterations on all seeds in parallel
393+ all_res = Parallel (n_jobs = self .n_jobs )(
394+ delayed (_mean_shift_single_seed )
395+ (seed , X , nbrs , self .max_iter ) for seed in seeds )
396+ # copy results in a dictionary
397+ for i in range (len (seeds )):
398+ if all_res [i ][1 ]: # i.e. len(points_within) > 0
399+ center_intensity_dict [all_res [i ][0 ]] = all_res [i ][1 ]
400+
401+ self .n_iter_ = max ([x [2 ] for x in all_res ])
402+
403+ if not center_intensity_dict :
404+ # nothing near seeds
405+ raise ValueError ("No point was within bandwidth=%f of any seed."
406+ " Try a different seeding strategy \
407+ or increase the bandwidth."
408+ % bandwidth )
409+
410+ # POST PROCESSING: remove near duplicate points
411+ # If the distance between two kernels is less than the bandwidth,
412+ # then we have to remove one because it is a duplicate. Remove the
413+ # one with fewer points.
414+
415+ sorted_by_intensity = sorted (center_intensity_dict .items (),
416+ key = lambda tup : (tup [1 ], tup [0 ]),
417+ reverse = True )
418+ sorted_centers = np .array ([tup [0 ] for tup in sorted_by_intensity ])
419+ unique = np .ones (len (sorted_centers ), dtype = np .bool )
420+ nbrs = NearestNeighbors (radius = bandwidth ,
421+ n_jobs = self .n_jobs ).fit (sorted_centers )
422+ for i , center in enumerate (sorted_centers ):
423+ if unique [i ]:
424+ neighbor_idxs = nbrs .radius_neighbors ([center ],
425+ return_distance = False )[0 ]
426+ unique [neighbor_idxs ] = 0
427+ unique [i ] = 1 # leave the current point as unique
428+ cluster_centers = sorted_centers [unique ]
429+
430+ # ASSIGN LABELS: a point belongs to the cluster that it is closest to
431+ nbrs = NearestNeighbors (n_neighbors = 1 ,
432+ n_jobs = self .n_jobs ).fit (cluster_centers )
433+ labels = np .zeros (n_samples , dtype = np .int )
434+ distances , idxs = nbrs .kneighbors (X )
435+ if self .cluster_all :
436+ labels = idxs .flatten ()
437+ else :
438+ labels .fill (- 1 )
439+ bool_selector = distances .flatten () <= bandwidth
440+ labels [bool_selector ] = idxs .flatten ()[bool_selector ]
441+
442+ self .cluster_centers_ , self .labels_ = cluster_centers , labels
423443 return self
424444
425445 def predict (self , X ):
0 commit comments