8000 [MRG] Use Scipy cython BLAS API instead of bundled CBLAS (#12732) · scikit-learn/scikit-learn@d0f63a7 · GitHub
[go: up one dir, main page]

Skip to content

Commit d0f63a7

Browse files
jeremiedbbogrisel
authored andcommitted
[MRG] Use Scipy cython BLAS API instead of bundled CBLAS (#12732)
1 parent 3df0c19 commit d0f63a7

File tree

7 files changed

+432
-19
lines changed

7 files changed

+432
-19
lines changed

doc/whats_new/v0.21.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ Support for Python 3.4 and below has been officially dropped.
176176
and now it returns NaN and raises :class:`exceptions.UndefinedMetricWarning`.
177177
:issue:`12855` by :user:`Pawel Sendyk <psendyk>.`
178178

179+
- |Efficiency| The pairwise manhattan distances with sparse input now uses the
180+
BLAS shipped with scipy instead of the bundled BLAS. :issue:`12732` by
181+
:user:`Jérémie du Boisberranger <jeremiedbb>`
182+
179183
:mod:`sklearn.model_selection`
180184
..............................
181185

sklearn/metrics/pairwise_fast.pyx

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@ import numpy as np
1212
cimport numpy as np
1313
from cython cimport floating
1414

15-
16-
cdef extern from "cblas.h":
17-
double cblas_dasum(int, const double *, int) nogil
15+
from ..utils._cython_blas cimport _asum
1816

1917

2018
np.import_array()
@@ -67,4 +65,4 @@ def _sparse_manhattan(floating[::1] X_data, int[:] X_indices, int[:] X_indptr,
6765
for j in range(Y_indptr[iy], Y_indptr[iy + 1]):
6866
row[Y_indices[j]] -= Y_data[j]
6967

70-
D[ix, iy] = cblas_dasum(n_features, &row[0], 1)
68+
D[ix, iy] = _asum(n_features, &row[0], 1)

sklearn/metrics/setup.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,26 @@
11
import os
2-
import os.path
32

4-
import numpy
53
from numpy.distutils.misc_util import Configuration
64

7-
from sklearn._build_utils import get_blas_info
8-
95

106
def configuration(parent_package="", top_path=None):
117
config = Configuration("metrics", parent_package, top_path)
128

13-
cblas_libs, blas_info = get_blas_info()
9+
libraries = []
1410
if os.name == 'posix':
15-
cblas_libs.append('m')
11+
libraries.append('m')
1612

1713
config.add_subpackage('cluster')
14+
1815
config.add_extension("pairwise_fast",
1916
sources=["pairwise_fast.pyx"],
20-
include_dirs=[os.path.join('..', 'src', 'cblas'),
21-
numpy.get_include(),
22-
blas_info.pop('include_dirs', [])],
23-
libraries=cblas_libs,
24-
extra_compile_args=blas_info.pop('extra_compile_args',
25-
[]),
26-
**blas_info)
17+
libraries=libraries)
18+
2719
config.add_subpackage('tests')
2820

2921
return config
3022

23+
3124
if __name__ == "__main__":
3225
from numpy.distutils.core import setup
3326
setup(**configuration().todict())

sklearn/utils/_cython_blas.pxd

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# cython: language_level=3
2+
3+
from cython cimport floating
4+
5+
6+
cpdef enum BLAS_Order:
7+
RowMajor # C contiguous
8+
ColMajor # Fortran contiguous
9+
10+
11+
cpdef enum BLAS_Trans:
12+
NoTrans = 110 # correspond to 'n'
13+
Trans = 116 # correspond to 't'
14+
15+
16+
# BLAS Level 1 ################################################################
17+
cdef floating _dot(int, floating*, int, floating*, int) nogil
18+
19+
cdef floating _asum(int, floating*, int) nogil
20+
21+
cdef void _axpy(int, floating, floating*, int, floating*, int) nogil
22+
23+
cdef floating _nrm2(int, floating*, int) nogil
24+
25+
cdef void _copy(int, floating*, int, floating*, int) nogil
26+
27+
cdef void _scal(int, floating, floating*, int) nogil
28+
29+
# BLAS Level 2 ################################################################
30+
cdef void _gemv(BLAS_Order, BLAS_Trans, int, int, floating, floating*, int,
31+
floating*, int, floating, floating*, int) nogil
32+
33+
cdef void _ger(BLAS_Order, int, int, floating, floating*, int, floating*, int,
34+
floating*, int) nogil
35+
36+
# BLASLevel 3 ################################################################
37+
cdef void _gemm(BLAS_Order, BLAS_Trans, BLAS_Trans, int, int, int, floating,
38+
floating*, int, floating*, int, floating, floating*,
39+
int) nogil

sklearn/utils/_cython_blas.pyx

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
from cython cimport floating
2+
3+
from scipy.linalg.cython_blas cimport sdot, ddot
4+
from scipy.linalg.cython_blas cimport sasum, dasum
5+
from scipy.linalg.cython_blas cimport saxpy, daxpy
6+
from scipy.linalg.cython_blas cimport snrm2, dnrm2
7+
from scipy.linalg.cython_blas cimport scopy, dcopy
8+
from scipy.linalg.cython_blas cimport sscal, dscal
9+
from scipy.linalg.cython_blas cimport sgemv, dgemv
10+
from scipy.linalg.cython_blas cimport sger, dger
11+
from scipy.linalg.cython_blas cimport sgemm, dgemm
12+
13+
14+
################
15+
# BLAS Level 1 #
16+
################
17+
18+
cdef floating _dot(int n, floating *x, int incx,
19+
floating *y, int incy) nogil:
20+
"""x.T.y"""
21+
if floating is float:
22+
return sdot(&n, x, &incx, y, &incy)
23+
else:
24+
return ddot(&n, x, &incx, y, &incy)
25+
26+
27+
cpdef _dot_memview(floating[::1] x, floating[::1] y):
28+
return _dot(x.shape[0], &x[0], 1, &y[0], 1)
29+
30+
31+
cdef floating _asum(int n, floating *x, int incx) nogil:
32+
"""sum(|x_i|)"""
33+
if floating is float:
34+
return sasum(&n, x, &incx)
35+
else:
36+
return dasum(&n, x, &incx)
37+
38+
39+
cpdef _asum_memview(floating[::1] x):
40+
return _asum(x.shape[0], &x[0], 1)
41+
42+
43+
cdef void _axpy(int n, floating alpha, floating *x, int incx,
44+
floating *y, int incy) nogil:
45+
"""y := alpha * x + y"""
46+
if floating is float:
47+
saxpy(&n, &alpha, x, &incx, y, &incy)
48+
else:
49+
daxpy(&n, &alpha, x, &incx, y, &incy)
50+
51+
52+
cpdef _axpy_memview(floating alpha, floating[::1] x, floating[::1] y):
53+
_axpy(x.shape[0], alpha, &x[0], 1, &y[0], 1)
54+
55+
56+
cdef floating _nrm2(int n, floating *x, int incx) nogil:
57+
"""sqrt(sum((x_i)^2))"""
58+
if floating is float:
59+
return snrm2(&n, x, &incx)
60+
else:
61+
return dnrm2(&n, x, &incx)
62+
63+
64+
cpdef _nrm2_memview(floating[::1] x):
65+
return _nrm2(x.shape[0], &x[0], 1)
66+
67+
68+
cdef void _copy(int n, floating *x, int incx, floating *y, int incy) nogil:
69+
"""y := x"""
70+
if floating is float:
71+
scopy(&n, x, &incx, y, &incy)
72+
else:
73+
dcopy(&n, x, &incx, y, &incy)
74+
75+
76+
cpdef _copy_memview(floating[::1] x, floating[::1] y):
77+
_copy(x.shape[0], &x[0], 1, &y[0], 1)
78+
79+
80+
cdef void _scal(int n, floating alpha, floating *x, int incx) nogil:
81+
"""x := alpha * x"""
82+
if floating is float:
83+
sscal(&n, &alpha, x, &incx)
84+
else:
85+
dscal(&n, &alpha, x, &incx)
86+
87+
88+
cpdef _scal_memview(floating alpha, floating[::1] x):
89+
_scal(x.shape[0], alpha, &x[0], 1)
90+
91+
92+
################
93+
# BLAS Level 2 #
94+
################
95+
96+
cdef void _gemv(BLAS_Order order, BLAS_Trans ta, int m, int n, floating alpha,
97+
floating *A, int lda, floating *x, int incx,
98+
floating beta, floating *y, int incy) nogil:
99+
"""y := alpha * op(A).x + beta * y"""
100+
cdef char ta_ = ta
101+
if order == RowMajor:
102+
ta_ = NoTrans if ta == Trans else Trans
103+
if floating is float:
104+
sgemv(&ta_, &n, &m, &alpha, A, &lda, x, &incx, &beta, y, &incy)
105+
else:
106+
dgemv(&ta_, &n, &m, &alpha, A, &lda, x, &incx, &beta, y, &incy)
107+
else:
108+
if floating is float:
109+
sgemv(&ta_, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy)
110+
else:
111+
dgemv(&ta_, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy)
112+
113+
114+
cpdef _gemv_memview(BLAS_Trans ta, floating alpha, floating[:, :] A,
115+
floating[::1] x, floating beta, floating[::1] y):
116+
cdef:
117+
int m = A.shape[0]
118+
int n = A.shape[1]
119+
BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor
120+
int lda = m if order == ColMajor else n
121+
122+
_gemv(order, ta, m, n, alpha, &A[0, 0], lda, &x[0], 1, beta, &y[0], 1)
123+
124+
125+
cdef void _ger(BLAS_Order order, int m, int n, floating alpha, floating *x,
126+
int incx, floating *y, int incy, floating *A, int lda) nogil:
127+
"""A := alpha * x.y.T + A"""
128+
if order == RowMajor:
129+
if floating is float:
130+
sger(&n, &m, &alpha, y, &incy, x, &incx, A, &lda)
131+
else:
132+
dger(&n, &m, &alpha, y, &incy, x, &incx, A, &lda)
133+
else:
134+
if floating is float:
135+
sger(&m, &n, &alpha, x, &incx, y, &incy, A, &lda)
136+
else:
137+
dger(&m, &n, &alpha, x, &incx, y, &incy, A, &lda)
138+
139+
140+
cpdef _ger_memview(floating alpha, floating[::1] x, floating[::] y,
141+
floating[:, :] A):
142+
cdef:
143+
int m = A.shape[0]
144+
int n = A.shape[1]
145+
BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor
146+
int lda = m if order == ColMajor else n
147+
148+
_ger(order, m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda)
149+
150+
151+
################
152+
# BLAS Level 3 #
153+
################
154+
155+
cdef void _gemm(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, int m, int n,
156+
int k, floating alpha, floating *A, int lda, floating *B,
157+
int ldb, floating beta, floating *C, int ldc) nogil:
158+
"""C := alpha * op(A).op(B) + beta * C"""
159+
cdef:
160+
char ta_ = ta
161+
char tb_ = tb
162+
if order == RowMajor:
163+
if floating is float:
164+
sgemm(&tb_, &ta_, &n, &m, &k, &alpha, B,
165+
&ldb, A, &lda, &beta, C, &ldc)
166+
else:
167+
dgemm(&tb_, &ta_, &n, &m, &k, &alpha, B,
168+
&ldb, A, &lda, &beta, C, &ldc)
169+
else:
170+
if floating is float:
171+
sgemm(&ta_, &tb_, &m, &n, &k, &alpha, A,
172+
&lda, B, &ldb, &beta, C, &ldc)
173+
else:
174+
dgemm(&ta_, &tb_, &m, &n, &k, &alpha, A,
175+
&lda, B, &ldb, &beta, C, &ldc)
176+
177+
178+
cpdef _gemm_memview(BLAS_Trans ta, BLAS_Trans tb, floating alpha,
179+
floating[:, :] A, floating[:, :] B, floating beta,
180+
floating[:, :] C):
181+
cdef:
182+
int m = A.shape[0] if ta == NoTrans else A.shape[1]
183+
int n = B.shape[1] if tb == NoTrans else B.shape[0]
184+
int k = A.shape[1] if ta == NoTrans else A.shape[0]
185+
int lda, ldb, ldc
186+
BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor
187+
188+
if order == RowMajor:
189+
lda = k if ta == NoTrans else m
190+
ldb = n if tb == NoTrans else k
191+
ldc = n
192+
else:
193+
lda = m if ta == NoTrans else k
194+
ldb = k if tb == NoTrans else n
195+
ldc = m
196+
197+
_gemm(order, ta, tb, m, n, k, alpha, &A[0, 0],
198+
lda, &B[0, 0], ldb, beta, &C[0, 0], ldc)

sklearn/utils/setup.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,17 @@ def configuration(parent_package='', top_path=None):
2424
config.add_extension('sparsefuncs_fast', sources=['sparsefuncs_fast.pyx'],
2525
libraries=libraries)
2626

27+
config.add_extension('_cython_blas',
28+
sources=['_cython_blas.pyx'],
29+
libraries=libraries)
30+
2731
config.add_extension('arrayfuncs',
2832
sources=['arrayfuncs.pyx'],
2933
depends=[join('src', 'cholesky_delete.h')],
3034
libraries=cblas_libs,
3135
include_dirs=cblas_includes,
3236
extra_compile_args=cblas_compile_args,
33-
**blas_info
34-
)
37+
**blas_info)
3538

3639
config.add_extension('murmurhash',
3740
sources=['murmurhash.pyx', join(

0 commit comments

Comments
 (0)
0