8000 [MRG+2] algorithm='auto' should always work for nearest neighbors (co… · raghavrv/scikit-learn@60deaea · GitHub
[go: up one dir, main page]

Skip to content

Commit 60deaea

Browse files
herilalainaamueller
authored andcommitted
[MRG+2] algorithm='auto' should always work for nearest neighbors (continuation) (scikit-learn#9145)
* Merge neighbors.rst * issue scikit-learn#4931 * Improve test implementation * Update base.py * Remove unused import * Customize test for precomputed metric * Change test function name * rename _metrics to require_params, add set assert Checking that the test is non-empty, and more didactic variable name * Remove blank line
1 parent 93bbe54 commit 60deaea

File tree

3 files changed

+72
-16
lines changed

3 files changed

+72
-16
lines changed

doc/modules/neighbors.rst

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -419,13 +419,16 @@ depends on a number of factors:
419419
a significant fraction of the total cost. If very few query points
420420
will be required, brute force is better than a tree-based method.
421421

422-
Currently, ``algorithm = 'auto'`` selects ``'kd_tree'`` if :math:`k < N/2`
423-
and the ``'effective_metric_'`` is in the ``'VALID_METRICS'`` list of
424-
``'kd_tree'``. It selects ``'ball_tree'`` if :math:`k < N/2` and the
425-
``'effective_metric_'`` is not in the ``'VALID_METRICS'`` list of
426-
``'kd_tree'``. It selects ``'brute'`` if :math:`k >= N/2`. This choice is based on the assumption that the number of query points is at least the
427-
same order as the number of training points, and that ``leaf_size`` is
428-
close to its default value of ``30``.
422+
Currently, ``algorithm = 'auto'`` selects ``'kd_tree'`` if :math:`k < N/2`
423+
and the ``'effective_metric_'`` is in the ``'VALID_METRICS'`` list of
424+
``'kd_tree'``. It selects ``'ball_tree'`` if :math:`k < N/2` and the
425+
``'effective_metric_'`` is in the ``'VALID_METRICS'`` list of
426+
``'ball_tree'``. It selects ``'brute'`` if :math:`k < N/2` and the
427+
``'effective_metric_'`` is not in the ``'VALID_METRICS'`` list of
428+
``'kd_tree'`` or ``'ball_tree'``. It selects ``'brute'`` if :math:`k >= N/2`.
429+
This choice is based on the assumption that the number of query points is at
430+
least the same order as the number of training points, and that ``leaf_size``
431+
is close to its default value of ``30``.
429432

430433
Effect of ``leaf_size``
431434
-----------------------
@@ -510,4 +513,4 @@ the model from 0.81 to 0.82.
510513
.. topic:: Examples:
511514

512515
* :ref:`sphx_glr_auto_examples_neighbors_plot_nearest_centroid.py`: an example of
513-
classification using nearest centroid with different shrink thresholds.
516+
classification using nearest centroid with different shrink thresholds.

sklearn/neighbors/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,10 @@ def _init_params(self, n_neighbors=None, radius=None,
125125
if algorithm == 'auto':
126126
if metric == 'precomputed':
127127
alg_check = 'brute'
128-
else:
128+
elif callable(metric) or metric in VALID_METRICS['ball_tree']:
129129
alg_check = 'ball_tree'
130+
else:
131+
alg_check = 'brute'
130132
else:
131133
alg_check = algorithm
132134

@@ -228,8 +230,11 @@ def _fit(self, X):
228230
self.metric != 'precomputed'):
229231
if self.effective_metric_ in VALID_METRICS['kd_tree']:
230232
self._fit_method = 'kd_tree'
231-
else:
233+
elif (callable(self.effective_metric_) or
234+
self.effective_metric_ in VALID_METRICS['ball_tree']):
232235
self._fit_method = 'ball_tree'
236+
else:
237+
self._fit_method = 'brute'
233238
else:
234239
self._fit_method = 'brute'
235240

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,27 @@
11
from itertools import product
2+
23
import numpy as np
34
from scipy.sparse import (bsr_matrix, coo_matrix, csc_matrix, csr_matrix,
45< 8000 /code>
dok_matrix, lil_matrix)
56

67
from sklearn import metrics
7-
from sklearn.model_selection import train_test_split
8+
from sklearn import neighbors, datasets
9+
from sklearn.exceptions import DataConversionWarning
10+
from sklearn.metrics.pairwise import pairwise_distances
811
from sklearn.model_selection import cross_val_score
12+
from sklearn.model_selection import train_test_split
13+
from sklearn.neighbors.base import VALID_METRICS_SPARSE, VALID_METRICS
914
from sklearn.utils.testing import assert_array_almost_equal
1015
from sklearn.utils.testing import assert_array_equal
11-
from sklearn.utils.testing import assert_raises
1216
from sklearn.utils.testing import assert_equal
17+
from sklearn.utils.testing import assert_false
18+
from sklearn.utils.testing import assert_greater
19+
from sklearn.utils.testing import assert_in
20+
from sklearn.utils.testing import assert_raises
1321
from sklearn.utils.testing import assert_true
1422
from sklearn.utils.testing import assert_warns
1523
from sklearn.utils.testing import ignore_warnings
16-
from sklearn.utils.testing import assert_greater
1724
from sklearn.utils.validation import check_random_state
18-
from sklearn.metrics.pairwise import pairwise_distances
19-
from sklearn import neighbors, datasets
20-
from sklearn.exceptions import DataConversionWarning
2125

2226
rng = np.random.RandomState(0)
2327
# load and shuffle iris dataset
@@ -988,6 +992,50 @@ def custom_metric(x1, x2):
988992
assert_array_almost_equal(dist1, dist2)
989993

990994

995+
def test_valid_brute_metric_for_auto_algorithm():
996+
X = rng.rand(12, 12)
997+
Xcsr = csr_matrix(X)
998+
999+
# check that there is a metric that is valid for brute
1000+
# but not ball_tree (so we actually test something)
1001+
assert_in("cosine", VALID_METRICS['brute'])
1002+
assert_false("cosine" in VALID_METRICS['ball_tree'])
1003+
1004+
# Metric which don't required any additional parameter
1005+
require_params = ['mahalanobis', 'wminkowski', 'seuclidean']
1006+
for metric in VALID_METRICS['brute']:
1007+
if metric != 'precomputed' and metric not in require_params:
1008+
nn = neighbors.NearestNeighbors(n_neighbors=3, algorithm='auto',
1009+
metric=metric).fit(X)
1010+
nn.kneighbors(X)
1011+
elif metric == 'precomputed':
1012+
X_precomputed = rng.random_sample((10, 4))
1013+
Y_precomputed = rng.random_sample((3, 4))
1014+
DXX = metrics.pairwise_distances(X_precomputed, metric='euclidean')
1015+
DYX = metrics.pairwise_distances(Y_precomputed, X_precomputed,
1016+
metric='euclidean')
1017+
nb_p = neighbors.NearestNeighbors(n_neighbors=3)
1018+
nb_p.fit(DXX)
1019+
nb_p.kneighbors(DYX)
1020+
1021+
for metric in VALID_METRICS_SPARSE['brute']:
1022+
if metric != 'precomputed' and metric not in require_params:
1023+
nn = neighbors.NearestNeighbors(n_neighbors=3, algorithm='auto',
1024+
metric=metric).fit(Xcsr)
1025+
nn.kneighbors(Xcsr)
1026+
1027+
# Metric with parameter
1028+
VI = np.dot(X, X.T)
1029+
list_metrics = [('seuclidean', dict(V=rng.rand(12))),
1030+
('wminkowski', dict(w=rng.rand(12))),
1031+
('mahalanobis', dict(VI=VI))]
1032+
for metric, params in list_metrics:
1033+
nn = neighbors.NearestNeighbors(n_neighbors=3, algorithm='auto',
1034+
metric=metric,
1035+
metric_params=params).fit(X)
1036+
nn.kneighbors(X)
1037+
1038+
9911039
def test_metric_params_interface():
9921040
assert_warns(SyntaxWarning, neighbors.KNeighborsClassifier,
9931041
metric_params={'p': 3})

0 commit comments

Comments
 (0)
0