8000 ENH Make `KNeighborsClassifier.predict` handle `X=None` (#30047) · commit-0/scikit-learn@bcc6430 · GitHub
[go: up one dir, main page]

Skip to content

Commit bcc6430

Browse files
authored
ENH Make KNeighborsClassifier.predict handle X=None (scikit-learn#30047)
1 parent c08b433 commit bcc6430

File tree

4 files changed

+143
-16
lines changed

4 files changed

+143
-16
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
- Make `predict`, `predict_proba`, and `score` of
2+
:class:`neighbors.KNeighborsClassifier` and
3+
:class:`neighbors.RadiusNeighborsClassifier` accept `X=None` as input. In this case
4+
predictions for all training set points are returned, and points are not included
5+
into their own neighbors.
6+
:pr:`30047` by :user:`Dmitry Kobak <dkobak>`.

sklearn/neighbors/_classification.py

Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,10 @@ def predict(self, X):
244244
Parameters
245245
----------
246246
X : {array-like, sparse matrix} of shape (n_queries, n_features), \
247-
or (n_queries, n_indexed) if metric == 'precomputed'
248-
Test samples.
247+
or (n_queries, n_indexed) if metric == 'precomputed', or None
248+
Test samples. If `None`, predictions for all indexed points are
249+
returned; in this case, points are not considered their own
250+
neighbors.
249251
250252
Returns
251253
-------
@@ -281,7 +283,7 @@ def predict(self, X):
281283
classes_ = [self.classes_]
282284

283285
n_outputs = len(classes_)
284-
n_queries = _num_samples(X)
286+
n_queries = _num_samples(self._fit_X if X is None else X)
285287
weights = _get_weights(neigh_dist, self.weights)
286288
if weights is not None and _all_with_any_reduction_axis_1(weights, value=0):
287289
raise ValueError(
@@ -311,8 +313,10 @@ def predict_proba(self, X):
311313
Parameters
312314
----------
313315
X : {array-like, sparse matrix} of shape (n_queries, n_features), \
314-
or (n_queries, n_indexed) if metric == 'precomputed'
315-
Test samples.
316+
or (n_queries, n_indexed) if metric == 'precomputed', or None
317+
Test samples. If `None`, predictions for all indexed points are
318+
returned; in this case, points are not considered their own
319+
neighbors.
316320
317321
Returns
318322
-------
@@ -375,7 +379,7 @@ def predict_proba(self, X):
375379
_y = self._y.reshape((-1, 1))
376380
classes_ = [self.classes_]
377381

378-
n_queries = _num_samples(X)
382+
n_queries = _num_samples(self._fit_X if X is None else X)
379383

380384
weights = _get_weights(neigh_dist, self.weights)
381385
if weights is None:
@@ -408,6 +412,39 @@ def predict_proba(self, X):
408412

409413
return probabilities
410414

415+
# This function is defined here only to modify the parent docstring
416+
# and add information about X=None
417+
def score(self, X, y, sample_weight=None):
418+
"""
419+
Return the mean accuracy on the given test data and labels.
420+
421+
In multi-label classification, this is the subset accuracy
422+
which is a harsh metric since you require for each sample that
423+
each label set be correctly predicted.
424+
425+
Parameters
426+
----------
427+
X : array-like of shape (n_samples, n_features), or None
428+
Test samples. If `None`, predictions for all indexed points are
429+
used; in this case, points are not considered their own
430+
neighbors. This means that `knn.fit(X, y).score(None, y)`
431+
implicitly performs a leave-one-out cross-validation procedure
432+
and is equivalent to `cross_val_score(knn, X, y, cv=LeaveOneOut())`
433+
but typically much faster.
434+
435+
y : array-like of shape (n_samples,) or (n_samples, n_outputs)
436+
True labels for `X`.
437+
438+
sample_weight : array-like of shape (n_samples,), default=None
439+
Sample weights.
440+
441+
Returns
442+
-------
443+
score : float
444+
Mean accuracy of ``self.predict(X)`` w.r.t. `y`.
445+
"""
446+
return super().score(X, y, sample_weight)
447+
411448
def __sklearn_tags__(self):
412449
tags = super().__sklearn_tags__()
413450
tags.classifier_tags.multi_label = True
@@ -692,8 +729,10 @@ def predict(self, X):
692729
Parameters
693730
----------
694731
X : {array-like, sparse matrix} of shape (n_queries, n_features), \
695-
or (n_queries, n_indexed) if metric == 'precomputed'
696-
Test samples.
732+
or (n_queries, n_indexed) if metric == 'precomputed', or None
733+
Test samples. If `None`, predictions for all indexed points are
734+
returned; in this case, points are not considered their own
735+
neighbors.
697736
698737
Returns
699738
-------
@@ -734,8 +773,10 @@ def predict_proba(self, X):
734773
Parameters
735774
----------
736775
X : {array-like, sparse matrix} of shape (n_queries, n_features), \
737-
or (n_queries, n_indexed) if metric == 'precomputed'
738-
Test samples.
776+
or (n_queries, n_indexed) if metric == 'precomputed', or None
777+
Test samples. If `None`, predictions for all indexed points are
778+
returned; in this case, points are not considered their own
779+
neighbors.
739780
740781
Returns
741782
-------
@@ -745,7 +786,7 @@ def predict_proba(self, X):
745786
by lexicographic order.
746787
"""
747788
check_is_fitted(self, "_fit_method")
748-
n_queries = _num_samples(X)
789+
n_queries = _num_samples(self._fit_X if X is None else X)
749790

750791
metric, metric_kwargs = _adjusted_metric(
751792
metric=self.metric, metric_kwargs=self.metric_params, p=self.p
@@ -846,6 +887,39 @@ def predict_proba(self, X):
846887

847888
return probabilities
848889

890+
# This function is defined here only to modify the parent docstring
891+
# and add information about X=None
892+
def score(self, X, y, sample_weight=None):
893+
"""
894+
Return the mean accuracy on the given test data and labels.
895+
896+
In multi-label classification, this is the subset accuracy
897+
which is a harsh metric since you require for each sample that
898+
each label set be correctly predicted.
899+
900+
Parameters
901+
----------
902+
X : array-like of shape (n_samples, n_features), or None
903+
Test samples. If `None`, predictions for all indexed points are
904+
used; in this case, points are not considered their own
905+
neighbors. This means that `knn.fit(X, y).score(None, y)`
906+
implicitly performs a leave-one-out cross-validation procedure
907+
and is equivalent to `cross_val_score(knn, X, y, cv=LeaveOneOut())`
908+
but typically much faster.
909+
910+
y : array-like of shape (n_samples,) or (n_samples, n_outputs)
911+
True labels for `X`.
912+
913+
sample_weight : array-like of shape (n_samples,), default=None
914+
Sample weights.
915+
916+
Returns
917+
-------
918+
score : float
919+
Mean accuracy of ``self.predict(X)`` w.r.t. `y`.
920+
"""
921+
return super().score(X, y, sample_weight)
922+
849923
def __sklearn_tags__(self):
850924
tags = super().__sklearn_tags__()
851925
tags.classifier_tags.multi_label = True

sklearn/neighbors/_regression.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,10 @@ def predict(self, X):
234234
Parameters
235235
----------
236236
X : {array-like, sparse matrix} of shape (n_queries, n_features), \
237-
or (n_queries, n_indexed) if metric == 'precomputed'
238-
Test samples.
237+
or (n_queries, n_indexed) if metric == 'precomputed', or None
238+
Test samples. If `None`, predictions for all indexed points are
239+
returned; in this case, points are not considered their own
240+
neighbors.
239241
240242
Returns
241243
-------
@@ -464,8 +466,10 @@ def predict(self, X):
464466
Parameters
465467
----------
466468
X : {array-like, sparse matrix} of shape (n_queries, n_features), \
467-
or (n_queries, n_indexed) if metric == 'precomputed'
468-
Test samples.
469+
or (n_queries, n_indexed) if metric == 'precomputed', or None
470+
Test samples. If `None`, predictions for all indexed points are
471+
returned; in this case, points are not considered their own
472+
neighbors.
469473
470474
Returns
471475
-------

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@
2424
assert_compatible_argkmin_results,
2525
assert_compatible_radius_results,
2626
)
27-
from sklearn.model_selection import cross_val_score, train_test_split
27+
from sklearn.model_selection import (
28+
LeaveOneOut,
29+
cross_val_predict,
30+
cross_val_score,
31+
train_test_split,
32+
)
2833
from sklearn.neighbors import (
2934
VALID_METRICS_SPARSE,
3035
KNeighborsRegressor,
@@ -2390,3 +2395,41 @@ def _weights(dist):
23902395

23912396
with pytest.raises(ValueError, match=msg):
23922397
est.predict_proba([[1.1, 1.1]])
2398+
2399+
2400+
@pytest.mark.parametrize(
2401+
"nn_model",
2402+
[
2403+
neighbors.KNeighborsClassifier(n_neighbors=10),
2404+
neighbors.RadiusNeighborsClassifier(radius=5.0),
2405+
],
2406+
)
2407+
def test_neighbor_classifiers_loocv(nn_model):
2408+
"""Check that `predict` and related functions work fine with X=None"""
2409+
X, y = datasets.make_blobs(n_samples=500, centers=5, n_features=2, random_state=0)
2410+
2411+
loocv = cross_val_score(nn_model, X, y, cv=LeaveOneOut())
2412+
nn_model.fit(X, y)
2413+
2414+
assert np.all(loocv == (nn_model.predict(None) == y))
2415+
assert np.mean(loocv) == nn_model.score(None, y)
2416+
assert nn_model.score(None, y) < nn_model.score(X, y)
2417+
2418+
2419+
@pytest.mark.parametrize(
2420+
"nn_model",
2421+
[
2422+
neighbors.KNeighborsRegressor(n_neighbors=10),
2423+
neighbors.RadiusNeighborsRegressor(radius=0.5),
2424+
],
2425+
)
2426+
def test_neighbor_regressors_loocv(nn_model):
2427+
"""Check that `predict` and related functions work fine with X=None"""
2428+
X, y = datasets.load_diabetes(return_X_y=True)
2429+
2430+
# Only checking cross_val_predict and not cross_val_score because
2431+
# cross_val_score does not work with LeaveOneOut() for a regressor
2432+
loocv = cross_val_predict(nn_model, X, y, cv=LeaveOneOut())
2433+
nn_model.fit(X, y)
2434+
2435+
assert np.all(loocv == nn_model.predict(None))

0 commit comments

Comments
 (0)
0