8000 Add randomized solver for PCA that does not require centering · scikit-learn/scikit-learn@b5964c1 · GitHub
[go: up one dir, main page]

Skip to content

Commit b5964c1

Browse files
Add randomized solver for PCA that does not require centering
1 parent 7e4e167 commit b5964c1

File tree

3 files changed

+121
-32
lines changed

3 files changed

+121
-32
lines changed

sklearn/decomposition/pca.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ..utils.extmath import fast_logdet, randomized_svd, svd_flip
2828
from ..utils.extmath import stable_cumsum
2929
from ..utils.validation import check_is_fitted
30+
from ..utils.sparsefuncs import mean_variance_axis
3031

3132

3233
def _assess_dimension_(spectrum, rank, n_samples, n_features):
@@ -370,14 +371,8 @@ def fit_transform(self, X, y=None):
370371

371372
def _fit(self, X):
372373
"""Dispatch to the right submethod depending on the chosen solver."""
373-
374-
# Raise an error for sparse input.
375-
# This is more informative than the generic one raised by check_array.
376-
if issparse(X):
377-
raise TypeError('PCA does not support sparse input. See '
378-
'TruncatedSVD for a possible alternative.')
379-
380-
X = check_array(X, dtype=[np.float64, np.float32], ensure_2d=True,
374+
X = check_array(X, accept_sparse=['csr', 'csc'],
375+
dtype=[np.float64, np.float32], ensure_2d=True,
381376
copy=self.copy)
382377

383378
# Handle n_components==None
@@ -392,15 +387,24 @@ def _fit(self, X):
392387
# Handle svd_solver
393388
self._fit_svd_solver = self.svd_solver
394389
if self._fit_svd_solver == 'auto':
390+
# Sparse data can only be handled with the randomized solver
391+
if issparse(X):
392+
self._fit_svd_solver = 'randomized'
395393
# Small problem or n_components == 'mle', just call full PCA
396-
if max(X.shape) <= 500 or n_components == 'mle':
394+
elif max(X.shape) <= 500 or n_components == 'mle':
397395
self._fit_svd_solver = 'full'
398396
elif n_components >= 1 and n_components < .8 * min(X.shape):
399397
self._fit_svd_solver = 'randomized'
400398
# This is also the case of n_components in (0,1)
401399
else:
402400
self._fit_svd_solver = 'full'
403401

402+
# Ensure we don't try call arpack or full on a sparse matrix
403+
if issparse(X) and self._fit_svd_solver != 'randomized':
404+
raise ValueError(
405+
'only the randomized solver supports sparse matrices'
406+
)
407+
404408
# Call different fits for either full or truncated SVD
405409
if self._fit_svd_solver == 'full':
406410
return self._fit_full(X, n_components)
@@ -503,11 +507,15 @@ def _fit_truncated(self, X, n_components, svd_solver):
503507

504508
random_state = check_random_state(self.random_state)
505509

506-
# Center data
507-
self.mean_ = np.mean(X, axis=0)
508-
X -= self.mean_
510+
if issparse(X):
511+
self.mean_, total_var = mean_variance_axis(X, axis=0, ddof=1)
512+
else:
513+
self.mean_ = np.mean(X, axis=0)
514+
total_var = np.var(X, axis=0, ddof=1)
509515

510516
if svd_solver == 'arpack':
517+
# Center data
518+
X -= self.mean_
511519
# random init solution, as ARPACK does it internally
512520
v0 = random_state.uniform(-1, 1, size=min(X.shape))
513521
U, S, V = svds(X, k=n_components, tol=self.tol, v0=v0)
@@ -519,18 +527,21 @@ def _fit_truncated(self, X, n_components, svd_solver):
519527

520528
elif svd_solver == 'randomized':
521529
# sign flipping is done inside
522-
U, S, V = randomized_svd(X, n_components=n_components,
523-
n_iter=self.iterated_power,
524-
flip_sign=True,
525-
random_state=random_state)
530+
U, S, V = randomized_svd(
531+
X,
532+
n_components=n_components,
533+
n_iter=self.iterated_power,
534+
flip_sign=True,
535+
subtract_mean=True,
536+
random_state=random_state,
537+
)
526538

527539
self.n_samples_, self.n_features_ = n_samples, n_features
528540
self.components_ = V
529541
self.n_components_ = n_components
530542

531543
# Get variance explained by singular values
532544
self.explained_variance_ = (S ** 2) / (n_samples - 1)
533-
total_var = np.var(X, ddof=1, axis=0)
534545
self.explained_variance_ratio_ = \
535546
self.explained_variance_ / total_var.sum()
536547
self.singular_values_ = S.copy() # Store the singular values.

sklearn/decomposition/tests/test_pca.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def test_singular_values():
307307
rpca.fit(X_hat)
308308
assert_array_almost_equal(pca.singular_values_, [3.142, 2.718, 1.0], 14)
309309
assert_array_almost_equal(apca.singular_values_, [3.142, 2.718, 1.0], 14)
310-
assert_array_almost_equal(rpca.singular_values_, [3.142, 2.718, 1.0], 14)
310+
assert_array_almost_equal(rpca.singular_values_, [3.142, 2.718, 1.0], 2)
311311

312312

313313
def test_pca_check_projection():
@@ -683,15 +683,58 @@ def test_svd_solver_auto():
683683
assert_array_almost_equal(pca.components_, pca_test.components_)
684684

685685

686-
@pytest.mark.parametrize('svd_solver', solver_list)
687-
def test_pca_sparse_input(svd_solver):
686+
def test_pca_sparse_input_randomized_solver():
687+
rng = np.random.RandomState(0)
688+
n_samples = 100
689+
n_features = 80
690+
691+
# The randomized method produces larger errors whenever the means of the
692+
# matrix are way off the origin
693+
X = rng.normal(1000, 20, (n_samples, n_features))
694+
695+
X_sp = sp.sparse.csr_matrix(X)
696+
assert sp.sparse.issparse(X_sp)
697+
698+
# Compute the complete decomposition on the dense matrix
699+
pca = PCA(n_components=3, svd_solver='full', random_state=rng).fit(X)
700+
# And compute a randomized decomposition on the sparse matrix. Increase the
701+
# number of power iterations to account for the non-zero means
702+
pca_sp = PCA(
703+
n_components=3,
704+
svd_solver='randomized',
705+
random_state=rng,
706+
iterated_power=20,
707+
).fit(X_sp)
708+
709+
# Ensure the singular values are close to the exact singular values
710+
assert_array_almost_equal(pca_sp.singular_values_, pca.singular_values_, 5)
711+
712+
# Ensure that the basis is close to the true basis
713+
X_pca = pca.transform(X)
714+
X_sppca = pca_sp.transform(X)
715+
assert_array_almost_equal(X_sppca, X_pca, 2)
716+
717+
718+
@pytest.mark.parametrize('svd_solver', ['full', 'arpack'])
719+
def test_pca_sparse_input_bad_solvers(svd_solver):
688720
X = np.random.RandomState(0).rand(5, 4)
689721
X = sp.sparse.csr_matrix(X)
690-
assert(sp.sparse.issparse(X))
722+
assert sp.sparse.issparse(X)
691723

692724
pca = PCA(n_components=3, svd_solver=svd_solver)
693725

694-
assert_raises(TypeError, pca.fit, X)
726+
assert_raises(ValueError, pca.fit, X)
727+
728+
729+
def test_pca_auto_solver_selects_randomized_solver_for_sparse_matrices():
730+
X = np.random.RandomState(0).rand(5, 4)
731+
X = sp.sparse.csr_matrix(X)
732+
assert sp.sparse.issparse(X)
733+
734+
pca = PCA(n_components=3, svd_solver='auto')
735+
pca.fit(X)
736+
737+
assert_equal(pca._fit_svd_solver, 'randomized')
695738

696739

697740
def test_pca_bad_solver():

sklearn/utils/extmath.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def safe_sparse_dot(a, b, dense_output=False):
147147

148148
def randomized_range_finder(A, size, n_iter,
149149
power_iteration_normalizer='auto',
150+
subtract_mean=False,
150151
random_state=None):
151152
"""Computes an orthonormal matrix whose range approximates the range of A.
152153
@@ -171,6 +172,13 @@ def randomized_range_finder(A, size, n_iter,
171172
172173
.. versionadded:: 0.18
173174
175+
subtract_mean : bool
176+
Whether the mean of `A` should be subtracted after each multiplication
177+
by the `A` matrix. This is equivalent to multiplying matrices by a
178+
centered `A` without ever having to explicitly center. This is
179+
especially useful for performing PCA on large sparse matrices, so they
180+
do not need to be centered.
181+
174182
random_state : int, RandomState instance or None, optional (default=None)
175183
The seed of the pseudo random number generator to use when shuffling
176184
the data. If int, random_state is the seed used by the random number
@@ -211,28 +219,39 @@ def randomized_range_finder(A, size, n_iter,
211219
else:
212220
power_iteration_normalizer = 'LU'
213221

222+
if subtract_mean:
223+
c = np.mean(A, axis=0).reshape((1, -1))
224+
applyA = lambda X: safe_sparse_dot(A, X) - safe_sparse_dot(c, X)
225+
applyAT = lambda X: safe_sparse_dot(A.T, X) - \
226+
safe_sparse_dot(c.T, Q.sum(axis=0).reshape((1, -1)))
227+
else:
228+
applyA = lambda X: safe_sparse_dot(A, X)
229+
applyAT = lambda X: safe_sparse_dot(A.T, X)
230+
231+
Q = applyA(Q)
232+
214233
# Perform power iterations with Q to further 'imprint' the top
215234
# singular vectors of A in Q
216235
for i in range(n_iter):
217236
if power_iteration_normalizer == 'none':
218-
Q = safe_sparse_dot(A, Q)
219-
Q = safe_sparse_dot(A.T, Q)
237+
Q = applyAT(Q)
238+
Q = applyA(Q)
220239
elif power_iteration_normalizer == 'LU':
221-
Q, _ = linalg.lu(safe_sparse_dot(A, Q), permute_l=True)
222-
Q, _ = linalg.lu(safe_sparse_dot(A.T, Q), permute_l=True)
240+
Q, _ = linalg.lu(applyAT(Q), permute_l=True)
241+
Q, _ = linalg.lu(applyA(Q), permute_l=True)
223242
elif power_iteration_normalizer == 'QR':
224-
Q, _ = linalg.qr(safe_sparse_dot(A, Q), mode='economic')
225-
Q, _ = linalg.qr(safe_sparse_dot(A.T, Q), mode='economic')
243+
Q, _ = linalg.qr(applyAT(Q), mode='economic')
244+
Q, _ = linalg.qr(applyA(Q), mode='economic')
226245

227246
# Sample the range of A using by linear projection of Q
228247
# Extract an orthonormal basis
229-
Q, _ = linalg.qr(safe_sparse_dot(A, Q), mode='economic')
248+
Q, _ = linalg.qr(Q, mode='economic')
230249
return Q
231250

232251

233252
def randomized_svd(M, n_components, n_oversamples=10, n_iter='auto',
234253
power_iteration_normalizer='auto', transpose='auto',
235-
flip_sign=True, random_state=0):
254+
flip_sign=True, subtract_mean=False, random_state=0):
236255
"""Computes a truncated randomized SVD
237256
238257
Parameters
@@ -283,6 +302,13 @@ def randomized_svd(M, n_components, n_oversamples=10, n_iter='auto',
283302
set to `True`, the sign ambiguity is resolved by making the largest
284303
loadings for each component in the left singular vectors positive.
285304
305+
subtract_mean : bool
306+
Whether the mean of `A` should be subtracted after each multiplication
307+
by the `A` matrix. This is equivalent to multiplying matrices by a
308+
centered `A` without ever having to explicitly center. This is
309+
especially useful for performing PCA on large sparse matrices, so they
310+
do not need to be centered.
311+
286312
random_state : int, RandomState instance or None, optional (default=None)
287313
The seed of the pseudo random number generator to use when shuffling
288314
the data. If int, random_state is the seed used by the random number
@@ -333,11 +359,20 @@ def randomized_svd(M, n_components, n_oversamples=10, n_iter='auto',
333359
# this implementation is a bit faster with smaller shape[1]
334360
M = M.T
335361

336-
Q = randomized_range_finder(M, n_random, n_iter,
337-
power_iteration_normalizer, random_state)
362+
Q = randomized_range_finder(
363+
M,
364+
size=n_random,
365+
n_iter=n_iter,
366+
power_iteration_normalizer=power_iteration_normalizer,
367+
subtract_mean=subtract_mean,
368+
random_state=random_state,
369+
)
338370

339371
# project M to the (k + p) dimensional space using the basis vectors
340372
B = safe_sparse_dot(Q.T, M)
373+
if subtract_mean:
374+
c = M.mean(axis=0).reshape((1, -1))
375+
B -= np.dot(c.T, Q.sum(axis=0).reshape((1, -1))).T
341376

342377
# compute the SVD on the thin matrix: (k + p) wide
343378
Uhat, s, V = linalg.svd(B, full_matrices=False)

0 commit comments

Comments
 (0)
0