diff --git a/sklearn/neighbors/tests/test_ball_tree.py b/sklearn/neighbors/tests/test_ball_tree.py index 5263f201f320b..5f244c272d773 100644 --- a/sklearn/neighbors/tests/test_ball_tree.py +++ b/sklearn/neighbors/tests/test_ball_tree.py @@ -20,6 +20,7 @@ "manhattan": {}, "minkowski": dict(p=3), "chebyshev": {}, + "precomputed": {}, } DISCRETE_METRICS = ["hamming", "canberra", "braycurtis"] @@ -41,6 +42,9 @@ def brute_force_neighbors(X, Y, k, metric, **kwargs): from sklearn.metrics import DistanceMetric + + if metric == "precomputed": + return Y, np.argsort(Y, axis=1)[:, :k] X, Y = check_array(X), check_array(Y) D = DistanceMetric.get_metric(metric, **kwargs).pairwise(Y, X) @@ -73,7 +77,8 @@ def test_ball_tree_query_metrics(metric, array_type, BallTreeImplementation): dist1, ind1 = bt.query(Y, k) dist2, ind2 = brute_force_neighbors(X, Y, k, metric) assert_array_almost_equal(dist1, dist2) - + + @pytest.mark.parametrize( "BallTreeImplementation, decimal_tol", zip(BALL_TREE_CLASSES, [6, 5])