-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
ENH use more blas functions in cd solvers #22972
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
443eff9
to
548cafc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm surprised that blas level3 does not outperform double loop blas level 1.
I see in the comment that you're creating a new array each time which is sub-optimal.
Benchmark with the scipt in the top post based on 7623065. Speedup is main / PR.
Timings of main branch
|
Ah I missed the fact that 1 matrix is (n_targets, n_samples). It's true that gemm can't reach peak performance with small matrices or really not squared matrices. Still the benchmark shows that it's almost always better or similar. The only really bad case is small n_samples, small n_features, large n_targets. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I run the same benchmarks with a 16 core machine, I get runtime regressions. For reference, here is my sklearn.show_versions()
for threadpoolctl:
threadpoolctl info
threadpoolctl info:
user_api: blas
internal_api: openblas
prefix: libopenblas
filepath: /home/thomasfan/mambaforge/envs/sk2/lib/python3.9/site-packages/numpy.libs/libopenblas64_p-r0-2f7c42d4.3.18.so
version: 0.3.18
threading_layer: pthreads
architecture: Zen
num_threads: 32
user_api: blas
internal_api: openblas
prefix: libopenblas
filepath: /home/thomasfan/mambaforge/envs/sk2/lib/python3.9/site-packages/scipy.libs/libopenblasp-r0-8b9e111f.3.17.so
version: 0.3.17
threading_layer: pthreads
architecture: Zen
num_threads: 32
user_api: openmp
internal_api: openmp
prefix: libgomp
filepath: /home/thomasfan/mambaforge/envs/sk2/lib/libgomp.so.1.0.0
version: None
num_threads: 32
@lorentzenchr What does sklearn.show_versions()
look like for you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you @lorentzenchr.
Can you provide a benchmark showing that there is no regression in performance?
sklearn/linear_model/_cd_fast.pyx
Outdated
# Using numpy: | ||
# XtA = np.dot(R.T, X) - l2_reg * W | ||
# Using BLAS Level 3: | ||
# XtA = np.dot(R.T, X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Using numpy: | |
# XtA = np.dot(R.T, X) - l2_reg * W | |
# Using BLAS Level 3: | |
# XtA = np.dot(R.T, X) | |
# Using numpy: | |
# | |
# XtA = np.dot(R.T, X) - l2_reg * W | |
# | |
# Using BLAS Level 3: | |
# | |
# XtA = np.dot(R.T, X) | |
# |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just followed the style that is already present in the file. How about doing the indentation but without the empty comment lines?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like keeping some local consistency but the indentation seems reasonable
See #22972 (comment). TLDR: It's a trade-off. |
I ran the same benchmarks and the results are more mixed: (16 cores) PR Results
main results
|
@thomasjpfan Thanks for your benchmark. For me, it's fine to close this PR and also close #13210. I could open a new PR that just adds a comment in the code that |
I'm okay with closing this PR and the issue as well. What do you think @jeremiedbb ? |
I'm fine with that. Thanks for the investigation @lorentzenchr |
Reference Issues/PRs
Closes #13210.
What doe 8000 s this implement/fix? Explain your changes.
This PR tries to use more BLAS functions in the coordinate descent solvers in Cython.
Any other comments?
I did some benchmarking for different options for the computation of
XtA
inenet_coordinate_descent_multi_task
without clear conclusion. Therefore, I let it be as is with more comments added.