8000 ENH Add `eigh` solver to `FastICA` by Micky774 · Pull Request #22527 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

ENH Add eigh solver to FastICA #22527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 92 commits into from
Jun 14, 2022
Merged

Conversation

Micky774
Copy link
Contributor
@Micky774 Micky774 commented Feb 18, 2022

Reference Issues/PRs

Picks up stalled PR #11860

What does this implement/fix? Explain your changes.

PR #11860: Provides an alternative implementation that avoids SVD's extraneous calculations. Especially effective when num_samples >> num_features

This PR:

  • Merged w/ main
  • Resolved changes w/ changes in main implementation of fastica
  • Added testing coverage
  • Adds benchmarks to establish ideal/preferred conditions for solvers
  • Adds consistency between two solvers

Any other comments?

Ongoing problems/consideration:

  • Need to consider whether other tests must be extended to try both solvers
  • Need to figure out why the "correctness" test of the eigh 8000 solver is failing

Needs follow-up PRs for adding whiten_solver="auto" and beginning deprecation to mark it as default.
May need a follow-up PR for changing the default value of sign_flip to True.

@Micky774
Copy link
Contributor Author

Initial benchmarks:
image
image

@Micky774
Copy link
Contributor Author

Doing some more thorough benchmarks right now. Current insights:

  • eigh seems to perform better when num_samples, num_features < (10, 200)
  • eigh seems to perform better when num_samples, num_features > (4000, 200)
  • eigh is, at best, a 50x speedup (more often a 2-10x) and at worst 8100x slower (often 5-100x slower)
  • The decision boundary here is messy and complicated.

In the end, I think it's fine to support 8000 this, but the default should definitely be svd since the worst-case performance of eigh is awful. We can recommend the user try eigh when num_samples > 50x num_features. Granted how severe the slowdown is, if we detect that num_features > num_samples and the user has specified eignh, then I think it's worth raising a warning letting them know it may be a heft slowdown.

To reproduce the benchmarks yourself, install streamlit then run sreamlit run bench.py -- --load saved_df and it will open the (extremely shabby) streamlit app in your browser. If, for some reason you want to see the raw runtime ratio results, I've included them below.

Raw markdown table (100 rows) of runtime ratios (sorted)

shape eigh/svd time
(20, 40) 0.027392656353497417
(40, 100) 0.049180007809449436
(20, 100) 0.057552231442334706
(100, 210) 0.09524925799194661
(40, 210) 0.14363688595231255
(10000, 2150) 0.15028339448301672
(10000, 1000) 0.16632894827428377
(10000, 4640) 0.2231837700834356
(210, 460) 0.255318340196389
(4640, 1000) 0.25760193383976254
(4640, 2150) 0.2705542398602077
(100, 460) 0.31250651342137
(10, 40) 0.4000095251702624
(10, 100) 0.4286492374727669
(10000, 10000) 0.4302058722478829
(4640, 460) 0.45040866172290944
(40, 40) 0.4512553865343217
(4640, 4640) 0.4970881378182798
(20, 210) 0.5005114838637351
(2150, 1000) 0.5133361242621924
(40, 460) 0.5368546912364608
(2150, 460) 0.558764971616147
(2150, 2150) 0.5784793357915168
(10000, 460) 0.581362587353498
(10000, 100) 0.5949108134957792
(210, 40) 0.6667923929010998
(10000, 20) 0.6785468187734477
(10, 210) 0.7140718562874252
(10000, 210) 0.7170820644256588
(1000, 1000) 0.7432485593638343
(1000, 460) 0.7585017085919287
(1000, 210) 0.7757380344363141
(10, 10) 0.7778704144830871
(2150, 210) 0.7843657658086848
(10, 20) 0.8006673021925643
(4640, 210) 0.8494370390852348
(1000, 100) 0.8531986590850484
(460, 210) 0.8709470864100771
(100, 100) 0.8805924514349099
(460, 100) 0.8902193173565722
(2150, 40) 0.9091096867046015
(210, 210) 0.9341971234160453
(210, 20) 0.9403313542148114
(460, 1000) 0.9470091699123392
(2150, 100) 0.9545480559375227
(210, 100) 0.9857112717274422
(460, 40) 0.9873451308755288
(460, 460) 1.0485398049989825
(4640, 20) 1.2996863617968435
(10000, 10) 1.3380504252428163
(210, 1000) 1.5976384259180922
(4640, 10) 1.7245353268266128
(20, 460) 1.796624385901847
(1000, 2150) 1.9342903545181518
(20, 1000) 1.941229856102695
(100, 1000) 1.9552005859013921
(4640, 100) 2.1782621749262523
(1000, 40) 2.1999714160186747
(4640, 10000) 2.3147969893025975
(2150, 4640) 2.3618363123805866
(20, 20) 2.4138472027139137
(4640, 40) 2.60901399115585
(10, 460) 2.857677647699428
(1000, 20) 2.9141486320947325
(460, 20) 2.91976943597561
(2150, 20) 3.1247816593886464
(40, 20) 3.5111649033100334
(460, 10) 3.599615732727821
(2150, 10) 3.7331004366812226
(10000, 40) 3.86277229757501
(100, 20) 4.333788612721353
(100, 10) 4.8001810385898045
(460, 2150) 4.908660288538512
(100, 40) 6.700262029537875
(1000, 10) 7.000102089430341
(20, 10) 8.447050725405061
(40, 1000) 9.06677668515431
(210, 2150) 9.524194969384608
(1000, 4640) 10.299849220810476
(2150, 10000) 14.012837424083301
(210, 10) 14.372849062220899
(100, 2150) 14.394944140638104
(40, 2150) 16.02257964738633
(20, 2150) 17.638397808561795
(40, 10) 18.84263375138911
(10, 1000) 33.761422529338176
(460, 4640) 35.79277485135849
(1000, 10000) 57.46434908415292
(210, 4640) 75.173407932718
(100, 4640) 116.85089251715708
(40, 4640) 155.7330396113295
(10, 2150) 185.7089703371322
(460, 10000) 216.22483466959886
(210, 10000) 518.3531549170498
(20, 4640) 581.4435082022284
(100, 10000) 895.9513969760258
(10, 4640) 1159.891123696607
(40, 10000) 1430.8166473045646
(20, 10000) 1482.6329039045597
(10, 10000) 8128.295908847292

Copy link
Member
@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's best to include the raw numbers in a gist, so others can look and make graphs themselves to analyze the results. You can also include benchmark scripts in a gist, etc.

runtime ratios (sorted)

Assuming the ratio is eigh runtime / svd runtime, I think absolute values matter here. 0.1 sec vs 0.2 sec is very different from 10 secs vs 20 seconds.

What does the runtime performance look like for eigh when svd spends >= 1 seconds for the same data?

@@ -547,9 +561,22 @@ def g(x, fun_args):
XT -= X_mean[:, np.newaxis]

# Whitening and preprocessing by PCA
u, d, _ = linalg.svd(XT, full_matrices=False, check_finite=False)
if self.svd_solver == "eigh":
D, u = linalg.eigh(X.T.dot(X)) # Faster when n < p
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overhead of eigh comes from the matrix multiply and calling the eigh solver. From your benchmarks, it may only be better for n_samples >> n_features.

Also there should be memory overhead for storing the result of the matrix multiply.

@Micky774
Copy link
Contributor Author

It's best to include the raw numbers in a gist, so others can look and make graphs themselves to analyze the results. You can also include benchmark scripts in a gist, etc.

Here's the gist for the benchmarks -- it includes the generating script and CSV of the results.

What does the runtime performance look like for eigh when svd spends >= 1 seconds for the same data?

Info on runs where svd time >=1 second
shape svd eigh eigh/svd
(1000, 10000) 1.637315273284912 98.84875655174255 60.37246348617045
(2150, 2150) 3.372101068496704 1.8216583728790283 0.5402146424072428
(2150, 4640) 5.041795492172241 12.133899450302124 2.4066623624740227
(2150, 10000) 7.4768290519714355 102.90255951881409 13.762861074332239
(4640, 1000) 2.192542791366577 0.552424430847168 0.2519560544142678
(4640, 2150) 8.363417148590088 2.205007076263428 0.2636490607951022
(4640, 4640) 32.12346792221069 14.978487730026243 0.4662786647536853
(4640, 10000) 45.00250458717346 104.07629084587096 2.312677745396744
(10000, 1000) 5.492422819137573 0.8891241550445557 0.16188195707485006
(10000, 2150) 20.327980518341064 2.90564489364624 0.14293819747734415
(10000, 4640) 79.02881979942322 16.81713080406189 0.21279744334717532
(10000, 10000) 279.7206346988678 121.5934534072876 0.4346960442807109

Copy link
Member
@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the benchmarks, it looks like the only advantage of eigh is for n_samples >> n_features. For future reference, it's good to do order of magnetite jumps when bench marking. AKA 100, 1_000, 10_000, etc. Comparing 10 and 20 is not as interesting.

What is the memory usage like for (10000, 1000) between the solvers?

@Micky774
Copy link
Contributor Author
Micky774 commented Feb 27, 2022

What is the memory usage like for (10000, 1000) between the solvers?

I tried to stay consistent with how it's calculated in FastICA since it internally transposes the data.

image

More coarse benchmark results can be found here

@Micky774 Micky774 changed the title [WIP] No longer use SVD for fastica [WIP] Add eigh solver to FastICA Mar 3, 2022
Copy link
Member
@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI failures look related

67E6
Comment on lines 596 to 597
# Give consistent eigenvectors for both svd solvers
u *= np.sign(u[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it seem a bit too much to give control to this flipping through a parameter?

This is feels like a major change to me, so I am okay with a parameter to control this.

Although, the sign flipping logic is a little different from svd_flip. This implementation forces the first row to the positive, which works for all seeds in the test_fastica_simple_different_solvers test. We tried the sign flipping logic in svd_flip, but it did not work for some seeds in test_fastica_simple_different_solvers.

@Micky774 Can you investigate why svd_flip like logic did not work? (I have a snippet here: #22527 (comment))

@glemaitre
Copy link
Member

This is feels like a major change to me, so I am okay with a parameter to control this.

OK let's go for adding a new parameter. We can always revisit this choice in the future if we wish to change the behaviour and deprecate this parameter.

Copy link
Member
@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM.

@glemaitre glemaitre self-requested a review June 13, 2022 17:58
@glemaitre glemaitre merged commit 54c1503 into scikit-learn:main Jun 14, 2022
@glemaitre
Copy link
Member

LGTM. Thanks @Micky774

@Micky774 Micky774 deleted the change_svd branch June 14, 2022 13:27
ogrisel pushed a commit to ogrisel/scikit-learn that referenced this pull request Jul 11, 2022
Co-authored-by: Pierre Ablin <pierreablin@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0