14
14
from numbers import Integral , Real
15
15
16
16
import numpy as np
17
- from scipy import linalg
18
17
from scipy .special import gammaln
19
18
from scipy .sparse import issparse
20
19
from scipy .sparse .linalg import svds
28
27
from ..utils .validation import check_is_fitted
29
28
from ..utils ._param_validation import Interval , StrOptions
30
29
from ..utils ._param_validation import RealNotInt
30
+ from ..utils ._array_api import get_namespace , _is_torch_namespace
31
31
32
32
33
33
def _assess_dimension (spectrum , rank , n_samples ):
@@ -109,8 +109,10 @@ def _infer_dimension(spectrum, n_samples):
109
109
110
110
The returned value will be in [1, n_features - 1].
111
111
"""
112
- ll = np .empty_like (spectrum )
113
- ll [0 ] = - np .inf # we don't want to return n_components = 0
112
+ xp , _ = get_namespace (spectrum )
113
+
114
+ ll = xp .empty_like (spectrum )
115
+ ll [0 ] = - xp .inf # we don't want to return n_components = 0
114
116
for rank in range (1 , spectrum .shape [0 ]):
115
117
ll [rank ] = _assess_dimension (spectrum , rank , n_samples )
116
118
return ll .argmax ()
@@ -380,6 +382,9 @@ class PCA(_BasePCA):
380
382
"power_iteration_normalizer" : [StrOptions ({"auto" , "QR" , "LU" , "none" })],
381
383
"random_state" : ["random_state" ],
382
384
}
385
+ _pca_torch_arpack_solver_error_message : str = (
386
+ "PCA with arpack solver does not support PyTorch tensors."
387
+ )
383
388
384
389
def __init__ (
385
390
self ,
@@ -474,6 +479,7 @@ def fit_transform(self, X, y=None):
474
479
475
480
def _fit (self , X ):
476
481
"""Dispatch to the right submethod depending on the chosen solver."""
482
+ xp , _ = get_namespace (X )
477
483
478
484
# Raise an error for sparse input.
479
485
# This is more informative than the generic one raised by check_array.
@@ -482,9 +488,13 @@ def _fit(self, X):
482
488
"PCA does not support sparse input. See "
483
489
"TruncatedSVD for a possible alternative."
484
490
)
491
+ # Raise an error for torch input and arpack or randomized solver.
492
+ # TODO support randomized solver for torch tensors
493
+ if self .svd_solver in ["arpack" , "randomized" ] and _is_torch_namespace (xp ):
494
+ raise TypeError (self ._pca_torch_arpack_solver_error_message )
485
495
486
496
X = self ._validate_data (
487
- X , dtype = [np .float64 , np .float32 ], ensure_2d = True , copy = self .copy
497
+ X , dtype = [xp .float64 , xp .float32 ], ensure_2d = True , copy = self .copy
488
498
)
489
499
490
500
# Handle n_components==None
@@ -516,6 +526,8 @@ def _fit(self, X):
516
526
517
527
def _fit_full (self , X , n_components ):
518
528
"""Fit the model by computing full SVD on X."""
529
+ xp , _ = get_namespace (X )
530
+
519
531
n_samples , n_features = X .shape
520
532
521
533
if n_components == "mle" :
@@ -531,10 +543,10 @@ def _fit_full(self, X, n_components):
531
543
)
532
544
533
545
# Center data
534
- self .mean_ = np .mean (X , axis = 0 )
546
+ self .mean_ = xp .mean (X , axis = 0 )
535
547
X -= self .mean_
536
548
537
- U , S , Vt = linalg .svd (X , full_matrices = False )
549
+ U , S , Vt = xp . linalg .svd (X , full_matrices = False )
538
550
# flip eigenvectors' sign to enforce deterministic output
539
551
U , Vt = svd_flip (U , Vt )
540
552
@@ -544,7 +556,7 @@ def _fit_full(self, X, n_components):
544
556
explained_variance_ = (S ** 2 ) / (n_samples - 1 )
545
557
total_var = explained_variance_ .sum ()
546
558
explained_variance_ratio_ = explained_variance_ / total_var
547
- singular_values_ = S . copy ( ) # Store the singular values.
559
+ singular_values_ = xp . asarray ( S , copy = True ) # Store the singular values.
548
560
549
561
# Postprocess the number of components required
550
562
if n_components == "mle" :
@@ -556,7 +568,7 @@ def _fit_full(self, X, n_components):
556
568
# their variance is always greater than n_components float
557
569
# passed. More discussion in issue: #15669
558
570
ratio_cumsum = stable_cumsum (explained_variance_ratio_ )
559
- n_components = np .searchsorted (ratio_cumsum , n_components , side = "right" ) + 1
571
+ n_components = xp .searchsorted (ratio_cumsum , n_components , side = "right" ) + 1
560
572
# Compute noise covariance using Probabilistic PCA model
561
573
# The sigma2 maximum likelihood (cf. eq. 12.46)
562
574
if n_components < min (n_features , n_samples ):
@@ -577,6 +589,8 @@ def _fit_truncated(self, X, n_components, svd_solver):
577
589
"""Fit the model by computing truncated SVD (by ARPACK or randomized)
578
590
on X.
579
591
"""
592
+ xp , _ = get_namespace (X )
593
+
580
594
n_samples , n_features = X .shape
581
595
582
596
if isinstance (n_components , str ):
@@ -602,7 +616,7 @@ def _fit_truncated(self, X, n_components, svd_solver):
602
616
random_state = check_random_state (self .random_state )
603
617
604
618
# Center data
605
- self .mean_ = np .mean (X , axis = 0 )
619
+ self .mean_ = xp .mean (X , axis = 0 )
606
620
X -= self .mean_
607
621
608
622
if svd_solver == "arpack" :
@@ -636,12 +650,12 @@ def _fit_truncated(self, X, n_components, svd_solver):
636
650
# Workaround in-place variance calculation since at the time numpy
637
651
# did not have a way to calculate variance in-place.
638
652
N = X .shape [0 ] - 1
639
- np .square (X , out = X )
640
- np .sum (X , axis = 0 , out = X [0 ])
653
+ xp .square (X , out = X )
654
+ xp .sum (X , axis = 0 , out = X [0 ])
641
655
total_var = (X [0 ] / N ).sum ()
642
656
643
657
self .explained_variance_ratio_ = self .explained_variance_ / total_var
644
- self .singular_values_ = S . copy ( ) # Store the singular values.
658
+ self .singular_values_ = xp . asarray ( S ) # Store the singular values.
645
659
646
660
if self .n_components_ < min (n_features , n_samples ):
647
661
self .noise_variance_ = total_var - self .explained_variance_ .sum ()
0 commit comments