8000 ENH Adds PyTorch support for PCA · scikit-learn/scikit-learn@dd4c9fc · GitHub
[go: up one dir, main page]

Skip to content

Commit dd4c9fc

Browse files
committed
ENH Adds PyTorch support for PCA
1 parent 1882672 commit dd4c9fc

File tree

4 files changed

+86
-21
lines changed

4 files changed

+86
-21
lines changed

sklearn/decomposition/_pca.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from numbers import Integral, Real
1515

1616
import numpy as np
17-
from scipy import linalg
1817
from scipy.special import gammaln
1918
from scipy.sparse import issparse
2019
from scipy.sparse.linalg import svds
@@ -28,6 +27,7 @@
2827
from ..utils.validation import check_is_fitted
2928
from ..utils._param_validation import Interval, StrOptions
3029
from ..utils._param_validation import RealNotInt
30+
from ..utils._array_api import get_namespace, _is_torch_namespace
3131

3232

3333
def _assess_dimension(spectrum, rank, n_samples):
@@ -109,8 +109,10 @@ def _infer_dimension(spectrum, n_samples):
109109
110110
The returned value will be in [1, n_features - 1].
111111
"""
112-
ll = np.empty_like(spectrum)
113-
ll[0] = -np.inf # we don't want to return n_components = 0
112+
xp, _ = get_namespace(spectrum)
113+
114+
ll = xp.empty_like(spectrum)
115+
ll[0] = -xp.inf # we don't want to return n_components = 0
114116
for rank in range(1, spectrum.shape[0]):
115117
ll[rank] = _assess_dimension(spectrum, rank, n_samples)
116118
return ll.argmax()
@@ -380,6 +382,9 @@ class PCA(_BasePCA):
380382
"power_iteration_normalizer": [StrOptions({"auto", "QR", "LU", "none"})],
381383
"random_state": ["random_state"],
382384
}
385+
_pca_torch_arpack_solver_error_message: str = (
386+
"PCA with arpack solver does not support PyTorch tensors."
387+
)
383388

384389
def __init__(
385390
self,
@@ -474,6 +479,7 @@ def fit_transform(self, X, y=None):
474479

475480
def _fit(self, X):
476481
"""Dispatch to the right submethod depending on the chosen solver."""
482+
xp, _ = get_namespace(X)
477483

478484
# Raise an error for sparse input.
479485
# This is more informative than the generic one raised by check_array.
@@ -482,9 +488,13 @@ def _fit(self, X):
482488
"PCA does not support sparse input. See "
483489
"TruncatedSVD for a possible alternative."
484490
)
491+
# Raise an error for torch input and arpack or randomized solver.
492+
# TODO support randomized solver for torch tensors
493+
if self.svd_solver in ["arpack", "randomized"] and _is_torch_namespace(xp):
494+
raise TypeError(self._pca_torch_arpack_solver_error_message)
485495

486496
X = self._validate_data(
487-
X, dtype=[np.float64, np.float32], ensure_2d=True, copy=self.copy
497+
X, dtype=[xp.float64, xp.float32], ensure_2d=True, copy=self.copy
488498
)
489499

490500
# Handle n_components==None
@@ -516,6 +526,8 @@ def _fit(self, X):
516526

517527
def _fit_full(self, X, n_components):
518528
"""Fit the model by computing full SVD on X."""
529+
xp, _ = get_namespace(X)
530+
519531
n_samples, n_features = X.shape
520532

521533
if n_components == "mle":
@@ -531,10 +543,10 @@ def _fit_full(self, X, n_components):
531543
)
532544

533545
# Center data
534-
self.mean_ = np.mean(X, axis=0)
546+
self.mean_ = xp.mean(X, axis=0)
535547
X -= self.mean_
536548

537-
U, S, Vt = linalg.svd(X, full_matrices=False)
549+
U, S, Vt = xp.linalg.svd(X, full_matrices=False)
538550
# flip eigenvectors' sign to enforce deterministic output
539551
U, Vt = svd_flip(U, Vt)
540552

@@ -544,7 +556,7 @@ def _fit_full(self, X, n_components):
544556
explained_variance_ = (S**2) / (n_samples - 1)
545557
total_var = explained_variance_.sum()
546558
explained_variance_ratio_ = explained_variance_ / total_var
547-
singular_values_ = S.copy() # Store the singular values.
559+
singular_values_ = xp.asarray(S, copy=True) # Store the singular values.
548560

549561
# Postprocess the number of components required
550562
if n_components == "mle":
@@ -556,7 +568,7 @@ def _fit_full(self, X, n_components):
556568
# their variance is always greater than n_components float
557569
# passed. More discussion in issue: #15669
558570
ratio_cumsum = stable_cumsum(explained_variance_ratio_)
559-
n_components = np.searchsorted(ratio_cumsum, n_components, side="right") + 1
571+
n_components = xp.searchsorted(ratio_cumsum, n_components, side="right") + 1
560572
# Compute noise covariance using Probabilistic PCA model
561573
# The sigma2 maximum likelihood (cf. eq. 12.46)
562574
if n_components < min(n_features, n_samples):
@@ -577,6 +589,8 @@ def _fit_truncated(self, X, n_components, svd_solver):
577589
"""Fit the model by computing truncated SVD (by ARPACK or randomized)
578590
on X.
579591
"""
592+
xp, _ = get_namespace(X)
593+
580594
n_samples, n_features = X.shape
581595

582596
if isinstance(n_components, str):
@@ -602,7 +616,7 @@ def _fit_truncated(self, X, n_components, svd_solver):
602616
random_state = check_random_state(self.random_state)
603617

604618
# Center data
605-
self.mean_ = np.mean(X, axis=0)
619+
self.mean_ = xp.mean(X, axis=0)
606620
X -= self.mean_
607621

608622
if svd_solver == "arpack":
@@ -636,12 +650,12 @@ def _fit_truncated(self, X, n_components, svd_solver):
636650
# Workaround in-place variance calculation since at the time numpy
637651
# did not have a way to calculate variance in-place.
638652
N = X.shape[0] - 1
639-
np.square(X, out=X)
640-
np.sum(X, axis=0, out=X[0])
653+
xp.square(X, out=X)
654+
xp.sum(X, axis=0, out=X[0])
641655
total_var = (X[0] / N).sum()
642656

643657
self.explained_variance_ratio_ = self.explained_variance_ / total_var
644-
self.singular_values_ = S.copy() # Store the singular values.
658+
self.singular_values_ = xp.asarray(S) # Store the singular values.
645659

646660
if self.n_components_ < min(n_features, n_samples):
647661
self.noise_variance_ = total_var - self.explained_variance_.sum()

sklearn/decomposition/tests/test_pca.py

Lines changed: 43 additions & 1 deletion
< 2851 /tr>
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@
55
import pytest
66
import warnings
77

8-
from sklearn.utils._testing import assert_allclose
8+
from sklearn.utils._testing import (
9+
assert_allclose,
10+
skip_if_array_api_compat_not_configured,
11+
)
912

1013
from sklearn import datasets
14+
from sklearn._config import config_context
15+
from sklearn.base import clone
1116
from sklearn.decomposition import PCA
1217
from sklearn.datasets import load_iris
1318
from sklearn.decomposition._pca import _assess_dimension
@@ -17,6 +22,43 @@
1722
PCA_SOLVERS = ["full", "arpack", "randomized", "auto"]
1823

1924

25+
@skip_if_array_api_compat_not_configured
26+
@pytest.mark.parametrize("device", ["cuda", "cpu"])
27+
@pytest.mark.parametrize("dtype", ["float32", "float64"])
28+
@pytest.mark.parametrize("svd_solver", PCA_SOLVERS)
29+
@pytest.mark.parametrize("n_components", range(1, iris.data.shape[1]))
30+
def test_pca_array_torch(device, dtype, svd_solver, n_components):
31+
"""Check that running on PyTorch Tensors gives the same results as NumPy"""
32+
torch = pytest.importorskip("torch")
33+
if device == "cuda" and not torch.has_cuda:
34+
pytest.skip("test requires cuda")
35+
36+
iris_data = iris.data.astype(dtype)
37+
X_np = iris_data
38+
X_torch = torch.asarray(iris_data, device=device)
39+
40+
pca_np = PCA(n_components=n_components, svd_solver=svd_solver)
41+
pca_torch = clone(pca_np)
42+
43+
with config_context(array_api_dispatch=True):
44+
if svd_solver in ["arpack", "randomized"]:
45+
with pytest.raises(
46+
TypeError, match=PCA._pca_torch_arpack_solver_error_message
47+
):
48+
pca_torch.fit_transform(X_torch)
49+
else:
50+
X_transformed_torch = pca_torch.fit_transform(X_torch)
51+
X_transformed_np = pca_np.fit_transform(X_np)
52+
53+
assert type(X_transformed_np) == np.ndarray, "Invalid type"
54+
assert type(X_transformed_torch) == torch.Tensor, "Invalid type"
55+
assert_allclose(X_transformed_np, X_transformed_torch, atol=1e-3)
56+
57+
# TODO introduce pytorch support for below methods
58+
# cov = pca.get_covariance()
59+
# precision = pca.get_precision()
60+
61+
2062
@pytest.mark.parametrize("svd_solver", PCA_SOLVERS)
2163
@pytest.mark.parametrize("n_components", range(1, iris.data.shape[1]))
2264
def test_pca(svd_solver, n_components):

sklearn/utils/_array_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ def _is_numpy_namespace(xp):
7272
return xp.__name__ in {"numpy", "array_api_compat.numpy", "numpy.array_api"}
7373

7474

75+
def _is_torch_namespace(xp):
76+
"""Return True if xp is backed by PyTorch."""
77+
return xp.__name__ in {"torch", "array_api_compat.torch"}
78+
79+
7580
def isdtype(dtype, kind, *, xp):
7681
"""Returns a boolean indicating whether a provided dtype is of type "kind".
7782

sklearn/utils/extmath.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -800,16 +800,18 @@ def svd_flip(u, v, u_based_decision=True):
800800
v_adjusted : ndarray
801801
Array v with adjusted rows and the same dimensions as v.
802802
"""
803+
xp, _ = get_namespace(u)
804+
803805
if u_based_decision:
804806
# columns of u, rows of v
805-
max_abs_cols = np.argmax(np.abs(u), axis=0)
806-
signs = np.sign(u[max_abs_cols, range(u.shape[1])])
807+
max_abs_cols = xp.argmax(xp.abs(u), axis=0)
808+
signs = xp.sign(u[max_abs_cols, range(u.shape[1])])
807809
u *= signs
808810
v *= signs[:, np.newaxis]
809811
else:
810812
# rows of v, columns of u
811-
max_abs_rows = np.argmax(np.abs(v), axis=1)
812-
signs = np.sign(v[range(v.shape[0]), max_abs_rows])
813+
max_abs_rows = xp.argmax(xp.abs(v), axis=1)
814+
signs = xp.sign(v[range(v.shape[0]), max_abs_rows])
813815
u *= signs
814816
v *= signs[:, np.newaxis]
815817
return u, v
@@ -1139,10 +1141,12 @@ def stable_cumsum(arr, axis=None, rtol=1e-05, atol=1e-08):
11391141
out : ndarray
11401142
Array with the cumulative sums along the chosen axis.
11411143
"""
1142-
out = np.cumsum(arr, axis=axis, dtype=np.float64)
1143-
expected = np.sum(arr, axis=axis, dtype=np.float64)
1144-
if not np.all(
1145-
np.isclose(
1144+
xp, _ = get_namespace(arr)
1145+
1146+
out = xp.cumsum(arr, axis=axis, dtype=np.float64)
1147+
expected = xp.sum(arr, axis=axis, dtype=np.float64)
1148+
if not xp.all(
1149+
xp.isclose(
11461150
out.take(-1, axis=axis), expected, rtol=rtol, atol=atol, equal_nan=True
11471151
)
11481152
):

0 commit comments

Comments
 (0)
0