13
13
14
14
import numpy as np
15
15
from scipy import linalg
16
-
17
16
from ..base import BaseEstimator , TransformerMixin , _ClassNamePrefixFeaturesOutMixin
18
17
from ..exceptions import ConvergenceWarning
19
18
@@ -162,10 +161,12 @@ def fastica(
162
161
max_iter = 200 ,
163
162
tol = 1e-04 ,
164
163
w_init = None ,
164
+ whiten_solver = "svd" ,
165
165
random_state = None ,
166
166
return_X_mean = False ,
167
167
compute_sources = True ,
168
168
return_n_iter = False ,
169
+ sign_flip = False ,
169
170
):
170
171
"""Perform Fast Independent Component Analysis.
171
172
@@ -228,6 +229,18 @@ def my_g(x):
228
229
Initial un-mixing array. If `w_init=None`, then an array of values
229
230
drawn from a normal distribution is used.
230
231
232
+ whiten_solver : {"eigh", "svd"}, default="svd"
233
+ The solver to use for whitening.
234
+
235
+ - "svd" is more stable numerically if the problem is degenerate, and
236
+ often faster when `n_samples <= n_features`.
237
+
238
+ - "eigh" is generally more memory efficient when
239
+ `n_samples >= n_features`, and can be faster when
240
+ `n_samples >= 50 * n_features`.
241
+
242
+ .. versionadded:: 1.2
243
+
231
244
random_state : int, RandomState instance or None, default=None
232
245
Used to initialize ``w_init`` when not specified, with a
233
246
normal distribution. Pass an int, for reproducible results
@@ -244,6 +257,21 @@ def my_g(x):
244
257
return_n_iter : bool, default=False
245
258
Whether or not to return the number of iterations.
246
259
260
+ sign_flip : bool, default=False
261
+ Used to determine whether to enable sign flipping during whitening for
262
+ consistency in output between solvers.
263
+
264
+ - If `sign_flip=False` then the output of different choices for
265
+ `whiten_solver` may not be equal. Both outputs will still be correct,
266
+ but may differ numerically.
267
+
268
+ - If `sign_flip=True` then the output of both solvers will be
269
+ reconciled during fit so that their outputs match. This may produce
270
+ a different output for each solver when compared to
271
+ `sign_flip=False`.
272
+
273
+ .. versionadded:: 1.2
274
+
247
275
Returns
248
276
-------
249
277
K : ndarray of shape (n_components, n_features) or None
@@ -300,7 +328,9 @@ def my_g(x):
300
328
max_iter = max_iter ,
301
329
tol = tol ,
302
330
w_init = w_init ,
331
+ whiten_solver = whiten_solver ,
303
332
random_state = random_state ,
333
+ sign_flip = sign_flip ,
304
334
)
305
335
S = est ._fit (X , compute_sources = compute_sources )
306
336
@@ -378,12 +408,39 @@ def my_g(x):
378
408
Initial un-mixing array. If `w_init=None`, then an array of values
379
409
drawn from a normal distribution is used.
380
410
411
+ whiten_solver : {"eigh", "svd"}, default="svd"
412
+ The solver to use for whitening.
413
+
414
+ - "svd" is more stable numerically if the problem is degenerate, and
415
+ often faster when `n_samples <= n_features`.
416
+
417
+ - "eigh" is generally more memory efficient when
418
+ `n_samples >= n_features`, and can be faster when
419
+ `n_samples >= 50 * n_features`.
420
+
421
+ .. versionadded:: 1.2
422
+
381
423
random_state : int, RandomState instance or None, default=None
382
424
Used to initialize ``w_init`` when not specified, with a
383
425
normal distribution. Pass an int, for reproducible results
384
426
across multiple function calls.
385
427
See :term:`Glossary <random_state>`.
386
428
429
+ sign_flip : bool, default=False
430
+ Used to determine whether to enable sign flipping during whitening for
431
+ consistency in output between solvers.
432
+
433
+ - If `sign_flip=False` then the output of different choices for
434
+ `whiten_solver` may not be equal. Both outputs will still be correct,
435
+ but may differ numerically.
436
+
437
+ - If `sign_flip=True` then the output of both solvers will be
438
+ reconciled during fit so that their outputs match. This may produce
439
+ a different output for each solver when compared to
440
+ `sign_flip=False`.
441
+
442
+ .. versionadded:: 1.2
443
+
387
444
Attributes
388
445
----------
389
446
components_ : ndarray of shape (n_components, n_features)
@@ -457,7 +514,9 @@ def __init__(
457
514
max_iter = 200 ,
458
515
tol = 1e-4 ,
459
516
w_init = None ,
517
+ whiten_solver = "svd" ,
460
518
random_state = None ,
519
+ sign_flip = False ,
461
520
):
462
521
super ().__init__ ()
463
522
self .n_components = n_components
@@ -468,7 +527,9 @@ def __init__(
468
527
self .max_iter = max_iter
469
528
self .tol = tol
470
529
self .w_init = w_init
530
+ self .whiten_solver = whiten_solver
471
531
self .random_state = random_state
532
+ self .sign_flip = sign_flip
472
533
473
534
def _fit (self , X , compute_sources = False ):
474
535
"""Fit the model.
@@ -557,9 +618,33 @@ def g(x, fun_args):
557
618
XT -= X_mean [:, np .newaxis ]
558
619
559
620
# Whitening and preprocessing by PCA
560
- u , d , _ = linalg .svd (XT , full_matrices = False , check_finite = False )
621
+ if self .whiten_solver == "eigh" :
622
+ # Faster when num_samples >> n_features
623
+ d , u = linalg .eigh (XT .dot (X ))
624
+ sort_indices = np .argsort (d )[::- 1 ]
625
+ eps = np .finfo (d .dtype ).eps
626
+ degenerate_idx = d < eps
627
+ if np .any (degenerate_idx ):
628
+ warnings .warn (
629
+ "There are some small singular values, using "
630
+ "whiten_solver = 'svd' might lead to more "
631
+ "accurate results."
632
+ )
633
+ d [degenerate_idx ] = eps # For numerical issues
634
+ np .sqrt (d , out = d )
635
+ d , u = d [sort_indices ], u [:, sort_indices ]
636
+ elif self .whiten_solver == "svd" :
637
+ u , d = linalg .svd (XT , full_matrices = False , check_finite = False )[:2 ]
638
+ else :
639
+ raise ValueError (
640
+ "`whiten_solver` must be 'eigh' or 'svd' but got"
641
+ f" { self .whiten_solver } instead"
642
+ )
643
+
644
+ # Give consistent eigenvectors for both svd solvers
645
+ if self .sign_flip :
646
+ u *= np .sign (u [0 ])
561
647
562
- del _
563
648
K = (u / d ).T [:n_components ] # see (6.33) p.140
564
649
del u , d
565
650
X1 = np .dot (K , XT )
0 commit comments