@@ -343,49 +343,57 @@ def test_paired_distances_callable():
343
343
def test_pairwise_distances_argmin_min ():
344
344
# Check pairwise minimum distances computation for any metric
345
345
X = [[0 ], [1 ]]
346
- Y = [[- 1 ], [2 ]]
346
+ Y = [[- 2 ], [3 ]]
347
347
348
348
Xsp = dok_matrix (X )
349
349
Ysp = csr_matrix (Y , dtype = np .float32 )
350
350
351
- # euclidean metric
352
- D , E = pairwise_distances_argmin_min (X , Y , metric = "euclidean" )
353
- D2 = pairwise_distances_argmin (X , Y , metric = "euclidean" )
354
- assert_array_almost_equal (D , [0 , 1 ])
355
- assert_array_almost_equal (D2 , [0 , 1 ])
356
- assert_array_almost_equal (D , [0 , 1 ])
357
- assert_array_almost_equal (E , [1. , 1. ])
351
+ expected_idx = [0 , 1 ]
352
+ expected_vals = [2 , 2 ]
353
+ expected_vals_sq = [4 , 4 ]
358
354
355
+ # euclidean metric
356
+ idx , vals = pairwise_distances_argmin_min (X , Y , metric = "euclidean" )
357
+ idx2 = pairwise_distances_argmin (X , Y , metric = "euclidean" )
358
+ assert_array_almost_equal (idx , expected_idx )
359
+ assert_array_almost_equal (idx2 , expected_idx )
360
+ assert_array_almost_equal (vals , expected_vals )
359
361
# sparse matrix case
360
- Dsp , Esp = pairwise_distances_argmin_min (Xsp , Ysp , metric = "euclidean" )
361
- assert_array_equal ( Dsp , D )
362
- assert_array_equal ( Esp , E )
362
+ idxsp , valssp = pairwise_distances_argmin_min (Xsp , Ysp , metric = "euclidean" )
363
+ assert_array_almost_equal ( idxsp , expected_idx )
364
+ assert_array_almost_equal ( valssp , expected_vals )
363
365
# We don't want np.matrix here
364
- assert_equal (type (Dsp ), np .ndarray )
365
- assert_equal (type (Esp ), np .ndarray )
366
+ assert_equal (type (idxsp ), np .ndarray )
367
+ assert_equal (type (valssp ), np .ndarray )
368
+
369
+ # euclidean metric squared
370
+ idx , vals = pairwise_distances_argmin_min (X , Y , metric = "euclidean" ,
371
+ metric_kwargs = {"squared" : True })
372
+ assert_array_almost_equal (idx , expected_idx )
373
+ assert_array_almost_equal (vals , expected_vals_sq )
366
374
367
375
# Non-euclidean scikit-learn metric
368
- D , E = pairwise_distances_argmin_min (X , Y , metric = "manhattan" )
369
- D2 = pairwise_distances_argmin (X , Y , metric = "manhattan" )
370
- assert_array_almost_equal (D , [ 0 , 1 ] )
371
- assert_array_almost_equal (D2 , [ 0 , 1 ] )
372
- assert_array_almost_equal (E , [ 1. , 1. ] )
373
- D , E = pairwise_distances_argmin_min ( Xsp , Ysp , metric = "manhattan" )
374
- D2 = pairwise_distances_argmin (Xsp , Ysp , metric = "manhattan" )
375
- assert_array_almost_equal (D , [ 0 , 1 ] )
376
- assert_array_almost_equal (E , [ 1. , 1. ] )
376
+ idx , vals = pairwise_distances_argmin_min (X , Y , metric = "manhattan" )
377
+ idx2 = pairwise_distances_argmin (X , Y , metric = "manhattan" )
378
+ assert_array_almost_equal (idx , expected_idx )
379
+ assert_array_almost_equal (idx2 , expected_idx )
380
+ assert_array_almost_equal (vals , expected_vals )
381
+ # sparse matrix case
382
+ idxsp , valssp = pairwise_distances_argmin_min (Xsp , Ysp , metric = "manhattan" )
383
+ assert_array_almost_equal (idxsp , expected_idx )
384
+ assert_array_almost_equal (valssp , expected_vals )
377
385
378
386
# Non-euclidean Scipy distance (callable)
379
- D , E = pairwise_distances_argmin_min (X , Y , metric = minkowski ,
380
- metric_kwargs = {"p" : 2 })
381
- assert_array_almost_equal (D , [ 0 , 1 ] )
382
- assert_array_almost_equal (E , [ 1. , 1. ] )
387
+ idx , vals = pairwise_distances_argmin_min (X , Y , metric = minkowski ,
388
+ metric_kwargs = {"p" : 2 })
389
+ assert_array_almost_equal (idx , expected_idx )
390
+ assert_array_almost_equal (vals , expected_vals )
383
391
384
392
# Non-euclidean Scipy distance (string)
385
- D , E = pairwise_distances_argmin_min (X , Y , metric = "minkowski" ,
386
- metric_kwargs = {"p" : 2 })
387
- assert_array_almost_equal (D , [ 0 , 1 ] )
388
- assert_array_almost_equal (E , [ 1. , 1. ] )
393
+ idx , vals = pairwise_distances_argmin_min (X , Y , metric = "minkowski" ,
394
+ metric_kwargs = {"p" : 2 })
395
+ assert_array_almost_equal (idx , expected_idx )
396
+ assert_array_almost_equal (vals , expected_vals )
389
397
390
398
# Compare with naive implementation
391
399
rng = np .random .RandomState (0 )
0 commit comments