@@ -280,44 +280,39 @@ def predict(self, X):
280
280
List of class labels (one for each data sample).
281
281
"""
282
282
X = atleast2d_or_csr (X )
283
+ n_samples = X .shape [0 ]
283
284
284
285
neigh_dist , neigh_ind = self .radius_neighbors (X )
285
- pred_labels = [self ._y [ind ] for ind in neigh_ind ]
286
+ inliers = [i for i , nind in enumerate (neigh_ind )
287
+ if len (nind ) != 0 ]
288
+ outliers = [i for i , nind in enumerate (neigh_ind )
289
+ if len (nind ) == 0 ]
286
290
287
291
if self .outlier_label is not None :
288
- outlier_label = np .array ([self .outlier_label ])
289
- small_value = np .array ([1e-6 ])
290
- for i , pl in enumerate (pred_labels ):
291
- # Check that all have at least 1 neighbor
292
- if len (pl ) < 1 :
293
- pred_labels [i ] = outlier_label
294
- neigh_dist [i ] = small_value
295
- else :
296
- for i , pl in enumerate (pred_labels ):
297
- # Check that all have at least 1 neighbor
298
- # TODO we should gather all outliers, or the first k,
299
- # before constructing the error message.
300
- if len (pl ) < 1 :
301
- raise ValueError ('No neighbors found for test sample %d, '
302
- 'you can try using larger radius, '
303
- 'give a label for outliers, '
304
- 'or consider removing it from your '
305
- 'dataset.' % i )
292
+ neigh_dist [outliers ] = 1e-6
293
+ elif outliers :
294
+ raise ValueError ('No neighbors found for test samples %r, '
295
+ 'you can try using larger radius, '
296
+ 'give a label for outliers, '
297
+ 'or consider removing them from your dataset.'
298
+ % outliers )
306
299
307
300
weights = _get_weights (neigh_dist , self .weights )
308
301
302
+ pred_labels = np .array ([self ._y [ind ] for ind in neigh_ind ],
303
+ dtype = object )
309
304
if weights is None :
310
- mode = np .array ([stats .mode (pl )[0 ] for pl in pred_labels ],
305
+ mode = np .array ([stats .mode (pl )[0 ] for pl in pred_labels [ inliers ] ],
311
306
dtype = np .int )
312
307
else :
313
308
mode = np .array ([weighted_mode (pl , w )[0 ]
314
- for (pl , w ) in zip (pred_labels , weights )],
309
+ for (pl , w ) in zip (pred_labels [ inliers ] , weights )],
315
310
dtype = np .int )
316
311
317
312
mode = mode .ravel ().astype (np .int )
318
- # map indices to classes
319
- prediction = self .classes_ .take (mode )
320
- if self . outlier_label is not None :
321
- # reset outlier label
322
- prediction [ prediction == outlier_label ] = self . outlier_label
313
+ prediction = np . empty ( n_samples , dtype = self . classes_ . dtype )
314
+ prediction [ inliers ] = self .classes_ .take (mode )
315
+ if outliers :
316
+ prediction [ outliers ] = self . outlier_label
317
+
323
318
return prediction
0 commit comments