144
144
cimport cython
145
145
cimport numpy as np
146
146
from libc .math cimport fabs , sqrt , exp , cos , pow , log
147
+ from libc .stdlib cimport malloc , free
148
+ from libc .string cimport memcpy
147
149
from sklearn .utils .lgamma cimport lgamma
148
150
149
151
import numpy as np
150
152
import warnings
153
+ import ctypes
151
154
from ..utils import check_array
152
155
153
156
from typedefs cimport DTYPE_t , ITYPE_t , DITYPE_t
@@ -484,7 +487,7 @@ cdef inline void swap(DITYPE_t* arr, ITYPE_t i1, ITYPE_t i2):
484
487
485
488
486
489
cdef inline void dual_swap (DTYPE_t * darr , ITYPE_t * iarr ,
487
- ITYPE_t i1 , ITYPE_t i2 ):
490
+ ITYPE_t i1 , ITYPE_t i2 ) nogil :
488
491
"""swap the values at inex i1 and i2 of both darr and iarr"""
489
492
cdef DTYPE_t dtmp = darr [i1 ]
490
493
darr [i1 ] = darr [i2 ]
@@ -657,7 +660,7 @@ cdef class NeighborsHeap:
657
660
658
661
659
662
cdef int _simultaneous_sort (DTYPE_t * dist , ITYPE_t * idx ,
660
- ITYPE_t size ) except - 1 :
663
+ ITYPE_t size ) nogil except - 1 :
661
664
"""
662
665
Perform a recursive quicksort on the dist array, simultaneously
663
666
performing the same swaps on the idx array. The equivalent in
@@ -1446,10 +1449,15 @@ cdef class BinaryTree:
1446
1449
cdef DTYPE_t [::1 ] rarr = rarr_np
1447
1450
1448
1451
# prepare variables for iteration
1452
+ # if not count_only:
1453
+ # indices = np.zeros(Xarr.shape[0], dtype='object')
1454
+ # if return_distance:
1455
+ # distances = np.zeros(Xarr.shape[0], dtype='object')
1449
1456
if not count_only :
1450
- indices = np . zeros ( Xarr .shape [0 ], dtype = 'object' )
1457
+ indices = < ITYPE_t ** > malloc ( sizeof ( ITYPE_t * ) * Xarr .shape [0 ])
1451
1458
if return_distance :
1452
- distances = np .zeros (Xarr .shape [0 ], dtype = 'object' )
1459
+ distances = < DTYPE_t ** > malloc (sizeof (DTYPE_t * )
1460
+ * Xarr .shape [0 ])
1453
1461
1454
1462
np_idx_arr = np .zeros (self .data .shape [0 ], dtype = ITYPE )
1455
1463
idx_arr_i = np_idx_arr
@@ -1461,33 +1469,65 @@ cdef class BinaryTree:
1461
1469
counts = counts_arr
1462
1470
1463
1471
pt = & Xarr [0 , 0 ]
1464
- for i in range (Xarr .shape [0 ]):
1465
- counts [i ] = self ._query_radius_single (0 , pt , rarr [i ],
1466
- & idx_arr_i [0 ],
1467
- & dist_arr_i [0 ],
1468
- 0 , count_only ,
1469
- return_distance )
1470
- pt += n_features
1472
+ cdef int return_distance_c = return_distance
1473
+ cdef int count_only_c = count_only
1474
+ with nogil :
1475
+ for i in range (Xarr .shape [0 ]):
1476
+ counts [i ] = self ._query_radius_single (0 , pt , rarr [i ],
1477
+ & idx_arr_i [0 ],
1478
+ & dist_arr_i [0 ],
1479
+ 0 , count_only_c ,
1480
+ return_distance_c )
1481
+ pt += n_features
1471
1482
1472
- if count_only :
1473
- pass
1474
- else :
1475
- if sort_results :
1476
- _simultaneous_sort (& dist_arr_i [0 ], & idx_arr_i [0 ],
1477
- counts [i ])
1483
+ if not count_only :
1484
+ if sort_results :
1485
+ _simultaneous_sort (& dist_arr_i [0 ], & idx_arr_i [0 ],
1486
+ counts [i ])
1487
+ indices [i ] = \
1488
+ < ITYPE_t * > malloc (sizeof (ITYPE_t ) * counts [i ])
1489
+ memcpy (indices [i ], & idx_arr_i [0 ],
1490
+ sizeof (ITYPE_t ) * counts [i ])
1491
+ if return_distance_c :
1492
+ distances [i ] = \
1493
+ < DTYPE_t * > malloc (sizeof (DTYPE_t ) * counts [i ])
1494
+ memcpy (distances [i ], & dist_arr_i [0 ],
1495
+ sizeof (DTYPE_t ) * counts [i ])
1478
1496
1479
- indices [i ] = np_idx_arr [:counts [i ]].copy ()
1480
- if return_distance :
1481
- distances [i ] = np_dist_arr [:counts [i ]].copy ()
1482
1497
1483
1498
# deflatten results
1499
+ # if count_only:
1500
+ # return counts_arr.reshape(X.shape[:X.ndim - 1])
1501
+ # elif return_distance:
1502
+ # return (indices.reshape(X.shape[:X.ndim - 1]),
1503
+ # distances.reshape(X.shape[:X.ndim - 1]))
1504
+ # else:
1505
+ # return indices.reshape(X.shape[:X.ndim - 1])
1484
1506
if count_only :
1485
1507
return counts_arr .reshape (X .shape [:X .ndim - 1 ])
1486
- elif return_distance :
1487
- return (indices .reshape (X .shape [:X .ndim - 1 ]),
1488
- distances .reshape (X .shape [:X .ndim - 1 ]))
1489
1508
else :
1490
- return indices .reshape (X .shape [:X .ndim - 1 ])
1509
+ indices_arr = np .empty (Xarr .shape [0 ], dtype = 'object' )
1510
+ for i in range (indices_arr .shape [0 ]):
1511
+ arr = np .zeros (counts [i ], ITYPE )
1512
+ idx_arr_i = arr
1513
+ memcpy (& idx_arr_i [0 ], indices [i ], sizeof (ITYPE_t ) * counts [i ])
1514
+ indices_arr [i ] = arr
1515
+ free (indices [i ])
1516
+ free (indices )
1517
+ if return_distance :
1518
+ distances_arr = np .empty (Xarr .shape [0 ], dtype = 'object' )
1519
+ for i in range (distances_arr .shape [0 ]):
1520
+ arr = np .zeros (counts [i ], DTYPE )
1521
+ dist_arr_i = arr
1522
+ memcpy (& dist_arr_i [0 ], distances [i ],
1523
+ sizeof (DTYPE_t ) * counts [i ])
1524
+ distances_arr [i ] = arr
1525
+ free (distances [i ])
1526
+ free (distances )
1527
+ return indices_arr , distances_arr
1528
+ else :
1529
+ return indices_arr
1530
+
1491
1531
1492
1532
def kernel_density (self , X , h , kernel = 'gaussian' ,
1493
1533
atol = 0 , rtol = 1E-8 ,
@@ -2010,7 +2050,7 @@ cdef class BinaryTree:
2010
2050
DTYPE_t * distances ,
2011
2051
ITYPE_t count ,
2012
2052
int count_only ,
2013
- int return_distance ) except - 1 :
2053
+ int return_distance ) nogil except - 1 :
2014
2054
"""recursive single-tree radius query, depth-first"""
2015
2055
cdef DTYPE_t * data = & self .data [0 , 0 ]
2016
2056
cdef ITYPE_t * idx_array = & self .idx_array [0 ]
@@ -2038,8 +2078,9 @@ cdef class BinaryTree:
2038
2078
else :
2039
2079
for i in range (node_info .idx_start , node_info .idx_end ):
2040
2080
if (count < 0 ) or (count >= self .data .shape [0 ]):
2041
- raise ValueError ("Fatal: count too big: "
2042
- "this should never happen" )
2081
+ with gil :
2082
+ raise ValueError ("Fatal: count too big: "
2083
+ "this should never happen" )
2043
2084
indices [count ] = idx_array [i ]
2044
2085
if return_distance :
2045
2086
distances [count ] = self .dist (pt , (data + n_features
@@ -2058,8 +2099,9 @@ cdef class BinaryTree:
2058
2099
n_features )
2059
2100
if dist_pt <= reduced_r :
2060
2101
if (count < 0 ) or (count >= self .data .shape [0 ]):
2061
- raise ValueError ("Fatal: count out of range. "
2062
- "This should never happen." )
2102
+ with gil :
2103
+ raise ValueError ("Fatal: count out of range. "
2104
+ "This should never happen." )
2063
2105
if count_only :
2064
2106
pass
2065
2107
else :
0 commit comments