8000 FIX: amg requires sparse matrices input · satra/scikit-learn@d5ff04a · GitHub
[go: up one dir, main page]

Skip to content

Commit d5ff04a

Browse files
weilinearamueller
authored andcommitted
FIX: amg requires sparse matrices input
1 parent f09a579 commit d5ff04a

File tree

2 files changed

+15
-31
lines changed

2 files changed

+15
-31
lines changed

sklearn/manifold/spectral_embedding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,9 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None,
264264
if eigen_solver == 'amg':
265265
# Use AMG to get a preconditioner and speed up the eigenvalue
266266
# problem.
267+
if not sparse.issparse(laplacian):
268+
warnings.warn("AMG works for sparse matrices better")
269+
laplacian = sparse.csr_matrix(laplacian)
267270
laplacian = laplacian.astype(np.float) # lobpcg needs native floats
268271
ml = smoothed_aggregation_solver(atleast2d_or_csr(laplacian))
269272
M = ml.aspreconditioner()
@@ -446,7 +449,7 @@ def fit(self, X, y=None):
446449
self.random_state = check_random_state(self.random_state)
447450
if isinstance(self.affinity, basestring):
448451
if self.affinity not in set(("nearest_neighbors", "rbf",
449-
"precomputed")):
452+
"precomputed")):
450453
raise ValueError(("%s is not a valid affinity. Expected "
451454
"'precomputed', 'rbf', 'nearest_neighbors' "
452455
"or a callable.") % self.affinity)

sklearn/manifold/tests/test_spectral_embedding.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
116116
except ImportError:
117117
raise SkipTest
118118

119-
gamma = 0.9
120-
se_amg = SpectralEmbedding(n_components=3, affinity="rbf",
121-
gamma=gamma, eigen_solver="amg",
119+
se_amg = SpectralEmbedding(n_components=3, affinity="nearest_neighbors",
120+
eigen_solver="amg", n_neighbors=5,
122121
random_state=np.random.RandomState(seed))
123-
se_arpack = SpectralEmbedding(n_components=3, affinity="rbf",
124-
gamma=gamma, eigen_solver="arpack",
122+
se_arpack = SpectralEmbedding(n_components=3, affinity="nearest_neighbors",
123+
eigen_solver="arpack", n_neighbors=5,
125124
random_state=np.random.RandomState(seed))
126125
embed_amg = se_amg.fit_transform(S)
127126
embed_arpack = se_arpack.fit_transform(S)
128-
assert_array_almost_equal(
129-
se_amg.affinity_matrix_, se_arpack.affinity_matrix_)
130127
assert_true(_check_with_col_sign_flipping(embed_amg, embed_arpack, 0.01))
131128

132129

@@ -151,33 +148,17 @@ def test_pipline_spectral_clustering(seed=36):
151148

152149
def test_spectral_embedding_unknown_eigensolver(seed=36):
153150
"""Test that SpectralClustering fails with an unknown eigensolver"""
154-
centers = np.array([
155-
[0., 0., 0.],
156-
[10., 10., 10.],
157-
[20., 20., 20.],
158-
])
159-
X, true_labels = make_blobs(n_samples=100, centers=centers,
160-
cluster_std=1., random_state=42)
161-
162-
se_precomp = SpectralEmbedding(n_components=1, affinity="precomputed",
163-
random_state=np.random.RandomState(seed),
164-
eigen_solver="<unknown>")
165-
assert_raises(ValueError, se_precomp.fit, S)
151+
se = SpectralEmbedding(n_components=1, affinity="precomputed",
152+
random_state=np.random.RandomState(seed),
153+
eigen_solver="<unknown>")
154+
assert_raises(ValueError, se.fit, S)
166155

167156

168157
def test_spectral_embedding_unknown_affinity(seed=36):
169158
"""Test that SpectralClustering fails with an unknown affinity type"""
170-
centers = np.array([
171-
[0., 0., 0.],
172-
[10., 10., 10.],
173-
[20., 20., 20.],
174-
])
175-
X, true_labels = make_blobs(n_samples=100, centers=centers,
176-
cluster_std=1., random_state=42)
177-
178-
se_precomp = SpectralEmbedding(n_components=1, affinity="<unknown>",
179-
random_state=np.random.RandomState(seed))
180-
assert_raises(ValueError, se_precomp.fit, S)
159+
se = SpectralEmbedding(n_components=1, affinity="<unknown>",
160+
random_state=np.random.RandomState(seed))
161+
assert_raises(ValueError, se.fit, S)
181162

182163

183164
def test_connectivity(seed=36):

0 commit comments

Comments
 (0)
0