-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] Use Scipy cython BLAS API instead of bundled CBLAS #12732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
c6cd0a9
Scipy cython_blas fused helpers
jeremiedbb cc81002
cython language_level=3
jeremiedbb d823f7c
rtol
jeremiedbb 5e0656b
rtol
jeremiedbb 73522a4
rtol
jeremiedbb d656a3c
alpha beta
jeremiedbb 24e0dd0
enum order & trans
jeremiedbb df86d94
fix numpy order type
jeremiedbb 9372276
change blas functions names
jeremiedbb 1868648
flake8
jeremiedbb 12b3660
clean up
jeremiedbb 86556d9
what's new
jeremiedbb 0019d25
infer memory layout
jeremiedbb 3564ffa
remove blank line
jeremiedbb File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,26 @@ | ||
import os | ||
import os.path | ||
|
||
import numpy | ||
from numpy.distutils.misc_util import Configuration | ||
|
||
from sklearn._build_utils import get_blas_info | ||
|
||
|
||
def configuration(parent_package="", top_path=None): | ||
config = Configuration("metrics", parent_package, top_path) | ||
|
||
cblas_libs, blas_info = get_blas_info() | ||
libraries = [] | ||
if os.name == 'posix': | ||
cblas_libs.append('m') | ||
libraries.append('m') | ||
|
||
config.add_subpackage('cluster') | ||
|
||
config.add_extension("pairwise_fast", | ||
sources=["pairwise_fast.pyx"], | ||
include_dirs=[os.path.join('..', 'src', 'cblas'), | ||
numpy.get_include(), | ||
blas_info.pop('include_dirs', [])], | ||
libraries=cblas_libs, | ||
extra_compile_args=blas_info.pop('extra_compile_args', | ||
[]), | ||
**blas_info) | ||
libraries=libraries) | ||
|
||
config.add_subpackage('tests') | ||
|
||
return config | ||
|
||
|
||
if __name__ == "__main__": | ||
from numpy.distutils.core import setup | ||
setup(**configuration().todict()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# cython: language_level=3 | ||
|
||
from cython cimport floating | ||
|
||
|
||
cpdef enum BLAS_Order: | ||
RowMajor # C contiguous | ||
ColMajor # Fortran contiguous | ||
|
||
|
||
cpdef enum BLAS_Trans: | ||
NoTrans = 110 # correspond to 'n' | ||
Trans = 116 # correspond to 't' | ||
|
||
|
||
# BLAS Level 1 ################################################################ | ||
cdef floating _dot(int, floating*, int, floating*, int) nogil | ||
|
||
cdef floating _asum(int, floating*, int) nogil | ||
|
||
cdef void _axpy(int, floating, floating*, int, floating*, int) nogil | ||
|
||
cdef floating _nrm2(int, floating*, int) nogil | ||
|
||
cdef void _copy(int, floating*, int, floating*, int) nogil | ||
|
||
cdef void _scal(int, floating, floating*, int) nogil | ||
|
||
# BLAS Level 2 ################################################################ | ||
cdef void _gemv(BLAS_Order, BLAS_Trans, int, int, floating, floating*, int, | ||
floating*, int, floating, floating*, int) nogil | ||
|
||
cdef void _ger(BLAS_Order, int, int, floating, floating*, int, floating*, int, | ||
floating*, int) nogil | ||
|
||
# BLASLevel 3 ################################################################ | ||
cdef void _gemm(BLAS_Order, BLAS_Trans, BLAS_Trans, int, int, int, floating, | ||
floating*, int, floating*, int, floating, floating*, | ||
int) nogil |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
from cython cimport floating | ||
|
||
from scipy.linalg.cython_blas cimport sdot, ddot | ||
from scipy.linalg.cython_blas cimport sasum, dasum | ||
from scipy.linalg.cython_blas cimport saxpy, daxpy | ||
from scipy.linalg.cython_blas cimport snrm2, dnrm2 | ||
from scipy.linalg.cython_blas cimport scopy, dcopy | ||
from scipy.linalg.cython_blas cimport sscal, dscal | ||
from scipy.linalg.cython_blas cimport sgemv, dgemv | ||
from scipy.linalg.cython_blas cimport sger, dger | ||
from scipy.linalg.cython_blas cimport sgemm, dgemm | ||
|
||
|
||
################ | ||
# BLAS Level 1 # | ||
################ | ||
|
||
cdef floating _dot(int n, floating *x, int incx, | ||
floating *y, int incy) nogil: | ||
"""x.T.y""" | ||
if floating is float: | ||
return sdot(&n, x, &incx, y, &incy) | ||
else: | ||
return ddot(&n, x, &incx, y, &incy) | ||
|
||
|
||
cpdef _dot_memview(floating[::1] x, floating[::1] y): | ||
return _dot(x.shape[0], &x[0], 1, &y[0], 1) | ||
|
||
|
||
cdef floating _asum(int n, floating *x, int incx) nogil: | ||
"""sum(|x_i|)""" | ||
if floating is float: | ||
return sasum(&n, x, &incx) | ||
else: | ||
return dasum(&n, x, &incx) | ||
|
||
|
||
cpdef _asum_memview(floating[::1] x): | ||
return _asum(x.shape[0], &x[0], 1) | ||
|
||
|
||
cdef void _axpy(int n, floating alpha, floating *x, int incx, | ||
floating *y, int incy) nogil: | ||
"""y := alpha * x + y""" | ||
if floating is float: | ||
saxpy(&n, &alpha, x, &incx, y, &incy) | ||
else: | ||
daxpy(&n, &alpha, x, &incx, y, &incy) | ||
|
||
|
||
cpdef _axpy_memview(floating alpha, floating[::1] x, floating[::1] y): | ||
_axpy(x.shape[0], alpha, &x[0], 1, &y[0], 1) | ||
|
||
|
||
cdef floating _nrm2(int n, floating *x, int incx) nogil: | ||
"""sqrt(sum((x_i)^2))""" | ||
if floating is float: | ||
return snrm2(&n, x, &incx) | ||
else: | ||
return dnrm2(&n, x, &incx) | ||
|
||
|
||
cpdef _nrm2_memview(floating[::1] x): | ||
return _nrm2(x.shape[0], &x[0], 1) | ||
|
||
|
||
cdef void _copy(int n, floating *x, int incx, floating *y, int incy) nogil: | ||
"""y := x""" | ||
if floating is float: | ||
scopy(&n, x, &incx, y, &incy) | ||
else: | ||
dcopy(&n, x, &incx, y, &incy) | ||
|
||
|
||
cpdef _copy_memview(floating[::1] x, floating[::1] y): | ||
_copy(x.shape[0], &x[0], 1, &y[0], 1) | ||
|
||
|
||
cdef void _scal(int n, floating alpha, floating *x, int incx) nogil: | ||
"""x := alpha * x""" | ||
if floating is float: | ||
sscal(&n, &alpha, x, &incx) | ||
else: | ||
dscal(&n, &alpha, x, &incx) | ||
|
||
|
||
cpdef _scal_memview(floating alpha, floating[::1] x): | ||
_scal(x.shape[0], alpha, &x[0], 1) | ||
|
||
|
||
################ | ||
# BLAS Level 2 # | ||
################ | ||
|
||
cdef void _gemv(BLAS_Order order, BLAS_Trans ta, int m, int n, floating alpha, | ||
floating *A, int lda, floating *x, int incx, | ||
floating beta, floating *y, int incy) nogil: | ||
"""y := alpha * op(A).x + beta * y""" | ||
cdef char ta_ = ta | ||
if order == RowMajor: | ||
ta_ = NoTrans if ta == Trans else Trans | ||
if floating is float: | ||
sgemv(&ta_, &n, &m, &alpha, A, &lda, x, &incx, &beta, y, &incy) | ||
else: | ||
dgemv(&ta_, &n, &m, &alpha, A, &lda, x, &incx, &beta, y, &incy) | ||
else: | ||
if floating is float: | ||
sgemv(&ta_, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy) | ||
else: | ||
dgemv(&ta_, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy) | ||
|
||
|
||
cpdef _gemv_memview(BLAS_Trans ta, floating alpha, floating[:, :] A, | ||
floating[::1] x, floating beta, floating[::1] y): | ||
cdef: | ||
int m = A.shape[0] | ||
int n = A.shape[1] | ||
BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor | ||
int lda = m if order == ColMajor else n | ||
|
||
_gemv(order, ta, m, n, alpha, &A[0, 0], lda, &x[0], 1, beta, &y[0], 1) | ||
|
||
|
||
cdef void _ger(BLAS_Order order, int m, int n, floating alpha, floating *x, | ||
int incx, floating *y, int incy, floating *A, int lda) nogil: | ||
"""A := alpha * x.y.T + A""" | ||
if order == RowMajor: | ||
if floating is float: | ||
sger(&n, &m, &alpha, y, &incy, x, &incx, A, &lda) | ||
else: | ||
dger(&n, &m, &alpha, y, &incy, x, &incx, A, &lda) | ||
else: | ||
if floating is float: | ||
sger(&m, &n, &alpha, x, &incx, y, &incy, A, &lda) | ||
else: | ||
dger(&m, &n, &alpha, x, &incx, y, &incy, A, &lda) | ||
|
||
|
||
cpdef _ger_memview(floating alpha, floating[::1] x, floating[::] y, | ||
floating[:, :] A): | ||
cdef: | ||
int m = A.shape[0] | ||
int n = A.shape[1] | ||
BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor | ||
int lda = m if order == ColMajor else n | ||
|
||
_ger(order, m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda) | ||
|
||
|
||
################ | ||
# BLAS Level 3 # | ||
################ | ||
|
||
cdef void _gemm(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, int m, int n, | ||
int k, floating alpha, floating *A, int lda, floating *B, | ||
int ldb, floating beta, floating *C, int ldc) nogil: | ||
"""C := alpha * op(A).op(B) + beta * C""" | ||
cdef: | ||
char ta_ = ta | ||
char tb_ = tb | ||
if order == RowMajor: | ||
if floating is float: | ||
sgemm(&tb_, &ta_, &n, &m, &k, &alpha, B, | ||
&ldb, A, &lda, &beta, C, &ldc) | ||
else: | ||
dgemm(&tb_, &ta_, &n, &m, &k, &alpha, B, | ||
&ldb, A, &lda, &beta, C, &ldc) | ||
else: | ||
if floating is float: | ||
sgemm(&ta_, &tb_, &m, &n, &k, &alpha, A, | ||
&lda, B, &ldb, &beta, C, &ldc) | ||
else: | ||
dgemm(&ta_, &tb_, &m, &n, &k, &alpha, A, | ||
&lda, B, &ldb, &beta, C, &ldc) | ||
|
||
|
||
cpdef _gemm_memview(BLAS_Trans ta, BLAS_Trans tb, floating alpha, | ||
floating[:, :] A, floating[:, :] B, floating beta, | ||
floating[:, :] C): | ||
cdef: | ||
int m = A.shape[0] if ta == NoTrans else A.shape[1] | ||
int n = B.shape[1] if tb == NoTrans else B.shape[0] | ||
int k = A.shape[1] if ta == NoTrans else A.shape[0] | ||
int lda, ldb, ldc | ||
BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor | ||
|
||
if order == RowMajor: | ||
lda = k if ta == NoTrans else m | ||
ldb = n if tb == NoTrans else k | ||
ldc = n | ||
else: | ||
lda = m if ta == NoTrans else k | ||
ldb = k if tb == NoTrans else n | ||
ldc = m | ||
|
||
_gemm(order, ta, tb, m, n, k, alpha, &A[0, 0], | ||
lda, &B[0, 0], ldb, beta, &C[0, 0], ldc) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we be using const memoryviews to allow read-only input arrays?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fused typed const memoryviews does not work yet, see #10624
However, all the
_xxx_memview
functions are just python wrappers to be able to test the C functions with pytest. They are not meant to be used in the python code base (if we want to multiply matrices in python we just do numpy dot), we don't want to expose blas functions at the python level.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah of course. But is there a reason we should not be using memview interfaces when that would simplify the call?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might be the small overhead of an additional function call (and I don't really know how memview behave versus pointers performance wise) but I agree it would simplify it.
Currently all blas functions are called within functions where we have access to the pointers, so I'm not sure it's worth making interfaces for what we don't need currently. Maybe we should reconsider doing it when the need comes ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there's any function call overhead. Certainly not python functions. There will be cost in accessing members of the memoryview struct, but minimal.
No hurry, you're right, but I still find it strange that we are passed order rather than determining it from the memoryview strides.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there's any function call overhead. Certainly not python functions. There will be cost in accessing members of the memoryview struct, but minimal.
No hurry, you're right, but I still find it strange that we are passed order rather than determining it from the memoryview strides.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You convinced me :) I updated the functions to infer the memory layout.