8000 MAINT Miscellaneous maintenance items the private `PairwiseDistancesR… · scikit-learn/scikit-learn@e179277 · GitHub
[go: up one dir, main page]

Skip to content

Commit e179277

Browse files
jjerphanogriselthomasjpfan
authored
MAINT Miscellaneous maintenance items the private PairwiseDistancesReductions submodule (#24542)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent f702e97 commit e179277

File tree

2 files changed

+87
-53
lines changed

2 files changed

+87
-53
lines changed

sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp

Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,11 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
162162
for i in range(n_samples_X):
163163
for j in range(n_samples_Y):
164164
heap_push(
165-
heaps_r_distances + i * self.k,
166-
heaps_indices + i * self.k,
167-
self.k,
168-
self.datasets_pair.surrogate_dist(X_start + i, Y_start + j),
169-
Y_start + j,
165+
values=heaps_r_distances + i * self.k,
166+
indices=heaps_indices + i * self.k,
167+
size=self.k,
168+
val=self.datasets_pair.surrogate_dist(X_start + i, Y_start + j),
169+
val_idx=Y_start + j,
170170
)
171171

172172
cdef void _parallel_on_X_init_chunk(
@@ -255,11 +255,11 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
255255
for thread_num in range(self.chunks_n_threads):
256256
for jdx in range(self.k):
257257
heap_push(
258-
&self.argkmin_distances[X_start + idx, 0],
259-
&self.argkmin_indices[X_start + idx, 0],
260-
self.k,
261-
self.heaps_r_distances_chunks[thread_num][idx * self.k + jdx],
262-
self.heaps_indices_chunks[thread_num][idx * self.k + jdx],
258+
values=&self.argkmin_distances[X_start + idx, 0],
259+
indices=&self.argkmin_indices[X_start + idx, 0],
260+
size=self.k,
261+
val=self.heaps_r_distances_chunks[thread_num][idx * self.k + jdx],
262+
val_idx=self.heaps_indices_chunks[thread_num][idx * self.k + jdx],
263263
)
264264

265265
cdef void _parallel_on_Y_finalize(
@@ -292,7 +292,7 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
292292
num_threads=self.effective_n_threads):
293293
for j in range(self.k):
294294
distances[i, j] = self.datasets_pair.distance_metric._rdist_to_dist(
295-
# Guard against eventual -0., causing nan production.
295+
# Guard against potential -0., causing nan production.
296296
max(distances[i, j], 0.)
297297
)
298298

@@ -304,7 +304,7 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
304304

305305
# Values are returned identically to the way `KNeighborsMixin.kneighbors`
306306
# returns values. This is 10000 counter-intuitive but this allows not using
307-
# complex adaptations where `ArgKmin{{name_suffix}}.compute` is called.
307+
# complex adaptations where `ArgKmin.compute` is called.
308308
return np.asarray(self.argkmin_distances), np.asarray(self.argkmin_indices)
309309

310310
return np.asarray(self.argkmin_indices)
@@ -330,8 +330,10 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
330330
):
331331
if (
332332
metric_kwargs is not None and
333-
len(metric_kwargs) > 0 and
334-
"Y_norm_squared" not in metric_kwargs
333+
len(metric_kwargs) > 0 and (
334+
"Y_norm_squared" not in metric_kwargs or
335+
"X_norm_squared" not in metric_kwargs
336+
)
335337
):
336338
warnings.warn(
337339
f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't "
@@ -365,20 +367,31 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
365367
metric_kwargs.pop("Y_norm_squared"),
366368
ensure_2d=False,
367369
input_name="Y_norm_squared",
368-
dtype=np.float64
370+
dtype=np.float64,
369371
)
370372
else:
371373
self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}(
372-
Y, self.effective_n_threads
374+
Y,
375+
self.effective_n_threads,
373376
)
374377

375-
# Do not recompute norms if datasets are identical.
376-
self.X_norm_squared = (
377-
self.Y_norm_squared if X is Y else
378-
_sqeuclidean_row_norms{{name_suffix}}(
379-
X, self.effective_n_threads
378+
if metric_kwargs is not None and "X_norm_squared" in metric_kwargs:
379+
self.X_norm_squared = check_array(
380+
metric_kwargs.pop("X_norm_squared"),
381+
ensure_2d=False,
382+
input_name="X_norm_squared",
383+
dtype=np.float64,
380384
)
381-
)
385+
else:
386+
# Do not recompute norms if datasets are identical.
387+
self.X_norm_squared = (
388+
self.Y_norm_squared if X is Y else
389+
_sqeuclidean_row_norms{{name_suffix}}(
390+
X,
391+
self.effective_n_threads,
392+
)
393+
)
394+
382395
self.use_squared_distances = use_squared_distances
383396

384397
@final
@@ -470,6 +483,7 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
470483
) nogil:
471484
cdef:
472485
ITYPE_t i, j
486+
DTYPE_t sqeuclidean_dist_i_j
473487
ITYPE_t n_X = X_end - X_start
474488
ITYPE_t n_Y = Y_end - Y_start
475489
DTYPE_t * dist_middle_terms = self.middle_term_computer._compute_dist_middle_terms(
@@ -483,20 +497,22 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
483497
# which keep tracks of the argkmin.
484498
for i in range(n_X):
485499
for j in range(n_Y):
500+
sqeuclidean_dist_i_j = (
501+
self.X_norm_squared[i + X_start] +
502+
dist_middle_terms[i * n_Y + j] +
503+
self.Y_norm_squared[j + Y_start]
504+
)
505+
506+
# Catastrophic cancellation might cause -0. to be present,
507+
# e.g. when computing d(x_i, y_i) when X is Y.
508+
sqeuclidean_dist_i_j = max(0., sqeuclidean_dist_i_j)
509+
486510
heap_push(
487-
heaps_r_distances + i * self.k,
488-
heaps_indices + i * self.k,
489-
self.k,
490-
# Using the squared euclidean distance as the rank-preserving distance:
491-
#
492-
# ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||²
493-
#
494-
(
495-
self.X_norm_squared[i + X_start] +
496-
dist_middle_terms[i * n_Y + j] +
497-
self.Y_norm_squared[j + Y_start]
498-
),
499-
j + Y_start,
511+
values=heaps_r_distances + i * self.k,
512+
indices=heaps_indices + i * self.k,
513+
size=self.k,
514+
val=sqeuclidean_dist_i_j,
515+
val_idx=j + Y_start,
500516
)
501517

502518
{{endfor}}

sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ cdef class RadiusNeighbors{{name_suffix}}(BaseDistancesReduction{{name_suffix}})
311311
for j in range(deref(self.neigh_indices)[i].size()):
312312
deref(self.neigh_distances)[i][j] = (
313313
self.datasets_pair.distance_metric._rdist_to_dist(
314-
# Guard against eventual -0., causing nan production.
314+
# Guard against potential -0., causing nan production.
315315
max(deref(self.neigh_distances)[i][j], 0.)
316316
)
317317
)
@@ -338,8 +338,10 @@ cdef class EuclideanRadiusNeighbors{{name_suffix}}(RadiusNeighbors{{name_suffix}
338338
):
339339
if (
340340
metric_kwargs is not None and
341-
len(metric_kwargs) > 0 and
342-
"Y_norm_squared" not in metric_kwargs
341+
len(metric_kwargs) > 0 and (
342+
"Y_norm_squared" not in metric_kwargs or
343+
"X_norm_squared" not in metric_kwargs
344+
)
343345
):
344346
warnings.warn(
345347
f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't "
@@ -374,16 +376,31 @@ cdef class EuclideanRadiusNeighbors{{name_suffix}}(RadiusNeighbors{{name_suffix}
374376
metric_kwargs.pop("Y_norm_squared"),
375377
ensure_2d=False,
376378
input_name="Y_norm_squared",
377-
dtype=np.float64
379+
dtype=np.float64,
378380
)
379381
else:
380-
self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}(Y, self.effective_n_threads)
382+
self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}(
383+
Y,
384+
self.effective_n_threads,
385+
)
386+
387+
if metric_kwargs is not None and "X_norm_squared" in metric_kwargs:
388+
self.X_norm_squared = check_array(
389+
metric_kwargs.pop("X_norm_squared"),
390+
ensure_2d=False,
391+
input_name="X_norm_squared",
392+
dtype=np.float64,
393+
)
394+
else:
395+
# Do not recompute norms if datasets are identical.
396+
self.X_norm_squared = (
397+
self.Y_norm_squared if X is Y else
398+
_sqeuclidean_row_norms{{name_suffix}}(
399+
X,
400+
self.effective_n_threads,
401+
)
402+
)
381403

382-
# Do not recompute norms if datasets are identical.
383-
self.X_norm_squared = (
384-
self.Y_norm_squared if X is Y else
385-
_sqeuclidean_row_norms{{name_suffix}}(X, self.effective_n_threads)
386-
)
387404
self.use_squared_distances = use_squared_distances
388405

389406
if use_squared_distances:
@@ -480,7 +497,7 @@ cdef class EuclideanRadiusNeighbors{{name_suffix}}(RadiusNeighbors{{name_suffix}
480497
) nogil:
481498
cdef:
482499
ITYPE_t i, j
483-
DTYPE_t squared_dist_i_j
500+
DTYPE_t sqeuclidean_dist_i_j
484501
ITYPE_t n_X = X_end - X_start
485502
ITYPE_t n_Y = Y_end - Y_start
486503
DTYPE_t *dist_middle_terms = self.middle_term_computer._compute_dist_middle_terms(
@@ -490,17 +507,18 @@ cdef class EuclideanRadiusNeighbors{{name_suffix}}(RadiusNeighbors{{name_suffix}
490507
# Pushing the distance and their associated indices in vectors.
491508
for i in range(n_X):
492509
for j in range(n_Y):
493-
# Using the squared euclidean distance as the rank-preserving distance:
494-
#
495-
# ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||²
496-
#
497-
squared_dist_i_j = (
510+
sqeuclidean_dist_i_j = (
498511
self.X_norm_squared[i + X_start]
499512
+ dist_middle_terms[i * n_Y + j]
500513
+ self.Y_norm_squared[j + Y_start]
501514
)
502-
if squared_dist_i_j <= self.r_radius:
503-
deref(self.neigh_distances_chunks[thread_num])[i + X_start].push_back(squared_dist_i_j)
515+
516+
# Catastrophic cancellation might cause -0. to be present,
517+
# e.g. when computing d(x_i, y_i) when X is Y.
518+
sqeuclidean_dist_i_j = max(0., sqeuclidean_dist_i_j)
519+
520+
if sqeuclidean_dist_i_j <= self.r_radius:
521+
deref(self.neigh_distances_chunks[thread_num])[i + X_start].push_back(sqeuclidean_dist_i_j)
504522
deref(self.neigh_indices_chunks[thread_num])[i + X_start].push_back(j + Y_start)
505523

506524
{{endfor}}

0 commit comments

Comments
 (0)
0