8000 [MRG] Change the default of precompute from "auto" to False by MechCoder · Pull Request #3249 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] Change the default of precompute from "auto" to False #3249

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 4, 2014

Conversation

MechCoder
Copy link
Member

The Gram variant has been found to be slower, then the normal update rules. See the discussion in #3220 . So the default is changed from "auto" to False.

@MechCoder MechCoder changed the title Change the default of precompute from "auto" to False [MRG] Change the default of precompute from "auto" to False Jun 5, 2014
@ogrisel
Copy link
Member
ogrisel commented Jun 5, 2014

Based on the benchmarks done in the parent PR it appears that with the current state of the code base, the precomputation of the Gram matrix can never be used to accelerate the solver of Lasso and ElasticNet model, even when n_samples >> n_features (with noise or not).

The precompute="auto" option is checked by the sklearn.linear_model.base._pre_fit to set precompute=True if n_samples > n_features and precompute=False otherwise.

As the _pre_fit helper is also used by the OrthogonalMatchingPursuit model, we should also benchmark whether precompute=False is always faster for that model as well. If so, it means that the auto mode is useless. We should then properly deprecate the support for precompute="auto" for all those model and switch to precompute=True as the new default. All the docstring should be updated. The _pre_fit function should also be simplified to remove this useless heuristic test.

If precompute="auto" is actually useful for OrthogonalMatchingPursuit, then we should only deprecate it's use in the ElasticNet and Lasso models but not for OrthogonalMatchingPursuit.

@ogrisel
Copy link
Member
ogrisel commented Jun 5, 2014

BTW, I don't think this specific PR is a priority for the GSoC but of course you can decide about that with your mentors.

@agramfort
Copy link
Member

and test with many y are provided so the gram is computed once for many independent targets

+1 for not being a priority for the GSOC. It needs to be addressed but no rush.

the log reg is more important now.

@MechCoder
Copy link
Member Author

The log reg CV is almost done. I just need to add a few more tests.

@ogrisel ogrisel changed the title [MRG] Change the default of precompute from "auto" to False [WIP] Change the default of precompute from "auto" to False Jun 9, 2014
@MechCoder
Copy link
Member Author

@ogrisel @agramfort I am facing some issues which is blocking this PR, #3719 . It slows down the setting stopping=dual_gap. I suggest we just set precompute=False as the new default for ElasticNet models right now. WDYT?

@agramfort
Copy link
Member

what are you suggesting? we merge this and have a regression in master
before you fix it????

@agramfort
Copy link
Member

this PR can stay as it is. No rush. You can integrate the changes one you've made enough progress with the other PR.

@MechCoder
Copy link
Member Author

@agramfort I'm sorry for not bring clear.

The point is that, the benchmarks that I have reported in the description of the #3719 are when I set precompute=False and the speedup is not that promising when I set precompute=True or precompute='auto'. Are you suggesting we raise an error when precompute=auto and stopping=objective by meaning integrate the changes?

To summarize I do not want to use the Gram trick, when precompute is set to "auto" and stopping=objective , because there is no observable speedup

@agramfort
Copy link
Member

glmnet does the gram trick so how does it do it? it says it precomputes the
dot products between features.

@MechCoder
Copy link
Member Author

Are you referring to these papers?

http://web.stanford.edu/~hastie/glmnet/glmnet_alpha.html
http://cran.r-project.org/web/packages/glmnet/glmnet.pdf

I did a quick skim, and a Ctrl + F for dot , compute but nothing seemed to surface :( .

Anyhow I am not asking to remove precompute=True which I understand would be useful, if the dot product of np.dot(X, y)has to be calculated repeatedly, but just remove precompute="auto" which does not seem to be useful according to the benchmarks in the parent PR (this comment #3249 (comment) by @ogrisel )and actually slows down the implementation in the new PR.

@agramfort
Copy link
Member

@MechCoder
Copy link
Member Author

Sorry for being a noob, but it states that the inner products are being cached in the form of a covariance matrix (if at all done), instead of doing a run across n_features when possible.

Don't the intelligent updates in the existing cd_code take care of the "naivety" that they say? For instance, I cannot see any run across n_features except for the external run.

@MechCoder
Copy link
Member Author

@agramfort It does seem that precompute=True is useful if done along a path since the huge dot products are not computed multiple times. That is the reason for this comment by @ogrisel #3220 (comment) where precompute=True in the CV models seem to be faster. However, if not done in a path, the computation of the dot product seems to outweigh the benefits.

>>> X, y = make_regression(n_samples=10000, n_features=500)
>>> %timeit ElasticNetCV(precompute=False).fit(X, y)
1 loops, best of 3: 10.4 s per loop
>>> %timeit ElasticNetCV(precompute=True).fit(X, y)
1 loops, best of 3: 1.74 s per loop
>>> %timeit ElasticNet(precompute=True).fit(X, y)
1 loops, best of 3: 308 ms per loop
>>> %timeit ElasticNet(precompute=False).fit(X, y)
10 loops, best of 3: 106 ms per loop

So I suggest we could just remove the "auto" option for ENet models and keep it for the CV models, since it is hugely beneficial.

@MechCoder
Copy link
Member Author

In other words, This comment by you, #3220 (comment)

@agramfort
Copy link
Member

fine with setting it to False in Lasso and ElasticNet and only auto in CV classes.

make sure it does not affect performance of dict learning classes.

@MechCoder MechCoder force-pushed the change_auto branch 3 times, most recently from d3c82f6 to 0667e44 Compare October 3, 2014 15:00
@MechCoder
Copy link
Member Author

@agramfort @ogrisel Can you review this now? Have I done the deprecation right?

@MechCoder MechCoder changed the title [WIP] Change the default of precompute from "auto" to False [MRG] Change the default of precompute from "auto" to False Oct 3, 2014
@MechCoder MechCoder force-pushed the change_auto branch 3 times, most recently from 995640d to 9904a04 Compare October 3, 2014 15:19
8000
Setting precompute to "auto" was found to be slower when n_samples > n_features
since the computation of the Gram matrix is computationally expensive and
outweighs the benefit of fitting the Gram for just one alpha.
@agramfort
Copy link
Member

any test that should be updated too? I suspect some now raise warnings.

@MechCoder
Copy link
Member Author

@agramfort It does seem that precompute="auto", was never tested explicitly, since it was the default before. I've added a test to test for Deprecation warnings in the last commit. Let me know if its ok.

@coveralls
Copy link

Coverage Status

Coverage increased (+0.0%) when pulling 8577e79 on MechCoder:change_auto into 8357f17 on scikit-learn:master.

y = np.asarray(y, dtype=np.float64)

if self.precompute == 'auto':
warnings.warn("Setting precompute to 'auto', has found to be "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has -> was

@agramfort
Copy link
Member

can you also have a look at OrthogonalMatchingPursuit as suggested by @ogrisel ?

@MechCoder
Copy link
Member Author

I will, but it is better to keep both PR's independent, don't you think?

@agramfort
Copy link
Member

Do it here please

@MechCoder
Copy link
Member Author

I played with this examples for different values of n_samples , n_features and n_targets . I confirm that unlike ElasticNet for a greater value of n_targets, there is a huge speed gain for omp when precompute is set to "auto"

@MechCoder
Copy link
Member Author

@agramfort
Copy link
Member

can you share the updated graphs from the lasso and omp bencharks:
bench_lasso.py bench_plot_lasso_path.py etc.

@MechCoder
Copy link
Member Author

This is the updated graph for bench_lasso.py.
scikit-learn_lasso_benchmark_results

bench_plot_lasso_path.py
scikit-learn_lasso_path_benchmark_results

Any other qs? I haven't altered anything else.

@MechCoder
Copy link
Member Author

I ran a quick plot using make_regression with n_features=10 . Even when n_features=100 i.e n_samples >> n_features , one can see that the Gram trick is slower.

precompute

Hope this finally convinces everyone.

@@ -1207,6 +1217,7 @@ def fit(self, X, y):
model.alpha = best_alpha
model.l1_ratio = best_l1_ratio
model.copy_X = copy_X
model.precompute = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not model.precompute = self.precompute?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the last fit.

  1. There might be a case where self.precompute="auto", and we have deprecated it.
  2. If self.precompute=True and we set model.precompute=True, I think we might be going against our principle of believing that computing the Gram, is useless for doing a single fit.

This is basically to make the last fit as fast as possible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok got it.

@agramfort agramfort merged commit 7cb3ba1 into scikit-learn:master Oct 4, 2014
@MechCoder MechCoder deleted the change_auto branch October 4, 2014 19:38
@agramfort
Copy link
Member

merged by rebase. Thanks !

@MechCoder
Copy link
Member Author

this PR next? ;) #3719

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0