|
1 | 1 | from itertools import product
|
| 2 | + |
2 | 3 | import numpy as np
|
3 | 4 | from scipy.sparse import (bsr_matrix, coo_matrix, csc_matrix, csr_matrix,
|
4 | 5<
8000
/code> | dok_matrix, lil_matrix)
|
5 | 6 |
|
6 | 7 | from sklearn import metrics
|
7 |
| -from sklearn.model_selection import train_test_split |
| 8 | +from sklearn import neighbors, datasets |
| 9 | +from sklearn.exceptions import DataConversionWarning |
| 10 | +from sklearn.metrics.pairwise import pairwise_distances |
8 | 11 | from sklearn.model_selection import cross_val_score
|
| 12 | +from sklearn.model_selection import train_test_split |
| 13 | +from sklearn.neighbors.base import VALID_METRICS_SPARSE, VALID_METRICS |
9 | 14 | from sklearn.utils.testing import assert_array_almost_equal
|
10 | 15 | from sklearn.utils.testing import assert_array_equal
|
11 |
| -from sklearn.utils.testing import assert_raises |
12 | 16 | from sklearn.utils.testing import assert_equal
|
| 17 | +from sklearn.utils.testing import assert_false |
| 18 | +from sklearn.utils.testing import assert_greater |
| 19 | +from sklearn.utils.testing import assert_in |
| 20 | +from sklearn.utils.testing import assert_raises |
13 | 21 | from sklearn.utils.testing import assert_true
|
14 | 22 | from sklearn.utils.testing import assert_warns
|
15 | 23 | from sklearn.utils.testing import ignore_warnings
|
16 |
| -from sklearn.utils.testing import assert_greater |
17 | 24 | from sklearn.utils.validation import check_random_state
|
18 |
| -from sklearn.metrics.pairwise import pairwise_distances |
19 |
| -from sklearn import neighbors, datasets |
20 |
| -from sklearn.exceptions import DataConversionWarning |
21 | 25 |
|
22 | 26 | rng = np.random.RandomState(0)
|
23 | 27 | # load and shuffle iris dataset
|
@@ -988,6 +992,50 @@ def custom_metric(x1, x2):
|
988 | 992 | assert_array_almost_equal(dist1, dist2)
|
989 | 993 |
|
990 | 994 |
|
| 995 | +def test_valid_brute_metric_for_auto_algorithm(): |
| 996 | + X = rng.rand(12, 12) |
| 997 | + Xcsr = csr_matrix(X) |
| 998 | + |
| 999 | + # check that there is a metric that is valid for brute |
| 1000 | + # but not ball_tree (so we actually test something) |
| 1001 | + assert_in("cosine", VALID_METRICS['brute']) |
| 1002 | + assert_false("cosine" in VALID_METRICS['ball_tree']) |
| 1003 | + |
| 1004 | + # Metric which don't required any additional parameter |
| 1005 | + require_params = ['mahalanobis', 'wminkowski', 'seuclidean'] |
| 1006 | + for metric in VALID_METRICS['brute']: |
| 1007 | + if metric != 'precomputed' and metric not in require_params: |
| 1008 | + nn = neighbors.NearestNeighbors(n_neighbors=3, algorithm='auto', |
| 1009 | + metric=metric).fit(X) |
| 1010 | + nn.kneighbors(X) |
| 1011 | + elif metric == 'precomputed': |
| 1012 | + X_precomputed = rng.random_sample((10, 4)) |
| 1013 | + Y_precomputed = rng.random_sample((3, 4)) |
| 1014 | + DXX = metrics.pairwise_distances(X_precomputed, metric='euclidean') |
| 1015 | + DYX = metrics.pairwise_distances(Y_precomputed, X_precomputed, |
| 1016 | + metric='euclidean') |
| 1017 | + nb_p = neighbors.NearestNeighbors(n_neighbors=3) |
| 1018 | + nb_p.fit(DXX) |
| 1019 | + nb_p.kneighbors(DYX) |
| 1020 | + |
| 1021 | + for metric in VALID_METRICS_SPARSE['brute']: |
| 1022 | + if metric != 'precomputed' and metric not in require_params: |
| 1023 | + nn = neighbors.NearestNeighbors(n_neighbors=3, algorithm='auto', |
| 1024 | + metric=metric).fit(Xcsr) |
| 1025 | + nn.kneighbors(Xcsr) |
| 1026 | + |
| 1027 | + # Metric with parameter |
| 1028 | + VI = np.dot(X, X.T) |
| 1029 | + list_metrics = [('seuclidean', dict(V=rng.rand(12))), |
| 1030 | + ('wminkowski', dict(w=rng.rand(12))), |
| 1031 | + ('mahalanobis', dict(VI=VI))] |
| 1032 | + for metric, params in list_metrics: |
| 1033 | + nn = neighbors.NearestNeighbors(n_neighbors=3, algorithm='auto', |
| 1034 | + metric=metric, |
| 1035 | + metric_params=params).fit(X) |
| 1036 | + nn.kneighbors(X) |
| 1037 | + |
| 1038 | + |
991 | 1039 | def test_metric_params_interface():
|
992 | 1040 | assert_warns(SyntaxWarning, neighbors.KNeighborsClassifier,
|
993 | 1041 | metric_params={'p': 3})
|
|
0 commit comments