8000 ENH use more blas functions in cd solvers by lorentzenchr · Pull Request #22972 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

ENH use more blas functions in cd solvers #22972

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

lorentzenchr
Copy link
Member
@lorentzenchr lorentzenchr commented Mar 28, 2022

Reference Issues/PRs

Closes #13210.

What doe 8000 s this implement/fix? Explain your changes.

This PR tries to use more BLAS functions in the coordinate descent solvers in Cython.

Any other comments?

I did some benchmarking for different options for the computation of XtA in enet_coordinate_descent_multi_task without clear conclusion. Therefore, I let it be as is with more comments added.

from collections import OrderedDict
from itertools import product
import time

from neurtu import delayed, timeit, Benchmark
import numpy as np
import pandas as pd

from sklearn.datasets import make_regression
from sklearn.linear_model import MultiTaskLasso


alpha = 0.01
list_n_features = [300, 1000, 4000]
list_n_samples = [100, 500]
list_n_tasks = [2, 10, 20, 50]

X, Y, coef_ = make_regression(
    n_samples=max(list_n_samples),
    n_features=max(list_n_features),
    n_targets=max(list_n_tasks),
    n_informative=max(list_n_features) // 10,
    noise=0.1,
    coef=True
)

X /= np.sqrt(np.sum(X ** 2, axis=0))  # Normalize data


def benchmark_cases(X, Y):
    """Benchmark MultiTaskLasso"""
    for it, (ns, nf, nt) in enumerate(product(list_n_samples, list_n_features, list_n_tasks)):
        tags = OrderedDict(n_samples=ns, n_features=nf, n_targets=nt)
        clf = MultiTaskLasso(alpha=alpha, fit_intercept=False, max_iter=10_000)
        yield delayed(clf.fit, tags=tags)(X[:ns, :nf], Y[:ns, :nt])

bench = Benchmark(repeat=5)
print("Run benchmark for multi target lasso.")
start = time.time()
result = bench(benchmark_cases(X, Y))
end =  time.time()
print(f"Finished after {end - start} seconds.")
print(result)

Copy link
Member
@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised that blas level3 does not outperform double loop blas level 1.
I see in the comment that you're creating a new array each time which is sub-optimal.

@lorentzenchr
Copy link
Member Author

Benchmark with the scipt in the top post based on 7623065. Speedup is main / PR.

PR with GEMM
                                 wall_time                      
                                      mean         max       std  speedup
n_samples n_features n_targets                                  
100       300        2            0.145484    0.147704  0.001475  1.08
                     10           0.677363    0.684821  0.010156  0.955
                     20           1.496502    1.514316  0.020742  0.927
                     50           5.971845    6.008925  0.038615  0.594
          1000       2            0.497537    0.509534  0.008607  1.27
                     10           2.488258    2.514491  0.017274  1.58
                     20           6.157018    6.181837  0.023664  1.31
                     50          20.179273   20.565162  0.235771  1.06
          4000       2            2.141750    2.175920  0.032265  1.36
                     10          14.191492   14.516755  0.186170  1.06
                     20          26.280743   26.790688  0.296696  1.13
                     50          71.058634   71.500954  0.315243  1.09
500       300        2            0.008196    0.008323  0.000087  1.26
                     10           0.041458    0.042422  0.000625  1.24
                     20           0.086655    0.095563  0.004997  1.24
                     50           0.238273    0.262718  0.015520  1.34
          1000       2            1.162102    1.173390  0.010608  1.11
                     10           5.397492    5.425889  0.021651  1.16
                     20          12.493163   12.701867  0.122206  1.18
                     50          35.322965   36.820687  1.244409  1.57
          4000       2            4.084729    4.101981  0.012868  1.03
                     10          21.156687   21.183405  0.017483  0.86
                     20          44.667367   44.732523  0.059359  0.97
                     50         143.688414  147.081807  4.331627  1.07

Timings of main branch

MAIN
                                 wall_time                      
                                      mean         max       std
n_samples n_features n_targets                                  
100       300        2            0.156672    0.186633  0.017599
                     10           0.647087    0.729176  0.049114
                     20           1.387310    1.493217  0.072670
                     50           3.546491    3.717894  0.106801
          1000       2            0.634045    0.688953  0.040473
                     10           3.937150    4.127279  0.126444
                     20           8.093542    8.752413  0.439550
                     50          21.394161   22.490225  0.808843
          4000       2            2.920357    3.397627  0.286309
                     10          14.978383   16.277204  0.920301
                     20          29.779723   31.012935  1.112678
                     50          77.519848   78.394667  1.187355
500       300        2            0.010294    0.012238  0.001227
                     10           0.051599    0.053901  0.001610
                     20           0.107360    0.116611  0.008979
                     50           0.318203    0.358433  0.035147
          1000       2            1.293451    1.574638  0.202206
                     10           6.266583    7.150905  0.760419
                     20          14.796782   16.219736  0.969471
                     50          55.599680   59.203180  3.049441
          4000       2            4.200956    4.325626  0.127612
                     10          18.290157   18.776878  0.340980
                     20          43.271756   46.699309  2.054371
                     50         153.149346  158.434174  3.035532

@jeremiedbb
Copy link
Member

Ah I missed the fact that 1 matrix is (n_targets, n_samples). It's true that gemm can't reach peak performance with small matrices or really not squared matrices.

Still the benchmark shows that it's almost always better or similar. The only really bad case is small n_samples, small n_features, large n_targets.

Copy link
Member
@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member
@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I run the same benchmarks with a 16 core machine, I get runtime regressions. For reference, here is my sklearn.show_versions() for threadpoolctl:

threadpoolctl info
threadpoolctl info:
       user_api: blas
   internal_api: openblas
         prefix: libopenblas
       filepath: /home/thomasfan/mambaforge/envs/sk2/lib/python3.9/site-packages/numpy.libs/libopenblas64_p-r0-2f7c42d4.3.18.so
        version: 0.3.18
threading_layer: pthreads
   architecture: Zen
    num_threads: 32

       user_api: blas
   internal_api: openblas
         prefix: libopenblas
       filepath: /home/thomasfan/mambaforge/envs/sk2/lib/python3.9/site-packages/scipy.libs/libopenblasp-r0-8b9e111f.3.17.so
        version: 0.3.17
threading_layer: pthreads
   architecture: Zen
    num_threads: 32

       user_api: openmp
   internal_api: openmp
         prefix: libgomp
       filepath: /home/thomasfan/mambaforge/envs/sk2/lib/libgomp.so.1.0.0
        version: None
    num_threads: 32

@lorentzenchr What does sklearn.show_versions() look like for you?

@lorentzenchr
Copy link
Member Author
< 8000 td class="d-block comment-body markdown-body js-comment-body">

What does sklearn.show_versions() look like for you?

import sklearn
sklearn.show_versions()
System:
    python: 3.9.9 (main, Nov 21 2021, 03:23:44)  [Clang 13.0.0 (clang-1300.0.29.3)]
executable: /Users/lorentzen/github/python3_sklearn/bin/python
   machine: macOS-12.3-x86_64-i386-64bit

Python dependencies:
      sklearn: 1.1.dev0
          pip: 22.0.4
   setuptools: 58.1.0
        numpy: 1.21.4
        scipy: 1.7.3
       Cython: 0.29.24
       pandas: 1.3.4
   matplotlib: 3.5.0
       joblib: 1.1.0
threadpoolctl: 3.1.0

Built with OpenMP: True

threadpoolctl info:
       user_api: openmp
   internal_api: openmp
         prefix: libomp
       filepath: /usr/local/Cellar/libomp/13.0.0/lib/libomp.dylib
        version: None
    num_threads: 8

       user_api: blas
   internal_api: openblas
         prefix: libopenblas
       filepath: /Users/lorentzen/github/python3_sklearn/lib/python3.9/site-packages/numpy/.dylibs/libopenblas.0.dylib
        version: 0.3.17
threading_layer: pthreads
   architecture: Haswell
    num_threads: 4

       user_api: blas
   internal_api: openblas
         prefix: libopenblas
       filepath: /Users/lorentzen/github/python3_sklearn/lib/python3.9/site-packages/scipy/.dylibs/libopenblas.0.dylib
        version: 0.3.17
threading_layer: pthreads
   architecture: Haswell
    num_threads: 4

Copy link
Member
@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you @lorentzenchr.

Can you provide a benchmark showing that there is no regression in performance?

Comment on lines 864 to 867
# Using numpy:
# XtA = np.dot(R.T, X) - l2_reg * W
# Using BLAS Level 3:
# XtA = np.dot(R.T, X)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Using numpy:
# XtA = np.dot(R.T, X) - l2_reg * W
# Using BLAS Level 3:
# XtA = np.dot(R.T, X)
# Using numpy:
#
# XtA = np.dot(R.T, X) - l2_reg * W
#
# Using BLAS Level 3:
#
# XtA = np.dot(R.T, X)
#

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just followed the style that is already present in the file. How about doing the indentation but without the empty comment lines?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like keeping some local consistency but the indentation seems reasonable

@lorentzenchr
Copy link
Member Author

Can you provide a benchmark showing that there is no regression in performance?

See #22972 (comment). TLDR: It's a trade-off.

@thomasjpfan
Copy link
Member

I ran the same benchmarks and the results are more mixed: (16 cores)

PR Results

                                 wall_time                                 
                                      mean         max        std   speedup
n_samples n_features n_targets                                             
100       300        2            0.126477    0.133409   0.005072  0.968208
                     10           0.406969    0.424240   0.018226  1.415034
                     20           0.869521    0.950086   0.050928  1.151000
                     50           2.727007    2.968354   0.241424  1.231588
          1000       2            0.387590    0.405099   0.013158  0.893656
                     10           2.045166    2.203917   0.134049  1.081720
                     20           3.990948    4.475209   0.309848  0.979089
                     50          14.276450   15.779340   1.426308  0.848775
          4000       2            1.846177    1.922902   0.046784  0.886903
                     10           7.613850    7.941037   0.298287  1.226365
                     20          19.718723   21.544081   1.389188  1.031470
                     50          50.751171   54.299573   3.520457  0.985871
500       300        2            0.009481    0.010009   0.000330  0.758570
                     10           0.049852    0.057965   0.005843  0.692570
                     20           0.099282    0.135009   0.020148  0.779940
                     50           0.254732    0.276009   0.016889  1.022278
          1000       2            1.319801    1.406223   0.065944  0.524015
                     10           5.737427    6.472419   0.589744  0.812979
                     20           9.116146   10.666551   0.873925  1.382002
                     50          25.771272   30.471168   2.637147  1.408798
          4000       2            1.815238    1.941360   0.077922  1.230866
                     10          13.432530   13.722961   0.237632  1.123602
                     20          34.402257   38.986120   3.829004  1.150028
                     50         105.409047  124.545587  10.747810  1.207296

main results
                                 wall_time
                                      mean         max        std
n_samples n_features n_targets
100       300        2            0.122456    0.129250   0.006159
                     10           0.575875    0.619054   0.050962
                     20           1.000819    1.057568   0.060326
                     50           3.358550    3.547139   0.196966
          1000       2            0.346372    0.355733   0.009064
                     10           2.212296    2.392588   0.115992
                     20           3.907493    4.185451   0.211641
                     50          12.117488   12.713651   0.406823
          4000       2            1.637380    1.727820   0.112734
                     10           9.337362   10.118820   0.705827
                     20          20.339279   21.667922   0.951616
                     50          50.034083   54.822024   3.693212
500       300        2            0.007192    0.007889   0.000589
                     10           0.034526    0.037122   0.002169
                     20           0.077434    0.083015   0.004440
                     50           0.260407    0.310219   0.029956
          1000       2            0.691596    0.742283   0.054577
                     10           4.664405    5.156468   0.446011
                     20          12.598534   14.264985   1.648169
                     50          36.306519   43.336331   4.830675
          4000       2            2.234314    2.501939   0.273651
                     10          15.092820   17.050780   1.122793
                     20          39.563568   43
6D40
.195534   5.673444                                                                                                    [0/1980]
                     50         127.259915  139.570600  11.125536

@lorentzenchr
Copy link
Member Author

@thomasjpfan Thanks for your benchmark. For me, it's fine to close this PR and also close #13210.

I could open a new PR that just adds a comment in the code that _gemm is not good here for small number of tasks.

@thomasjpfan
Copy link
Member

I'm okay with closing this PR and the issue as well. What do you think @jeremiedbb ?

@jeremiedbb
Copy link
Member

I'm okay with closing this PR and the issue as well. What do you think @jeremiedbb

I'm fine with that. Thanks for the investigation @lorentzenchr

@jeremiedbb jeremiedbb closed this May 16, 2022
@lorentzenchr lorentzenchr deleted the cd_fast_with_blas branch May 16, 2022 14:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

PERF: use higher level BLAS functions
4 participants
0