8000 Nogil query_radius · scikit-learn/scikit-learn@4e29212 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4e29212

Browse files
author
Nikolay Mayorov
committed
Nogil query_radius
1 parent 88387bc commit 4e29212

File tree

9 files changed

+6839
-6151
lines changed

9 files changed

+6839
-6151
lines changed

sklearn/neighbors/ball_tree.c

Lines changed: 3351 additions & 3019 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sklearn/neighbors/ball_tree.pyx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,9 @@ cdef inline DTYPE_t max_dist(BinaryTree tree, ITYPE_t i_node,
104104
return dist_pt + tree.node_data[i_node].radius
105105

106106

107-
cdef inline int min_max_dist(BinaryTree tree, ITYPE_t i_node, DTYPE_t* pt,
108-
DTYPE_t* min_dist, DTYPE_t* max_dist) except -1:
107+
cdef inline int min_max_dist(
108+
BinaryTree tree, ITYPE_t i_node, DTYPE_t* pt,
109+
DTYPE_t* min_dist, DTYPE_t* max_dist) nogil except -1:
109110
"""Compute the minimum and maximum distance between a point and a node"""
110111
cdef DTYPE_t dist_pt = tree.dist(pt, &tree.node_bounds[0, i_node, 0],
111112
tree.data.shape[1])

sklearn/neighbors/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ class from an array representing our data set and ask who's
537537
X[s], radius, return_distance)
538538
for s in gen_even_slices(X.shape[0], n_jobs)
539539
)
540+
540541
if return_distance:
541542
neigh_ind, dist = tuple(zip(*result))
542543
result = np.hstack(neigh_ind), np.hstack(dist)

sklearn/neighbors/binary_tree.pxi

Lines changed: 71 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,13 @@
144144
cimport cython
145145
cimport numpy as np
146146
from libc.math cimport fabs, sqrt, exp, cos, pow, log
147+
from libc.stdlib cimport malloc, free
148+
from libc.string cimport memcpy
147149
from sklearn.utils.lgamma cimport lgamma
148150

149151
import numpy as np
150152
import warnings
153+
import ctypes
151154
from ..utils import check_array
152155

153156
from typedefs cimport DTYPE_t, ITYPE_t, DITYPE_t
@@ -484,7 +487,7 @@ cdef inline void swap(DITYPE_t* arr, ITYPE_t i1, ITYPE_t i2):
484487

485488

486489
cdef inline void dual_swap(DTYPE_t* darr, ITYPE_t* iarr,
487-
ITYPE_t i1, ITYPE_t i2):
490+
ITYPE_t i1, ITYPE_t i2) nogil:
488491
"""swap the values at inex i1 and i2 of both darr and iarr"""
489492
cdef DTYPE_t dtmp = darr[i1]
490493
darr[i1] = darr[i2]
@@ -657,7 +660,7 @@ cdef class NeighborsHeap:
657660

658661

659662
cdef int _simultaneous_sort(DTYPE_t* dist, ITYPE_t* idx,
660-
ITYPE_t size) except -1:
663+
ITYPE_t size) nogil except -1:
661664
"""
662665
Perform a recursive quicksort on the dist array, simultaneously
663666
performing the same swaps on the idx array. The equivalent in
@@ -1446,10 +1449,15 @@ cdef class BinaryTree:
14461449
cdef DTYPE_t[::1] rarr = rarr_np
14471450

14481451
# prepare variables for iteration
1452+
# if not count_only:
1453+
# indices = np.zeros(Xarr.shape[0], dtype='object')
1454+
# if return_distance:
1455+
# distances = np.zeros(Xarr.shape[0], dtype='object')
14491456
if not count_only:
1450-
indices = np.zeros(Xarr.shape[0], dtype='object')
1457+
indices = <ITYPE_t**> malloc(sizeof(ITYPE_t*) * Xarr.shape[0])
14511458
if return_distance:
1452-
distances = np.zeros(Xarr.shape[0], dtype='object')
1459+
distances = <DTYPE_t**> malloc(sizeof(DTYPE_t*)
1460+
* Xarr.shape[0])
14531461

14541462
np_idx_arr = np.zeros(self.data.shape[0], dtype=ITYPE)
14551463
idx_arr_i = np_idx_arr
@@ -1461,33 +1469,65 @@ cdef class BinaryTree:
14611469
counts = counts_arr
14621470

14631471
pt = &Xarr[0, 0]
1464-
for i in range(Xarr.shape[0]):
1465-
counts[i] = self._query_radius_single(0, pt, rarr[i],
1466-
&idx_arr_i[0],
1467-
&dist_arr_i[0],
1468-
0, count_only,
1469-
return_distance)
1470-
pt += n_features
1472+
cdef int return_distance_c = return_distance
1473+
cdef int count_only_c = count_only
1474+
with nogil:
1475+
for i in range(Xarr.shape[0]):
1476+
counts[i] = self._query_radius_single(0, pt, rarr[i],
1477+
&idx_arr_i[0],
1478+
&dist_arr_i[0],
1479+
0, count_only_c,
1480+
return_distance_c)
1481+
pt += n_features
14711482

1472-
if count_only:
1473-
pass
1474-
else:
1475-
if sort_results:
1476-
_simultaneous_sort(&dist_arr_i[0], &idx_arr_i[0],
1477-
counts[i])
1483+
if not count_only:
1484+
if sort_results:
1485+
_simultaneous_sort(&dist_arr_i[0], &idx_arr_i[0],
1486+
counts[i])
1487+
indices[i] = \
1488+
<ITYPE_t*> malloc(sizeof(ITYPE_t) * counts[i])
1489+
memcpy(indices[i], &idx_arr_i[0],
1490+
sizeof(ITYPE_t) * counts[i])
1491+
if return_distance_c:
1492+
distances[i] = \
1493+
<DTYPE_t*> malloc(sizeof(DTYPE_t) * counts[i])
1494+
memcpy(distances[i], &dist_arr_i[0],
1495+
sizeof(DTYPE_t) * counts[i])
14781496

1479-
indices[i] = np_idx_arr[:counts[i]].copy()
1480-
if return_distance:
1481-
distances[i] = np_dist_arr[:counts[i]].copy()
14821497

14831498
# deflatten results
1499+
# if count_only:
1500+
# return counts_arr.reshape(X.shape[:X.ndim - 1])
1501+
# elif return_distance:
1502+
# return (indices.reshape(X.shape[:X.ndim - 1]),
1503+
# distances.reshape(X.shape[:X.ndim - 1]))
1504+
# else:
1505+
# return indices.reshape(X.shape[:X.ndim - 1])
14841506
if count_only:
14851507
return counts_arr.reshape(X.shape[:X.ndim - 1])
1486-
elif return_distance:
1487-
return (indices.reshape(X.shape[:X.ndim - 1]),
1488-
distances.reshape(X.shape[:X.ndim - 1]))
14891508
else:
1490-
return indices.reshape(X.shape[:X.ndim - 1])
1509+
indices_arr = np.empty(Xarr.shape[0], dtype='object')
1510+
for i in range(indices_arr.shape[0]):
1511+
arr = np.zeros(counts[i], ITYPE)
1512+
idx_arr_i = arr
1513+
memcpy(&idx_arr_i[0], indices[i], sizeof(ITYPE_t) * counts[i])
1514+
indices_arr[i] = arr
1515+
free(indices[i])
1516+
free(indices)
1517+
if return_distance:
1518+
distances_arr = np.empty(Xarr.shape[0], dtype='object')
1519+
for i in range(distances_arr.shape[0]):
1520+
arr = np.zeros(counts[i], DTYPE)
1521+
dist_arr_i = arr
1522+
memcpy(&dist_arr_i[0], distances[i],
1523+
sizeof(DTYPE_t) * counts[i])
1524+
distances_arr[i] = arr
1525+
free(distances[i])
1526+
free(distances)
1527+
return indices_arr, distances_arr
1528+
else:
1529+
return indices_arr
1530+
14911531

14921532
def kernel_density(self, X, h, kernel='gaussian',
14931533
atol=0, rtol=1E-8,
@@ -2010,7 +2050,7 @@ cdef class BinaryTree:
20102050
DTYPE_t* distances,
20112051
ITYPE_t count,
20122052
int count_only,
2013-
int return_distance) except -1:
2053+
int return_distance) nogil except -1:
20142054
"""recursive single-tree radius query, depth-first"""
20152055
cdef DTYPE_t* data = &self.data[0, 0]
20162056
cdef ITYPE_t* idx_array = &self.idx_array[0]
@@ -2038,8 +2078,9 @@ cdef class BinaryTree:
20382078
else:
20392079
for i in range(node_info.idx_start, node_info.idx_end):
20402080
if (count < 0) or (count >= self.data.shape[0]):
2041-
raise ValueError("Fatal: count too big: "
2042-
"this should never happen")
2081+
with gil:
2082+
raise ValueError("Fatal: count too big: "
2083+
"this should never happen")
20432084
indices[count] = idx_array[i]
20442085
if return_distance:
20452086
distances[count] = self.dist(pt, (data + n_features
@@ -2058,8 +2099,9 @@ cdef class BinaryTree:
20582099
n_features)
20592100
if dist_pt <= reduced_r:
20602101
if (count < 0) or (count >= self.data.shape[0]):
2061-
raise ValueError("Fatal: count out of range. "
2062-
"This should never happen.")
2102+
with gil:
2103+
raise ValueError("Fatal: count out of range. "
2104+
"This should never happen.")
20632105
if count_only:
20642106
pass
20652107
else:

0 commit comments

Comments
 (0)
0