8000 Alter the compute method of ArgKminClassMode to take in precomputed m… · scikit-learn/scikit-learn@e2ef246 · GitHub
[go: up one dir, main page]

Skip to content

Commit e2ef246

Browse files
kyrajeepjjerphan
authored andcommitted
Alter the compute method of ArgKminClassMode to take in precomputed matrix stored in datasets_pair.pyx
1 parent d961e3b commit e2ef246

File tree

3 files changed

+23
-20
lines changed

3 files changed

+23
-20
lines changed

sklearn/metrics/_pairwise_distances_reduction/_argkmin_classmode.pyx.tp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ cdef class ArgKminClassMode{{name_suffix}}(ArgKmin{{name_suffix}}):
3939
chunk_size=None,
4040
dict metric_kwargs=None,
4141
str strategy=None,
42+
precomputed_matrix = None
4243
):
4344
"""Compute the argkmin reduction with Y_labels.
4445

@@ -55,8 +56,14 @@ cdef class ArgKminClassMode{{name_suffix}}(ArgKmin{{name_suffix}}):
5556
"""
5657
# Use a generic implementation that handles most scipy
5758
# metrics by computing the distances between 2 vectors at a time.
58-
if metric == 'precomputed':
59-
return PrecomputedDistanceMatrix.dist(i,j)
59+
60+
# If a precomputed matrix is provided at the class level, skip the
61+
# rest of the computation and just return the variable stored in the
62+
# datasets_pair file. TODO: check syntax
63+
64+
if precomputed_matrix:
65+
return PrecomputedDistanceMatrix.precomputed
66+
6067
pda = ArgKminClassMode{{name_suffix}}(
6168
datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs),
6269
k=k,

sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ cdef class BaseDistancesReduction{{name_suffix}}:
129129
minimal redundant code.
130130

131131
If metric is 'precomputed' and the precomputed matrix is provided,
132-
able to return it.
132+
a subclass must be able to access it through the compute method.
133133
"""
134134

135135

sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
BOOL_METRICS,
1313
METRIC_MAPPING64,
1414
DistanceMetric,
15-
1615
)
1716
from ._argkmin import (
1817
ArgKmin32,
@@ -32,6 +31,7 @@
3231
RadiusNeighborsClassMode64,
3332
)
3433

34+
3535
def sqeuclidean_row_norms(X, num_threads):
3636
"""Compute the squared euclidean norm of the rows of X in parallel.
3737
@@ -81,7 +81,9 @@ def valid_metrics(cls) -> List[str]:
8181
"hamming",
8282
*BOOL_METRICS,
8383
}
84-
return sorted(({"sqeuclidean", "precomputed"} | set(METRIC_MAPPING64.keys())) - excluded)
84+
return sorted(
85+
({"sqeuclidean", "precomputed"} | set(METRIC_MAPPING64.keys())) - excluded
86+
)
8587

8688
@classmethod
8789
def is_usable_for(cls, X, Y, metric) -> bool:
@@ -105,12 +107,8 @@ def is_usable_for(cls, X, Y, metric) -> bool:
105107
-------
106108
True if the dispatcher can be used, else False.
107109
"""
108-
109-
#TODO: have the option to use the dispatcher but pass on the compute function
110-
# instead call the precomputed distance for (i,j)
111-
#question: for the rest of the indices should the compute method be defined?
112-
# look at EuclideanRadiusNeighbors{32,64}
113-
if metric == 'precomputed':
110+
111+
if metric == "precomputed":
114112
return True
115113

116114
# FIXME: the current Cython implementation is too slow for a large number of
@@ -176,8 +174,7 @@ def compute(
176174
This method is an abstract class method: it has to be implemented
177175
for all subclasses.
178176
"""
179-
180-
177+
181178

182179

183180
class ArgKmin(BaseDistancesReductionDispatcher):
@@ -201,6 +198,7 @@ def compute(
201198
Y,
202199
k,
203200
metric="euclidean",
201+
precomputed_matrix = None,
204202
chunk_size=None,
205203
metric_kwargs=None,
206204
strategy=None,
@@ -430,10 +428,10 @@ def compute(
430428
for the concrete implementation are therefore freed when this classmethod
431429
returns.
432430
"""
433-
#TODO: to maintain RAII, look at the implementation of the compute method
434-
if metric == 'precomputed':
435-
return PrecomputedDistanceMatrix.precomputed_distance()
436-
431+
# TODO: to maintain RAII, look at the implementation of the compute method
432+
# if metric == 'precomputed':
433+
# return PrecomputedDistanceMatrix.precomputed_distance()
434+
437435
if X.dtype == Y.dtype == np.float64:
438436
return RadiusNeighbors64.compute(
439437
X=X,
@@ -591,9 +589,7 @@ def compute(
591589
for the concrete implementation are therefore freed when this classmethod
592590
returns.
593591
"""
594-
if metric == "precomputed":
595-
return PrecomputedDistanceMatrix.precomputed_distance()
596-
592+
597593
if weights not in {"uniform", "distance"}:
598594
raise ValueError(
599595
"Only the 'uniform' or 'distance' weights options are supported"

0 commit comments

Comments
 (0)
0