@@ -162,11 +162,11 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
162
162
for i in range(n_samples_X):
163
163
for j in range(n_samples_Y):
164
164
heap_push(
165
- heaps_r_distances + i * self.k,
166
- heaps_indices + i * self.k,
167
- self.k,
168
- self.datasets_pair.surrogate_dist(X_start + i, Y_start + j),
169
- Y_start + j,
165
+ values= heaps_r_distances + i * self.k,
166
+ indices= heaps_indices + i * self.k,
167
+ size= self.k,
168
+ val= self.datasets_pair.surrogate_dist(X_start + i, Y_start + j),
169
+ val_idx= Y_start + j,
170
170
)
171
171
172
172
cdef void _parallel_on_X_init_chunk(
@@ -255,11 +255,11 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
255
255
for thread_num in range(self.chunks_n_threads):
256
256
for jdx in range(self.k):
257
257
heap_push(
258
- &self.argkmin_distances[X_start + idx, 0],
259
- &self.argkmin_indices[X_start + idx, 0],
260
- self.k,
261
- self.heaps_r_distances_chunks[thread_num][idx * self.k + jdx],
262
- self.heaps_indices_chunks[thread_num][idx * self.k + jdx],
258
+ values= &self.argkmin_distances[X_start + idx, 0],
259
+ indices= &self.argkmin_indices[X_start + idx, 0],
260
+ size= self.k,
261
+ val= self.heaps_r_distances_chunks[thread_num][idx * self.k + jdx],
262
+ val_idx= self.heaps_indices_chunks[thread_num][idx * self.k + jdx],
263
263
)
264
264
265
265
cdef void _parallel_on_Y_finalize(
@@ -292,7 +292,7 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
292
292
num_threads=self.effective_n_threads):
293
293
for j in range(self.k):
294
294
distances[i, j] = self.datasets_pair.distance_metric._rdist_to_dist(
295
- # Guard against eventual -0., causing nan production.
295
+ # Guard against potential -0., causing nan production.
296
296
max(distances[i, j], 0.)
297
297
)
298
298
@@ -304,7 +304,7 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
304
304
305
305
# Values are returned identically to the way `KNeighborsMixin.kneighbors`
306
306
# returns values. This is
10000
counter-intuitive but this allows not using
307
- # complex adaptations where `ArgKmin{{name_suffix}} .compute` is called.
307
+ # complex adaptations where `ArgKmin.compute` is called.
308
308
return np.asarray(self.argkmin_distances), np.asarray(self.argkmin_indices)
309
309
310
310
return np.asarray(self.argkmin_indices)
@@ -330,8 +330,10 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
330
330
):
331
331
if (
332
332
metric_kwargs is not None and
333
- len(metric_kwargs) > 0 and
334
- "Y_norm_squared" not in metric_kwargs
333
+ len(metric_kwargs) > 0 and (
334
+ "Y_norm_squared" not in metric_kwargs or
335
+ "X_norm_squared" not in metric_kwargs
336
+ )
335
337
):
336
338
warnings.warn(
337
339
f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't "
@@ -365,20 +367,31 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
365
367
metric_kwargs.pop("Y_norm_squared"),
366
368
ensure_2d=False,
367
369
input_name="Y_norm_squared",
368
- dtype=np.float64
370
+ dtype=np.float64,
369
371
)
370
372
else:
371
373
self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}(
372
- Y, self.effective_n_threads
374
+ Y,
375
+ self.effective_n_threads,
373
376
)
374
377
375
- # Do not recompute norms if datasets are identical.
376
- self.X_norm_squared = (
377
- self.Y_norm_squared if X is Y else
378
- _sqeuclidean_row_norms{{name_suffix}}(
379
- X, self.effective_n_threads
378
+ if metric_kwargs is not None and "X_norm_squared" in metric_kwargs:
379
+ self.X_norm_squared = check_array(
380
+ metric_kwargs.pop("X_norm_squared"),
381
+ ensure_2d=False,
382
+ input_name="X_norm_squared",
383
+ dtype=np.float64,
380
384
)
381
- )
385
+ else:
386
+ # Do not recompute norms if datasets are identical.
387
+ self.X_norm_squared = (
388
+ self.Y_norm_squared if X is Y else
389
+ _sqeuclidean_row_norms{{name_suffix}}(
390
+ X,
391
+ self.effective_n_threads,
392
+ )
393
+ )
394
+
382
395
self.use_squared_distances = use_squared_distances
383
396
384
397
@final
@@ -470,6 +483,7 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
470
483
) nogil:
471
484
cdef:
472
485
ITYPE_t i, j
486
+ DTYPE_t sqeuclidean_dist_i_j
473
487
ITYPE_t n_X = X_end - X_start
474
488
ITYPE_t n_Y = Y_end - Y_start
475
489
DTYPE_t * dist_middle_terms = self.middle_term_computer._compute_dist_middle_terms(
@@ -483,20 +497,22 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
483
497
# which keep tracks of the argkmin.
484
498
for i in range(n_X):
485
499
for j in range(n_Y):
500
+ sqeuclidean_dist_i_j = (
501
+ self.X_norm_squared[i + X_start] +
502
+ dist_middle_terms[i * n_Y + j] +
503
+ self.Y_norm_squared[j + Y_start]
504
+ )
505
+
506
+ # Catastrophic cancellation might cause -0. to be present,
507
+ # e.g. when computing d(x_i, y_i) when X is Y.
508
+ sqeuclidean_dist_i_j = max(0., sqeuclidean_dist_i_j)
509
+
486
510
heap_push(
487
- heaps_r_distances + i * self.k,
488
- heaps_indices + i * self.k,
489
- self.k,
490
- # Using the squared euclidean distance as the rank-preserving distance:
491
- #
492
- # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||²
493
- #
494
- (
495
- self.X_norm_squared[i + X_start] +
496
- dist_middle_terms[i * n_Y + j] +
497
- self.Y_norm_squared[j + Y_start]
498
- ),
499
- j + Y_start,
511
+ values=heaps_r_distances + i * self.k,
512
+ indices=heaps_indices + i * self.k,
513
+ size=self.k,
514
+ val=sqeuclidean_dist_i_j,
515
+ val_idx=j + Y_start,
500
516
)
501
517
502
518
{{endfor}}
0 commit comments