8000 ENH Add `eigh` solver to `FastICA` (#22527) · scikit-learn/scikit-learn@54c1503 · GitHub
[go: up one dir, main page]

Skip to content

Commit 54c1503

Browse files
Micky774pierreablinthomasjpfan
authored
ENH Add eigh solver to FastICA (#22527)
Co-authored-by: Pierre Ablin <pierreablin@gmail.com> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent bb0ea58 commit 54c1503

File tree

3 files changed

+164
-4
lines changed

3 files changed

+164
-4
lines changed

doc/whats_new/v1.2.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,20 @@ Changelog
105105
- |Efficiency| Improve runtime performance of :class:`ensemble.IsolationForest`
106106
by avoiding data copies. :pr:`23252` by :user:`Zhehao Liu <MaxwellLZH>`.
107107

108+
:mod:`sklearn.decomposition`
109+
............................
110+
111+
- |Enhancement| :class:`decomposition.FastICA` now allows the user to select
112+
how whitening is performed through the new `whiten_solver` parameter, which
113+
supports `svd` and `eigh`. `whiten_solver` defaults to `svd` although `eigh`
114+
may be faster and more memory efficient in cases where
115+
`num_features > num_samples`. An additional `sign_flip` parameter is added.
116+
When `sign_flip=True`, then the output of both solvers will be reconciled
117+
during `fit` so that their outputs match. This may change the output of the
118+
default solver, and hence may not be backwards compatible.
119+
:pr:`11860` by :user:`Pierre Ablin <pierreablin>`,
120+
:pr:`22527` by :user:`Meekail Zain <micky774>` and `Thomas Fan`_.
121+
108122
:mod:`sklearn.impute`
109123
.....................
110124

sklearn/decomposition/_fastica.py

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
import numpy as np
1515
from scipy import linalg
16-
1716
from ..base import BaseEstimator, TransformerMixin, _ClassNamePrefixFeaturesOutMixin
1817
from ..exceptions import ConvergenceWarning
1918

@@ -162,10 +161,12 @@ def fastica(
162161
max_iter=200,
163162
tol=1e-04,
164163
w_init=None,
164+
whiten_solver="svd",
165165
random_state=None,
166166
return_X_mean=False,
167167
compute_sources=True,
168168
return_n_iter=False,
169+
sign_flip=False,
169170
):
170171
"""Perform Fast Independent Component Analysis.
171172
@@ -228,6 +229,18 @@ def my_g(x):
228229
Initial un-mixing array. If `w_init=None`, then an array of values
229230
drawn from a normal distribution is used.
230231
232+
whiten_solver : {"eigh", "svd"}, default="svd"
233+
The solver to use for whitening.
234+
235+
- "svd" is more stable numerically if the problem is degenerate, and
236+
often faster when `n_samples <= n_features`.
237+
238+
- "eigh" is generally more memory efficient when
239+
`n_samples >= n_features`, and can be faster when
240+
`n_samples >= 50 * n_features`.
241+
242+
.. versionadded:: 1.2
243+
231244
random_state : int, RandomState instance or None, default=None
232245
Used to initialize ``w_init`` when not specified, with a
233246
normal distribution. Pass an int, for reproducible results
@@ -244,6 +257,21 @@ def my_g(x):
244257
return_n_iter : bool, default=False
245258
Whether or not to return the number of iterations.
246259
260+
sign_flip : bool, default=False
261+
Used to determine whether to enable sign flipping during whitening for
262+
consistency in output between solvers.
263+
264+
- If `sign_flip=False` then the output of different choices for
265+
`whiten_solver` may not be equal. Both outputs will still be correct,
266+
but may differ numerically.
267+
268+
- If `sign_flip=True` then the output of both solvers will be
269+
reconciled during fit so that their outputs match. This may produce
270+
a different output for each solver when compared to
271+
`sign_flip=False`.
272+
273+
.. versionadded:: 1.2
274+
247275
Returns
248276
-------
249277
K : ndarray of shape (n_components, n_features) or None
@@ -300,7 +328,9 @@ def my_g(x):
300328
max_iter=max_iter,
301329
tol=tol,
302330
w_init=w_init,
331+
whiten_solver=whiten_solver,
303332
random_state=random_state,
333+
sign_flip=sign_flip,
304334
)
305335
S = est._fit(X, compute_sources=compute_sources)
306336

@@ -378,12 +408,39 @@ def my_g(x):
378408
Initial un-mixing array. If `w_init=None`, then an array of values
379409
drawn from a normal distribution is used.
380410
411+
whiten_solver : {"eigh", "svd"}, default="svd"
412+
The solver to use for whitening.
413+
414+
- "svd" is more stable numerically if the problem is degenerate, and
415+
often faster when `n_samples <= n_features`.
416+
417+
- "eigh" is generally more memory efficient when
418+
`n_samples >= n_features`, and can be faster when
419+
`n_samples >= 50 * n_features`.
420+
421+
.. versionadded:: 1.2
422+
381423
random_state : int, RandomState instance or None, default=None
382424
Used to initialize ``w_init`` when not specified, with a
383425
normal distribution. Pass an int, for reproducible results
384426
across multiple function calls.
385427
See :term:`Glossary <random_state>`.
386428
429+
sign_flip : bool, default=False
430+
Used to determine whether to enable sign flipping during whitening for
431+
consistency in output between solvers.
432+
433+
- If `sign_flip=False` then the output of different choices for
434+
`whiten_solver` may not be equal. Both outputs will still be correct,
435+
but may differ numerically.
436+
437+
- If `sign_flip=True` then the output of both solvers will be
438+
reconciled during fit so that their outputs match. This may produce
439+
a different output for each solver when compared to
440+
`sign_flip=False`.
441+
442+
.. versionadded:: 1.2
443+
387444
Attributes
388445
----------
389446
components_ : ndarray of shape (n_components, n_features)
@@ -457,7 +514,9 @@ def __init__(
457514
max_iter=200,
458515
tol=1e-4,
459516
w_init=None,
517+
whiten_solver="svd",
460518
random_state=None,
519+
sign_flip=False,
461520
):
462521
super().__init__()
463522
self.n_components = n_components
@@ -468,7 +527,9 @@ def __init__(
468527
self.max_iter = max_iter
469528
self.tol = tol
470529
self.w_init = w_init
530+
self.whiten_solver = whiten_solver
471531
self.random_state = random_state
532+
self.sign_flip = sign_flip
472533

473534
def _fit(self, X, compute_sources=False):
474535
"""Fit the model.
@@ -557,9 +618,33 @@ def g(x, fun_args):
557618
XT -= X_mean[:, np.newaxis]
558619

559620
# Whitening and preprocessing by PCA
560-
u, d, _ = linalg.svd(XT, full_matrices=False, check_finite=False)
621+
if self.whiten_solver == "eigh":
622+
# Faster when num_samples >> n_features
623+
d, u = linalg.eigh(XT.dot(X))
624+
sort_indices = np.argsort(d)[::-1]
625+
eps = np.finfo(d.dtype).eps
626+
degenerate_idx = d < eps
627+
if np.any(degenerate_idx):
628+
warnings.warn(
629+
"There are some small singular values, using "
630+
"whiten_solver = 'svd' might lead to more "
631+
"accurate results."
632+
)
633+
d[degenerate_idx] = eps # For numerical issues
634+
np.sqrt(d, out=d)
635+
d, u = d[sort_indices], u[:, sort_indices]
636+
elif self.whiten_solver == "svd":
637+
u, d = linalg.svd(XT, full_matrices=False, check_finite=False)[:2]
638+
else:
639+
raise ValueError(
640+
"`whiten_solver` must be 'eigh' or 'svd' but got"
641+
f" {self.whiten_solver} instead"
642+
)
643+
644+
# Give consistent eigenvectors for both svd solvers
645+
if self.sign_flip:
646+
u *= np.sign(u[0])
561647

562-
del _
563648
K = (u / d).T[:n_components] # see (6.33) p.140
564649
del u, d
565650
X1 = np.dot(K, XT)

sklearn/decomposition/tests/test_fastica.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
from scipy import stats
10+
from sklearn.datasets import make_low_rank_matrix
1011

1112
from sklearn.utils._testing import assert_array_equal
1213
from sklearn.utils._testing import assert_allclose
@@ -422,7 +423,10 @@ def test_fastica_whiten_backwards_compatibility():
422423

423424
# No warning must be raised in this case.
424425
av_ica = FastICA(
425-
n_components=n_components, whiten="arbitrary-variance", random_state=0
426+
n_components=n_components,
427+
whiten="arbitrary-variance",
428+
random_state=0,
429+
whiten_solver="svd",
426430
)
427431
with warnings.catch_warnings():
428432
warnings.simplefilter("error", FutureWarning)
@@ -457,3 +461,60 @@ def test_fastica_output_shape(whiten, return_X_mean, return_n_iter):
457461
assert len(out) == expected_len
458462
if not whiten:
459463
assert out[0] is None
464+
465+
466+
@pytest.mark.parametrize("add_noise", [True, False])
467+
def test_fastica_simple_different_solvers(add_noise, global_random_seed):
468+
"""Test FastICA is consistent between whiten_solvers when `sign_flip=True`."""
469+
rng = np.random.RandomState(global_random_seed)
470+
n_samples = 1000
471+
# Generate two sources:
472+
s1 = (2 * np.sin(np.linspace(0, 100, n_samples)) > 0) - 1
473+
s2 = stats.t.rvs(1, size=n_samples, random_state=rng)
474+
s = np.c_[s1, s2].T
475+
center_and_norm(s)
476+
s1, s2 = s
477+
478+
# Mixing angle
479+
phi = rng.rand() * 2 * np.pi
480+
mixing = np.array([[np.cos(phi), np.sin(phi)], [np.sin(phi), -np.cos(phi)]])
481+
m = np.dot(mixing, s)
482+
483+
if add_noise:
484+
m += 0.1 * rng.randn(2, 1000)
485+
486+
center_and_norm(m)
487+
488+
outs = {}
489+
for solver in ("svd", "eigh"):
490+
ica = FastICA(
491+
random_state=0, whiten="unit-variance", whiten_solver=solver, sign_flip=True
492+
)
493+
sources = ica.fit_transform 2851 (m.T)
494+
outs[solver] = sources
495+
assert ica.components_.shape == (2, 2)
496+
assert sources.shape == (1000, 2)
497+
498+
assert_allclose(outs["eigh"], outs["svd"])
499+
500+
501+
def test_fastica_eigh_low_rank_warning(global_random_seed):
502+
"""Test FastICA eigh solver raises warning for low-rank data."""
503+
rng = np.random.RandomState(global_random_seed)
504+
X = make_low_rank_matrix(
505+
n_samples=10, n_features=10, random_state=rng, effective_rank=2
506+
)
507+
ica = FastICA(random_state=0, whiten="unit-variance", whiten_solver="eigh")
508+
msg = "There are some small singular values"
509+
with pytest.warns(UserWarning, match=msg):
510+
ica.fit(X)
511+
512+
513+
@pytest.mark.parametrize("whiten_solver", ["this_should_fail", "test", 1, None])
514+
def test_fastica_whiten_solver_validation(whiten_solver):
515+
rng = np.random.RandomState(0)
516+
X = rng.random_sample((10, 2))
517+
ica = FastICA(random_state=rng, whiten_solver=whiten_solver, whiten="unit-variance")
518+
msg = f"`whiten_solver` must be 'eigh' or 'svd' but got {whiten_solver} instead"
519+
with pytest.raises(ValueError, match=msg):
520+
ica.fit_transform(X)

0 commit comments

Comments
 (0)
0