@@ -244,8 +244,10 @@ def predict(self, X):
244
244
Parameters
245
245
----------
246
246
X : {array-like, sparse matrix} of shape (n_queries, n_features), \
247
- or (n_queries, n_indexed) if metric == 'precomputed'
248
- Test samples.
247
+ or (n_queries, n_indexed) if metric == 'precomputed', or None
248
+ Test samples. If `None`, predictions for all indexed points are
249
+ returned; in this case, points are not considered their own
250
+ neighbors.
249
251
250
252
Returns
251
253
-------
@@ -281,7 +283,7 @@ def predict(self, X):
281
283
classes_ = [self .classes_ ]
282
284
283
285
n_outputs = len (classes_ )
284
- n_queries = _num_samples (X )
286
+ n_queries = _num_samples (self . _fit_X if X is None else X )
285
287
weights = _get_weights (neigh_dist , self .weights )
286
288
if weights is not None and _all_with_any_reduction_axis_1 (weights , value = 0 ):
287
289
raise ValueError (
@@ -311,8 +313,10 @@ def predict_proba(self, X):
311
313
Parameters
312
314
----------
313
315
X : {array-like, sparse matrix} of shape (n_queries, n_features), \
314
- or (n_queries, n_indexed) if metric == 'precomputed'
315
- Test samples.
316
+ or (n_queries, n_indexed) if metric == 'precomputed', or None
317
+ Test samples. If `None`, predictions for all indexed points are
318
+ returned; in this case, points are not considered their own
319
+ neighbors.
316
320
317
321
Returns
318
322
-------
@@ -375,7 +379,7 @@ def predict_proba(self, X):
375
379
_y = self ._y .reshape ((- 1 , 1 ))
376
380
classes_ = [self .classes_ ]
377
381
378
- n_queries = _num_samples (X )
382
+ n_queries = _num_samples (self . _fit_X if X is None else X )
379
383
380
384
weights = _get_weights (neigh_dist , self .weights )
381
385
if weights is None :
@@ -408,6 +412,39 @@ def predict_proba(self, X):
408
412
409
413
return probabilities
410
414
415
+ # This function is defined here only to modify the parent docstring
416
+ # and add information about X=None
417
+ def score (self , X , y , sample_weight = None ):
418
+ """
419
+ Return the mean accuracy on the given test data and labels.
420
+
421
+ In multi-label classification, this is the subset accuracy
422
+ which is a harsh metric since you require for each sample that
423
+ each label set be correctly predicted.
424
+
425
+ Parameters
426
+ ----------
427
+ X : array-like of shape (n_samples, n_features), or None
428
+ Test samples. If `None`, predictions for all indexed points are
429
+ used; in this case, points are not considered their own
430
+ neighbors. This means that `knn.fit(X, y).score(None, y)`
431
+ implicitly performs a leave-one-out cross-validation procedure
432
+ and is equivalent to `cross_val_score(knn, X, y, cv=LeaveOneOut())`
433
+ but typically much faster.
434
+
435
+ y : array-like of shape (n_samples,) or (n_samples, n_outputs)
436
+ True labels for `X`.
437
+
438
+ sample_weight : array-like of shape (n_samples,), default=None
439
+ Sample weights.
440
+
441
+ Returns
442
+ -------
443
+ score : float
444
+ Mean accuracy of ``self.predict(X)`` w.r.t. `y`.
445
+ """
446
+ return super ().score (X , y , sample_weight )
447
+
411
448
def __sklearn_tags__ (self ):
412
449
tags = super ().__sklearn_tags__ ()
413
450
tags .classifier_tags .multi_label = True
@@ -692,8 +729,10 @@ def predict(self, X):
692
729
Parameters
693
730
----------
694
731
X : {array-like, sparse matrix} of shape (n_queries, n_features), \
695
- or (n_queries, n_indexed) if metric == 'precomputed'
696
- Test samples.
732
+ or (n_queries, n_indexed) if metric == 'precomputed', or None
733
+ Test samples. If `None`, predictions for all indexed points are
734
+ returned; in this case, points are not considered their own
735
+ neighbors.
697
736
698
737
Returns
699
738
-------
@@ -734,8 +773,10 @@ def predict_proba(self, X):
734
773
Parameters
735
774
----------
736
775
X : {array-like, sparse matrix} of shape (n_queries, n_features), \
737
- or (n_queries, n_indexed) if metric == 'precomputed'
738
- Test samples.
776
+ or (n_queries, n_indexed) if metric == 'precomputed', or None
777
+ Test samples. If `None`, predictions for all indexed points are
778
+ returned; in this case, points are not considered their own
779
+ neighbors.
739
780
740
781
Returns
741
782
-------
@@ -745,7 +786,7 @@ def predict_proba(self, X):
745
786
by lexicographic order.
746
787
"""
747
788
check_is_fitted (self , "_fit_method" )
748
- n_queries = _num_samples (X )
789
+ n_queries = _num_samples (self . _fit_X if X is None else X )
749
790
750
791
metric , metric_kwargs = _adjusted_metric (
751
792
metric = self .metric , metric_kwargs = self .metric_params , p = self .p
@@ -846,6 +887,39 @@ def predict_proba(self, X):
846
887
847
888
return probabilities
848
889
890
+ # This function is defined here only to modify the parent docstring
891
+ # and add information about X=None
892
+ def score (self , X , y , sample_weight = None ):
893
+ """
894
+ Return the mean accuracy on the given test data and labels.
895
+
896
+ In multi-label classification, this is the subset accuracy
897
+ which is a harsh metric since you require for each sample that
898
+ each label set be correctly predicted.
899
+
900
+ Parameters
901
+ ----------
902
+ X : array-like of shape (n_samples, n_features), or None
903
+ Test samples. If `None`, predictions for all indexed points are
904
+ used; in this case, points are not considered their own
905
+ neighbors. This means that `knn.fit(X, y).score(None, y)`
906
+ implicitly performs a leave-one-out cross-validation procedure
907
+ and is equivalent to `cross_val_score(knn, X, y, cv=LeaveOneOut())`
908
+ but typically much faster.
909
+
910
+ y : array-like of shape (n_samples,) or (n_samples, n_outputs)
911
+ True labels for `X`.
912
+
913
+ sample_weight : array-like of shape (n_samples,), default=None
914
+ Sample weights.
915
+
916
+ Returns
917
+ -------
918
+ score : float
919
+ Mean accuracy of ``self.predict(X)`` w.r.t. `y`.
920
+ """
921
+ return super ().score (X , y , sample_weight )
922
+
849
923
def __sklearn_tags__ (self ):
850
924
tags = super ().__sklearn_tags__ ()
851
925
tags .classifier_tags .multi_label = True
0 commit comments