8000 Formatting fixes for randomized svd/pca solver · scikit-learn/scikit-learn@201ff7d · GitHub
[go: up one dir, main page]

Skip to content

Commit 201ff7d

Browse files
Formatting fixes for randomized svd/pca solver
1 parent b64c3ef commit 201ff7d

File tree

3 files changed

+53
-64
lines changed

3 files changed

+53
-64
lines changed

sklearn/decomposition/pca.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -527,14 +527,11 @@ def _fit_truncated(self, X, n_components, svd_solver):
527527

528528
elif svd_solver == 'randomized':
529529
# sign flipping is done inside
530-
U, S, V = randomized_svd(
531-
X,
532-
n_components=n_components,
533-
n_iter=self.iterated_power,
534-
flip_sign=True,
535-
subtract_mean=True,
536-
random_state=random_state,
537-
)
530+
U, S, V = randomized_svd(X, n_components=n_components,
531+
n_iter=self.iterated_power,
532+
flip_sign=True,
533+
random_state=random_state,
534+
subtract_mean=True)
538535

539536
self.n_samples_, self.n_features_ = n_samples, n_features
540537
self.components_ = V

sklearn/decomposition/tests/test_pca.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def test_singular_values():
283283
assert_array_almost_equal(apca.singular_values_,
284284
np.sqrt(np.sum(X_apca**2.0, axis=0)), 12)
285285
assert_array_almost_equal(rpca.singular_values_,
286-
np.sqrt(np.sum(X_rpca**2.0, axis=0)), 2)
286+
np.sqrt(np.sum(X_rpca**2.0, axis=0)), 12)
287287

288288
# Set the singular values and see what we get back
289289
rng = np.random.RandomState(0)
@@ -688,38 +688,30 @@ def test_pca_sparse_input_randomized_solver():
688688
n_samples = 100
689689
n_features = 80
690690

691-
# The randomized method produces larger errors whenever the means of the
692-
# matrix are way off the origin
693-
X = rng.normal(1000, 20, (n_samples, n_features))
694-
691+
X = rng.binomial(1, 0.1, (n_samples, n_features))
695692
X_sp = sp.sparse.csr_matrix(X)
696-
assert sp.sparse.issparse(X_sp)
697693

698694
# Compute the complete decomposition on the dense matrix
699-
pca = PCA(n_components=3, svd_solver='full', random_state=rng).fit(X)
695+
pca = PCA(n_components=3, svd_solver='randomized',
696+
random_state=rng, iterated_power=30).fit(X)
700697
# And compute a randomized decomposition on the sparse matrix. Increase the
701698
# number of power iterations to account for the non-zero means
702-
pca_sp = PCA(
703-
n_components=3,
704-
svd_solver='randomized',
705-
random_state=rng,
706-
iterated_power=20,
707-
).fit(X_sp)
699+
pca_sp = PCA(n_components=3, svd_solver='randomized',
700+
random_state=rng, iterated_power=30).fit(X_sp)
708701

709702
# Ensure the singular values are close to the exact singular values
710-
assert_array_almost_equal(pca_sp.singular_values_, pca.singular_values_, 5)
703+
np.testing.assert_allclose(pca_sp.singular_values_, pca.singular_values_)
711704

712705
# Ensure that the basis is close to the true basis
713706
X_pca = pca.transform(X)
714707
X_sppca = pca_sp.transform(X)
715-
assert_array_almost_equal(X_sppca, X_pca, 2)
708+
np.testing.assert_allclose(X_sppca, X_pca, 1e-3)
716709

717710

718711
@pytest.mark.parametrize('svd_solver', ['full', 'arpack'])
719712
def test_pca_sparse_input_bad_solvers(svd_solver):
720713
X = np.random.RandomState(0).rand(5, 4)
721714
X = sp.sparse.csr_matrix(X)
722-
assert sp.sparse.issparse(X)
723715

724716
pca = PCA(n_components=3, svd_solver=svd_solver)
725717

@@ -729,12 +721,11 @@ def test_pca_sparse_input_bad_solvers(svd_solver):
729721
def test_pca_auto_solver_selects_randomized_solver_for_sparse_matrices():
730722
X = np.random.RandomState(0).rand(5, 4)
731723
X = sp.sparse.csr_matrix(X)
732-
assert sp.sparse.issparse(X)
733724

734725
pca = PCA(n_components=3, svd_solver='auto')
735726
pca.fit(X)
736727

737-
assert_equal(pca._fit_svd_solver, 'randomized')
728+
assert pca._fit_svd_solver == 'randomized'
738729

739730

740731
def test_pca_bad_solver():

sklearn/utils/extmath.py

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ def safe_sparse_dot(a, b, dense_output=False):
147147

148148
def randomized_range_finder(A, size, n_iter,
149149
power_iteration_normalizer='auto',
150-
subtract_mean=False,
151-
random_state=None):
150+
random_state=None,
151+
subtract_mean=False):
152152
"""Computes an orthonormal matrix whose range approximates the range of A.
153153
154154
Parameters
@@ -172,20 +172,20 @@ def randomized_range_finder(A, size, n_iter,
172172
173173
.. versionadded:: 0.18
174174
175-
subtract_mean : bool
176-
Whether the mean of `A` should be subtracted after each multiplication
177-
by the `A` matrix. This is equivalent to multiplying matrices by a
178-
centered `A` without ever having to explicitly center. This is
179-
especially useful for performing PCA on large sparse matrices, so they
180-
do not need to be centered.
181-
182175
random_state : int, RandomState instance or None, optional (default=None)
183176
The seed of the pseudo random number generator to use when shuffling
184177
the data. If int, random_state is the seed used by the random number
185178
generator; If RandomState instance, random_state is the random number
186179
generator; If None, the random number generator is the RandomState
187180
instance used by `np.random`.
188181
182+
subtract_mean : bool
183+
Whether the mean of `A` should be subtracted after each multiplication
184+
by the `A` matrix. This is equivalent to multiplying matrices by a
185+
centered `A` without ever having to explicitly center. This is
186+
especially useful for performing PCA on large sparse matrices, so they
187+
do not need to be centered.
188+
189189
Returns
190190
-------
191191
Q : 2D array
@@ -219,39 +219,45 @@ def randomized_range_finder(A, size, n_iter,
219219
else:
220220
power_iteration_normalizer = 'LU'
221221

222+
# Prepare funcitons that will multiply `Q` with `A`
222223
if subtract_mean:
223224
c = A.mean(axis=0).reshape((1, -1))
224-
applyA = lambda X: safe_sparse_dot(A, X) - safe_sparse_dot(c, X)
225-
applyAT = lambda X: safe_sparse_dot(A.T, X) - \
226-
safe_sparse_dot(c.T, Q.sum(axis=0).reshape((1, -1)))
225 6377 +
226+
def _apply_A(X):
227+
return safe_sparse_dot(A, X) - safe_sparse_dot(c, X)
228+
229+
def _apply_AT(X):
230+
return safe_sparse_dot(A.T, X) - \
231+
safe_sparse_dot(c.T, Q.sum(axis=0).reshape((1, -1)))
227232
else:
228-
applyA = lambda X: safe_sparse_dot(A, X)
229-
applyAT = lambda X: safe_sparse_dot(A.T, X)
233+
def _apply_A(X):
234+
return safe_sparse_dot(A, X)
230235

231-
Q = applyA(Q)
236+
def _apply_AT(X):
237+
return safe_sparse_dot(A.T, X)
232238

233239
# Perform power iterations with Q to further 'imprint' the top
234240
# singular vectors of A in Q
235241
for i in range(n_iter):
236242
if power_iteration_normalizer == 'none':
237-
Q = applyAT(Q)
238-
Q = applyA(Q)
243+
Q = _apply_A(Q)
244+
Q = _apply_AT(Q)
239245
elif power_iteration_normalizer == 'LU':
240-
Q, _ = linalg.lu(applyAT(Q), permute_l=True)
241-
Q, _ = linalg.lu(applyA(Q), permute_l=True)
246+
Q, _ = linalg.lu(_apply_A(Q), permute_l=True)
247+
Q, _ = linalg.lu(_apply_AT(Q), permute_l=True)
242248
elif power_iteration_normalizer == 'QR':
243-
Q, _ = linalg.qr(applyAT(Q), mode='economic')
244-
Q, _ = linalg.qr(applyA(Q), mode='economic')
249+
Q, _ = linalg.qr(_apply_A(Q), mode='economic')
250+
Q, _ = linalg.qr(_apply_AT(Q), mode='economic')
245251

246252
# Sample the range of A using by linear projection of Q
247253
# Extract an orthonormal basis
248-
Q, _ = linalg.qr(Q, mode='economic')
254+
Q, _ = linalg.qr(_apply_A(Q), mode='economic')
249255
return Q
250256

251257

252258
def randomized_svd(M, n_components, n_oversamples=10, n_iter='auto',
253259
power_iteration_normalizer='auto', transpose='auto',
254-
flip_sign=True, subtract_mean=False, random_state=0):
260+
flip_sign=True, random_state=0, subtract_mean=False):
255261
"""Computes a truncated randomized SVD
256262
257263
Parameters
@@ -302,20 +308,20 @@ def randomized_svd(M, n_components, n_oversamples=10, n_iter='auto',
302308
set to `True`, the sign ambiguity is resolved by making the largest
303309
loadings for each component in the left singular vectors positive.
304310
305-
subtract_mean : bool
306-
Whether the mean of `A` should be subtracted after each multiplication
307-
by the `A` matrix. This is equivalent to multiplying matrices by a
308-
centered `A` without ever having to explicitly center. This is
309-
especially useful for performing PCA on large sparse matrices, so they
310-
do not need to be centered.
311-
312311
random_state : int, RandomState instance or None, optional (default=None)
313312
The seed of the pseudo random number generator to use when shuffling
314313
the data. If int, random_state is the seed used by the random number
315314
generator; If RandomState instance, random_state is the random number
316315
generator; If None, the random number generator is the RandomState
317316
instance used by `np.random`.
318317
318+
subtract_mean : bool
319+
Whether the mean of `A` should be subtracted after each multiplication
320+
by the `A` matrix. This is equivalent to multiplying matrices by a
321+
centered `A` without ever having to explicitly center. This is
322+
especially useful for performing PCA on large sparse matrices, so they
323+
do not need to be centered.
324+
319325
Notes
320326
-----
321327
This algorithm finds a (usually very good) approximate truncated
@@ -359,14 +365,9 @@ def randomized_svd(M, n_components, n_oversamples=10, n_iter='auto',
359365
# this implementation is a bit faster with smaller shape[1]
360366
M = M.T
361367

362-
Q = randomized_range_finder(
363-
M,
364-
size=n_random,
365-
n_iter=n_iter,
366-
power_iteration_normalizer=power_iteration_normalizer,
367-
subtract_mean=subtract_mean,
368-
random_state=random_state,
369-
)
368+
Q = randomized_range_finder(M, n_random, n_iter,
369+
power_iteration_normalizer, random_state,
370+
subtract_mean)
370371

371372
# project M to the (k + p) dimensional space using the basis vectors
372373
B = safe_sparse_dot(Q.T, M)

0 commit comments

Comments
 (0)
0