8000 [MRG+1] Initialize ARPACK eigsh by yanlend · Pull Request #5012 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+1] Initialize ARPACK eigsh #5012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 23, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/modules/pipeline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ and ``value`` is an estimator object::
n_components=None, whiten=False)), ('kernel_pca', KernelPCA(alpha=1.0,
coef0=1, degree=3, eigen_solver='auto', fit_inverse_transform=False,
gamma=None, kernel='linear', kernel_params=None, max_iter=None,
n_components=None, remove_zero_eig=False, tol=0))],
n_components=None, random_state=None, remove_zero_eig=False, tol=0))],
transformer_weights=None)

Like pipelines, feature unions have a shorthand constructor called
Expand Down
4 changes: 4 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ Bug fixes
- Fixed bug in :func:`manifold.spectral_embedding` where diagonal of unnormalized
Laplacian matrix was incorrectly set to 1. By `Peter Fischer`_.

- Fixed incorrect initialization of :func:`utils.arpack.eigsh` on all
occurrences. Affects :class:`cluster.SpectralBiclustering`,
:class:`decomposition.KernelPCA`, :class:`manifold.LocallyLinearEmbedding`,
and :class:`manifold.SpectralEmbedding`. By `Peter Fischer`_.

API changes summary
-------------------
Expand Down
15 changes: 11 additions & 4 deletions sklearn/cluster/bicluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from . import KMeans, MiniBatchKMeans
from ..base import BaseEstimator, BiclusterMixin
from ..externals import six
from ..utils import check_random_state
from ..utils.arpack import eigsh, svds

from ..utils.extmath import (make_nonnegative, norm, randomized_svd,
Expand Down Expand Up @@ -140,12 +141,18 @@ def _svd(self, array, n_components, n_discard):
# some eigenvalues of A * A.T are negative, causing
# sqrt() to be np.nan. This causes some vectors in vt
# to be np.nan.
_, v = eigsh(safe_sparse_dot(array.T, array),
ncv=self.n_svd_vecs)
A = safe_sparse_dot(array.T, array)
random_state = check_random_state(self.random_state)
# initialize with [-1,1] as in ARPACK
v0 = random_state.uniform(-1, 1, A.shape[0])
_, v = eigsh(A, ncv=self.n_svd_vecs, v0=v0)
vt = v.T
if np.any(np.isnan(u)):
_, u = eigsh(safe_sparse_dot(array, array.T),
ncv=self.n_svd_vecs)
A = safe_sparse_dot(array, array.T)
random_state = check_random_state(self.random_state)
# initialize with [-1,1] as in ARPACK
v0 = random_state.uniform(-1, 1, A.shape[0])
_, u = eigsh(A, ncv=self.n_svd_vecs, v0=v0)

assert_all_finite(u)
assert_all_finite(vt)
Expand Down
15 changes: 13 additions & 2 deletions sklearn/decomposition/kernel_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from scipy import linalg

from ..utils import check_random_state
from ..utils.arpack import eigsh
from ..utils.validation import check_is_fitted
from ..exceptions import NotFittedError
Expand Down Expand Up @@ -76,6 +77,10 @@ class KernelPCA(BaseEstimator, TransformerMixin):
When n_components is None, this parameter is ignored and components
with zero eigenvalues are removed regardless.

random_state : int seed, RandomState instance, or None, default : None
A pseudo random number generator used for the initialization of the
residuals when eigen_solver == 'arpack'.

Attributes
----------

Expand Down Expand Up @@ -103,7 +108,8 @@ class KernelPCA(BaseEstimator, TransformerMixin):
def __init__(self, n_components=None, kernel="linear",
gamma=None, degree=3, coef0=1, kernel_params=None,
alpha=1.0, fit_inverse_transform=False, eigen_solver='auto',
tol=0, max_iter=None, remove_zero_eig=False):
tol=0, max_iter=None, remove_zero_eig=False,
random_state=None):
if fit_inverse_transform and kernel == 'precomputed':
raise ValueError(
"Cannot fit_inverse_transform with a precomputed kernel.")
Expand All @@ -120,6 +126,7 @@ def __init__(self, n_components=None, kernel="linear",
self.tol = tol
self.max_iter = max_iter
self._centerer = KernelCenterer()
self.random_state = random_state

@property
def _pairwise(self):
Expand Down Expand Up @@ -158,10 +165,14 @@ def _fit_transform(self, K):
self.lambdas_, self.alphas_ = linalg.eigh(
K, eigvals=(K.shape[0] - n_components, K.shape[0] - 1))
elif eigen_solver == 'arpack':
random_state = check_random_state(self.random_state)
# initialize with [-1,1] as in ARPACK
v0 = random_state.uniform(-1, 1, K.shape[0])
self.lambdas_, self.alphas_ = eigsh(K, n_components,
which="LA",
tol=self.tol,
maxiter=self.max_iter)
maxiter=self.max_iter,
v0=v0)

# sort eigenvectors in descending order
indices = self.lambdas_.argsort()[::-1]
Expand Down
3 changes: 2 additions & 1 deletion sklearn/manifold/locally_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def null_space(M, k, k_skip=1, eigen_solver='arpack', tol=1E-6, max_iter=100,

if eigen_solver == 'arpack':
random_state = check_random_state(random_state)
v0 = random_state.rand(M.shape[0])
# initialize with [-1,1] as in ARPACK
v0 = random_state.uniform(-1, 1, M.shape[0])
try:
eigen_values, eigen_vectors = eigsh(M, k + k_skip, sigma=0.0,
tol=tol, maxiter=max_iter,
Expand Down
3 changes: 2 additions & 1 deletion sklearn/manifold/spectral_embedding_.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,10 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None,
# We are computing the opposite of the laplacian inplace so as
# to spare a memory allocation of a possibly very large array
laplacian *= -1
v0 = random_state.uniform(-1, 1, laplacian.shape[0])
lambdas, diffusion_map = eigsh(laplacian, k=n_components,
sigma=1.0, which='LM',
tol=eigen_tol)
tol=eigen_tol, v0=v0)
embedding = diffusion_map.T[n_components::-1] * dd
except RuntimeError:
# When submatrices are exactly singular, an LU decomposition
Expand Down
26 changes: 25 additions & 1 deletion sklearn/utils/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import numpy as np
import scipy.sparse as sp
from scipy.linalg import pinv2
from scipy.linalg import eigh
from itertools import chain

from sklearn.utils.testing import (assert_equal, assert_raises, assert_true,
assert_almost_equal, assert_array_equal,
SkipTest, assert_raises_regex)
SkipTest, assert_raises_regex,
assert_greater_equal)

from sklearn.utils import check_random_state
from sklearn.utils import deprecated
Expand All @@ -18,7 +20,9 @@
from sklearn.utils import shuffle
from sklearn.utils import gen_even_slices
from sklearn.utils.extmath import pinvh
from sklearn.utils.arpack import eigsh
from sklearn.utils.mocking import MockDataFrame
from sklearn.utils.graph import graph_laplacian


def test_make_rng():
Expand Down Expand Up @@ -126,6 +130,26 @@ def test_pinvh_simple_complex():
assert_almost_equal(np.dot(a, a_pinv), np.eye(3))


def test_arpack_eigsh_initialization():
# Non-regression test that shows null-space computation is better with
# initialization of eigsh from [-1,1] instead of [0,1]
random_state = check_random_state(42)

A = random_state.rand(50, 50)
A = np.dot(A.T, A) # create s.p.d. matrix
A = graph_laplacian(A) + 1e-7 * np.identity(A.shape[0])
k = 5

# Test if eigsh is working correctly
# New initialization [-1,1] (as in original ARPACK)
# Was [0,1] before, with which this test could fail
v0 = random_state.uniform(-1,1, A.shape[0])
w, _ = eigsh(A, k=k, sigma=0.0, v0=v0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this kind of test to be of interest I think we should run it many times with different seeds:

for seed in range(30):
    random_state = check_random_state(i)
    # put the content of the test here

However we need to make sure that it is still fast enough to run (e.g. < 100ms).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment was not addressed.


# Eigenvalues of s.p.d. matrix should be nonnegative, w[0] is smallest
assert_greater_equal(w[0], 0)


def test_column_or_1d():
EXAMPLES = [
("binary", ["spam", "egg", "spam"]),
Expand Down
0