8000 ENH rewrite radius-NN classifier's outlier handling · deepatdotnet/scikit-learn@3ff1a15 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3ff1a15

Browse files
larsmansamueller
authored andcommitted
ENH rewrite radius-NN classifier's outlier handling
Fixes a bug introduced in 2dfe13d. IMHO, the code is much more readable this way.
1 parent 2f2a382 commit 3ff1a15

File tree

1 file changed

+21
-26
lines changed

1 file changed

+21
-26
lines changed

sklearn/neighbors/classification.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -280,44 +280,39 @@ def predict(self, X):
280280
List of class labels (one for each data sample).
281281
"""
282282
X = atleast2d_or_csr(X)
283+
n_samples = X.shape[0]
283284

284285
neigh_dist, neigh_ind = self.radius_neighbors(X)
285-
pred_labels = [self._y[ind] for ind in neigh_ind]
286+
inliers = [i for i, nind in enumerate(neigh_ind)
287+
if len(nind) != 0]
288+
outliers = [i for i, nind in enumerate(neigh_ind)
289+
if len(nind) == 0]
286290

287291
if self.outlier_label is not None:
288-
outlier_label = np.array([self.outlier_label])
289-
small_value = np.array([1e-6])
290-
for i, pl in enumerate(pred_labels):
291-
# Check that all have at least 1 neighbor
292-
if len(pl) < 1:
293-
pred_labels[i] = outlier_label
294-
neigh_dist[i] = small_value
295-
else:
296-
for i, pl in enumerate(pred_labels):
297-
# Check that all have at least 1 neighbor
298-
# TODO we should gather all outliers, or the first k,
299-
# before constructing the error message.
300-
if len(pl) < 1:
301-
raise ValueError('No neighbors found for test sample %d, '
302-
'you can try using larger radius, '
303-
'give a label for outliers, '
304-
'or consider removing it from your '
305-
'dataset.' % i)
292+
neigh_dist[outliers] = 1e-6
293+
elif outliers:
294+
raise ValueError('No neighbors found for test samples %r, '
295+
'you can try using larger radius, '
296+
'give a label for outliers, '
297+
'or consider removing them from your dataset.'
298+
% outliers)
306299

307300
weights = _get_weights(neigh_dist, self.weights)
308301

302+
pred_labels = np.array([self._y[ind] for ind in neigh_ind],
303+
dtype=object)
309304
if weights is None:
310-
mode = np.array([stats.mode(pl)[0] for pl in pred_labels],
305+
mode = np.array([stats.mode(pl)[0] for pl in pred_labels[inliers]],
311306
dtype=np.int)
312307
else:
313308
mode = np.array([weighted_mode(pl, w)[0]
314-
for (pl, w) in zip(pred_labels, weights)],
309+
for (pl, w) in zip(pred_labels[inliers], weights)],
315310
dtype=np.int)
316311

317312
mode = mode.ravel().astype(np.int)
318-
# map indices to classes
319-
prediction = self.classes_.take(mode)
320-
if self.outlier_label is not None:
321-
# reset outlier label
322-
prediction[prediction == outlier_label] = self.outlier_label
313+
prediction = np.empty(n_samples, dtype=self.classes_.dtype)
314+
prediction[inliers] = self.classes_.take(mode)
315+
if outliers:
316+
prediction[outliers] = self.outlier_label
317+
323318
return prediction

0 commit comments

Comments
 (0)
0