8000 MAINT Remove -Wcpp warnings when compiling `_kd_tree` and `_ball_tree… · npache/scikit-learn@5564541 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5564541

Browse files
MAINT Remove -Wcpp warnings when compiling _kd_tree and _ball_tree (scikit-learn#24965)
1 parent 9527920 commit 5564541

File tree

2 files changed

+38
-28
lines changed

2 files changed

+38
-28
lines changed

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@
9898
"sklearn.metrics._pairwise_distances_reduction._argkmin",
9999
"sklearn.metrics._pairwise_distances_reduction._radius_neighbors",
100100
"sklearn.metrics._pairwise_fast",
101+
"sklearn.neighbors._ball_tree",
102+
"sklearn.neighbors._kd_tree",
101103
"sklearn.neighbors._partition_nodes",
102104
"sklearn.tree._splitter",
103105
"sklearn.tree._utils",

sklearn/neighbors/_binary_tree.pxi

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ from ..utils._typedefs import DTYPE, ITYPE
166166
from ..utils._heap cimport heap_push
167167
from ..utils._sorting cimport simultaneous_sort as _simultaneous_sort
168168

169+
# TODO: use cnp.PyArray_ENABLEFLAGS when Cython>=3.0 is used.
169170
cdef extern from "numpy/arrayobject.h":
170171
void PyArray_ENABLEFLAGS(cnp.ndarray arr, int flags)
171172

@@ -511,8 +512,8 @@ cdef class NeighborsHeap:
511512
n_nbrs : int
512513
the size of each heap.
513514
"""
514-
cdef cnp.ndarray distances_arr
515-
cdef cnp.ndarray indices_arr
515+
cdef DTYPE_t[:, ::1] distances_arr
516+
cdef ITYPE_t[:, ::1] indices_arr
516517

517518
cdef DTYPE_t[:, ::1] distances
518519
cdef ITYPE_t[:, ::1] indices
@@ -538,7 +539,7 @@ cdef class NeighborsHeap:
538539
"""
539540
if sort:
540541
self._sort()
541-
return self.distances_arr, self.indices_arr
542+
return self.distances_arr.base, self.indices_arr.base
542543

543544
cdef inline DTYPE_t largest(self, ITYPE_t row) nogil except -1:
544545
"""Return the largest distance in the given row"""
@@ -643,8 +644,8 @@ cdef class NodeHeap:
643644
644645
heap[i].val < min(heap[2 * i + 1].val, heap[2 * i + 2].val)
645646
"""
646-
cdef cnp.ndarray data_arr
647-
cdef NodeHeapData_t[::1] data
647+
cdef NodeHeapData_t[:] data_arr
648+
cdef NodeHeapData_t[:] data
648649
cdef ITYPE_t n
649650

650651
def __cinit__(self):
@@ -660,13 +661,16 @@ cdef class NodeHeap:
660661

661662
cdef int resize(self, ITYPE_t new_size) except -1:
662663
"""Resize the heap to be either larger or smaller"""
663-
cdef NodeHeapData_t *data_ptr
664-
cdef NodeHeapData_t *new_data_ptr
665-
cdef ITYPE_t i
666-
cdef ITYPE_t size = self.data.shape[0]
667-
cdef cnp.ndarray new_data_arr = np.zeros(new_size,
668-
dtype=NodeHeapData)
669-
cdef NodeHeapData_t[::1] new_data = new_data_arr
664+
cdef:
665+
NodeHeapData_t *data_ptr
666+
NodeHeapData_t *new_data_ptr
667+
ITYPE_t i
668+
ITYPE_t size = self.data.shape[0]
669+
NodeHeapData_t[:] new_data_arr = np.zeros(
670+
new_size,
671+
dtype=NodeHeapData,
672+
)
673+
NodeHeapData_t[:] new_data = new_data_arr
670674

671675
if size > 0 and new_size > 0:
672676
data_ptr = &self.data[0]
@@ -769,11 +773,11 @@ VALID_METRIC_IDS = get_valid_metric_ids(VALID_METRICS)
769773
# Binary Tree class
770774
cdef class BinaryTree:
771775

772-
cdef cnp.ndarray data_arr
773-
cdef cnp.ndarray sample_weight_arr
774-
cdef cnp.ndarray idx_array_arr
775-
cdef cnp.ndarray node_data_arr
776-
cdef cnp.ndarray node_bounds_arr
776+
cdef const DTYPE_t[:, ::1] data_arr
777+
cdef const DTYPE_t[::1] sample_weight_arr
778+
cdef const ITYPE_t[::1] idx_array_arr
779+
cdef const NodeData_t[::1] node_data_arr
780+
cdef const DTYPE_t[:, :, ::1] node_bounds_arr
777781

778782
cdef readonly const DTYPE_t[:, ::1] data
779783
cdef readonly const DTYPE_t[::1] sample_weight
@@ -869,7 +873,7 @@ cdef class BinaryTree:
869873
# Allocate tree-specific data
870874
allocate_data(self, self.n_nodes, n_features)
871875
self._recursive_build(
872-
node_data=self.node_data_arr,
876+
node_data=self.node_data_arr.base,
873877
i_node=0,
874878
idx_start=0,
875879
idx_end=n_samples
@@ -905,15 +909,15 @@ cdef class BinaryTree:
905909
"""
906910
if self.sample_weight is not None:
907911
# pass the numpy array
908-
sample_weight_arr = self.sample_weight_arr
912+
sample_weight_arr = self.sample_weight_arr.base
909913
else:
910914
# pass None to avoid confusion with the empty place holder
911915
# of size 1 from __cinit__
912916
sample_weight_arr = None
913-
return (self.data_arr,
914-
self.idx_array_arr,
915-
self.node_data_arr,
916-
self.node_bounds_arr,
917+
return (self.data_arr.base,
918+
self.idx_array_arr.base,
919+
self.node_data_arr.base,
920+
self.node_bounds_arr.base,
917921
int(self.leaf_size),
918922
int(self.n_levels),
919923
int(self.n_nodes),
@@ -993,8 +997,12 @@ cdef class BinaryTree:
993997
arrays: tuple of array
994998
Arrays for storing tree data, index, node data and node bounds.
995999
"""
996-
return (self.data_arr, self.idx_array_arr,
997-
self.node_data_arr, self.node_bounds_arr)
1000+
return (
1001+
self.data_arr.base,
1002+
self.idx_array_arr.base,
1003+
self.node_data_arr.base,
1004+
self.node_bounds_arr.base,
1005+
)
9981006

9991007
cdef inline DTYPE_t dist(self, DTYPE_t* x1, DTYPE_t* x2,
10001008
ITYPE_t size) nogil except -1:
@@ -1340,14 +1348,14 @@ cdef class BinaryTree:
13401348
# make a new numpy array that wraps the existing data
13411349
indices_npy[i] = cnp.PyArray_SimpleNewFromData(1, &counts[i], cnp.NPY_INTP, indices[i])
1342< B8E0 /td>1350
# make sure the data will be freed when the numpy array is garbage collected
1343-
PyArray_ENABLEFLAGS(indices_npy[i], cnp.NPY_OWNDATA)
1351+
PyArray_ENABLEFLAGS(indices_npy[i], cnp.NPY_ARRAY_OWNDATA)
13441352
# make sure the data is not freed twice
13451353
indices[i] = NULL
13461354

13471355
# make a new numpy array that wraps the existing data
13481356
distances_npy[i] = cnp.PyArray_SimpleNewFromData(1, &counts[i], cnp.NPY_DOUBLE, distances[i])
13491357
# make sure the data will be freed when the numpy array is garbage collected
1350-
PyArray_ENABLEFLAGS(distances_npy[i], cnp.NPY_OWNDATA)
1358+
PyArray_ENABLEFLAGS(distances_npy[i], cnp.NPY_ARRAY_OWNDATA)
13511359
# make sure the data is not freed twice
13521360
distances[i] = NULL
13531361

@@ -1360,7 +1368,7 @@ cdef class BinaryTree:
13601368
# make a new numpy array that wraps the existing data
13611369
indices_npy[i] = cnp.PyArray_SimpleNewFromData(1, &counts[i], cnp.NPY_INTP, indices[i])
13621370
# make sure the data will be freed when the numpy array is garbage collected
1363-
PyArray_ENABLEFLAGS(indices_npy[i], cnp.NPY_OWNDATA)
1371+
PyArray_ENABLEFLAGS(indices_npy[i], cnp.NPY_ARRAY_OWNDATA)
13641372
# make sure the data is not freed twice
13651373
indices[i] = NULL
13661374

0 commit comments

Comments
 (0)
0