8000 Add randomized solver for PCA that does not require centering · scikit-learn/scikit-learn@fe3d3a8 · GitHub
[go: up one dir, main page]

Skip to content

Commit fe3d3a8

Browse files
Add randomized solver for PCA that does not require centering
1 parent 633e3ca commit fe3d3a8

File tree

3 files changed

+213
-23
lines changed

3 files changed

+213
-23
lines changed

sklearn/decomposition/pca.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
from .base import _BasePCA
2525
from ..utils import check_random_state
2626
from ..utils import check_array
27-
from ..utils.extmath import fast_logdet, randomized_svd, svd_flip
27+
from ..utils.extmath import fast_logdet, randomized_pca, svd_flip
2828
from ..utils.extmath import stable_cumsum
2929
from ..utils.validation import check_is_fitted
30+
from ..utils.sparsefuncs import mean_variance_axis
3031

3132

3233
def _assess_dimension_(spectrum, rank, n_samples, n_features):
@@ -370,14 +371,8 @@ def fit_transform(self, X, y=None):
370371

371372
def _fit(self, X):
372373
"""Dispatch to the right submethod depending on the chosen solver."""
373-
374-
# Raise an error for sparse input.
375-
# This is more informative than the generic one raised by check_array.
376-
if issparse(X):
377-
raise TypeError('PCA does not support sparse input. See '
378-
'TruncatedSVD for a possible alternative.')
379-
380-
X = check_array(X, dtype=[np.float64, np.float32], ensure_2d=True,
374+
X = check_array(X, accept_sparse=['csr', 'csc'],
375+
dtype=[np.float64, np.float32], ensure_2d=True,
381376
copy=self.copy)
382377

383378
# Handle n_components==None
@@ -392,15 +387,24 @@ def _fit(self, X):
392387
# Handle svd_solver
393388
self._fit_svd_solver = self.svd_solver
394389
if self._fit_svd_solver == 'auto':
390+
# Sparse data can only be handled with the randomized solver
391+
if issparse(X):
392+
self._fit_svd_solver = 'randomized'
395393
# Small problem or n_components == 'mle', just call full PCA
396-
if max(X.shape) <= 500 or n_components == 'mle':
394+
elif max(X.shape) <= 500 or n_components == 'mle':
397395
self._fit_svd_solver = 'full'
398396
elif n_components >= 1 and n_components < .8 * min(X.shape):
399397
self._fit_svd_solver = 'randomized'
400398
# This is also the case of n_components in (0,1)
401399
else:
402400
self._fit_svd_solver = 'full'
403401

402+
# Ensure we don't try call arpack or full on a sparse matrix
403+
if issparse(X) and self._fit_svd_solver != 'randomized':
404+
raise ValueError(
405+
'only the randomized solver supports sparse matrices'
406+
)
407+
404408
# Call different fits for either full or truncated SVD
405409
if self._fit_svd_solver == 'full':
406410
return self._fit_full(X, n_components)
@@ -503,11 +507,15 @@ def _fit_truncated(self, X, n_components, svd_solver):
503507

504508
random_state = check_random_state(self.random_state)
505509

506-
# Center data
507-
self.mean_ = np.mean(X, axis=0)
508-
X -= self.mean_
510+
if issparse(X):
511+
self.mean_, total_var = mean_variance_axis(X, axis=0, ddof=1)
512+
else:
513+
self.mean_ = np.mean(X, axis=0)
514+
total_var = np.var(X, axis=0, ddof=1)
509515

510516
if svd_solver == 'arpack':
517+
# Center data
518+
X -= self.mean_
511519
# random init solution, as ARPACK does it internally
512520
v0 = random_state.uniform(-1, 1, size=min(X.shape))
513521
U, S, V = svds(X, k=n_components, tol=self.tol, v0=v0)
@@ -519,7 +527,7 @@ def _fit_truncated(self, X, n_components, svd_solver):
519527

520528
elif svd_solver == 'randomized':
521529
# sign flipping is done inside
522-
U, S, V = randomized_svd(X, n_components=n_components,
530+
U, S, V = randomized_pca(X, n_components=n_components,
523531
n_iter=self.iterated_power,
524532
flip_sign=True,
525533
random_state=random_state)
@@ -530,7 +538,6 @@ def _fit_truncated(self, X, n_components, svd_solver):
530538

531539
# Get variance explained by singular values
532540
self.explained_variance_ = (S ** 2) / (n_samples - 1)
533-
total_var = np.var(X, ddof=1, axis=0)
534541
self.explained_variance_ratio_ = \
535542
self.explained_variance_ / total_var.sum()
536543
self.singular_values_ = S.copy() # Store the singular values.

sklearn/decomposition/tests/test_pca.py

Lines changed: 44 additions & 8 deletions
-
@pytest.mark.parametrize('svd_solver', solver_list)
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sklearn.utils.testing import assert_no_warnings
1515
from sklearn.utils.testing import ignore_warnings
1616
from sklearn.utils.testing import assert_less
17+
from sklearn.utils.testing import assert_allclose
1718

1819
from sklearn import datasets
1920
from sklearn.decomposition import PCA
@@ -260,11 +261,11 @@ def test_singular_values():
260261
random_state=rng).fit(X)
261262
apca = PCA(n_components=2, svd_solver='arpack',
262263
random_state=rng).fit(X)
263-
rpca = PCA(n_components=2, svd_solver='randomized',
264+
rpca = PCA(n_components=2, svd_solver='randomized', iterated_power=40,
264265
random_state=rng).fit(X)
265266
assert_array_almost_equal(pca.singular_values_, apca.singular_values_, 12)
266-
assert_array_almost_equal(pca.singular_values_, rpca.singular_values_, 1)
267-
assert_array_almost_equal(apca.singular_values_, rpca.singular_values_, 1)
267+
assert_array_almost_equal(pca.singular_values_, rpca.singular_values_, 12)
268+
assert_array_almost_equal(apca.singular_values_, rpca.singular_values_, 12)
268269

269270
# Compare to the Frobenius norm
270271
X_pca = pca.transform(X)
@@ -283,7 +284,7 @@ def test_singular_values():
283284
assert_array_almost_equal(apca.singular_values_,
284285
np.sqrt(np.sum(X_apca**2.0, axis=0)), 12)
285286
assert_array_almost_equal(rpca.singular_values_,
286-
np.sqrt(np.sum(X_rpca**2.0, axis=0)), 2)
287+
np.sqrt(np.sum(X_rpca**2.0, axis=0)), 12)
287288

288289
# Set the singular values and see what we get back
289290
rng = np.random.RandomState(0)
@@ -305,6 +306,7 @@ def test_singular_values():
305306
pca.fit(X_hat)
306307
apca.fit(X_hat)
307308
rpca.fit(X_hat)
309+
308310
assert_array_almost_equal(pca.singular_values_, [3.142, 2.718, 1.0], 14)
309311
assert_array_almost_equal(apca.singular_values_, [3.142, 2.718, 1.0], 14)
310312
assert_array_almost_equal(rpca.singular_values_, [3.142, 2.718, 1.0], 14)
@@ -683,15 +685,49 @@ def test_svd_solver_auto():
683685
assert_array_almost_equal(pca.components_, pca_test.components_)
684686

685687

686
687-
def test_pca_sparse_input(svd_solver):
688+
def test_pca_sparse_input_randomized_solver():
689+
rng = np.random.RandomState(0)
690+
n_samples = 100
691+
n_features = 80
692+
693+
X = rng.binomial(1, 0.1, (n_samples, n_features))
694+
X_sp = sp.sparse.csr_matrix(X)
695+
696+
# Compute the complete decomposition on the dense matrix
697+
pca = PCA(n_components=3, svd_solver='randomized',
698+
random_state=0).fit(X)
699+
# And compute a randomized decomposition on the sparse matrix. Increase the
700+
# number of power iterations to account for the non-zero means
701+
pca_sp = PCA(n_components=3, svd_solver='randomized',
702+
random_state=0).fit(X_sp)
703+
704+
# Ensure the singular values are close to the exact singular values
705+
assert_allclose(pca_sp.singular_values_, pca.singular_values_)
706+
707+
# Ensure that the basis is close to the true basis
708+
X_pca = pca.transform(X)
709+
X_sppca = pca_sp.transform(X)
710+
assert_allclose(X_sppca, X_pca)
711+
712+
713+
@pytest.mark.parametrize('svd_solver', ['full', 'arpack'])
714+
def test_pca_sparse_input_bad_solvers(svd_solver):
688715
X = np.random.RandomState(0).rand(5, 4)
689716
X = sp.sparse.csr_matrix(X)
690-
assert(sp.sparse.issparse(X))
691717

692718
pca = PCA(n_components=3, svd_solver=svd_solver)
693719

694-
assert_raises(TypeError, pca.fit, X)
720+
assert_raises(ValueError, pca.fit, X)
721+
722+
723+
def test_pca_auto_solver_selects_randomized_solver_for_sparse_matrices():
724+
X = np.random.RandomState(0).rand(5, 4)
725+
X = sp.sparse.csr_matrix(X)
726+
727+
pca = PCA(n_components=3, svd_solver='auto')
728+
pca.fit(X)
729+
730+
assert pca._fit_svd_solver == 'randomized'
695731

696732

697733
def test_pca_bad_solver():

sklearn/utils/extmath.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,153 @@ def randomized_svd(M, n_components, n_oversamples=10, n_iter='auto',
360360
return U[:, :n_components], s[:n_components], V[:n_components, :]
361361

362362

363+
def _normalize_power_iteration(x, power_iteration_normalizer):
364+
"""Normalize the matrix when doing power iterations for stability."""
365+
if power_iteration_normalizer == "none":
366+
return x
367+
elif power_iteration_normalizer == "LU":
368+
Q, _ = linalg.lu(x, permute_l=True)
369+
return Q
370+
elif power_iteration_normalizer == "QR":
371+
Q, _ = linalg.qr(x, mode="economic")
372+
return Q
373+
else:
374+
raise ValueError("Unrecognized normalization method `%s`" %
375+
power_iteration_normalizer)
376+
377+
378+
def randomized_pca(A, n_components, n_oversamples=10, n_iter="auto",
379+
power_iteration_normalizer="auto", flip_sign=True,
380+
random_state=0):
381+
"""Computes a truncated randomized PCA decomposition.
382+
383+
Parameters
384+
----------
385+
A : ndarray or sparse matrix
386+
Matrix to decompose
387+
388+
n_components : int
389+
Number of singular values and vectors to extract.
390+
391+
n_oversamples : int (default is 10)
392+
Additional number of random vectors to sample the range of M so as
393+
to ensure proper conditioning. The total number of random vectors
394+
used to find the range of M is n_components + n_oversamples. Smaller
395+
number can improve speed but can negatively impact the quality of
396+
approximation of singular vectors and singular values.
397+
398+
n_iter : int or 'auto' (default is 'auto')
399+
Number of power iterations. It can be used to deal with very noisy
400+
problems. When 'auto', it is set to 4, unless `n_components` is small
401+
(< .1 * min(X.shape)) `n_iter` in which case is set to 7.
402+
This improves precision with few components.
403+
404+
.. versionchanged:: 0.18
405+
406+
power_iteration_normalizer : 'auto' (default), 'QR', 'LU', 'none'
407+
Whether the power iterations are normalized with step-by-step
408+
QR factorization (the slowest but most accurate), 'none'
409+
(the fastest but numerically unstable when `n_iter` is large, e.g.
410+
typically 5 or larger), or 'LU' factorization (numerically stable
411+
but can lose slightly in accuracy). The 'auto' mode applies no
412+
normalization if `n_iter` <= 2 and switches to LU otherwise.
413+
414+
.. versionadded:: 0.18
415+
416+
flip_sign : boolean, (True by default)
417+
The output of a singular value decomposition is only unique up to a
418+
permutation of the signs of the singular vectors. If `flip_sign` is
419+
set to `True`, the sign ambiguity is resolved by making the largest
420+
loadings for each component in the left singular vectors positive.
421+
422+
random_state : int, RandomState instance or None, optional (default=None)
423+
The seed of the pseudo random number generator to use when shuffling
424+
the data. If int, random_state is the seed used by the random number
425+
generator; If RandomState instance, random_state is the random number
426+
generator; If None, the random number generator is the RandomState
427+
instance used by `np.random`.
428+
429+
Notes
430+
-----
431+
This algorithm finds a (usually very good) approximate truncated principal
432+
component analysis decomposition using randomized methods to speed up the
433+
computations. It is particulary useful on large, sparse matrices since this
434+
implementation doesn't require centering the original matrix (which would
435+
center and therefore densify potentially large sparse matrices, leading to
436+
memory issues). In order to obtain further speed up, `n_iter` can be set
437+
<=2 (at the cost of loss of precision).
438+
439+
References
440+
----------
441+
* Algorithm 971: An implementation of a randomized algorithm for principal
442+
component analysis
443+
Li, Huamin, et al. 2017
444+
445+
"""
446+
if n_iter == "auto":
447+
# Checks if the number of iterations is explicitly specified
448+
# Adjust n_iter. 7 was found a good compromise for PCA. See sklearn #5299
449+
n_iter = 7 if n_components < .1 * min(A.shape) else 4
450+
451+
# Deal with "auto" mode
452+
if power_iteration_normalizer == "auto":
453+
if n_iter <= 2:
454+
power_iteration_normalizer = "none"
455+
else:
456+
power_iteration_normalizer = "LU"
457+
458+
n_samples, n_features = A.shape
459+
460+
c = np.atleast_2d(A.mean(axis=0))
461+
462+
if n_samples >= n_features:
463+
Q = random_state.normal(size=(n_features, n_components + n_oversamples))
464+
if A.dtype.kind == "f":
465+
# Ensure f32 is preserved as f32
466+
Q = Q.astype(A.dtype, copy=False)
467+
468+
Q = safe_sparse_dot(A, Q) - safe_sparse_dot(c, Q)
469+
470+
# Normalized power iterations
471+
for _ in range(n_iter):
472+
Q = safe_sparse_dot(A.T, Q) - safe_sparse_dot(c.T, Q.sum(axis=0)[None, :])
473+
Q = _normalize_power_iteration(Q, power_iteration_normalizer)
474+
Q = safe_sparse_dot(A, Q) - safe_sparse_dot(c, Q)
475+
Q = _normalize_power_iteration(Q, power_iteration_normalizer)
476+
477+
Q, _ = linalg.qr(Q, mode="economic")
478+
479+
QA = safe_sparse_dot(A.T, Q) - safe_sparse_dot(c.T, Q.sum(axis=0)[None, :])
480+
R, s, V = linalg.svd(QA.T, full_matrices=False)
481+
U = Q.dot(R)
482+
483+
else: # n_features > n_samples
484+
Q = random_state.normal(size=(n_samples, n_components + n_oversamples))
485+
if A.dtype.kind == "f":
486+
# Ensure f32 is preserved as f32
487+
Q = Q.astype(A.dtype, copy=False)
488+
489+
Q = safe_sparse_dot(A.T, Q) - safe_sparse_dot(c.T, Q.sum(axis=0)[None, :])
490+
491+
# Normalized power iterations
492+
for _ in range(n_iter):
493+
Q = safe_sparse_dot(A, Q) - safe_sparse_dot(c, Q)
494+
Q = _normalize_power_iteration(Q, power_iteration_normalizer)
495+
Q = safe_sparse_dot(A.T, Q) - safe_sparse_dot(c.T, Q.sum(axis=0)[None, :])
496+
Q = _normalize_power_iteration(Q, power_iteration_normalizer)
497+
498+
Q, _ = linalg.qr(Q, mode="economic")
499+
500+
QA = safe_sparse_dot(A, Q) - safe_sparse_dot(c, Q)
501+
U, s, R = linalg.svd(QA, full_matrices=False)
502+
V = R.dot(Q.T)
503+
504+
if flip_sign:
505+
U, V = svd_flip(U, V)
506+
507+
return U[:, :n_components], s[:n_components], V[:n_components, :]
508+
509+
363510
def weighted_mode(a, w, axis=0):
364511
"""Returns an array of the weighted modal (most common) value in a
365512

0 commit comments

Comments
 (0)
0