8000 WIP refactor binary tree classes to use proper inheritance · scikit-learn/scikit-learn@0a78622 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0a78622

Browse files
committed
WIP refactor binary tree classes to use proper inheritance
1 parent 560115e commit 0a78622

File tree

6 files changed

+550
-458
lines changed

6 files changed

+550
-458
lines changed

sklearn/neighbors/ball_tree.pyx

Lines changed: 118 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -24,152 +24,123 @@ VALID_METRICS = ['EuclideanDistance', 'SEuclideanDistance',
2424

2525
include "binary_tree.pxi"
2626

27-
# Inherit BallTree from BinaryTree
27+
2828
cdef class BallTree(BinaryTree):
2929
__doc__ = CLASS_DOC.format(**DOC_DICT)
30-
pass
31-
32-
#----------------------------------------------------------------------
33-
# The functions below specialized the Binary Tree as a Ball Tree
34-
#
35-
# Note that these functions use the concept of "reduced distance".
36-
# The reduced distance, defined for some metrics, is a quantity which
37-
# is more efficient to compute than the distance, but preserves the
38-
# relative rankings of the true distance. For example, the reduced
39-
# distance for the Euclidean metric is the squared-euclidean distance.
40-
# For some metrics, the reduced distance is simply the distance.
41-
42-
43-
cdef int allocate_data(BinaryTree tree, ITYPE_t n_nodes,
44-
ITYPE_t n_features) except -1:
45-
"""Allocate arrays needed for the KD Tree"""
46-
tree.node_bounds_arr = np.zeros((1, n_nodes, n_features), dtype=DTYPE)
47-
tree.node_bounds = tree.node_bounds_arr
48-
return 0
49-
50-
51-
cdef int init_node(BinaryTree tree, ITYPE_t i_node,
52-
ITYPE_t idx_start, ITYPE_t idx_end) except -1:
53-
"""Initialize the node for the dataset stored in tree.data"""
54-
cdef ITYPE_t n_features = tree.data.shape[1]
55-
cdef ITYPE_t n_points = idx_end - idx_start
56-
57-
cdef ITYPE_t i, j
58-
cdef DTYPE_t radius
59-
cdef DTYPE_t *this_pt
60-
61-
cdef ITYPE_t* idx_array = &tree.idx_array[0]
62-
cdef DTYPE_t* data = &tree.data[0, 0]
63-
cdef DTYPE_t* centroid = &tree.node_bounds[0, i_node, 0]
64-
65-
# determine Node centroid
66-
for j in range(n_features):
67-
centroid[j] = 0
68-
69-
for i in range(idx_start, idx_end):
70-
this_pt = data + n_features * idx_array[i]
71-
for j from 0 <= j < n_features:
72-
centroid[j] += this_pt[j]
73-
74-
for j in range(n_features):
75-
centroid[j] /= n_points
76-
77-
# determine Node radius
78-
radius = 0
79-
for i in range(idx_start, idx_end):
80-
radius = fmax(radius,
81-
tree.rdist(centroid,
82-
data + n_features * idx_array[i],
83-
n_features))
84-
85-
tree.node_data[i_node].radius = tree.dist_metric._rdist_to_dist(radius)
86-
tree.node_data[i_node].idx_start = idx_start
87-
tree.node_data[i_node].idx_end = idx_end
88-
return 0
89-
90-
91-
cdef inline DTYPE_t min_dist(BinaryTree tree, ITYPE_t i_node,
92-
DTYPE_t* pt) except -1:
93-
"""Compute the minimum distance between a point and a node"""
94-
cdef DTYPE_t dist_pt = tree.dist(pt, &tree.node_bounds[0, i_node, 0],
95-
tree.data.shape[1])
96-
return fmax(0, dist_pt - tree.node_data[i_node].radius)
97-
98-
99-
cdef inline DTYPE_t max_dist(BinaryTree tree, ITYPE_t i_node,
100-
DTYPE_t* pt) except -1:
101-
"""Compute the maximum distance between a point and a node"""
102-
cdef DTYPE_t dist_pt = tree.dist(pt, &tree.node_bounds[0, i_node, 0],
103-
tree.data.shape[1])
104-
return dist_pt + tree.node_data[i_node].radius
105-
106-
107-
cdef inline int min_max_dist(BinaryTree tree, ITYPE_t i_node, DTYPE_t* pt,
108-
DTYPE_t* min_dist, DTYPE_t* max_dist) except -1:
109-
"""Compute the minimum and maximum distance between a point and a node"""
110-
cdef DTYPE_t dist_pt = tree.dist(pt, &tree.node_bounds[0, i_node, 0],
111-
tree.data.shape[1])
112-
cdef DTYPE_t rad = tree.node_data[i_node].radius
113-
min_dist[0] = fmax(0, dist_pt - rad)
114-
max_dist[0] = dist_pt + rad
115-
return 0
116-
117-
118-
cdef inline DTYPE_t min_rdist(BinaryTree tree, ITYPE_t i_node,
119-
DTYPE_t* pt) except -1:
120-
"""Compute the minimum reduced-distance between a point and a node"""
121-
if tree.euclidean:
122-
return euclidean_dist_to_rdist(min_dist(tree, i_node, pt))
123-
else:
124-
return tree.dist_metric._dist_to_rdist(min_dist(tree, i_node, pt))
125-
126-
127-
cdef inline DTYPE_t max_rdist(BinaryTree tree, ITYPE_t i_node,
128-
DTYPE_t* pt) except -1:
129-
"""Compute the maximum reduced-distance between a point and a node"""
130-
if tree.euclidean:
131-
return euclidean_dist_to_rdist(max_dist(tree, i_node, pt))
132-
else:
133-
return tree.dist_metric._dist_to_rdist(max_dist(tree, i_node, pt))
134-
135-
136-
cdef inline DTYPE_t min_dist_dual(BinaryTree tree1, ITYPE_t i_node1,
137-
BinaryTree tree2, ITYPE_t i_node2) except -1:
138-
"""compute the minimum distance between two nodes"""
139-
cdef DTYPE_t dist_pt = tree1.dist(&tree2.node_bounds[0, i_node2, 0],
140-
&tree1.node_bounds[0, i_node1, 0],
141-
tree1.data.shape[1])
142-
return fmax(0, (dist_pt - tree1.node_data[i_node1].radius
143-
- tree2.node_data[i_node2].radius))
144-
145-
146-
cdef inline DTYPE_t max_dist_dual(BinaryTree tree1, ITYPE_t i_node1,
147-
BinaryTree tree2, ITYPE_t i_node2) except -1:
148-
"""compute the maximum distance between two nodes"""
149-
cdef DTYPE_t dist_pt = tree1.dist(&tree2.node_bounds[0, i_node2, 0],
150-
&tree1.node_bounds[0, i_node1, 0],
151-
tree1.data.shape[1])
152-
return (dist_pt + tree1.node_data[i_node1].radius
153-
+ tree2.node_data[i_node2].radius)
154-
155-
156-
cdef inline DTYPE_t min_rdist_dual(BinaryTree tree1, ITYPE_t i_node1,
157-
BinaryTree tree2, ITYPE_t i_node2) except -1:
158-
"""compute the minimum reduced distance between two nodes"""
159-
if tree1.euclidean:
160-
return euclidean_dist_to_rdist(min_dist_dual(tree1, i_node1,
161-
tree2, i_node2))
162-
else:
163-
return tree1.dist_metric._dist_to_rdist(min_dist_dual(tree1, i_node1,
164-
tree2, i_node2))
165-
166-
167-
cdef inline DTYPE_t max_rdist_dual(BinaryTree tree1, ITYPE_t i_node1,
168-
BinaryTree tree2, ITYPE_t i_node2) except -1:
169-
"""compute the maximum reduced distance between two nodes"""
170-
if tree1.euclidean:
171-
return euclidean_dist_to_rdist(max_dist_dual(tree1, i_node1,
172-
tree2, i_node2))
173-
else:
174-
return tree1.dist_metric._dist_to_rdist(max_dist_dual(tree1, i_node1,
175-
tree2, i_node2))
30+
31+
valid_metrics = get_valid_metric_ids(VALID_METRICS)
32+
33+
# Implementations of abstract methods.
34+
#
35+
# Note that these functions use the concept of "reduced distance".
36+
# The reduced distance, defined for some metrics, is a quantity which
37+
# is more efficient to compute than the distance, but preserves the
38+
# relative rankings of the true distance. For example, the reduced
39+
# distance for the Euclidean metric is the squared-euclidean distance.
40+
# For some metrics, the reduced distance is simply the distance.
41+
42+
cdef int allocate_data(self, ITYPE_t n_nodes, ITYPE_t n_features) except -1:
43+
self.node_bounds_arr = np.zeros((1, n_nodes, n_features), dtype=DTYPE)
44+
self.node_bounds = self.node_bounds_arr
45+
return 0
46+
47+
cdef int init_node(self, ITYPE_t i_node,
48+
ITYPE_t idx_start, ITYPE_t idx_end) except -1:
49+
cdef ITYPE_t n_features = self.data.shape[1]
50+
cdef ITYPE_t n_points = idx_end - idx_start
51+
52+
cdef ITYPE_t i, j
53+
cdef DTYPE_t radius
54+
cdef DTYPE_t *this_pt
55+
56+
cdef ITYPE_t* idx_array = &self.idx_array[0]
57+
cdef DTYPE_t* data = &self.data[0, 0]
58+
cdef DTYPE_t* centroid = &self.node_bounds[0, i_node, 0]
59+
60+
# determine Node centroid
61+
for j in range(n_features):
62+
centroid[j] = 0
63+
64+
for i in range(idx_start, idx_end):
65+
this_pt = data + n_features * idx_array[i]
66+
for j from 0 <= j < n_features:
67+
centroid[j] += this_pt[j]
68+
69+
for j in range(n_features):
70+
centroid[j] /= n_points
71+
72+
# determine Node radius
73+
radius = 0
74+
for i in range(idx_start, idx_end):
75+
radius = fmax(radius,
76+
self.rdist(centroid,
77+
data + n_features * idx_array[i],
78+
n_features))
79+
80+
self.node_data[i_node].radius = self.dist_metric._rdist_to_dist(radius)
81+
self.node_data[i_node].idx_start = idx_start
82+
self.node_data[i_node].idx_end = idx_end
83+
return 0
84+
85+
cdef DTYPE_t min_dist(self, ITYPE_t i_node, DTYPE_t* pt) except -1:
86+
cdef DTYPE_t dist_pt = self.dist(pt, &self.node_bounds[0, i_node, 0],
87+
self.data.shape[1])
88+
return fmax(0, dist_pt - self.node_data[i_node].radius)
89+
90+
cdef DTYPE_t max_dist(self, ITYPE_t i_node, DTYPE_t* pt) except -1:
91+
cdef DTYPE_t dist_pt = self.dist(pt, &self.node_bounds[0, i_node, 0],
92+
self.data.shape[1])
93+
return dist_pt + self.node_data[i_node].radius
94+
95+
cdef int min_max_dist(self, ITYPE_t i_node, DTYPE_t* pt,
96+
DTYPE_t* min_dist, DTYPE_t* max_dist) except -1:
97+
cdef DTYPE_t dist_pt = self.dist(pt, &self.node_bounds[0, i_node, 0],
98+
self.data.shape[1])
99+
cdef DTYPE_t rad = self.node_data[i_node].radius
100+
min_dist[0] = fmax(0, dist_pt - rad)
101+
max_dist[0] = dist_pt + rad
102+
return 0
103+
104+
cdef DTYPE_t min_rdist(self, ITYPE_t i_node, DTYPE_t* pt) except -1:
105+
if self.euclidean:
106+
return euclidean_dist_to_rdist(self.min_dist(i_node, pt))
107+
else:
108+
return self.dist_metric._dist_to_rdist(self.min_dist(i_node, pt))
109+
110+
cdef DTYPE_t max_rdist(self, ITYPE_t i_node, DTYPE_t* pt) except -1:
111+
if self.euclidean:
112+
return euclidean_dist_to_rdist(self.max_dist(i_node, pt))
113+
else:
114+
return self.dist_metric._dist_to_rdist(self.max_dist(i_node, pt))
115+
116+
cdef DTYPE_t min_dist_dual(self, ITYPE_t i_node1,
117+
BinaryTree other, ITYPE_t i_node2) except -1:
118+
cdef DTYPE_t dist_pt = self.dist(&other.node_bounds[0, i_node2, 0],
119+
&self.node_bounds[0, i_node1, 0],
120+
self.data.shape[1])
121+
return fmax(0, (dist_pt - self.node_data[i_node1].radius
122+
- other.node_data[i_node2].radius))
123+
124+
cdef DTYPE_t max_dist_dual(self, ITYPE_t i_node1,
125+
BinaryTree other, ITYPE_t i_node2) except -1:
126+
cdef DTYPE_t dist_pt = self.dist(&other.node_bounds[0, i_node2, 0],
127+
&self.node_bounds[0, i_node1, 0],
128+
self.data.shape[1])
129+
return (dist_pt + self.node_data[i_node1].radius
130+
+ other.node_data[i_node2].radius)
131+
132+
cdef DTYPE_t min_rdist_dual(self, ITYPE_t i_node1,
133+
BinaryTree other, ITYPE_t i_node2) except -1:
134+
cdef DTYPE_t d = self.min_dist_dual(i_node1, other, i_node2)
135+
if self.euclidean:
136+
return euclidean_dist_to_rdist(d)
137+
else:
138+
return self.dist_metric._dist_to_rdist(d)
139+
140+
cdef DTYPE_t max_rdist_dual(self, ITYPE_t i_node1,
141+
BinaryTree other, ITYPE_t i_node2) except -1:
142+
cdef DTYPE_t d = self.max_dist_dual(i_node1, other, i_node2)
143+
if self.euclidean:
144+
return euclidean_dist_to_rdist(d)
145+
else:
146+
return self.dist_metric._dist_to_rdist(d)

0 commit comments

Comments
 (0)
0