8000 [MRG] MNT: Use `nrm2` to find the residuals squared by jakirkham · Pull Request #11923 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] MNT: Use nrm2 to find the residuals squared #11923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 31, 2018

Conversation

jakirkham
Copy link
Contributor

Using BLAS's nrm2 is a bit faster than squaring the residuals in-place and summing them. So switch to using nrm2 instead. Interestingly it doesn't appear necessary to flatten the array first as the BLAS function interprets the array as flat under the hood.

@jakirkham jakirkham force-pushed the update_dict_r2_nrm2 branch from 51f75f7 to b988008 Compare August 27, 2018 18:23 8000
@jakirkham
Copy link
Contributor Author

Quick benchmark below for comparison.

In [1]: import numpy as np

In [2]: from scipy import linalg

In [3]: a = 2 * np.random.random((100, 110)) - 1

In [4]: %timeit b = a.copy();
4.66 µs ± 181 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

In [5]: %timeit b = a.copy(); b **= 2; b.sum()
19.5 µs ± 371 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

In [6]: nrm2, = linalg.get_blas_funcs(('nrm2',), (a,))

In [7]: %timeit nrm2(a) ** 2
10.7 µs ± 66.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

@jakirkham jakirkham force-pushed the update_dict_r2_nrm2 branch from b988008 to 78133b0 Compare August 27, 2018 19:07
Using BLAS's `nrm2` is a bit faster than squaring the residuals in-place
and summing them. So switch to using `nrm2` instead. Interestingly it
doesn't appear necessary to flatten the array first as the BLAS function
interprets the array as flat under the hood.
@jakirkham jakirkham changed the title MNT: Use nrm2 to find the residuals squared [MRG] MNT: Use nrm2 to find the residuals squared Aug 27, 2018
Copy link
Member
@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gains here seem so tiny for something that's only called once per iteration (for a default 100 iterations) amidst much other logic. Why bother making this change?

@jnothman
Copy link
Member

And which is more readable?

@jakirkham
Copy link
Contributor Author

We already use nrm2 in the loop where it helps. So this seemed like a reasonable and nice change to me. Leave it up to you though.

@jnothman jnothman merged commit e00817d into scikit-learn:master Aug 31, 2018
@jakirkham jakirkham deleted the update_dict_r2_nrm2 branch August 31, 2018 20:59
jnothman pushed a commit to jnothman/scikit-learn that referenced this pull request Sep 2, 2018
jnothman pushed a commit to jnothman/scikit-learn that referenced this pull request Sep 17, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0