10000 ENH Add `eigh` solver to `FastICA` by Micky774 · Pull Request #22527 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

ENH Add eigh solver to FastICA #22527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 92 commits into from
Jun 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
d55e9af
update pca
pierreablin Aug 20, 2018
ba1cec1
change algorithm depending on n
pierreablin Aug 21, 2018
7903b6c
added a choice between solvers for svd
pierreablin Aug 22, 2018
c5272ec
fix docstring
pierreablin Aug 22, 2018
762b701
Merge branch 'main' into change_svd
Micky774 Feb 18, 2022
ff167f9
Added debugging statements
Micky774 Feb 18, 2022
8f923a2
Slightly improved test coverage and corrected implementation
Micky774 Feb 18, 2022
2cd481a
Added to changelong as re-enabled failing test
Micky774 Feb 18, 2022
a8ceb9e
Removed old debugging code
Micky774 Feb 18, 2022
0bc7343
Added temporary benchmark file
Micky774 Feb 18, 2022
43c1e86
Update sklearn/decomposition/_fastica.py
Micky774 Feb 18, 2022
fcd2542
Updated benchmark file to use csv instead of pickle
Micky774 Feb 18, 2022
1616deb
Minor benchmark generator update
Micky774 Feb 19, 2022
e9934f5
Merge branch 'change_svd' of https://github.com/Micky774/scikit-learn…
Micky774 Feb 19, 2022
93e1375
Updated benchmark file and added csv
Micky774 Feb 20, 2022
cd3a0d2
Merge branch 'main' into change_svd
Micky774 Feb 25, 2022
61e0f03
Merge branch 'main' into change_svd
Micky774 Feb 27, 2022
af1a482
Removed old files, improved benchmark file
Micky774 Feb 27, 2022
202afa0
Added ratio column
Micky774 Feb 27, 2022
2f4bd2e
Added matrix reordering and reduced equality strictness (up to parity)
Micky774 Mar 3, 2022
8618cd7
Merge branch 'main' into change_svd
Micky774 Mar 3, 2022
035ae9c
Simplified reorder/flip in `eigh` solver
Micky774 Mar 5, 2022
74b762d
Merge branch 'main' into change_svd
Micky774 Mar 5, 2022
d5a7791
Merge branch 'main' into change_svd
Micky774 Mar 5, 2022
7676fac
Removed benchmark file (in provided gist links)
Micky774 Mar 5, 2022
559287f
Merge branch 'change_svd' of https://github.com/Micky774/scikit-learn…
Micky774 Mar 5, 2022
66dd707
TST Check solvers
thomasjpfan Mar 9, 2022
e6e9602
TST Create new tests
thomasjpfan Mar 9, 2022
9cdbef9
CLN Slightly better
thomasjpfan Mar 9, 2022
c5ef5a9
TST Adjust seed
thomasjpfan Mar 9, 2022
77422c8
DOC Adds comment
thomasjpfan Mar 9, 2022
deaab6e
FIX Give a random state
thomasjpfan Mar 9, 2022
c077bfe
FIX Give a random state
thomasjpfan Mar 9, 2022
7eec239
Merge pull request #1 from thomasjpfan/pr/22527_fix
Micky774 Mar 13, 2022
c085f5e
Updated changelog
Micky774 Mar 13, 2022
2c5beaa
Merge branch 'main' into change_svd
Micky774 Mar 13, 2022
7a0a130
Corrected `svd_solver`->`whiten_solver`
Micky774 Mar 13, 2022
38fcc3a
Apply suggestions from code review
Micky774 Mar 13, 2022
5916ab8
Update changelog
Micky774 Mar 13, 2022
c8a7b46
Merge branch 'main' into change_svd
Micky774 Mar 13, 2022
1f12010
Updated `svd_solver`->`whiten_solver` in tests
Micky774 Mar 14, 2022
6299607
Changed sign flip convention
Micky774 Mar 16, 2022
d016c09
Merge branch 'main' into change_svd
Micky774 Mar 18, 2022
bcadc76
Merge branch 'change_svd' of https://github.com/Micky774/scikit-learn…
Micky774 Mar 18, 2022
0298da3
Merge branch 'main' into change_svd
Micky774 Mar 18, 2022
a365dfa
Specify whiten to avoid future warning
Micky774 Mar 19, 2022
8774d37
Merge branch 'main' into change_svd
Micky774 Mar 19, 2022
785131a
Fixed changelog class reference
Micky774 Mar 19, 2022
1a59a39
Apply suggestions from code review
Micky774 Mar 25, 2022
ad54244
Merge branch 'main' into change_svd
Micky774 Mar 25, 2022
4284fde
Add test for catching low-rank warning in `eigh` solver
Micky774 Mar 25, 2022
841ea1c
Merge branch 'main' into change_svd
Micky774 Mar 25, 2022
c756d00
Merge branch 'main' into change_svd
Micky774 Mar 26, 2022
8b77ef2
Fixed sphinx lists
Micky774 Mar 26, 2022
8e6273e
Reformatted sphinx lists
Micky774 Mar 27, 2022
5e06542
Actually fix sphinx error...hopefully
Micky774 Apr 3, 2022
ac6e8ce
Merge branch 'main' into change_svd
Micky774 Apr 3, 2022
f02b720
Merge branch 'main' into change_svd
Micky774 Apr 3, 2022
8113b28
Merge branch 'change_svd' of https://github.com/Micky774/scikit-learn…
Micky774 Apr 3, 2022
24a4fc8
Fixed git sync issue
Micky774 Apr 3, 2022
eb6d72b
Merge branch 'main' into change_svd
Micky774 Apr 22, 2022
9b726be
Merge branch 'main' into change_svd
Micky774 May 8, 2022
3c8a446
Update sklearn/decomposition/_fastica.py
Micky774 May 8, 2022
066e83a
Undo format change (form a separate PR)
Micky774 May 8, 2022
7253518
Merge branch 'change_svd' of https://github.com/Micky774/scikit-learn…
Micky774 May 8, 2022
9ab3777
Merge branch 'main' into change_svd
Micky774 May 12, 2022
76ed7b6
Added "auto" option as new option, and added tests
Micky774 May 19, 2022
3a271b1
Began deprecation of new `whiten_solver` param in favor of auto
Micky774 May 19, 2022
c0368aa
Merge branch 'main' into change_svd
Micky774 May 19, 2022
8cbbb7c
Merge branch 'main' into change_svd
Micky774 May 23, 2022
b89c7ad
Merge branch 'main' into change_svd
Micky774 May 25, 2022
2591dea
Removed auto option, will reintroduce in future PR
Micky774 May 25, 2022
547a220
Added changed models entry for sign-flipping
Micky774 May 25, 2022
e4bf76e
Merge branch 'main' into change_svd
Micky774 May 26, 2022
07ca199
Reverted default value for whiten solver, pending follow-up PR
Micky774 May 26, 2022
e2ecad0
Changed erroneous default value
Micky774 May 30, 2022
1a04bb8
Merge branch 'main' into change_svd
Micky774 May 31, 2022
6cb877b
Merge branch 'main' into change_svd
Micky774 May 31, 2022
1ef1229
Merge branch 'main' into change_svd
Micky774 Jun 1, 2022
1e0cc7a
Fixed bad changelog and corrected test description
Micky774 Jun 1, 2022
e15b6d5
Added sign-flip parameter
Micky774 Jun 1, 2022
8297355
Merge branch 'main' into change_svd
Micky774 Jun 2, 2022
fac13d4
Fixed test
Micky774 Jun 2, 2022
f9c7fab
Merge branch 'main' into change_svd
Micky774 Jun 6, 2022
4c13aac
Apply suggestions from code review
Micky774 Jun 6, 2022
cef794d
Merge branch 'change_svd' of https://github.com/Micky774/scikit-learn…
Micky774 Jun 6, 2022
66c7080
Removed extra test
Micky774 Jun 6, 2022
a8d17b1
Merge branch 'main' into change_svd
Micky774 Jun 9, 2022
ae9ac99
Incorporated review feedback
Micky774 Jun 9, 2022
273581b
Merge branch 'main' into change_svd
Micky774 Jun 13, 2022
a515617
Updated changelog entry
Micky774 Jun 13, 2022
ba52b00
Linting
Micky774 Jun 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions doc/whats_new/v1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,20 @@ Changelog
- |Efficiency| Improve runtime performance of :class:`ensemble.IsolationForest`
by avoiding data copies. :pr:`23252` by :user:`Zhehao Liu <MaxwellLZH>`.

:mod:`sklearn.decomposition`
............................

- |Enhancement| :class:`decomposition.FastICA` now allows the user to select
how whitening is performed through the new `whiten_solver` parameter, which
supports `svd` and `eigh`. `whiten_solver` defaults to `svd` although `eigh`
may be faster and more memory efficient in cases where
`num_features > num_samples`. An additional `sign_flip` parameter is added.
When `sign_flip=True`, then the output of both solvers will be reconciled
during `fit` so that their outputs match. This may change the output of the
default solver, and hence may not be backwards compatible.
:pr:`11860` by :user:`Pierre Ablin <pierreablin>`,
:pr:`22527` by :user:`Meekail Zain <micky774>` and `Thomas Fan`_.

:mod:`sklearn.impute`
.....................

Expand Down
91 changes: 88 additions & 3 deletions sklearn/decomposition/_fastica.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import numpy as np
from scipy import linalg

from ..base import BaseEstimator, TransformerMixin, _ClassNamePrefixFeaturesOutMixin
from ..exceptions import ConvergenceWarning

Expand Down Expand Up @@ -162,10 +161,12 @@ def fastica(
max_iter=200,
tol=1e-04,
w_init=None,
whiten_solver="svd",
random_state=None,
return_X_mean=False,
compute_sources=True,
return_n_iter=False,
sign_flip=False,
):
"""Perform Fast Independent Component Analysis.

Expand Down Expand Up @@ -228,6 +229,18 @@ def my_g(x):
Initial un-mixing array. If `w_init=None`, then an array of values
drawn from a normal distribution is used.

whiten_solver : {"eigh", "svd"}, default="svd"
The solver to use for whitening.

- "svd" is more stable numerically if the problem is degenerate, and
often faster when `n_samples <= n_features`.

- "eigh" is generally more memory efficient when
`n_samples >= n_features`, and can be faster when
`n_samples >= 50 * n_features`.

.. versionadded:: 1.2

random_state : int, RandomState instance or None, default=None
Used to initialize ``w_init`` when not specified, with a
normal distribution. Pass an int, for reproducible results
Expand All @@ -244,6 +257,21 @@ def my_g(x):
return_n_iter : bool, default=False
Whether or not to return the number of iterations.

sign_flip : bool, default=False
Used to determine whether to enable sign flipping during whitening for
consistency in output between solvers.

- If `sign_flip=False` then the output of different choices for
`whiten_solver` may not be equal. Both outputs will still be correct,
but may differ numerically.

- If `sign_flip=True` then the output of both solvers will be
reconciled during fit so that their outputs match. This may produce
a different output for each solver when compared to
`sign_flip=False`.

.. versionadded:: 1.2

Returns
-------
K : ndarray of shape (n_components, n_features) or None
Expand Down Expand Up @@ -300,7 +328,9 @@ def my_g(x):
max_iter=max_iter,
tol=tol,
w_init=w_init,
whiten_solver=whiten_solver,
random_state=random_state,
sign_flip=sign_flip,
)
S = est._fit(X, compute_sources=compute_sources)

Expand Down Expand Up @@ -378,12 +408,39 @@ def my_g(x):
Initial un-mixing array. If `w_init=None`, then an array of values
drawn from a normal distribution is used.

whiten_solver : {"eigh", "svd"}, default="svd"
The solver to use for whitening.

- "svd" is more stable numerically if the problem is degenerate, and
often faster when `n_samples <= n_features`.

- "eigh" is generally more memory efficient when
`n_samples >= n_features`, and can be faster when
`n_samples >= 50 * n_features`.

.. versionadded:: 1.2

random_state : int, RandomState instance or None, default=None
Used to initialize ``w_init`` when not specified, with a
normal distribution. Pass an int, for reproducible results
across multiple function calls.
See :term:`Glossary <random_state>`.

sign_flip : bool, default=False
Used to determine whether to enable sign flipping during whitening for
consistency in output between solvers.

- If `sign_flip=False` then the output of different choices for
`whiten_solver` may not be equal. Both outputs will still be correct,
but may differ numerically.

- If `sign_flip=True` then the output of both solvers will be
reconciled during fit so that their outputs match. This may produce
a different output for each solver when compared to
`sign_flip=False`.

.. versionadded:: 1.2

Attributes
----------
components_ : ndarray of shape (n_components, n_features)
Expand Down Expand Up @@ -457,7 +514,9 @@ def __init__(
max_iter=200,
tol=1e-4,
w_init=None,
whiten_solver="svd",
random_state=None,
sign_flip=False,
):
super().__init__()
self.n_components = n_components
Expand All @@ -468,7 +527,9 @@ def __init__(
self.max_iter = max_iter
self.tol = tol
self.w_init = w_init
self.whiten_solver = whiten_solver
self.random_state = random_state
self.sign_flip = sign_flip

def _fit(self, X, compute_sources=False):
"""Fit the model.
Expand Down Expand Up @@ -557,9 +618,33 @@ def g(x, fun_args):
XT -= X_mean[:, np.newaxis]

# Whitening and preprocessing by PCA
u, d, _ = linalg.svd(XT, full_matrices=False, check_finite=False)
if self.whiten_solver == "eigh":
# Faster when num_samples >> n_features
d, u = linalg.eigh(XT.dot(X))
sort_indices = np.argsort(d)[::-1]
eps = np.finfo(d.dtype).eps
degenerate_idx = d < eps
if np.any(degenerate_idx):
warnings.warn(
"There are some small singular values, using "
"whiten_solver = 'svd' might lead to more "
"accurate results."
)
d[degenerate_idx] = eps # For numerical issues
np.sqrt(d, out=d)
d, u = d[sort_indices], u[:, sort_indices]
elif self.whiten_solver == "svd":
u, d = linalg.svd(XT, full_matrices=False, check_finite=False)[:2]
else:
raise ValueError(
"`whiten_solver` must be 'eigh' or 'svd' but got"
f" {self.whiten_solver} instead"
)

# Give consistent eigenvectors for both svd solvers
if self.sign_flip:
u *= np.sign(u[0])

del _
K = (u / d).T[:n_components] # see (6.33) p.140
del u, d
X1 = np.dot(K, XT)
Expand Down
63 changes: 62 additions & 1 deletion sklearn/decomposition/tests/test_fastica.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
from scipy import stats
from sklearn.datasets import make_low_rank_matrix

from sklearn.utils._testing import assert_array_equal
from sklearn.utils._testing import assert_allclose
Expand Down Expand Up @@ -422,7 +423,10 @@ def test_fastica_whiten_backwards_compatibility():

# No warning must be raised in this case.
av_ica = FastICA(
n_components=n_components, whiten="arbitrary-variance", random_state=0
n_components=n_components,
whiten="arbitrary-variance",
random_state=0,
whiten_solver="svd",
)
with warnings.catch_warnings():
warnings.simplefilter("error", FutureWarning)
Expand Down Expand Up @@ -457,3 +461,60 @@ def test_fastica_output_shape(whiten, return_X_mean, return_n_iter):
assert len(out) == expected_len
if not whiten:
assert out[0] is None


@pytest.mark.parametrize("add_noise", [True, False])
def test_fastica_simple_different_solvers(add_noise, global_random_seed):
"""Test FastICA is consistent between whiten_solvers when `sign_flip=True`."""
rng = np.random.RandomState(global_random_seed)
n_samples = 1000
# Generate two sources:
s1 = (2 * np.sin(np.linspace(0, 100, n_samples)) > 0) - 1
s2 = stats.t.rvs(1, size=n_samples, random_state=rng)
s = np.c_[s1, s2].T
center_and_norm(s)
s1, s2 = s

# Mixing angle
phi = rng.rand() * 2 * np.pi
mixing = np.array([[np.cos(phi), np.sin(phi)], [np.sin(phi), -np.cos(phi)]])
m = np.dot(mixing, s)

if add_noise:
m += 0.1 * rng.randn(2, 1000)

center_and_norm(m)

outs = {}
for solver in ("svd", "eigh"):
ica = FastICA(
random_state=0, whiten="unit-variance", whiten_solver=solver, sign_flip=True
)
sources = ica.fit_transform(m.T)
outs[solver] = sources
assert ica.components_.shape == (2, 2)
assert sources.shape == (1000, 2)

assert_allclose(outs["eigh"], outs["svd"])


def test_fastica_eigh_low_rank_warning(global_random_seed):
"""Test FastICA eigh solver raises warning for low-rank data."""
rng = np.random.RandomState(global_random_seed)
X = make_low_rank_matrix(
n_samples=10, n_features=10, random_state=rng, effective_rank=2
)
ica = FastICA(random_state=0, whiten="unit-variance", whiten_solver="eigh")
msg = "There are some small singular values"
with pytest.warns(UserWarning, match=msg):
ica.fit(X)


@pytest.mark.parametrize("whiten_solver", ["this_should_fail", "test", 1, None])
def test_fastica_whiten_solver_validation(whiten_solver):
rng = np.random.RandomState(0)
X = rng.random_sample((10, 2))
ica = FastICA(random_state=rng, whiten_solver=whiten_solver, whiten="unit-variance")
msg = f"`whiten_solver` must be 'eigh' or 'svd' but got {whiten_solver} instead"
with pytest.raises(ValueError, match=msg):
ica.fit_transform(X)
0