9
9
# License: BSD 3 clause
10
10
11
11
import numpy as np
12
- from scipy import linalg
13
12
14
13
from ..base import BaseEstimator , TransformerMixin , ClassNamePrefixFeaturesOutMixin
15
14
from ..utils .validation import check_is_fitted
15
+ from ..utils ._array_api import get_namespace
16
16
from abc import ABCMeta , abstractmethod
17
17
18
18
@@ -37,13 +37,18 @@ def get_covariance(self):
37
37
cov : array of shape=(n_features, n_features)
38
38
Estimated covariance of data.
39
39
"""
40
+ xp , _ = get_namespace (self .components_ )
41
+
40
42
components_ = self .components_
41
43
exp_var = self .explained_variance_
42
44
if self .whiten :
43
- components_ = components_ * np .sqrt (exp_var [:, np .newaxis ])
44
- exp_var_diff = np .maximum (exp_var - self .noise_variance_ , 0.0 )
45
- cov = np .dot (components_ .T * exp_var_diff , components_ )
46
- cov .flat [:: len (cov ) + 1 ] += self .noise_variance_ # modify diag inplace
45
+ components_ = components_ * xp .sqrt (exp_var [:, np .newaxis ])
46
+ exp_var_diff = xp .maximum (
47
+ exp_var - self .noise_variance_ , xp .zeros_like (exp_var )
48
+ )
49
+ cov = (components_ .T * exp_var_diff ) @ components_
50
+ # TODO use views instead?
51
+ cov .reshape (- 1 )[:: len (cov ) + 1 ] += self .noise_variance_ # modify diag inplace
47
52
return cov
48
53
49
54
def get_precision (self ):
@@ -57,26 +62,33 @@ def get_precision(self):
57
62
precision : array, shape=(n_features, n_features)
58
63
Estimated precision of data.
59
64
"""
65
+ xp , _ = get_namespace (self .components_ )
66
+
60
67
n_features = self .components_ .shape [1 ]
61
68
62
69
# handle corner cases first
63
70
if self .n_components_ == 0 :
64
- return np .eye (n_features ) / self .noise_variance_
71
+ return xp .eye (n_features ) / self .noise_variance_
65
72
66
- if np .isclose (self .noise_variance_ , 0.0 , atol = 0.0 ):
67
- return linalg .inv (self .get_covariance ())
73
+ if xp .isclose (
74
+ self .noise_variance_ , xp .zeros_like (self .noise_variance_ ), atol = 0.0
75
+ ):
76
+ return xp .linalg .inv (self .get_covariance ())
68
77
69
78
# Get precision using matrix inversion lemma
70
79
components_ = self .components_
71
80
exp_var = self .explained_variance_
72
81
if self .whiten :
73
- components_ = components_ * np .sqrt (exp_var [:, np .newaxis ])
74
- exp_var_diff = np .maximum (exp_var - self .noise_variance_ , 0.0 )
75
- precision = np .dot (components_ , components_ .T ) / self .noise_variance_
76
- precision .flat [:: len (precision ) + 1 ] += 1.0 / exp_var_diff
77
- precision = np .dot (components_ .T , np .dot (linalg .inv (precision ), components_ ))
82
+ components_ = components_ * xp .sqrt (exp_var [:, np .newaxis ])
83
+ exp_var_diff = xp .maximum (
84
+ exp_var - self .noise_variance_ , xp .zeros_like (exp_var )
85
+ )
86
+ precision = components_ @ components_ .T / self .noise_variance_
87
+ # TODO use views instead?
88
+ precision .reshape (- 1 )[:: len (precision ) + 1 ] += 1.0 / exp_var_diff
89
+ precision = components_ .T @ xp .linalg .inv (precision ) @ components_
78
90
precisio
9E7A
n /= - (self .noise_variance_ ** 2 )
79
- precision .flat [:: len (precision ) + 1 ] += 1.0 / self .noise_variance_
91
+ precision .reshape ( - 1 ) [:: len (precision ) + 1 ] += 1.0 / self .noise_variance_
80
92
return precision
81
93
82
94
@abstractmethod
@@ -115,14 +127,16 @@ def transform(self, X):
115
127
Projection of X in the first principal components, where `n_samples`
116
128
is the number of samples and `n_components` is the number of the components.
117
129
"""
130
+ xp , _ = get_namespace (X )
131
+
118
132
check_is_fitted (self )
119
133
120
- X = self ._validate_data (X , dtype = [np .float64 , np .float32 ], reset = False )
134
+ X = self ._validate_data (X , dtype = [xp .float64 , xp .float32 ], reset = False )
121
135
if self .mean_ is not None :
122
136
X = X - self .mean_
123
- X_transformed = np . dot ( X , self .components_ .T )
137
+ X_transformed = X @ self .components_ .T
124
138
if self .whiten :
125
- X_transformed /= np .sqrt (self .explained_variance_ )
139
+ X_transformed /= xp .sqrt (self .explained_variance_ )
126
140
return X_transformed
127
141
128
142
def inverse_transform (self , X ):
@@ -147,16 +161,16 @@ def inverse_transform(self, X):
147
161
If whitening is enabled, inverse_transform will compute the
148
162
exact inverse operation, which includes reversing whitening.
149
163
"""
164
+ xp , _ = get_namespace (X )
165
+
150
166
if self .whiten :
151
167
return (
152
- np .dot (
153
- X ,
154
- np .sqrt (self .explained_variance_ [:, np .newaxis ]) * self .components_ ,
155
- )
168
+ X
169
+ @ (np .sqrt (self .explained_variance_ [:, np .newaxis ]) * self .components_ )
156
170
+ self .mean_
157
171
)
158
172
else :
159
- return np . dot ( X , self .components_ ) + self .mean_
173
+ return X @ self .components_ + self .mean_
160
174
161
175
@property
162
176
def _n_features_out (self ):
0 commit comments