8000 PERF Implement `PairwiseDistancesReduction` backend for `KNeighbors.p… · Veghit/scikit-learn@204fd87 · GitHub
[go: up one dir, main page]

Skip to content

Commit 204fd87

Browse files
Micky774jjerphanogrisel
authored andcommitted
PERF Implement PairwiseDistancesReduction backend for KNeighbors.predict_proba (scikit-learn#24076)
Signed-off-by: Julien Jerphanion <git@jjerphan.xyz> Co-authored-by: Julien Jerphanion <git@jjerphan.xyz> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent a4e94ff commit 204fd87

File tree

12 files changed

+648
-14
lines changed

12 files changed

+648
-14
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ sklearn/metrics/_dist_metrics.pyx
9090
sklearn/metrics/_dist_metrics.pxd
9191
sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd
9292
sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx
93+
sklearn/metrics/_pairwise_distances_reduction/_argkmin_classmode.pyx
9394
sklearn/metrics/_pairwise_distances_reduction/_base.pxd
9495
sklearn/metrics/_pairwise_distances_reduction/_base.pyx
9596
sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd

doc/whats_new/v1.3.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,11 @@ Changelog
292292
dissimilarity is not a metric and cannot be supported by the BallTree.
293293
:pr:`25417` by :user:`Guillaume Lemaitre <glemaitre>`.
294294

295+
- |Enhancement| The performance of :meth:`neighbors.KNeighborsClassifier.predict`
296+
and of :meth:`neighbors.KNeighborsClassifier.predict_proba` has been improved
297+
when `n_neighbors` is large and `algorithm="brute"` with non Euclidean metrics.
298+
:pr:`24076` by :user:`Meekail Zain <micky774>`, :user:`Julien Jerphanion <jjerphan>`.
299+
295300
:mod:`sklearn.neural_network`
296301
.............................
297302

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ ignore =
9090
sklearn/metrics/_dist_metrics.pxd
9191
sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd
9292
sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx
93+
sklearn/metrics/_pairwise_distances_reduction/_argkmin_classmode.pyx
9394
sklearn/metrics/_pairwise_distances_reduction/_base.pxd
9495
sklearn/metrics/_pairwise_distances_reduction/_base.pyx
9596
sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd

setup.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,12 @@ def check_package_status(package, min_version):
280280
"include_np": True,
281281
"extra_compile_args": ["-std=c++11"],
282282
},
283+
{
284+
"sources": ["_argkmin_classmode.pyx.tp"],
285+
"language": "c++",
286+
"include_np": True,
287+
"extra_compile_args": ["-std=c++11"],
288+
},
283289
{
284290
"sources": ["_radius_neighbors.pyx.tp", "_radius_neighbors.pxd.tp"],
285291
"language": "c++",

sklearn/metrics/_pairwise_distances_reduction/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,14 @@
9090
BaseDistancesReductionDispatcher,
9191
ArgKmin,
9292
RadiusNeighbors,
93+
ArgKminClassMode,
9394
sqeuclidean_row_norms,
9495
)
9596

9697
__all__ = [
9798
"BaseDistancesReductionDispatcher",
9899
"ArgKmin",
99100
"RadiusNeighbors",
101+
"ArgKminClassMode",
100102
"sqeuclidean_row_norms",
101103
]

sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
178178
self.heaps_r_distances_chunks[thread_num] = &self.argkmin_distances[X_start, 0]
179179
self.heaps_indices_chunks[thread_num] = &self.argkmin_indices[X_start, 0]
180180

181-
@final
182181
cdef void _parallel_on_X_prange_iter_finalize(
183182
self,
184183
ITYPE_t thread_num,
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
{{py:
2+
3+
implementation_specific_values = [
4+
# Values are the following ones:
5+
#
6+
# name_suffix, INPUT_DTYPE_t, INPUT_DTYPE
7+
#
8+
# We also use the float64 dtype and C-type names as defined in
9+
# `sklearn.utils._typedefs` to maintain consistency.
10+
#
11+
('64', 'DTYPE_t', 'DTYPE'),
12+
('32', 'cnp.float32_t', 'np.float32')
13+
]
14+
15+
}}
16+
17+
from cython cimport floating, integral
18+
from cython.parallel cimport parallel, prange
19+
from libcpp.map cimport map as cpp_map, pair as cpp_pair
20+
from libc.stdlib cimport free
21+
22+
cimport numpy as cnp
23+
24+
cnp.import_array()
25+
26+
from ...utils._typedefs cimport ITYPE_t, DTYPE_t
27+
from ...utils._typedefs import ITYPE, DTYPE
28+
import numpy as np
29+
from scipy.sparse import issparse
30+
from sklearn.utils.fixes import threadpool_limits
31+
32+
cpdef enum WeightingStrategy:
33+
uniform = 0
34+
# TODO: Implement the following options, most likely in
35+
# `weighted_histogram_mode`
36+
distance = 1
37+
callable = 2
38+
39+
{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}}
40+
from ._argkmin cimport ArgKmin{{name_suffix}}
41+
from ._datasets_pair cimport DatasetsPair{{name_suffix}}
42+
43+
cdef class ArgKminClassMode{{name_suffix}}(ArgKmin{{name_suffix}}):
44+
"""
45+
{{name_suffix}}bit implementation of ArgKminClassMode.
46+
"""
47+
cdef:
48+
const ITYPE_t[:] class_membership,
49+
const ITYPE_t[:] unique_labels
50+
DTYPE_t[:, :] class_scores
51+
cpp_map[ITYPE_t, ITYPE_t] labels_to_index
52+
WeightingStrategy weight_type
53+
54+
@classmethod
55+
def compute(
56+
cls,
57+
X,
58+
Y,
59+
ITYPE_t k,
60+
weights,
61+
class_membership,
62+
unique_labels,
63+
str metric="euclidean",
64+
chunk_size=None,
65+
dict metric_kwargs=None,
66+
str strategy=None,
67+
):
68+
"""Compute the argkmin reduction with class_membership.
69+
70+
This classmethod is responsible for introspecting the arguments
71+
values to dispatch to the most appropriate implementation of
72+
:class:`ArgKminClassMode{{name_suffix}}`.
73+
74+
This allows decoupling the API entirely from the implementation details
75+
whilst maintaining RAII: all temporarily allocated datastructures necessary
76+
for the concrete implementation are therefore freed when this classmethod
77+
returns.
78+
79+
No instance _must_ directly be created outside of this class method.
80+
"""
81+
# Use a generic implementation that handles most scipy
82+
# metrics by computing the distances between 2 vectors at a time.
83+
pda = ArgKminClassMode{{name_suffix}}(
84+
datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs),
85+
k=k,
86+
chunk_size=chunk_size,
87+
strategy=strategy,
88+
weights=weights,
89+
class_membership=class_membership,
90+
unique_labels=unique_labels,
91+
)
92+
93+
# Limit the number of threads in second level of nested parallelism for BLAS
94+
# to avoid threads over-subscription (in GEMM for instance).
95+
with threadpool_limits(limits=1, user_api="blas"):
96+
if pda.execute_in_parallel_on_Y:
97+
pda._parallel_on_Y()
98+
else:
99+
pda._parallel_on_X()
100+
101+
return pda._finalize_results()
102+
103+
def __init__(
104+
self,
105+
DatasetsPair{{name_suffix}} datasets_pair,
106+
const ITYPE_t[:] class_membership,
107+
const ITYPE_t[:] unique_labels,
108+
chunk_size=None,
109+
strategy=None,
110+
ITYPE_t k=1,
111+
weights=None,
112+
):
113+
super().__init__(
114+
datasets_pair=datasets_pair,
115+
chunk_size=chunk_size,
116+
strategy=strategy,
117+
k=k,
118+
)
119+
120+
if weights == "uniform":
121+
self.weight_type = WeightingStrategy.uniform
122+
elif weights == "distance":
123+
self.weight_type = WeightingStrategy.distance
124+
else:
125+
self.weight_type = WeightingStrategy.callable
126+
self.class_membership = class_membership
127+
128+
self.unique_labels = unique_labels
129+
130+
cdef ITYPE_t idx, neighbor_class_idx
131+
# Map from set of unique labels to their indices in `class_scores`
132+
# Buffer used in building a histogram for one-pass weighted mode
133+
self.class_scores = np.zeros(
134+
(self.n_samples_X, unique_labels.shape[0]), dtype=DTYPE,
135+
)
136+
137+
def _finalize_results(self):
138+
probabilities = np.asarray(self.class_scores)
139+
probabilities /= probabilities.sum(axis=1, keepdims=True)
< BEA9 code>140+
return probabilities
141+
142+
cdef inline void weighted_histogram_mode(
143+
self,
144+
ITYPE_t sample_index,
145+
ITYPE_t* indices,
146+
DTYPE_t* distances,
147+
) noexcept nogil:
148+
cdef:
149+
ITYPE_t neighbor_idx, neighbor_class_idx, label_index, multi_output_index
150+
DTYPE_t score_incr = 1
151+
# TODO: Implement other WeightingStrategy values
152+
bint use_distance_weighting = (
153+
self.weight_type == WeightingStrategy.distance
154+
)
155+
156+
# Iterate through the sample k-nearest neighbours
157+
for neighbor_rank in range(self.k):
158+
# Absolute indice of the neighbor_rank-th Nearest Neighbors
159+
# in range [0, n_samples_Y)
160+
# TODO: inspect if it worth permuting this condition
161+
# and the for-loop above for improved branching.
162+
if use_distance_weighting:
163+
score_incr = 1 / distances[neighbor_rank]
164+
neighbor_idx = indices[neighbor_rank]
165+
neighbor_class_idx = self.class_membership[neighbor_idx]
166+
self.class_scores[sample_index][neighbor_class_idx] += score_incr
167+
return
168+
169+
cdef void _parallel_on_X_prange_iter_finalize(
170+
self,
171+
ITYPE_t thread_num,
172+
ITYPE_t X_start,
173+
ITYPE_t X_end,
174+
) noexcept nogil:
175+
cdef:
176+
ITYPE_t idx, sample_index
177+
for idx in range(X_end - X_start):
178+
# One-pass top-one weighted mode
179+
# Compute the absolute index in [0, n_samples_X)
180+
sample_index = X_start + idx
181+
self.weighted_histogram_mode(
182+
sample_index,
183+
&self.heaps_indices_chunks[thread_num][idx * self.k],
184+
&self.heaps_r_distances_chunks[thread_num][idx * self.k],
185+
)
186+
return
187+
188+
cdef void _parallel_on_Y_finalize(
189+
self,
190+
) noexcept nogil:
191+
cdef:
192+
ITYPE_t sample_index, thread_num
193+
194+
with nogil, parallel(num_threads=self.chunks_n_threads):
195+
# Deallocating temporary datastructures
196+
for thread_num in prange(self.chunks_n_threads, schedule='static'):
197+
free(self.heaps_r_distances_chunks[thread_num])
198+
free(self.heaps_indices_chunks[thread_num])
199+
200+
for sample_index in prange(self.n_samples_X, schedule='static'):
201+
self.weighted_histogram_mode(
202+
sample_index,
203+
&self.argkmin_indices[sample_index][0],
204+
&self.argkmin_distances[sample_index][0],
205+
)
206+
return
207+
208+
{{endfor}}

0 commit comments

Comments
 (0)
0