22
22
gen_even_slices ,
23
23
)
24
24
from ..utils ._array_api import (
25
+ _fill_or_add_to_diagonal ,
25
26
_find_matching_floating_dtype ,
26
27
_is_numpy_namespace ,
28
+ _max_precision_float_dtype ,
29
+ _modify_in_place_if_numpy ,
27
30
get_namespace ,
31
+ get_namespace_and_device ,
28
32
)
29
33
from ..utils ._chunking import get_chunk_n_rows
30
34
from ..utils ._mask import _get_mask
@@ -335,13 +339,14 @@ def euclidean_distances(
335
339
array([[1. ],
336
340
[1.41421356]])
337
341
"""
342
+ xp , _ = get_namespace (X , Y )
338
343
X , Y = check_pairwise_arrays (X , Y )
339
344
340
345
if X_norm_squared is not None :
341
346
X_norm_squared = check_array (X_norm_squared , ensure_2d = False )
342
347
original_shape = X_norm_squared .shape
343
348
if X_norm_squared .shape == (X .shape [0 ],):
344
- X_norm_squared = X_norm_squared .reshape (- 1 , 1 )
349
+ X_norm_squared = xp .reshape (X_norm_squared , ( - 1 , 1 ) )
345
350
if X_norm_squared .shape == (1 , X .shape [0 ]):
346
351
X_norm_squared = X_norm_squared .T
347
352
if X_norm_squared .shape != (X .shape [0 ], 1 ):
@@ -354,7 +359,7 @@ def euclidean_distances(
354
359
Y_norm_squared = check_array (Y_norm_squared , ensure_2d = False )
355
360
original_shape = Y_norm_squared .shape
356
361
if Y_norm_squared .shape == (Y .shape [0 ],):
357
- Y_norm_squared = Y_norm_squared .reshape (1 , - 1 )
362
+ Y_norm_squared = xp .reshape (Y_norm_squared , ( 1 , - 1 ) )
358
363
if Y_norm_squared .shape == (Y .shape [0 ], 1 ):
359
364
Y_norm_squared = Y_norm_squared .T
360
365
if Y_norm_squared .shape != (1 , Y .shape [0 ]):
@@ -375,24 +380,25 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared
375
380
float32, norms needs to be recomputed on upcast chunks.
376
381
TODO: use a float64 accumulator in row_norms to avoid the latter.
377
382
"""
378
- if X_norm_squared is not None and X_norm_squared .dtype != np .float32 :
379
- XX = X_norm_squared .reshape (- 1 , 1 )
380
- elif X .dtype != np .float32 :
381
- XX = row_norms (X , squared = True )[:, np .newaxis ]
383
+ xp , _ , device_ = get_namespace_and_device (X , Y )
384
+ if X_norm_squared is not None and X_norm_squared .dtype != xp .float32 :
385
+ XX = xp .reshape (X_norm_squared , (- 1 , 1 ))
386
+ elif X .dtype != xp .float32 :
387
+ XX = row_norms (X , squared = True )[:, None ]
382
388
else :
383
389
XX = None
384
390
385
391
if Y is X :
386
392
YY = None if XX is None else XX .T
387
393
else :
388
- if Y_norm_squared is not None and Y_norm_squared .dtype != np .float32 :
389
- YY = Y_norm_squared .reshape (1 , - 1 )
390
- elif Y .dtype != np .float32 :
391
- YY = row_norms (Y , squared = True )[np . newaxis , :]
394
+ if Y_norm_squared is not None and Y_norm_squared .dtype != xp .float32 :
395
+ YY = xp .reshape (Y_norm_squared , ( 1 , - 1 ) )
396
+ elif Y .dtype != xp .float32 :
397
+ YY = row_norms (Y , squared = True )[None , :]
392
398
else :
393
399
YY = None
394
400
395
- if X .dtype == np .float32 or Y .dtype == np .float32 :
401
+ if X .dtype == xp .float32 or Y .dtype == xp .float32 :
396
402
# To minimize precision issues with float32, we compute the distance
397
403
# matrix on chunks of X and Y upcast to float64
398
404
distances = _euclidean_distances_upcast (X , XX , Y , YY )
@@ -401,14 +407,22 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared
401
407
distances = - 2 * safe_sparse_dot (X , Y .T , dense_output = True )
402
408
distances += XX
403
409
distances += YY
404
- np .maximum (distances , 0 , out = distances )
410
+
411
+ xp_zero = xp .asarray (0 , device = device_ , dtype = distances .dtype )
412
+ distances = _modify_in_place_if_numpy (
413
+ xp , xp .maximum , distances , xp_zero , out = distances
414
+ )
405
415
406
416
# Ensure that distances between vectors and themselves are set to 0.0.
407
417
# This may not be the case due to floating point rounding errors.
408
418
if X is Y :
409
- np . fill_diagonal (distances , 0 )
419
+ _fill_or_add_to_diagonal (distances , 0 , xp = xp , add_value = False )
410
420
411
- return distances if squared else np .sqrt (distances , out = distances )
421
+ if squared :
422
+ return distances
423
+
424
+ distances = _modify_in_place_if_numpy (xp , xp .sqrt , distances , out = distances )
425
+ return distances
412
426
413
427
414
428
@validate_params (
@@ -552,15 +566,20 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None, batch_size=None):
552
566
X and Y are upcast to float64 by chunks, which size is chosen to limit
553
567
memory increase by approximately 10% (at least 10MiB).
554
568
"""
569
+ xp , _ , device_ = get_namespace_and_device (X , Y )
555
570
n_samples_X = X .shape [0 ]
556
571
n_samples_Y = Y .shape [0 ]
557
572
n_features = X .shape [1 ]
558
573
559
- distances = np .empty ((n_samples_X , n_samples_Y ), dtype = np .float32 )
574
+ distances = xp .empty ((n_samples_X , n_samples_Y ), dtype = xp .float32 , device = device_ )
560
575
561
576
if batch_size is None :
562
- x_density = X .nnz / np .prod (X .shape ) if issparse (X ) else 1
563
- y_density = Y .nnz / np .prod (Y .shape ) if issparse (Y ) else 1
577
+ x_density = (
578
+ X .nnz / xp .prod (X .shape ) if issparse (X ) else xp .asarray (1 , device = device_ )
579
+ )
580
+ y_density = (
581
+ Y .nnz / xp .prod (Y .shape ) if issparse (Y ) else xp .asarray (1 , device = device_ )
582
+ )
564
583
565
584
# Allow 10% more memory than X, Y and the distance matrix take (at
566
585
# least 10MiB)
@@ -580,15 +599,15 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None, batch_size=None):
580
599
# Hence x² + (xd+yd)kx = M, where x=batch_size, k=n_features, M=maxmem
581
600
# xd=x_density and yd=y_density
582
601
tmp = (x_density + y_density ) * n_features
583
- batch_size = (- tmp + np .sqrt (tmp ** 2 + 4 * maxmem )) / 2
602
+ batch_size = (- tmp + xp .sqrt (tmp ** 2 + 4 * maxmem )) / 2
584
603
batch_size = max (int (batch_size ), 1 )
585
604
586
605
x_batches = gen_batches (n_samples_X , batch_size )
587
-
606
+ xp_max_float = _max_precision_float_dtype ( xp = xp , device = device_ )
588
607
for i , x_slice in enumerate (x_batches ):
589
- X_chunk = X [x_slice ]. astype ( np . float64 )
608
+ X_chunk = xp . astype ( X [x_slice ], xp_max_float )
590
609
if XX is None :
591
- XX_chunk = row_norms (X_chunk , squared = True )[:, np . newaxis ]
610
+ XX_chunk = row_norms (X_chunk , squared = True )[:, None ]
592
611
else :
593
612
XX_chunk = XX [x_slice ]
594
613
@@ -601,17 +620,17 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None, batch_size=None):
601
620
d = distances [y_slice , x_slice ].T
602
621
603
622
else :
604
- Y_chunk = Y [y_slice ]. astype ( np . float64 )
623
+ Y_chunk = xp . astype ( Y [y_slice ], xp_max_float )
605
624
if YY is None :
606
- YY_chunk = row_norms (Y_chunk , squared = True )[np . newaxis , :]
625
+ YY_chunk = row_norms (Y_chunk , squared = True )[None , :]
607
626
else :
608
627
YY_chunk = YY [:, y_slice ]
609
628
610
629
d = - 2 * safe_sparse_dot (X_chunk , Y_chunk .T , dense_output = True )
611
630
d += XX_chunk
612
631
d += YY_chunk
613
632
614
- distances [x_slice , y_slice ] = d .astype (np .float32 , copy = False )
633
+ distances [x_slice , y_slice ] = xp .astype (d , xp .float32 , copy = False )
615
634
616
635
return distances
617
636
@@ -1549,13 +1568,15 @@ def rbf_kernel(X, Y=None, gamma=None):
1549
1568
array([[0.71..., 0.51...],
1550
1569
[0.51..., 0.71...]])
1551
1570
"""
1571
+ xp , _ = get_namespace (X , Y )
1552
1572
X , Y = check_pairwise_arrays (X , Y )
1553
1573
if gamma is None :
1554
1574
gamma = 1.0 / X .shape [1 ]
1555
1575
1556
1576
K = euclidean_distances (X , Y , squared = True )
1557
1577
K *= - gamma
1558
- np .exp (K , K ) # exponentiate K in-place
1578
+ # exponentiate K in-place when using numpy
1579
+ K = _modify_in_place_if_numpy (xp , xp .exp , K , out = K )
1559
1580
return K
1560
1581
1561
1582
0 commit comments