10000 ENH Support get_precision and get_covariance · scikit-learn/scikit-learn@44d5d75 · GitHub
[go: up one dir, main page]

Skip to content

Commit 44d5d75

Browse files
committed
ENH Support get_precision and get_covariance
1 parent dd4c9fc commit 44d5d75

File tree

3 files changed

+54
-28
lines changed

3 files changed

+54
-28
lines changed

doc/whats_new/v1.3.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,9 @@ Changelog
230230
:class:`decomposition.MiniBatchNMF` which can produce different results than previous
231231
versions. :pr:`25438` by :user:`Yotam Avidar-Constantini <yotamcons>`.
232232

233+
- |Enhancement| :class:`decomposition/PCA` now supports the
234+
`PyTorch <https://pytorch.org/>`__ for `full` solver. See :pr:`26315`
235+
233236
:mod:`sklearn.discriminant_analysis`
234237
....................................
235238

sklearn/decomposition/_base.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
# License: BSD 3 clause
1010

1111
import numpy as np
12-
from scipy import linalg
1312

1413
from ..base import BaseEstimator, TransformerMixin, ClassNamePrefixFeaturesOutMixin
1514
from ..utils.validation import check_is_fitted
15+
from ..utils._array_api import get_namespace
1616
from abc import ABCMeta, abstractmethod
1717

1818

@@ -37,13 +37,18 @@ def get_covariance(self):
3737
cov : array of shape=(n_features, n_features)
3838
Estimated covariance of data.
3939
"""
40+
xp, _ = get_namespace(self.components_)
41+
4042
components_ = self.components_
4143
exp_var = self.explained_variance_
4244
if self.whiten:
43-
components_ = components_ * np.sqrt(exp_var[:, np.newaxis])
44-
exp_var_diff = np.maximum(exp_var - self.noise_variance_, 0.0)
45-
cov = np.dot(components_.T * exp_var_diff, components_)
46-
cov.flat[:: len(cov) + 1] += self.noise_variance_ # modify diag inplace
45+
components_ = components_ * xp.sqrt(exp_var[:, np.newaxis])
46+
exp_var_diff = xp.maximum(
47+
exp_var - self.noise_variance_, xp.zeros_like(exp_var)
48+
)
49+
cov = (components_.T * exp_var_diff) @ components_
50+
# TODO use views instead?
51+
cov.reshape(-1)[:: len(cov) + 1] += self.noise_variance_ # modify diag inplace
4752
return cov
4853

4954
def get_precision(self):
@@ -57,26 +62,33 @@ def get_precision(self):
5762
precision : array, shape=(n_features, n_features)
5863
Estimated precision of data.
5964
"""
65+
xp, _ = get_namespace(self.components_)
66+
6067
n_features = self.components_.shape[1]
6168

6269
# handle corner cases first
6370
if self.n_components_ == 0:
64-
return np.eye(n_features) / self.noise_variance_
71+
return xp.eye(n_features) / self.noise_variance_
6572

66-
if np.isclose(self.noise_variance_, 0.0, atol=0.0):
67-
return linalg.inv(self.get_covariance())
73+
if xp.isclose(
74+
self.noise_variance_, xp.zeros_like(self.noise_variance_), atol=0.0
75+
):
76+
return xp.linalg.inv(self.get_covariance())
6877

6978
# Get precision using matrix inversion lemma
7079
components_ = self.components_
7180
exp_var = self.explained_variance_
7281
if self.whiten:
73-
components_ = components_ * np.sqrt(exp_var[:, np.newaxis])
74-
exp_var_diff = np.maximum(exp_var - self.noise_variance_, 0.0)
75-
precision = np.dot(components_, components_.T) / self.noise_variance_
76-
precision.flat[:: len(precision) + 1] += 1.0 / exp_var_diff
77-
precision = np.dot(components_.T, np.dot(linalg.inv(precision), components_))
82+
components_ = components_ * xp.sqrt(exp_var[:, np.newaxis])
83+
exp_var_diff = xp.maximum(
84+
exp_var - self.noise_variance_, xp.zeros_like(exp_var)
85+
)
86+
precision = components_ @ components_.T / self.noise_variance_
87+
# TODO use views instead?
88+
precision.reshape(-1)[:: len(precision) + 1] += 1.0 / exp_var_diff
89+
precision = components_.T @ xp.linalg.inv(precision) @ components_
7890
precisio 9E7A n /= -(self.noise_variance_**2)
79-
precision.flat[:: len(precision) + 1] += 1.0 / self.noise_variance_
91+
precision.reshape(-1)[:: len(precision) + 1] += 1.0 / self.noise_variance_
8092
return precision
8193

8294
@abstractmethod
@@ -115,14 +127,16 @@ def transform(self, X):
115127
Projection of X in the first principal components, where `n_samples`
116128
is the number of samples and `n_components` is the number of the components.
117129
"""
130+
xp, _ = get_namespace(X)
131+
118132
check_is_fitted(self)
119133

120-
X = self._validate_data(X, dtype=[np.float64, np.float32], reset=False)
134+
X = self._validate_data(X, dtype=[xp.float64, xp.float32], reset=False)
121135
if self.mean_ is not None:
122136
X = X - self.mean_
123-
X_transformed = np.dot(X, self.components_.T)
137+
X_transformed = X @ self.components_.T
124138
if self.whiten:
125-
X_transformed /= np.sqrt(self.explained_variance_)
139+
X_transformed /= xp.sqrt(self.explained_variance_)
126140
return X_transformed
127141

128142
def inverse_transform(self, X):
@@ -147,16 +161,16 @@ def inverse_transform(self, X):
147161
If whitening is enabled, inverse_transform will compute the
148162
exact inverse operation, which includes reversing whitening.
149163
"""
164+
xp, _ = get_namespace(X)
165+
150166
if self.whiten:
151167
return (
152-
np.dot(
153-
X,
154-
np.sqrt(self.explained_variance_[:, np.newaxis]) * self.components_,
155-
)
168+
X
169+
@ (np.sqrt(self.explained_variance_[:, np.newaxis]) * self.components_)
156170
+ self.mean_
157171
)
158172
else:
159-
return np.dot(X, self.components_) + self.mean_
173+
return X @ self.components_ + self.mean_
160174

161175
@property
162176
def _n_features_out(self):

sklearn/decomposition/tests/test_pca.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,22 @@ def test_pca_array_torch(device, dtype, svd_solver, n_components):
5050
X_transformed_torch = pca_torch.fit_transform(X_torch)
5151
X_transformed_np = pca_np.fit_transform(X_np)
5252

53-
assert type(X_transformed_np) == np.ndarray, "Invalid type"
54-
assert type(X_transformed_torch) == torch.Tensor, "Invalid type"
55-
assert_allclose(X_transformed_np, X_transformed_torch, atol=1e-3)
53+
cov_np = pca_np.get_covariance()
54+
cov_torch = pca_torch.get_covariance()
5655

57-
# TODO introduce pytorch support for below methods
58-
# cov = pca.get_covariance()
59-
# precision = pca.get_precision()
56+
precision_np = pca_np.get_precision()
57+
precision_torch = pca_torch.get_precision()
58+
59+
for name, arr_np, arr_torch in zip(
60+
["X", "cov", "prec"],
61+
[X_transformed_np, cov_np, precision_np],
62+
[X_transformed_torch, cov_torch, precision_torch],
63+
):
64+
assert type(arr_np) == np.ndarray, f"Invalid type for {name}"
65+
assert type(arr_torch) == torch.Tensor, f"Invalid type for {name}"
66+
assert_allclose(
67+
arr_np, arr_torch, atol=1e-3, err_msg=f"Divergent values for {name}"
68+
)
6069

6170

6271
@pytest.mark.parametrize("svd_solver", PCA_SOLVERS)

0 commit comments

Comments
 (0)
0