12
12
BOOL_METRICS ,
13
13
METRIC_MAPPING64 ,
14
14
DistanceMetric ,
15
-
16
15
)
17
16
from ._argkmin import (
18
17
ArgKmin32 ,
32
31
RadiusNeighborsClassMode64 ,
33
32
)
34
33
34
+
35
35
def sqeuclidean_row_norms (X , num_threads ):
36
36
"""Compute the squared euclidean norm of the rows of X in parallel.
37
37
@@ -81,7 +81,9 @@ def valid_metrics(cls) -> List[str]:
81
81
"hamming" ,
82
82
* BOOL_METRICS ,
83
83
}
84
- return sorted (({"sqeuclidean" , "precomputed" } | set (METRIC_MAPPING64 .keys ())) - excluded )
84
+ return sorted (
85
+ ({"sqeuclidean" , "precomputed" } | set (METRIC_MAPPING64 .keys ())) - excluded
86
+ )
85
87
86
88
@classmethod
87
89
def is_usable_for (cls , X , Y , metric ) -> bool :
@@ -105,12 +107,8 @@ def is_usable_for(cls, X, Y, metric) -> bool:
105
107
-------
106
108
True if the dispatcher can be used, else False.
107
109
"""
108
-
109
- #TODO: have the option to use the dispatcher but pass on the compute function
110
- # instead call the precomputed distance for (i,j)
111
- #question: for the rest of the indices should the compute method be defined?
112
- # look at EuclideanRadiusNeighbors{32,64}
113
- if metric == 'precomputed' :
110
+
111
+ if metric == "precomputed" :
114
112
return True
115
113
116
114
# FIXME: the current Cython implementation is too slow for a large number of
@@ -176,8 +174,7 @@ def compute(
176
174
This method is an abstract class method: it has to be implemented
177
175
for all subclasses.
178
176
"""
179
-
180
-
177
+
181
178
182
179
183
180
class ArgKmin (BaseDistancesReductionDispatcher ):
@@ -201,6 +198,7 @@ def compute(
201
198
Y ,
202
199
k ,
203
200
metric = "euclidean" ,
201
+ precomputed_matrix = None ,
204
202
chunk_size = None ,
205
203
metric_kwargs = None ,
206
204
strategy = None ,
@@ -430,10 +428,10 @@ def compute(
430
428
for the concrete implementation are therefore freed when this classmethod
431
429
returns.
432
430
"""
433
- #TODO: to maintain RAII, look at the implementation of the compute method
434
- if metric == 'precomputed' :
435
- return PrecomputedDistanceMatrix .precomputed_distance ()
436
-
431
+ # TODO: to maintain RAII, look at the implementation of the compute method
432
+ # if metric == 'precomputed':
433
+ # return PrecomputedDistanceMatrix.precomputed_distance()
434
+
437
435
if X .dtype == Y .dtype == np .float64 :
438
436
return RadiusNeighbors64 .compute (
439
437
X = X ,
@@ -591,9 +589,7 @@ def compute(
591
589
for the concrete implementation are therefore freed when this classmethod
592
590
returns.
593
591
"""
594
- if metric == "precomputed" :
595
- return PrecomputedDistanceMatrix .precomputed_distance ()
596
-
592
+
597
593
if weights not in {"uniform" , "distance" }:
598
594
raise ValueError (
599
595
"Only the 'uniform' or 'distance' weights options are supported"
0 commit comments