10000 FIX validate input array-like in KDTree and BallTree (#18691) · scikit-learn/scikit-learn@61beb3b · GitHub
[go: up one dir, main page]

Skip to content

Commit 61beb3b

Browse files
authored
FIX validate input array-like in KDTree and BallTree (#18691)
1 parent c9543e1 commit 61beb3b

File tree

4 files changed

+48
-10
lines changed

4 files changed

+48
-10
lines changed

doc/whats_new/v0.24.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,11 @@ Changelog
623623
the data intrinsic dimensionality is too high for tree-based methods.
624624
:pr:`17148` by :user:`Geoffrey Bolmier <gbolmier>`.
625625

626+
- |Fix| :class:`neighbors.BinaryTree`
627+
will raise a `ValueError` when fitting on data array having points with
628+
different dimensions.
629+
:pr:`18691` by :user:`Chiara Marmo <cmarmo>`.
630+
626631
- |Fix| :class:`neighbors.NearestCentroid` with a numerical `shrink_threshold`
627632
will raise a `ValueError` when fitting on data with all constant features.
628633
:pr:`18370` by :user:`Trevor Waite <trewaite>`.

sklearn/neighbors/_binary_tree.pxi

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,17 +1049,17 @@ cdef class BinaryTree:
10491049
def __init__(self, data,
10501050
leaf_size=40, metric='minkowski', sample_weight=None, **kwargs):
10511051
# validate data
1052-
if data.size == 0:
1052+
self.data_arr = check_array(data, dtype=DTYPE, order='C')
1053+
if self.data_arr.size == 0:
10531054
raise ValueError("X is an empty array")
10541055

1056+
n_samples = self.data_arr.shape[0]
1057+
n_features = self.data_arr.shape[1]
1058+
10551059
if leaf_size < 1:
10561060
raise ValueError("leaf_size must be greater than or equal to 1")
1057-
1058-
n_samples = data.shape[0]
1059-
n_features = data.shape[1]
1060-
1061-
self.data_arr = np.asarray(data, dtype=DTYPE, order='C')
10621061
self.leaf_size = leaf_size
1062+
10631063
self.dist_metric = DistanceMetric.get_metric(metric, **kwargs)
10641064
self.euclidean = (self.dist_metric.__class__.__name__
10651065
== 'EuclideanDistance')
@@ -1069,7 +1069,7 @@ cdef class BinaryTree:
10691069
raise ValueError('metric {metric} is not valid for '
10701070
'{BinaryTree}'.format(metric=metric,
10711071
**DOC_DICT))
1072-
self.dist_metric._validate_data(data)
1072+
self.dist_metric._validate_data(self.data_arr)
10731073

10741074
# determine number of levels in the tree, and from this
10751075
# the number of nodes in the tree. This results in leaf nodes

sklearn/neighbors/tests/test_ball_tree.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from sklearn.neighbors._ball_tree import BallTree
77
from sklearn.neighbors import DistanceMetric
88
from sklearn.utils import check_random_state
9+
from sklearn.utils.validation import check_array
10+
from sklearn.utils._testing import _convert_container
911

1012
rng = np.random.RandomState(10)
1113
V_mahalanobis = rng.rand(3, 3)
@@ -31,22 +33,28 @@
3133

3234

3335
def brute_force_neighbors(X, Y, k, metric, **kwargs):
36+
X, Y = check_array(X), check_array(Y)
3437
D = DistanceMetric.get_metric(metric, **kwargs).pairwise(Y, X)
3538
ind = np.argsort(D, axis=1)[:, :k]
3639
dist = D[np.arange(Y.shape[0])[:, None], ind]
3740
return dist, ind
3841

3942

40-
@pytest.mark.parametrize('metric',
41-
itertools.chain(BOOLEAN_METRICS, DISCRETE_METRICS))
42-
def test_ball_tree_query_metrics(metric):
43+
@pytest.mark.parametrize(
44+
'metric',
45+
itertools.chain(BOOLEAN_METRICS, DISCRETE_METRICS)
46+
)
47+
@pytest.mark.parametrize("array_type", ["list", "array"])
48+
def test_ball_tree_query_metrics(metric, array_type):
4349
rng = check_random_state(0)
4450
if metric in BOOLEAN_METRICS:
4551
X = rng.random_sample((40, 10)).round(0)
4652
Y = rng.random_sample((10, 10)).round(0)
4753
elif metric in DISCRETE_METRICS:
4854
X = (4 * rng.random_sample((40, 10))).round(0)
4955
Y = (4 * rng.random_sample((10, 10))).round(0)
56+
X = _convert_container(X, array_type)
57+
Y = _convert_container(Y, array_type)
5058

5159
k = 5
5260

@@ -65,3 +73,13 @@ def test_query_haversine():
6573

6674
assert_array_almost_equal(dist1, dist2)
6775
assert_array_almost_equal(ind1, ind2)
76+
77+
78+
def test_array_object_type():
79+
"""Check that we do not accept object dtype array."""
80+
X = np.array([(1, 2, 3), (2, 5), (5, 5, 1, 2)], dtype=object)
81+
with pytest.raises(
82+
ValueError,
83+
match="setting an array element with a sequence"
84+
):
85+
BallTree(X)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,21 @@
1+
import numpy as np
2+
import pytest
3+
4+
from sklearn.neighbors._kd_tree import KDTree
5+
16
DIMENSION = 3
27

38
METRICS = {'euclidean': {},
49
'manhattan': {},
510
'chebyshev': {},
611
'minkowski': dict(p=3)}
12+
13+
14+
def test_array_object_type():
15+
"""Check that we do not accept object dtype array."""
16+
X = np.array([(1, 2, 3), (2, 5), (5, 5, 1, 2)], dtype=object)
17+
with pytest.raises(
18+
ValueError,
19+
match="setting an array element with a sequence"
20+
):
21+
KDTree(X)

0 commit comments

Comments
 (0)
0