8000 TST refactor test_truncated_svd (#14140) · scikit-learn/scikit-learn@eade48e · GitHub
[go: up one dir, main page]

Skip to content

Commit eade48e

Browse files
rththomasjpfan
authored andcommitted
TST refactor test_truncated_svd (#14140)
1 parent be17713 commit eade48e

File tree

1 file changed

+112
-161
lines changed

1 file changed

+112
-161
lines changed

sklearn/decomposition/tests/test_truncated_svd.py

Lines changed: 112 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -7,227 +7,178 @@
77

88
from sklearn.decomposition import TruncatedSVD, PCA
99
from sklearn.utils import check_random_state
10-
from sklearn.utils.testing import (assert_array_almost_equal, assert_equal,
11-
assert_raises, assert_greater,
12-
assert_array_less, assert_allclose)
10+
from sklearn.utils.testing import assert_array_less, assert_allclose
1311

12+
SVD_SOLVERS = ['arpack', 'randomized']
1413

15-
# Make an X that looks somewhat like a small tf-idf matrix.
16-
# XXX newer versions of SciPy >0.16 have scipy.sparse.rand for this.
17-
shape = 60, 55
18-
n_samples, n_features = shape
19-
rng = check_random_state(42)
20-
X = rng.randint(-100, 20, np.product(shape)).reshape(shape)
21-
X = sp.csr_matrix(np.maximum(X, 0), dtype=np.float64)
22-
X.data[:] = 1 + np.log(X.data)
23-
Xdense = X.A
2414

15+
@pytest.fixture(scope='module')
16+
def X_sparse():
17+
# Make an X that looks somewhat like a small tf-idf matrix.
18+
rng = check_random_state(42)
19+
X = sp.random(60, 55, density=0.2, format="csr", random_state=rng)
20+
X.data[:] = 1 + np.log(X.data)
21+
return X
2522

26-
def test_algorithms():
23+
24+
@pytest.mark.parametrize("solver", ['randomized'])
25+
@pytest.mark.parametrize('kind', ('dense', 'sparse'))
26+
def test_solvers(X_sparse, solver, kind):
27+
X = X_sparse if kind == 'sparse' else X_sparse.toarray()
2728
svd_a = TruncatedSVD(30, algorithm="arpack")
28-
svd_r = TruncatedSVD(30, algorithm="randomized", random_state=42)
29+
svd = TruncatedSVD(30, algorithm=solver, random_state=42)
2930

3031
Xa = svd_a.fit_transform(X)[:, :6]
31-
Xr = svd_r.fit_transform(X)[:, :6]
32-
assert_array_almost_equal(Xa, Xr, decimal=5)
32+
Xr = svd.fit_transform(X)[:, :6]
33+
assert_allclose(Xa, Xr, rtol=2e-3)
3334

3435
comp_a = np.abs(svd_a.components_)
35-
comp_r = np.abs(svd_r.components_)
36+
comp = np.abs(svd.components_)
3637
# All elements are equal, but some elements are more equal than others.
37-
assert_array_almost_equal(comp_a[:9], comp_r[:9])
38-
assert_array_almost_equal(comp_a[9:], comp_r[9:], decimal=2)
38+
assert_allclose(comp_a[:9], comp[:9], rtol=1e-3)
39+
assert_allclose(comp_a[9:], comp[9:], atol=1e-2)
3940

4041

41-
def test_attributes():
42-
1E79 for n_components in (10, 25, 41):
43-
tsvd = TruncatedSVD(n_components).fit(X)
44-
assert_equal(tsvd.n_components, n_components)
45-
assert_equal(tsvd.components_.shape, (n_components, n_features))
42+
@pytest.mark.parametrize("n_components", (10, 25, 41))
43+
def test_attributes(n_components, X_sparse):
44+
n_features = X_sparse.shape[1]
45+
tsvd = TruncatedSVD(n_components).fit(X_sparse)
46+
assert tsvd.n_components == n_components
47+
assert tsvd.components_.shape == (n_components, n_features)
4648

4749

48-
@pytest.mark.parametrize('algorithm', ("arpack", "randomized"))
49-
def test_too_many_components(algorithm):
50+
@pytest.mark.parametrize('algorithm', SVD_SOLVERS)
51+
def test_too_many_components(algorithm, X_sparse):
52+
n_features = X_sparse.shape[1]
5053
for n_components in (n_features, n_features + 1):
5154
tsvd = TruncatedSVD(n_components=n_components, algorithm=algorithm)
52-
assert_raises(ValueError, tsvd.fit, X)
55+
with pytest.raises(ValueError):
56+
tsvd.fit(X_sparse)
5357

5458

5559
@pytest.mark.parametrize('fmt', ("array", "csr", "csc", "coo", "lil"))
56-
def test_sparse_formats(fmt):
57-
Xfmt = Xdense if fmt == "dense" else getattr(X, "to" + fmt)()
60+
def test_sparse_formats(fmt, X_sparse):
61+
n_samples = X_sparse.shape[0]
62+
Xfmt = (X_sparse.toarray()
63+
if fmt == "dense" else getattr(X_sparse, "to" + fmt)())
5864
tsvd = TruncatedSVD(n_components=11)
5965
Xtrans = tsvd.fit_transform(Xfmt)
60-
assert_equal(Xtrans.shape, (n_samples, 11))
66+
assert Xtrans.shape == (n_samples, 11)
6167
Xtrans = tsvd.transform(Xfmt)
62-
assert_equal(Xtrans.shape, (n_samples, 11))
68+
assert Xtrans.shape == (n_samples, 11)
6369

6470

65-
@pytest.mark.parametrize('algo', ("arpack", "randomized"))
66-
def test_inverse_transform(algo):
71+
@pytest.mark.parametrize('algo', SVD_SOLVERS)
72+
def test_inverse_transform(algo, X_sparse):
6773
# We need a lot of components for the reconstruction to be "almost
6874
# equal" in all positions. XXX Test means or sums instead?
6975
tsvd = TruncatedSVD(n_components=52, random_state=42, algorithm=algo)
70-
Xt = tsvd.fit_transform(X)
76+
Xt = tsvd.fit_transform(X_sparse)
7177
Xinv = tsvd.inverse_transform(Xt)
72-
assert_array_almost_equal(Xinv, Xdense, decimal=1)
78+
assert_allclose(Xinv, X_sparse.toarray(), rtol=1e-1, atol=2e-1)
7379

7480

75-
def test_integers():
76-
Xint = X.astype(np.int64)
81+
def test_integers(X_sparse):
82+
n_samples = X_sparse.shape[0]
83+
Xint = X_sparse.astype(np.int64)
7784
tsvd = TruncatedSVD(n_components=6)
7885
Xtrans = tsvd.fit_transform(Xint)
79-
assert_equal(Xtrans.shape, (n_samples, tsvd.n_components))
80-
81-
82-
def test_explained_variance():
83-
# Test sparse data
84-
svd_a_10_sp = TruncatedSVD(10, algorithm="arpack")
85-
svd_r_10_sp = TruncatedSVD(10, algorithm="randomized", random_state=42)
86-
svd_a_20_sp = TruncatedSVD(20, algorithm="arpack")
87-
svd_r_20_sp = TruncatedSVD(20, algorithm="randomized", random_state=42)
88-
X_trans_a_10_sp = svd_a_10_sp.fit_transform(X)
89-
X_trans_r_10_sp = svd_r_10_sp.fit_transform(X)
90-
X_trans_a_20_sp = svd_a_20_sp.fit_transform(X)
91-
X_trans_r_20_sp = svd_r_20_sp.fit_transform(X)
92-
93-
# Test dense data
94-
svd_a_10_de = TruncatedSVD(10, algorithm="arpack")
95-
svd_r_10_de = TruncatedSVD(10, algorithm="randomized", random_state=42)
96-
svd_a_20_de = TruncatedSVD(20, algorithm="arpack")
97-
svd_r_20_de = TruncatedSVD(20, algorithm="randomized", random_state=42)
98-
X_trans_a_10_de = svd_a_10_de.fit_transform(X.toarray())
99-
X_trans_r_10_de = svd_r_10_de.fit_transform(X.toarray())
100-
X_trans_a_20_de = svd_a_20_de.fit_transform(X.toarray())
101-
X_trans_r_20_de = svd_r_20_de.fit_transform(X.toarray())
102-
103-
# helper arrays for tests below
104-
svds = (svd_a_10_sp, svd_r_10_sp, svd_a_20_sp, svd_r_20_sp, svd_a_10_de,
105-
svd_r_10_de, svd_a_20_de, svd_r_20_de)
106-
svds_trans = (
107-
(svd_a_10_sp, X_trans_a_10_sp),
108-
(svd_r_10_sp, X_trans_r_10_sp),
109-
(svd_a_20_sp, X_trans_a_20_sp),
110-
(svd_r_20_sp, X_trans_r_20_sp),
111-
(svd_a_10_de, X_trans_a_10_de),
112-
(svd_r_10_de, X_trans_r_10_de),
113-
(svd_a_20_de, X_trans_a_20_de),
114-
(svd_r_20_de, X_trans_r_20_de),
115-
)
116-
svds_10_v_20 = (
117-
(svd_a_10_sp, svd_a_20_sp),
118-
(svd_r_10_sp, svd_r_20_sp),
119-
(svd_a_10_de, svd_a_20_de),
120-
(svd_r_10_de, svd_r_20_de),
121-
)
122-
svds_sparse_v_dense = (
123-
(svd_a_10_sp, svd_a_10_de),
124-
(svd_a_20_sp, svd_a_20_de),
125-
(svd_r_10_sp, svd_r_10_de),
126-
(svd_r_20_sp, svd_r_20_de),
127-
)
86+
assert Xtrans.shape == (n_samples, tsvd.n_components)
12887

129-
# Assert the 1st component is equal
130-
for svd_10, svd_20 in svds_10_v_20:
131-
assert_array_almost_equal(
132-
svd_10.explained_variance_ratio_,
133-
svd_20.explained_variance_ratio_[:10],
134-
decimal=5,
135-
)
136-
137-
# Assert that 20 components has higher explained variance than 10
138-
for svd_10, svd_20 in svds_10_v_20:
139-
assert_greater(
140-
svd_20.explained_variance_ratio_.sum(),
141-
svd_10.explained_variance_ratio_.sum(),
142-
)
14388

89+
@pytest.mark.parametrize('kind', ('dense', 'sparse'))
90+
@pytest.mark.parametrize('n_components', [10, 20])
91+
@pytest.mark.parametrize('solver', SVD_SOLVERS)
92+
def test_explained_variance(X_sparse, kind, n_components, solver):
93+
X = X_sparse if kind == 'sparse' else X_sparse.toarray()
94+
svd = TruncatedSVD(n_components, algorithm=solver)
95+
7802 X_tr = svd.fit_transform(X)
14496
# Assert that all the values are greater than 0
145-
for svd in svds:
146-
assert_array_less(0.0, svd.explained_variance_ratio_)
97+
assert_array_less(0.0, svd.explained_variance_ratio_)
14798

14899
# Assert that total explained variance is less than 1
149-
for svd in svds:
150-
assert_array_less(svd.explained_variance_ratio_.sum(), 1.0)
151-
152-
# Compare sparse vs. dense
153-
for svd_sparse, svd_dense in svds_sparse_v_dense:
154-
assert_array_almost_equal(svd_sparse.explained_variance_ratio_,
155-
svd_dense.explained_variance_ratio_)
100+
assert_array_less(svd.explained_variance_ratio_.sum(), 1.0)
156101

157102
# Test that explained_variance is correct
158-
for svd, transformed in svds_trans:
159-
total_variance = np.var(X.toarray(), axis=0).sum()
160-
variances = np.var(transformed, axis=0)
161-
true_explained_variance_ratio = variances / total_variance
103+
total_variance = np.var(X_sparse.toarray(), axis=0).sum()
104+
variances = np.var(X_tr, axis=0)
105+
true_explained_variance_ratio = variances / total_variance
162106

163-
assert_array_almost_equal(
164-
svd.explained_variance_ratio_,
165-
true_explained_variance_ratio,
166-
)
107+
assert_allclose(
108+
svd.explained_variance_ratio_,
109+
true_explained_variance_ratio,
110+
)
167111

168112

169-
def test_singular_values():
170-
# Check that the TruncatedSVD output has the correct singular values
113+
@pytest.mark.parametrize('kind', ('dense', 'sparse'))
114+
@pytest.mark.parametrize('solver', SVD_SOLVERS)
115+
def test_explained_variance_components_10_20(X_sparse, kind, solver):
116+
X = X_sparse if kind == 'sparse' else X_sparse.toarray()
117+
svd_10 = TruncatedSVD(10, algorithm=solver).fit(X)
118+
svd_20 = TruncatedSVD(20, algorithm=solver).fit(X)
171119

172-
rng = np.random.RandomState(0)
173-
n_samples = 100
174-
n_features = 80
120+
# Assert the 1st component is equal
121+
assert_allclose(
122+
svd_10.explained_variance_ratio_,
123+
svd_20.explained_variance_ratio_[:10],
124+
rtol=3e-3,
125+
)
126+
127+
# Assert that 20 components has higher explained variance than 10
128+
assert (
129+
svd_20.explained_variance_ratio_.sum() >
130+
svd_10.explained_variance_ratio_.sum()
131+
)
175132

133+
134+
@pytest.mark.parametrize('solver', SVD_SOLVERS)
135+
def test_singular_values_consistency(solver):
136+
# Check that the TruncatedSVD output has the correct singular values
137+
rng = np.random.RandomState(0)
138+
n_samples, n_features = 100, 80
176139
X = rng.randn(n_samples, n_features)
177140

178-
apca = TruncatedSVD(n_components=2, algorithm='arpack',
179-
random_state=rng).fit(X)
180-
rpca = TruncatedSVD(n_components=2, algorithm='arpack',
181-
random_state=rng).fit(X)
182-
assert_array_almost_equal(apca.singular_values_, rpca.singular_values_, 12)
141+
pca = TruncatedSVD(n_components=2, algorithm=solver,
142+
random_state=rng).fit(X)
183143

184144
# Compare to the Frobenius norm
185-
X_apca = apca.transform(X)
186-
X_rpca = rpca.transform(X)
187-
assert_array_almost_equal(np.sum(apca.singular_values_**2.0),
188-
np.linalg.norm(X_apca, "fro")**2.0, 12)
189-
assert_array_almost_equal(np.sum(rpca.singular_values_**2.0),
190-
np.linalg.norm(X_rpca, "fro")**2.0, 12)
145+
X_pca = pca.transform(X)
146+
assert_allclose(np.sum(pca.singular_values_**2.0),
147+
np.linalg.norm(X_pca, "fro")**2.0, rtol=1e-2)
191148

192149
# Compare to the 2-norms of the score vectors
193-
assert_array_almost_equal(apca.singular_values_,
194-
np.sqrt(np.sum(X_apca**2.0, axis=0)), 12)
195-
assert_array_almost_equal(rpca.singular_values_,
196-
np.sqrt(np.sum(X_rpca**2.0, axis=0)), 12)
150+
assert_allclose(pca.singular_values_,
151+
np.sqrt(np.sum(X_pca**2.0, axis=0)), rtol=1e-2)
197152

153+
154+
@pytest.mark.parametrize('solver', SVD_SOLVERS)
155+
def test_singular_values_expected(solver):
198156
# Set the singular values and see what we get back
199157
rng = np.random.RandomState(0)
200158
n_samples = 100
201159
n_features = 110
202160

203161
X = rng.randn(n_samples, n_features)
204162

205-
apca = TruncatedSVD(n_components=3, algorithm='arpack',
206-
random_state=rng)
207-
rpca = TruncatedSVD(n_components=3, algorithm='randomized',
208-
random_state=rng)
209-
X_apca = apca.fit_transform(X)
210-
X_rpca = rpca.fit_transform(X)
211-
212-
X_apca /= np.sqrt(np.sum(X_apca**2.0, axis=0))
213-
X_rpca /= np.sqrt(np.sum(X_rpca**2.0, axis=0))
214-
X_apca[:, 0] *= 3.142
215-
X_apca[:, 1] *= 2.718
216-
X_rpca[:, 0] *= 3.142
217-
X_rpca[:, 1] *= 2.718
218-
219-
X_hat_apca = np.dot(X_apca, apca.components_)
220-
X_hat_rpca = np.dot(X_rpca, rpca.components_)
221-
apca.fit(X_hat_apca)
222-
rpca.fit(X_hat_rpca)
223-
assert_array_almost_equal(apca.singular_values_, [3.142, 2.718, 1.0], 14)
224-
assert_array_almost_equal(rpca.singular_values_, [3.142, 2.718, 1.0], 14)
225-
226-
227-
def test_truncated_svd_eq_pca():
163+
pca = TruncatedSVD(n_components=3, algorithm=solver,
164+
random_state=rng)
165+
X_pca = pca.fit_transform(X)
166< A84D /td>+
167+
X_pca /= np.sqrt(np.sum(X_pca**2.0, axis=0))
168+
X_pca[:, 0] *= 3.142
169+
X_pca[:, 1] *= 2.718
170+
171+
X_hat_pca = np.dot(X_pca, pca.components_)
172+
pca.fit(X_hat_pca)
173+
assert_allclose(pca.singular_values_, [3.142, 2.718, 1.0], rtol=1e-14)
174+
175+
176+
def test_truncated_svd_eq_pca(X_sparse):
228177
# TruncatedSVD should be equal to PCA on centered data
229178

230-
X_c = X - X.mean(axis=0)
179+
X_dense = X_sparse.toarray()
180+
181+
X_c = X_dense - X_dense.mean(axis=0)
231182

232183
params = dict(n_components=10, random_state=42)
233184

0 commit comments

Comments
 (0)
0