From c6cd0a98a4d156d8f907bf69d17adda03407d266 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 8 Nov 2018 13:38:56 +0100 Subject: [PATCH 01/14] Scipy cython_blas fused helpers --- sklearn/metrics/pairwise_fast.pyx | 5 +- sklearn/metrics/setup.py | 19 +-- sklearn/utils/_cython_blas.pxd | 20 +++ sklearn/utils/_cython_blas.pyx | 195 ++++++++++++++++++++++++ sklearn/utils/setup.py | 4 + sklearn/utils/tests/test_cython_blas.py | 171 +++++++++++++++++++++ 6 files changed, 398 insertions(+), 16 deletions(-) create mode 100644 sklearn/utils/_cython_blas.pxd create mode 100644 sklearn/utils/_cython_blas.pyx create mode 100644 sklearn/utils/tests/test_cython_blas.py diff --git a/sklearn/metrics/pairwise_fast.pyx b/sklearn/metrics/pairwise_fast.pyx index 4d7ad411fa20a..901bedb145f15 100644 --- a/sklearn/metrics/pairwise_fast.pyx +++ b/sklearn/metrics/pairwise_fast.pyx @@ -13,8 +13,7 @@ cimport numpy as np from cython cimport floating -cdef extern from "cblas.h": - double cblas_dasum(int, const double *, int) nogil +from ..utils._cython_blas cimport _xasum np.import_array() @@ -67,4 +66,4 @@ def _sparse_manhattan(floating[::1] X_data, int[:] X_indices, int[:] X_indptr, for j in range(Y_indptr[iy], Y_indptr[iy + 1]): row[Y_indices[j]] -= Y_data[j] - D[ix, iy] = cblas_dasum(n_features, &row[0], 1) + D[ix, iy] = _xasum(n_features, &row[0], 1) diff --git a/sklearn/metrics/setup.py b/sklearn/metrics/setup.py index d9a10d1df3290..97175456220cd 100644 --- a/sklearn/metrics/setup.py +++ b/sklearn/metrics/setup.py @@ -1,33 +1,26 @@ import os -import os.path -import numpy from numpy.distutils.misc_util import Configuration -from sklearn._build_utils import get_blas_info - def configuration(parent_package="", top_path=None): config = Configuration("metrics", parent_package, top_path) - cblas_libs, blas_info = get_blas_info() + libraries = [] if os.name == 'posix': - cblas_libs.append('m') + libraries.append('m') config.add_subpackage('cluster') + config.add_extension("pairwise_fast", sources=["pairwise_fast.pyx"], - include_dirs=[os.path.join('..', 'src', 'cblas'), - numpy.get_include(), - blas_info.pop('include_dirs', [])], - libraries=cblas_libs, - extra_compile_args=blas_info.pop('extra_compile_args', - []), - **blas_info) + libraries=libraries) + config.add_subpackage('tests') return config + if __name__ == "__main__": from numpy.distutils.core import setup setup(**configuration().todict()) diff --git a/sklearn/utils/_cython_blas.pxd b/sklearn/utils/_cython_blas.pxd new file mode 100644 index 0000000000000..b9ed826950650 --- /dev/null +++ b/sklearn/utils/_cython_blas.pxd @@ -0,0 +1,20 @@ +from cython cimport floating + + +# BLAS Level 1 ################################################################ +cdef floating _xdot(int, floating*, int, floating*, int) nogil +cdef floating _xasum(int, floating*, int) nogil +cdef void _xaxpy(int, floating, floating*, int, floating*, int) nogil +cdef floating _xnrm2(int, floating*, int) nogil +cdef void _xcopy(int, floating*, int, floating*, int) nogil +cdef void _xscal(int, floating, floating*, int) nogil + +# BLAS Level 2 ################################################################ +cdef void _xgemv(char, char, int, int, floating, floating*, int, floating*, + int, floating, floating*, int) nogil +cdef void _xger(char, int, int, floating, floating*, int, floating*, int, + floating*, int) nogil + +# BLASLevel 3 ################################################################ +cdef void _xgemm(char, char, char, int, int, int, floating, floating*, int, + floating*, int, floating, floating*, int) nogil diff --git a/sklearn/utils/_cython_blas.pyx b/sklearn/utils/_cython_blas.pyx new file mode 100644 index 0000000000000..99b9249023af3 --- /dev/null +++ b/sklearn/utils/_cython_blas.pyx @@ -0,0 +1,195 @@ +# cython: boundscheck=False, wraparound=False, cdivision=True + +from cython cimport floating + +from scipy.linalg.cython_blas cimport sdot, ddot +from scipy.linalg.cython_blas cimport sasum, dasum +from scipy.linalg.cython_blas cimport saxpy, daxpy +from scipy.linalg.cython_blas cimport snrm2, dnrm2 +from scipy.linalg.cython_blas cimport scopy, dcopy +from scipy.linalg.cython_blas cimport sscal, dscal +from scipy.linalg.cython_blas cimport sgemv, dgemv +from scipy.linalg.cython_blas cimport sger, dger +from scipy.linalg.cython_blas cimport sgemm, dgemm + + +################ +# BLAS Level 1 # +################ + +cdef floating _xdot(int n, floating *x, int incx, + floating *y, int incy) nogil: + """x.T.y""" + if floating is float: + return sdot(&n, x, &incx, y, &incy) + else: + return ddot(&n, x, &incx, y, &incy) + + +cpdef _xdot_memview(floating[::1] x, floating[::1] y): + return _xdot(x.shape[0], &x[0], 1, &y[0], 1) + + +cdef floating _xasum(int n, floating *x, int incx) nogil: + """sum(|x_i|)""" + if floating is float: + return sasum(&n, x, &incx) + else: + return dasum(&n, x, &incx) + + +cpdef _xasum_memview(floating[::1] x): + return _xasum(x.shape[0], &x[0], 1) + + +cdef void _xaxpy(int n, floating alpha, floating *x, int incx, + floating *y, int incy) nogil: + """y := alpha * x + y""" + if floating is float: + saxpy(&n, &alpha, x, &incx, y, &incy) + else: + daxpy(&n, &alpha, x, &incx, y, &incy) + + +cpdef _xaxpy_memview(floating alpha, floating[::1] x, floating[::1] y): + _xaxpy(x.shape[0], alpha, &x[0], 1, &y[0], 1) + + +cdef floating _xnrm2(int n, floating *x, int incx) nogil: + """sqrt(sum((x_i)^2))""" + if floating is float: + return snrm2(&n, x, &incx) + else: + return dnrm2(&n, x, &incx) + + +cpdef _xnrm2_memview(floating[::1] x): + return _xnrm2(x.shape[0], &x[0], 1) + + +cdef void _xcopy(int n, floating *x, int incx, floating *y, int incy) nogil: + """y := x""" + if floating is float: + scopy(&n, x, &incx, y, &incy) + else: + dcopy(&n, x, &incx, y, &incy) + + +cpdef _xcopy_memview(floating[::1] x, floating[::1] y): + _xcopy(x.shape[0], &x[0], 1, &y[0], 1) + + +cdef void _xscal(int n, floating alpha, floating *x, int incx) nogil: + """x := alpha * x""" + if floating is float: + sscal(&n, &alpha, x, &incx) + else: + dscal(&n, &alpha, x, &incx) + + +cpdef _xscal_memview(floating alpha, floating[::1] x): + _xscal(x.shape[0], alpha, &x[0], 1) + + +################ +# BLAS Level 2 # +################ + +cdef void _xgemv(char layout, char ta, int m, int n, floating alpha, + floating *A, int lda, floating *x, int incx, + floating beta, floating *y, int incy) nogil: + """y := alpha * op(A).x + beta * y""" + if layout == 'C': + ta = 'n' if ta == 't' else 't' + if floating is float: + sgemv(&ta, &n, &m, &alpha, A, &lda, x, &incx, &beta, y, &incy) + else: + dgemv(&ta, &n, &m, &alpha, A, &lda, x, &incx, &beta, y, &incy) + elif layout == 'F': + if floating is float: + sgemv(&ta, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy) + else: + dgemv(&ta, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy) + + +cpdef _xgemv_memview(layout, ta, floating alpha, floating[:, :] A, + floating[::1] x, floating beta, floating[::1] y): + cdef: + char layout_ = 'F' if layout == 'F' else 'C' + char ta_ = 'n' if ta == 'n' else 't' + int m = A.shape[0] + int n = A.shape[1] + int lda = m if layout == 'F' else n + + _xgemv(layout_, ta_, m, n, alpha, &A[0, 0], lda, + &x[0], 1, beta, &y[0], 1) + + +cdef void _xger(char layout, int m, int n, floating alpha, floating *x, + int incx, floating *y, int incy, floating *A, int lda) nogil: + """A := alpha * x.y.T + A""" + if layout == 'C': + if floating is float: + sger(&n, &m, &alpha, y, &incy, x, &incx, A, &lda) + else: + dger(&n, &m, &alpha, y, &incy, x, &incx, A, &lda) + elif layout == 'F': + if floating is float: + sger(&m, &n, &alpha, x, &incx, y, &incy, A, &lda) + else: + dger(&m, &n, &alpha, x, &incx, y, &incy, A, &lda) + + +cpdef _xger_memview(layout, floating alpha, floating[::1] x, floating[::] y, + floating[:, :] A): + cdef: + char layout_ = 'F' if layout == 'F' else 'C' + int m = A.shape[0] + int n = A.shape[1] + int lda = m if layout[0] == 'F' else n + + _xger(layout_, m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda) + + +################ +# BLAS Level 3 # +################ + +cdef void _xgemm(char layout, char ta, char tb, int m, int n, int k, + floating alpha, floating *A, int lda, floating *B, int ldb, + floating beta, floating *C, int ldc) nogil: + """C := alpha * op(A).op(B) + beta * C""" + if layout == 'C': + if floating is float: + sgemm(&tb, &ta, &n, &m, &k, &alpha, B, &ldb, A, &lda, &beta, C, &ldc) + else: + dgemm(&tb, &ta, &n, &m, &k, &alpha, B, &ldb, A, &lda, &beta, C, &ldc) + elif layout == 'F': + if floating is float: + sgemm(&ta, &tb, &m, &n, &k, &alpha, A, &lda, B, &ldb, &beta, C, &ldc) + else: + dgemm(&ta, &tb, &m, &n, &k, &alpha, A, &lda, B, &ldb, &beta, C, &ldc) + + +cpdef _xgemm_memview(layout, ta, tb, floating alpha, floating[:, :] A, + floating[:, :] B, floating beta, floating[:, :] C): + cdef: + char layout_ = 'F' if layout == 'F' else 'C' + char ta_ = 'n' if ta == 'n' else 't' + char tb_ = 'n' if tb == 'n' else 't' + int m = A.shape[0] if ta[0] == 'n' else A.shape[1] + int n = B.shape[1] if tb[0] == 'n' else B.shape[0] + int k = A.shape[1] if ta[0] == 'n' else A.shape[0] + int lda, ldb, ldc + + if layout == 'F': + lda = m if ta == 'n' else k + ldb = k if tb == 'n' else n + ldc = m + else: + lda = k if ta == 'n' else m + ldb = n if tb == 'n' else k + ldc = n + + _xgemm(layout_, ta_, tb_, m, n, k, alpha, + &A[0, 0], lda, &B[0, 0], ldb, beta, &C[0, 0], ldc) \ No newline at end of file diff --git a/sklearn/utils/setup.py b/sklearn/utils/setup.py index 13d772a5a53b7..4e0f4444e062c 100644 --- a/sklearn/utils/setup.py +++ b/sklearn/utils/setup.py @@ -24,6 +24,10 @@ def configuration(parent_package='', top_path=None): config.add_extension('sparsefuncs_fast', sources=['sparsefuncs_fast.pyx'], libraries=libraries) + config.add_extension('_cython_blas', + sources=['_cython_blas.pyx'], + libraries=libraries) + config.add_extension('arrayfuncs', sources=['arrayfuncs.pyx'], depends=[join('src', 'cholesky_delete.h')], diff --git a/sklearn/utils/tests/test_cython_blas.py b/sklearn/utils/tests/test_cython_blas.py new file mode 100644 index 0000000000000..73456831fd31a --- /dev/null +++ b/sklearn/utils/tests/test_cython_blas.py @@ -0,0 +1,171 @@ +import pytest +import cython + +import numpy as np + +from sklearn.utils.testing import assert_allclose + +from sklearn.utils._cython_blas import _xdot_memview +from sklearn.utils._cython_blas import _xasum_memview +from sklearn.utils._cython_blas import _xaxpy_memview +from sklearn.utils._cython_blas import _xnrm2_memview +from sklearn.utils._cython_blas import _xcopy_memview +from sklearn.utils._cython_blas import _xscal_memview +from sklearn.utils._cython_blas import _xgemv_memview +from sklearn.utils._cython_blas import _xger_memview +from sklearn.utils._cython_blas import _xgemm_memview + + +NUMPY_TO_CYTHON = {np.float32: cython.float, np.float64: cython.double} + + +def _no_op(x): + return x + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_dot(dtype): + dot = _xdot_memview[NUMPY_TO_CYTHON[dtype]] + + rng = np.random.RandomState(0) + x = rng.random_sample(10).astype(dtype, copy=False) + y = rng.random_sample(10).astype(dtype, copy=False) + + expected = x.dot(y) + actual = dot(x, y) + + assert_allclose(actual, expected) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_asum(dtype): + asum = _xasum_memview[NUMPY_TO_CYTHON[dtype]] + + rng = np.random.RandomState(0) + x = rng.random_sample(10).astype(dtype, copy=False) + + expected = np.abs(x).sum() + actual = asum(x) + + assert_allclose(actual, expected) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_axpy(dtype): + axpy = _xaxpy_memview[NUMPY_TO_CYTHON[dtype]] + + rng = np.random.RandomState(0) + x = rng.random_sample(10).astype(dtype, copy=False) + y = rng.random_sample(10).astype(dtype, copy=False) + alpha = 1.23 + + expected = alpha * x + y + axpy(alpha, x, y) + + assert_allclose(y, expected) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_nrm2(dtype): + nrm2 = _xnrm2_memview[NUMPY_TO_CYTHON[dtype]] + + rng = np.random.RandomState(0) + x = rng.random_sample(10).astype(dtype, copy=False) + + expected = np.linalg.norm(x) + actual = nrm2(x) + + assert_allclose(actual, expected) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_copy(dtype): + copy = _xcopy_memview[NUMPY_TO_CYTHON[dtype]] + + rng = np.random.RandomState(0) + x = rng.random_sample(10).astype(dtype, copy=False) + y = np.empty_like(x) + + expected = x.copy() + copy(x, y) + + assert_allclose(y, expected) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_scal(dtype): + scal = _xscal_memview[NUMPY_TO_CYTHON[dtype]] + + rng = np.random.RandomState(0) + x = rng.random_sample(10).astype(dtype, copy=False) + alpha = 1.23 + + expected = alpha * x + scal(alpha, x) + + assert_allclose(x, expected) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("opA, transA", + [(_no_op, 'n'), (np.transpose, 't')], + ids=["A", "A.T"]) +@pytest.mark.parametrize("layout", ['C', 'F']) +def test_gemv(dtype, opA, transA, layout): + gemv = _xgemv_memview[NUMPY_TO_CYTHON[dtype]] + + rng = np.random.RandomState(0) + A = np.asarray(opA(rng.random_sample((20, 10)).astype(dtype, copy=False)), + order=layout) + x = rng.random_sample(10).astype(dtype, copy=False) + y = rng.random_sample(20).astype(dtype, copy=False) + alpha, beta = 1.23, -3.21 + + expected = alpha * opA(A).dot(x) + beta * y + gemv(layout, transA, alpha, A, x, beta, y) + + assert_allclose(y, expected, rtol=1e-4) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("layout", ['C', 'F']) +def test_ger(dtype, layout): + ger = _xger_memview[NUMPY_TO_CYTHON[dtype]] + + rng = np.random.RandomState(0) + x = rng.random_sample(10).astype(dtype, copy=False) + y = rng.random_sample(20).astype(dtype, copy=False) + A = np.asarray(rng.random_sample((10, 20)).astype(dtype, copy=False), + order=layout) + alpha = 1.23 + + expected = alpha * np.outer(x, y) + A + ger(layout, alpha, x, y, A) + + assert_allclose(A, expected, rtol=1e-4) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("opB, transB", + [(_no_op, 'n'), (np.transpose, 't')], + ids=["B", "B.T"]) +@pytest.mark.parametrize("opA, transA", + [(_no_op, 'n'), (np.transpose, 't')], + ids=["A", "A.T"]) +@pytest.mark.parametrize("layout", ['C', 'F']) +def test_gemm(dtype, opA, transA, opB, transB, layout): + gemm = _xgemm_memview[NUMPY_TO_CYTHON[dtype]] + + rng = np.random.RandomState(0) + A = np.asarray(opA(rng.random_sample((30, 10)).astype(dtype, copy=False)), + order=layout) + B = np.asarray(opB(rng.random_sample((10, 20)).astype(dtype, copy=False)), + order=layout) + C = np.asarray(rng.random_sample((30, 20)).astype(dtype, copy=False), + order=layout) + alpha, beta = 1.23, -3.21 + + expected = alpha * opA(A).dot(opB(B)) + beta * C + gemm(layout, transA, transB, alpha, A, B, beta, C) + + assert_allclose(C, expected, rtol=1e-4) From cc81002a5b4693a811b11e950b331bd82eb12b79 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 17 Dec 2018 13:25:59 +0100 Subject: [PATCH 02/14] cython language_level=3 --- sklearn/utils/_cython_blas.pxd | 2 ++ sklearn/utils/_cython_blas.pyx | 41 +++++++++++++++---------- sklearn/utils/tests/test_cython_blas.py | 7 +++-- 3 files changed, 30 insertions(+), 20 deletions(-) diff --git a/sklearn/utils/_cython_blas.pxd b/sklearn/utils/_cython_blas.pxd index b9ed826950650..d245db9f9f350 100644 --- a/sklearn/utils/_cython_blas.pxd +++ b/sklearn/utils/_cython_blas.pxd @@ -1,3 +1,5 @@ +# cython: language_level=3 + from cython cimport floating diff --git a/sklearn/utils/_cython_blas.pyx b/sklearn/utils/_cython_blas.pyx index 99b9249023af3..c68ab94dcd076 100644 --- a/sklearn/utils/_cython_blas.pyx +++ b/sklearn/utils/_cython_blas.pyx @@ -13,6 +13,13 @@ from scipy.linalg.cython_blas cimport sger, dger from scipy.linalg.cython_blas cimport sgemm, dgemm +cdef: + char ColMajor = b'F' + char RowMajor = b'C' + char Trans = b't' + char NoTrans = b'n' + + ################ # BLAS Level 1 # ################ @@ -99,13 +106,13 @@ cdef void _xgemv(char layout, char ta, int m, int n, floating alpha, floating *A, int lda, floating *x, int incx, floating beta, floating *y, int incy) nogil: """y := alpha * op(A).x + beta * y""" - if layout == 'C': - ta = 'n' if ta == 't' else 't' + if layout == RowMajor: + ta = NoTrans if ta == Trans else Trans if floating is float: sgemv(&ta, &n, &m, &alpha, A, &lda, x, &incx, &beta, y, &incy) else: dgemv(&ta, &n, &m, &alpha, A, &lda, x, &incx, &beta, y, &incy) - elif layout == 'F': + elif layout == ColMajor: if floating is float: sgemv(&ta, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy) else: @@ -115,8 +122,8 @@ cdef void _xgemv(char layout, char ta, int m, int n, floating alpha, cpdef _xgemv_memview(layout, ta, floating alpha, floating[:, :] A, floating[::1] x, floating beta, floating[::1] y): cdef: - char layout_ = 'F' if layout == 'F' else 'C' - char ta_ = 'n' if ta == 'n' else 't' + char layout_ = ColMajor if layout == 'F' else RowMajor + char ta_ = NoTrans if ta == 'n' else Trans int m = A.shape[0] int n = A.shape[1] int lda = m if layout == 'F' else n @@ -128,12 +135,12 @@ cpdef _xgemv_memview(layout, ta, floating alpha, floating[:, :] A, cdef void _xger(char layout, int m, int n, floating alpha, floating *x, int incx, floating *y, int incy, floating *A, int lda) nogil: """A := alpha * x.y.T + A""" - if layout == 'C': + if layout == RowMajor: if floating is float: sger(&n, &m, &alpha, y, &incy, x, &incx, A, &lda) else: dger(&n, &m, &alpha, y, &incy, x, &incx, A, &lda) - elif layout == 'F': + elif layout == ColMajor: if floating is float: sger(&m, &n, &alpha, x, &incx, y, &incy, A, &lda) else: @@ -143,10 +150,10 @@ cdef void _xger(char layout, int m, int n, floating alpha, floating *x, cpdef _xger_memview(layout, floating alpha, floating[::1] x, floating[::] y, floating[:, :] A): cdef: - char layout_ = 'F' if layout == 'F' else 'C' + char layout_ = ColMajor if layout == 'F' else RowMajor int m = A.shape[0] int n = A.shape[1] - int lda = m if layout[0] == 'F' else n + int lda = m if layout == 'F' else n _xger(layout_, m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda) @@ -159,12 +166,12 @@ cdef void _xgemm(char layout, char ta, char tb, int m, int n, int k, floating alpha, floating *A, int lda, floating *B, int ldb, floating beta, floating *C, int ldc) nogil: """C := alpha * op(A).op(B) + beta * C""" - if layout == 'C': + if layout == RowMajor: if floating is float: sgemm(&tb, &ta, &n, &m, &k, &alpha, B, &ldb, A, &lda, &beta, C, &ldc) else: dgemm(&tb, &ta, &n, &m, &k, &alpha, B, &ldb, A, &lda, &beta, C, &ldc) - elif layout == 'F': + elif layout == ColMajor: if floating is float: sgemm(&ta, &tb, &m, &n, &k, &alpha, A, &lda, B, &ldb, &beta, C, &ldc) else: @@ -174,12 +181,12 @@ cdef void _xgemm(char layout, char ta, char tb, int m, int n, int k, cpdef _xgemm_memview(layout, ta, tb, floating alpha, floating[:, :] A, floating[:, :] B, floating beta, floating[:, :] C): cdef: - char layout_ = 'F' if layout == 'F' else 'C' - char ta_ = 'n' if ta == 'n' else 't' - char tb_ = 'n' if tb == 'n' else 't' - int m = A.shape[0] if ta[0] == 'n' else A.shape[1] - int n = B.shape[1] if tb[0] == 'n' else B.shape[0] - int k = A.shape[1] if ta[0] == 'n' else A.shape[0] + char layout_ = ColMajor if layout == 'F' else RowMajor + char ta_ = NoTrans if ta == 'n' else Trans + char tb_ = NoTrans if tb == 'n' else Trans + int m = A.shape[0] if ta == 'n' else A.shape[1] + int n = B.shape[1] if tb == 'n' else B.shape[0] + int k = A.shape[1] if ta == 'n' else A.shape[0] int lda, ldb, ldc if layout == 'F': diff --git a/sklearn/utils/tests/test_cython_blas.py b/sklearn/utils/tests/test_cython_blas.py index 73456831fd31a..7232ca9bd3db2 100644 --- a/sklearn/utils/tests/test_cython_blas.py +++ b/sklearn/utils/tests/test_cython_blas.py @@ -17,6 +17,7 @@ NUMPY_TO_CYTHON = {np.float32: cython.float, np.float64: cython.double} +RTOL = {np.float32: 1e-4, np.float64: 1e-13} def _no_op(x): @@ -124,7 +125,7 @@ def test_gemv(dtype, opA, transA, layout): expected = alpha * opA(A).dot(x) + beta * y gemv(layout, transA, alpha, A, x, beta, y) - assert_allclose(y, expected, rtol=1e-4) + assert_allclose(y, expected, rtol=RTOL[dtype]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @@ -142,7 +143,7 @@ def test_ger(dtype, layout): expected = alpha * np.outer(x, y) + A ger(layout, alpha, x, y, A) - assert_allclose(A, expected, rtol=1e-4) + assert_allclose(A, expected, rtol=RTOL[dtype]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @@ -168,4 +169,4 @@ def test_gemm(dtype, opA, transA, opB, transB, layout): expected = alpha * opA(A).dot(opB(B)) + beta * C gemm(layout, transA, transB, alpha, A, B, beta, C) - assert_allclose(C, expected, rtol=1e-4) + assert_allclose(C, expected, rtol=RTOL[dtype]) From d823f7c54bfb9135a1bb32f30d24840af8d0154f Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 17 Dec 2018 13:51:29 +0100 Subject: [PATCH 03/14] rtol --- sklearn/utils/tests/test_cython_blas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/tests/test_cython_blas.py b/sklearn/utils/tests/test_cython_blas.py index 7232ca9bd3db2..4a3ca21c8182e 100644 --- a/sklearn/utils/tests/test_cython_blas.py +++ b/sklearn/utils/tests/test_cython_blas.py @@ -17,7 +17,7 @@ NUMPY_TO_CYTHON = {np.float32: cython.float, np.float64: cython.double} -RTOL = {np.float32: 1e-4, np.float64: 1e-13} +RTOL = {np.float32: 1e-3, np.float64: 1e-13} def _no_op(x): From 5e0656bb908f96f591bfc52d0b4575a1c3b4fa80 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 17 Dec 2018 14:18:57 +0100 Subject: [PATCH 04/14] rtol --- sklearn/utils/tests/test_cython_blas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/tests/test_cython_blas.py b/sklearn/utils/tests/test_cython_blas.py index 4a3ca21c8182e..e37d583433e1e 100644 --- a/sklearn/utils/tests/test_cython_blas.py +++ b/sklearn/utils/tests/test_cython_blas.py @@ -17,7 +17,7 @@ NUMPY_TO_CYTHON = {np.float32: cython.float, np.float64: cython.double} -RTOL = {np.float32: 1e-3, np.float64: 1e-13} +RTOL = {np.float32: 1e-4, np.float64: 1e-12} def _no_op(x): From 73522a497ced8e42565615ff425fa15dfee9b5c0 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 17 Dec 2018 14:54:49 +0100 Subject: [PATCH 05/14] rtol --- sklearn/utils/tests/test_cython_blas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/tests/test_cython_blas.py b/sklearn/utils/tests/test_cython_blas.py index e37d583433e1e..0aafaebc287e3 100644 --- a/sklearn/utils/tests/test_cython_blas.py +++ b/sklearn/utils/tests/test_cython_blas.py @@ -17,7 +17,7 @@ NUMPY_TO_CYTHON = {np.float32: cython.float, np.float64: cython.double} -RTOL = {np.float32: 1e-4, np.float64: 1e-12} +RTOL = {np.float32: 1e-3, np.float64: 1e-12} def _no_op(x): From d656a3ca1ef789cad5b738cf9b648466098763f7 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Tue, 18 Dec 2018 15:57:06 +0100 Subject: [PATCH 06/14] alpha beta --- sklearn/utils/setup.py | 3 +-- sklearn/utils/tests/test_cython_blas.py | 24 ++++++++++++------------ 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/sklearn/utils/setup.py b/sklearn/utils/setup.py index 4e0f4444e062c..97aeb602408c4 100644 --- a/sklearn/utils/setup.py +++ b/sklearn/utils/setup.py @@ -34,8 +34,7 @@ def configuration(parent_package='', top_path=None): libraries=cblas_libs, include_dirs=cblas_includes, extra_compile_args=cblas_compile_args, - **blas_info - ) + **blas_info) config.add_extension('murmurhash', sources=['murmurhash.pyx', join( diff --git a/sklearn/utils/tests/test_cython_blas.py b/sklearn/utils/tests/test_cython_blas.py index 0aafaebc287e3..9cbcd2599dd73 100644 --- a/sklearn/utils/tests/test_cython_blas.py +++ b/sklearn/utils/tests/test_cython_blas.py @@ -17,7 +17,7 @@ NUMPY_TO_CYTHON = {np.float32: cython.float, np.float64: cython.double} -RTOL = {np.float32: 1e-3, np.float64: 1e-12} +RTOL = {np.float32: 1e-6, np.float64: 1e-12} def _no_op(x): @@ -35,7 +35,7 @@ def test_dot(dtype): expected = x.dot(y) actual = dot(x, y) - assert_allclose(actual, expected) + assert_allclose(actual, expected, rtol=RTOL[dtype]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @@ -48,7 +48,7 @@ def test_asum(dtype): expected = np.abs(x).sum() actual = asum(x) - assert_allclose(actual, expected) + assert_allclose(actual, expected, rtol=RTOL[dtype]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @@ -58,12 +58,12 @@ def test_axpy(dtype): rng = np.random.RandomState(0) x = rng.random_sample(10).astype(dtype, copy=False) y = rng.random_sample(10).astype(dtype, copy=False) - alpha = 1.23 + alpha = 2.5 expected = alpha * x + y axpy(alpha, x, y) - assert_allclose(y, expected) + assert_allclose(y, expected, rtol=RTOL[dtype]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @@ -76,7 +76,7 @@ def test_nrm2(dtype): expected = np.linalg.norm(x) actual = nrm2(x) - assert_allclose(actual, expected) + assert_allclose(actual, expected, rtol=RTOL[dtype]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @@ -90,7 +90,7 @@ def test_copy(dtype): expected = x.copy() copy(x, y) - assert_allclose(y, expected) + assert_allclose(y, expected, rtol=RTOL[dtype]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @@ -99,12 +99,12 @@ def test_scal(dtype): rng = np.random.RandomState(0) x = rng.random_sample(10).astype(dtype, copy=False) - alpha = 1.23 + alpha = 2.5 expected = alpha * x scal(alpha, x) - assert_allclose(x, expected) + assert_allclose(x, expected, rtol=RTOL[dtype]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @@ -120,7 +120,7 @@ def test_gemv(dtype, opA, transA, layout): order=layout) x = rng.random_sample(10).astype(dtype, copy=False) y = rng.random_sample(20).astype(dtype, copy=False) - alpha, beta = 1.23, -3.21 + alpha, beta = 2.5, -0.5 expected = alpha * opA(A).dot(x) + beta * y gemv(layout, transA, alpha, A, x, beta, y) @@ -138,7 +138,7 @@ def test_ger(dtype, layout): y = rng.random_sample(20).astype(dtype, copy=False) A = np.asarray(rng.random_sample((10, 20)).astype(dtype, copy=False), order=layout) - alpha = 1.23 + alpha = 2.5 expected = alpha * np.outer(x, y) + A ger(layout, alpha, x, y, A) @@ -164,7 +164,7 @@ def test_gemm(dtype, opA, transA, opB, transB, layout): order=layout) C = np.asarray(rng.random_sample((30, 20)).astype(dtype, copy=False), order=layout) - alpha, beta = 1.23, -3.21 + alpha, beta = 2.5, -0.5 expected = alpha * opA(A).dot(opB(B)) + beta * C gemm(layout, transA, transB, alpha, A, B, beta, C) From 24e0dd0235991c4f5d25fa9f9660dd94e2597b9e Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 19 Dec 2018 16:22:13 +0100 Subject: [PATCH 07/14] enum order & trans --- sklearn/utils/_cython_blas.pxd | 21 +++-- sklearn/utils/_cython_blas.pyx | 103 ++++++++++++------------ sklearn/utils/tests/test_cython_blas.py | 45 ++++++----- 3 files changed, 91 insertions(+), 78 deletions(-) diff --git a/sklearn/utils/_cython_blas.pxd b/sklearn/utils/_cython_blas.pxd index d245db9f9f350..3b3c25579f94f 100644 --- a/sklearn/utils/_cython_blas.pxd +++ b/sklearn/utils/_cython_blas.pxd @@ -3,6 +3,16 @@ from cython cimport floating +cpdef enum BLAS_Order: + RowMajor + ColMajor + + +cpdef enum BLAS_Trans: + Trans = 116 + NoTrans = 110 + + # BLAS Level 1 ################################################################ cdef floating _xdot(int, floating*, int, floating*, int) nogil cdef floating _xasum(int, floating*, int) nogil @@ -12,11 +22,12 @@ cdef void _xcopy(int, floating*, int, floating*, int) nogil cdef void _xscal(int, floating, floating*, int) nogil # BLAS Level 2 ################################################################ -cdef void _xgemv(char, char, int, int, floating, floating*, int, floating*, - int, floating, floating*, int) nogil -cdef void _xger(char, int, int, floating, floating*, int, floating*, int, +cdef void _xgemv(BLAS_Order, BLAS_Trans, int, int, floating, floating*, int, + floating*, int, floating, floating*, int) nogil +cdef void _xger(BLAS_Order, int, int, floating, floating*, int, floating*, int, floating*, int) nogil # BLASLevel 3 ################################################################ -cdef void _xgemm(char, char, char, int, int, int, floating, floating*, int, - floating*, int, floating, floating*, int) nogil +cdef void _xgemm(BLAS_Order, BLAS_Trans, BLAS_Trans, int, int, int, floating, + floating*, int, floating*, int, floating, floating*, + int) nogil diff --git a/sklearn/utils/_cython_blas.pyx b/sklearn/utils/_cython_blas.pyx index c68ab94dcd076..ada1bad79e94e 100644 --- a/sklearn/utils/_cython_blas.pyx +++ b/sklearn/utils/_cython_blas.pyx @@ -13,13 +13,6 @@ from scipy.linalg.cython_blas cimport sger, dger from scipy.linalg.cython_blas cimport sgemm, dgemm -cdef: - char ColMajor = b'F' - char RowMajor = b'C' - char Trans = b't' - char NoTrans = b'n' - - ################ # BLAS Level 1 # ################ @@ -102,101 +95,105 @@ cpdef _xscal_memview(floating alpha, floating[::1] x): # BLAS Level 2 # ################ -cdef void _xgemv(char layout, char ta, int m, int n, floating alpha, +cdef void _xgemv(BLAS_Order order, BLAS_Trans ta, int m, int n, floating alpha, floating *A, int lda, floating *x, int incx, floating beta, floating *y, int incy) nogil: """y := alpha * op(A).x + beta * y""" - if layout == RowMajor: - ta = NoTrans if ta == Trans else Trans + cdef char ta_ = ta + if order == RowMajor: + ta_ = NoTrans if ta == Trans else Trans if floating is float: - sgemv(&ta, &n, &m, &alpha, A, &lda, x, &incx, &beta, y, &incy) + sgemv(&ta_, &n, &m, &alpha, A, &lda, x, &incx, &beta, y, &incy) else: - dgemv(&ta, &n, &m, &alpha, A, &lda, x, &incx, &beta, y, &incy) - elif layout == ColMajor: + dgemv(&ta_, &n, &m, &alpha, A, &lda, x, &incx, &beta, y, &incy) + elif order == ColMajor: if floating is float: - sgemv(&ta, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy) + sgemv(&ta_, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy) else: - dgemv(&ta, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy) + dgemv(&ta_, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy) -cpdef _xgemv_memview(layout, ta, floating alpha, floating[:, :] A, - floating[::1] x, floating beta, floating[::1] y): +cpdef _xgemv_memview(BLAS_Order order, BLAS_Trans ta, floating alpha, + floating[:, :] A, floating[::1] x, floating beta, + floating[::1] y): cdef: - char layout_ = ColMajor if layout == 'F' else RowMajor - char ta_ = NoTrans if ta == 'n' else Trans int m = A.shape[0] int n = A.shape[1] - int lda = m if layout == 'F' else n + int lda = m if order == ColMajor else n - _xgemv(layout_, ta_, m, n, alpha, &A[0, 0], lda, - &x[0], 1, beta, &y[0], 1) + _xgemv(order, ta, m, n, alpha, &A[0, 0], lda, &x[0], 1, beta, &y[0], 1) -cdef void _xger(char layout, int m, int n, floating alpha, floating *x, +cdef void _xger(BLAS_Order order, int m, int n, floating alpha, floating *x, int incx, floating *y, int incy, floating *A, int lda) nogil: """A := alpha * x.y.T + A""" - if layout == RowMajor: + if order == RowMajor: if floating is float: sger(&n, &m, &alpha, y, &incy, x, &incx, A, &lda) else: dger(&n, &m, &alpha, y, &incy, x, &incx, A, &lda) - elif layout == ColMajor: + elif order == ColMajor: if floating is float: sger(&m, &n, &alpha, x, &incx, y, &incy, A, &lda) else: dger(&m, &n, &alpha, x, &incx, y, &incy, A, &lda) -cpdef _xger_memview(layout, floating alpha, floating[::1] x, floating[::] y, +cpdef _xger_memview(BLAS_Order order, floating alpha, floating[::1] x, floating[::] y, floating[:, :] A): cdef: - char layout_ = ColMajor if layout == 'F' else RowMajor + BLAS_Order order_ = ColMajor if order == ColMajor else RowMajor int m = A.shape[0] int n = A.shape[1] - int lda = m if layout == 'F' else n + int lda = m if order == ColMajor else n - _xger(layout_, m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda) + _xger(order_, m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda) ################ # BLAS Level 3 # ################ -cdef void _xgemm(char layout, char ta, char tb, int m, int n, int k, - floating alpha, floating *A, int lda, floating *B, int ldb, - floating beta, floating *C, int ldc) nogil: +cdef void _xgemm(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, int m, int n, + int k, floating alpha, floating *A, int lda, floating *B, + int ldb, floating beta, floating *C, int ldc) nogil: """C := alpha * op(A).op(B) + beta * C""" - if layout == RowMajor: + cdef: + char ta_ = ta + char tb_ = tb + if order == RowMajor: if floating is float: - sgemm(&tb, &ta, &n, &m, &k, &alpha, B, &ldb, A, &lda, &beta, C, &ldc) + sgemm(&tb_, &ta_, &n, &m, &k, &alpha, B, + &ldb, A, &lda, &beta, C, &ldc) else: - dgemm(&tb, &ta, &n, &m, &k, &alpha, B, &ldb, A, &lda, &beta, C, &ldc) - elif layout == ColMajor: + dgemm(&tb_, &ta_, &n, &m, &k, &alpha, B, + &ldb, A, &lda, &beta, C, &ldc) + elif order == ColMajor: if floating is float: - sgemm(&ta, &tb, &m, &n, &k, &alpha, A, &lda, B, &ldb, &beta, C, &ldc) + sgemm(&ta_, &tb_, &m, &n, &k, &alpha, A, + &lda, B, &ldb, &beta, C, &ldc) else: - dgemm(&ta, &tb, &m, &n, &k, &alpha, A, &lda, B, &ldb, &beta, C, &ldc) + dgemm(&ta_, &tb_, &m, &n, &k, &alpha, A, + &lda, B, &ldb, &beta, C, &ldc) -cpdef _xgemm_memview(layout, ta, tb, floating alpha, floating[:, :] A, - floating[:, :] B, floating beta, floating[:, :] C): +cpdef _xgemm_memview(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, + floating alpha, floating[:, :] A, floating[:, :] B, + floating beta, floating[:, :] C): cdef: - char layout_ = ColMajor if layout == 'F' else RowMajor - char ta_ = NoTrans if ta == 'n' else Trans - char tb_ = NoTrans if tb == 'n' else Trans - int m = A.shape[0] if ta == 'n' else A.shape[1] - int n = B.shape[1] if tb == 'n' else B.shape[0] - int k = A.shape[1] if ta == 'n' else A.shape[0] + int m = A.shape[0] if ta == NoTrans else A.shape[1] + int n = B.shape[1] if tb == NoTrans else B.shape[0] + int k = A.shape[1] if ta == NoTrans else A.shape[0] int lda, ldb, ldc - if layout == 'F': - lda = m if ta == 'n' else k - ldb = k if tb == 'n' else n + if order == ColMajor: + lda = m if ta == NoTrans else k + ldb = k if tb == NoTrans else n ldc = m else: - lda = k if ta == 'n' else m - ldb = n if tb == 'n' else k + lda = k if ta == NoTrans else m + ldb = n if tb == NoTrans else k ldc = n - _xgemm(layout_, ta_, tb_, m, n, k, alpha, - &A[0, 0], lda, &B[0, 0], ldb, beta, &C[0, 0], ldc) \ No newline at end of file + _xgemm(order, ta, tb, m, n, k, alpha, &A[0, 0], + lda, &B[0, 0], ldb, beta, &C[0, 0], ldc) \ No newline at end of file diff --git a/sklearn/utils/tests/test_cython_blas.py b/sklearn/utils/tests/test_cython_blas.py index 9cbcd2599dd73..84a23ad1320ee 100644 --- a/sklearn/utils/tests/test_cython_blas.py +++ b/sklearn/utils/tests/test_cython_blas.py @@ -14,6 +14,8 @@ from sklearn.utils._cython_blas import _xgemv_memview from sklearn.utils._cython_blas import _xger_memview from sklearn.utils._cython_blas import _xgemm_memview +from sklearn.utils._cython_blas import RowMajor, ColMajor +from sklearn.utils._cython_blas import Trans, NoTrans NUMPY_TO_CYTHON = {np.float32: cython.float, np.float64: cython.double} @@ -109,64 +111,67 @@ def test_scal(dtype): @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize("opA, transA", - [(_no_op, 'n'), (np.transpose, 't')], - ids=["A", "A.T"]) -@pytest.mark.parametrize("layout", ['C', 'F']) -def test_gemv(dtype, opA, transA, layout): + [(_no_op, NoTrans), (np.transpose, Trans)], + ids=["NoTrans", "Trans"]) +@pytest.mark.parametrize("order", [RowMajor, ColMajor], + ids=["RowMajor", "ColMajor"]) +def test_gemv(dtype, opA, transA, order): gemv = _xgemv_memview[NUMPY_TO_CYTHON[dtype]] rng = np.random.RandomState(0) A = np.asarray(opA(rng.random_sample((20, 10)).astype(dtype, copy=False)), - order=layout) + order=order) x = rng.random_sample(10).astype(dtype, copy=False) y = rng.random_sample(20).astype(dtype, copy=False) alpha, beta = 2.5, -0.5 expected = alpha * opA(A).dot(x) + beta * y - gemv(layout, transA, alpha, A, x, beta, y) + gemv(order, transA, alpha, A, x, beta, y) assert_allclose(y, expected, rtol=RTOL[dtype]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -@pytest.mark.parametrize("layout", ['C', 'F']) -def test_ger(dtype, layout): +@pytest.mark.parametrize("order", [RowMajor, ColMajor], + ids=["RowMajor", "ColMajor"]) +def test_ger(dtype, order): ger = _xger_memview[NUMPY_TO_CYTHON[dtype]] rng = np.random.RandomState(0) x = rng.random_sample(10).astype(dtype, copy=False) y = rng.random_sample(20).astype(dtype, copy=False) A = np.asarray(rng.random_sample((10, 20)).astype(dtype, copy=False), - order=layout) + order=order) alpha = 2.5 expected = alpha * np.outer(x, y) + A - ger(layout, alpha, x, y, A) + ger(order, alpha, x, y, A) assert_allclose(A, expected, rtol=RTOL[dtype]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize("opB, transB", - [(_no_op, 'n'), (np.transpose, 't')], - ids=["B", "B.T"]) + [(_no_op, NoTrans), (np.transpose, Trans)], + ids=["NoTrans", "Trans"]) @pytest.mark.parametrize("opA, transA", - [(_no_op, 'n'), (np.transpose, 't')], - ids=["A", "A.T"]) -@pytest.mark.parametrize("layout", ['C', 'F']) -def test_gemm(dtype, opA, transA, opB, transB, layout): + [(_no_op, NoTrans), (np.transpose, Trans)], + ids=["NoTrans", "Trans"]) +@pytest.mark.parametrize("order", [RowMajor, ColMajor], + ids=["RowMajor", "ColMajor"]) +def test_gemm(dtype, opA, transA, opB, transB, order): gemm = _xgemm_memview[NUMPY_TO_CYTHON[dtype]] rng = np.random.RandomState(0) A = np.asarray(opA(rng.random_sample((30, 10)).astype(dtype, copy=False)), - order=layout) + order=order) B = np.asarray(opB(rng.random_sample((10, 20)).astype(dtype, copy=False)), - order=layout) + order=order) C = np.asarray(rng.random_sample((30, 20)).astype(dtype, copy=False), - order=layout) + order=order) alpha, beta = 2.5, -0.5 expected = alpha * opA(A).dot(opB(B)) + beta * C - gemm(layout, transA, transB, alpha, A, B, beta, C) + gemm(order, transA, transB, alpha, A, B, beta, C) assert_allclose(C, expected, rtol=RTOL[dtype]) From df86d949a3691bfbab05f851bc4e2ef1e7cef662 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 19 Dec 2018 16:58:16 +0100 Subject: [PATCH 08/14] fix numpy order type --- sklearn/utils/_cython_blas.pxd | 4 ++-- sklearn/utils/tests/test_cython_blas.py | 11 ++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/sklearn/utils/_cython_blas.pxd b/sklearn/utils/_cython_blas.pxd index 3b3c25579f94f..8908bee01985e 100644 --- a/sklearn/utils/_cython_blas.pxd +++ b/sklearn/utils/_cython_blas.pxd @@ -9,8 +9,8 @@ cpdef enum BLAS_Order: cpdef enum BLAS_Trans: - Trans = 116 - NoTrans = 110 + Trans = 116 # correspond to 'n' + NoTrans = 110 # correspond to 't' # BLAS Level 1 ################################################################ diff --git a/sklearn/utils/tests/test_cython_blas.py b/sklearn/utils/tests/test_cython_blas.py index 84a23ad1320ee..5479408e944cd 100644 --- a/sklearn/utils/tests/test_cython_blas.py +++ b/sklearn/utils/tests/test_cython_blas.py @@ -20,6 +20,7 @@ NUMPY_TO_CYTHON = {np.float32: cython.float, np.float64: cython.double} RTOL = {np.float32: 1e-6, np.float64: 1e-12} +ORDER = {RowMajor: 'C', ColMajor: 'F'} def _no_op(x): @@ -120,7 +121,7 @@ def test_gemv(dtype, opA, transA, order): rng = np.random.RandomState(0) A = np.asarray(opA(rng.random_sample((20, 10)).astype(dtype, copy=False)), - order=order) + order=ORDER[order]) x = rng.random_sample(10).astype(dtype, copy=False) y = rng.random_sample(20).astype(dtype, copy=False) alpha, beta = 2.5, -0.5 @@ -141,7 +142,7 @@ def test_ger(dtype, order): x = rng.random_sample(10).astype(dtype, copy=False) y = rng.random_sample(20).astype(dtype, copy=False) A = np.asarray(rng.random_sample((10, 20)).astype(dtype, copy=False), - order=order) + order=ORDER[order]) alpha = 2.5 expected = alpha * np.outer(x, y) + A @@ -164,11 +165,11 @@ def test_gemm(dtype, opA, transA, opB, transB, order): rng = np.random.RandomState(0) A = np.asarray(opA(rng.random_sample((30, 10)).astype(dtype, copy=False)), - order=order) + order=ORDER[order]) B = np.asarray(opB(rng.random_sample((10, 20)).astype(dtype, copy=False)), - order=order) + order=ORDER[order]) C = np.asarray(rng.random_sample((30, 20)).astype(dtype, copy=False), - order=order) + order=ORDER[order]) alpha, beta = 2.5, -0.5 expected = alpha * opA(A).dot(opB(B)) + beta * C From 9372276c79156ca61ddc63ac15e0d03e7b54521d Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 19 Dec 2018 17:42:29 +0100 Subject: [PATCH 09/14] change blas functions names --- sklearn/metrics/pairwise_fast.pyx | 4 +- sklearn/utils/_cython_blas.pxd | 18 ++++----- sklearn/utils/_cython_blas.pyx | 54 ++++++++++++------------- sklearn/utils/tests/test_cython_blas.py | 36 ++++++++--------- 4 files changed, 56 insertions(+), 56 deletions(-) diff --git a/sklearn/metrics/pairwise_fast.pyx b/sklearn/metrics/pairwise_fast.pyx index 901bedb145f15..d465ab88fa6e3 100644 --- a/sklearn/metrics/pairwise_fast.pyx +++ b/sklearn/metrics/pairwise_fast.pyx @@ -13,7 +13,7 @@ cimport numpy as np from cython cimport floating -from ..utils._cython_blas cimport _xasum +from ..utils._cython_blas cimport _asum np.import_array() @@ -66,4 +66,4 @@ def _sparse_manhattan(floating[::1] X_data, int[:] X_indices, int[:] X_indptr, for j in range(Y_indptr[iy], Y_indptr[iy + 1]): row[Y_indices[j]] -= Y_data[j] - D[ix, iy] = _xasum(n_features, &row[0], 1) + D[ix, iy] = _asum(n_features, &row[0], 1) diff --git a/sklearn/utils/_cython_blas.pxd b/sklearn/utils/_cython_blas.pxd index 8908bee01985e..3968ae0f69c9e 100644 --- a/sklearn/utils/_cython_blas.pxd +++ b/sklearn/utils/_cython_blas.pxd @@ -14,20 +14,20 @@ cpdef enum BLAS_Trans: # BLAS Level 1 ################################################################ -cdef floating _xdot(int, floating*, int, floating*, int) nogil -cdef floating _xasum(int, floating*, int) nogil -cdef void _xaxpy(int, floating, floating*, int, floating*, int) nogil -cdef floating _xnrm2(int, floating*, int) nogil -cdef void _xcopy(int, floating*, int, floating*, int) nogil -cdef void _xscal(int, floating, floating*, int) nogil +cdef floating _dot(int, floating*, int, floating*, int) nogil +cdef floating _asum(int, floating*, int) nogil +cdef void _axpy(int, floating, floating*, int, floating*, int) nogil +cdef floating _nrm2(int, floating*, int) nogil +cdef void _copy(int, floating*, int, floating*, int) nogil +cdef void _scal(int, floating, floating*, int) nogil # BLAS Level 2 ################################################################ -cdef void _xgemv(BLAS_Order, BLAS_Trans, int, int, floating, floating*, int, +cdef void _gemv(BLAS_Order, BLAS_Trans, int, int, floating, floating*, int, floating*, int, floating, floating*, int) nogil -cdef void _xger(BLAS_Order, int, int, floating, floating*, int, floating*, int, +cdef void _ger(BLAS_Order, int, int, floating, floating*, int, floating*, int, floating*, int) nogil # BLASLevel 3 ################################################################ -cdef void _xgemm(BLAS_Order, BLAS_Trans, BLAS_Trans, int, int, int, floating, +cdef void _gemm(BLAS_Order, BLAS_Trans, BLAS_Trans, int, int, int, floating, floating*, int, floating*, int, floating, floating*, int) nogil diff --git a/sklearn/utils/_cython_blas.pyx b/sklearn/utils/_cython_blas.pyx index ada1bad79e94e..d1294b219ead4 100644 --- a/sklearn/utils/_cython_blas.pyx +++ b/sklearn/utils/_cython_blas.pyx @@ -17,7 +17,7 @@ from scipy.linalg.cython_blas cimport sgemm, dgemm # BLAS Level 1 # ################ -cdef floating _xdot(int n, floating *x, int incx, +cdef floating _dot(int n, floating *x, int incx, floating *y, int incy) nogil: """x.T.y""" if floating is float: @@ -26,11 +26,11 @@ cdef floating _xdot(int n, floating *x, int incx, return ddot(&n, x, &incx, y, &incy) -cpdef _xdot_memview(floating[::1] x, floating[::1] y): - return _xdot(x.shape[0], &x[0], 1, &y[0], 1) +cpdef _dot_memview(floating[::1] x, floating[::1] y): + return _dot(x.shape[0], &x[0], 1, &y[0], 1) -cdef floating _xasum(int n, floating *x, int incx) nogil: +cdef floating _asum(int n, floating *x, int incx) nogil: """sum(|x_i|)""" if floating is float: return sasum(&n, x, &incx) @@ -38,11 +38,11 @@ cdef floating _xasum(int n, floating *x, int incx) nogil: return dasum(&n, x, &incx) -cpdef _xasum_memview(floating[::1] x): - return _xasum(x.shape[0], &x[0], 1) +cpdef _asum_memview(floating[::1] x): + return _asum(x.shape[0], &x[0], 1) -cdef void _xaxpy(int n, floating alpha, floating *x, int incx, +cdef void _axpy(int n, floating alpha, floating *x, int incx, floating *y, int incy) nogil: """y := alpha * x + y""" if floating is float: @@ -51,11 +51,11 @@ cdef void _xaxpy(int n, floating alpha, floating *x, int incx, daxpy(&n, &alpha, x, &incx, y, &incy) -cpdef _xaxpy_memview(floating alpha, floating[::1] x, floating[::1] y): - _xaxpy(x.shape[0], alpha, &x[0], 1, &y[0], 1) +cpdef _axpy_memview(floating alpha, floating[::1] x, floating[::1] y): + _axpy(x.shape[0], alpha, &x[0], 1, &y[0], 1) -cdef floating _xnrm2(int n, floating *x, int incx) nogil: +cdef floating _nrm2(int n, floating *x, int incx) nogil: """sqrt(sum((x_i)^2))""" if floating is float: return snrm2(&n, x, &incx) @@ -63,11 +63,11 @@ cdef floating _xnrm2(int n, floating *x, int incx) nogil: return dnrm2(&n, x, &incx) -cpdef _xnrm2_memview(floating[::1] x): - return _xnrm2(x.shape[0], &x[0], 1) +cpdef _nrm2_memview(floating[::1] x): + return _nrm2(x.shape[0], &x[0], 1) -cdef void _xcopy(int n, floating *x, int incx, floating *y, int incy) nogil: +cdef void _copy(int n, floating *x, int incx, floating *y, int incy) nogil: """y := x""" if floating is float: scopy(&n, x, &incx, y, &incy) @@ -75,11 +75,11 @@ cdef void _xcopy(int n, floating *x, int incx, floating *y, int incy) nogil: dcopy(&n, x, &incx, y, &incy) -cpdef _xcopy_memview(floating[::1] x, floating[::1] y): - _xcopy(x.shape[0], &x[0], 1, &y[0], 1) +cpdef _copy_memview(floating[::1] x, floating[::1] y): + _copy(x.shape[0], &x[0], 1, &y[0], 1) -cdef void _xscal(int n, floating alpha, floating *x, int incx) nogil: +cdef void _scal(int n, floating alpha, floating *x, int incx) nogil: """x := alpha * x""" if floating is float: sscal(&n, &alpha, x, &incx) @@ -87,15 +87,15 @@ cdef void _xscal(int n, floating alpha, floating *x, int incx) nogil: dscal(&n, &alpha, x, &incx) -cpdef _xscal_memview(floating alpha, floating[::1] x): - _xscal(x.shape[0], alpha, &x[0], 1) +cpdef _scal_memview(floating alpha, floating[::1] x): + _scal(x.shape[0], alpha, &x[0], 1) ################ # BLAS Level 2 # ################ -cdef void _xgemv(BLAS_Order order, BLAS_Trans ta, int m, int n, floating alpha, +cdef void _gemv(BLAS_Order order, BLAS_Trans ta, int m, int n, floating alpha, floating *A, int lda, floating *x, int incx, floating beta, floating *y, int incy) nogil: """y := alpha * op(A).x + beta * y""" @@ -113,7 +113,7 @@ cdef void _xgemv(BLAS_Order order, BLAS_Trans ta, int m, int n, floating alpha, dgemv(&ta_, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy) -cpdef _xgemv_memview(BLAS_Order order, BLAS_Trans ta, floating alpha, +cpdef _gemv_memview(BLAS_Order order, BLAS_Trans ta, floating alpha, floating[:, :] A, floating[::1] x, floating beta, floating[::1] y): cdef: @@ -121,10 +121,10 @@ cpdef _xgemv_memview(BLAS_Order order, BLAS_Trans ta, floating alpha, int n = A.shape[1] int lda = m if order == ColMajor else n - _xgemv(order, ta, m, n, alpha, &A[0, 0], lda, &x[0], 1, beta, &y[0], 1) + _gemv(order, ta, m, n, alpha, &A[0, 0], lda, &x[0], 1, beta, &y[0], 1) -cdef void _xger(BLAS_Order order, int m, int n, floating alpha, floating *x, +cdef void _ger(BLAS_Order order, int m, int n, floating alpha, floating *x, int incx, floating *y, int incy, floating *A, int lda) nogil: """A := alpha * x.y.T + A""" if order == RowMajor: @@ -139,7 +139,7 @@ cdef void _xger(BLAS_Order order, int m, int n, floating alpha, floating *x, dger(&m, &n, &alpha, x, &incx, y, &incy, A, &lda) -cpdef _xger_memview(BLAS_Order order, floating alpha, floating[::1] x, floating[::] y, +cpdef _ger_memview(BLAS_Order order, floating alpha, floating[::1] x, floating[::] y, floating[:, :] A): cdef: BLAS_Order order_ = ColMajor if order == ColMajor else RowMajor @@ -147,14 +147,14 @@ cpdef _xger_memview(BLAS_Order order, floating alpha, floating[::1] x, floating[ int n = A.shape[1] int lda = m if order == ColMajor else n - _xger(order_, m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda) + _ger(order_, m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda) ################ # BLAS Level 3 # ################ -cdef void _xgemm(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, int m, int n, +cdef void _gemm(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, int m, int n, int k, floating alpha, floating *A, int lda, floating *B, int ldb, floating beta, floating *C, int ldc) nogil: """C := alpha * op(A).op(B) + beta * C""" @@ -177,7 +177,7 @@ cdef void _xgemm(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, int m, int n, &lda, B, &ldb, &beta, C, &ldc) -cpdef _xgemm_memview(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, +cpdef _gemm_memview(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, floating alpha, floating[:, :] A, floating[:, :] B, floating beta, floating[:, :] C): cdef: @@ -195,5 +195,5 @@ cpdef _xgemm_memview(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, ldb = n if tb == NoTrans else k ldc = n - _xgemm(order, ta, tb, m, n, k, alpha, &A[0, 0], + _gemm(order, ta, tb, m, n, k, alpha, &A[0, 0], lda, &B[0, 0], ldb, beta, &C[0, 0], ldc) \ No newline at end of file diff --git a/sklearn/utils/tests/test_cython_blas.py b/sklearn/utils/tests/test_cython_blas.py index 5479408e944cd..1df345fd32e24 100644 --- a/sklearn/utils/tests/test_cython_blas.py +++ b/sklearn/utils/tests/test_cython_blas.py @@ -5,15 +5,15 @@ from sklearn.utils.testing import assert_allclose -from sklearn.utils._cython_blas import _xdot_memview -from sklearn.utils._cython_blas import _xasum_memview -from sklearn.utils._cython_blas import _xaxpy_memview -from sklearn.utils._cython_blas import _xnrm2_memview -from sklearn.utils._cython_blas import _xcopy_memview -from sklearn.utils._cython_blas import _xscal_memview -from sklearn.utils._cython_blas import _xgemv_memview -from sklearn.utils._cython_blas import _xger_memview -from sklearn.utils._cython_blas import _xgemm_memview +from sklearn.utils._cython_blas import _dot_memview +from sklearn.utils._cython_blas import _asum_memview +from sklearn.utils._cython_blas import _axpy_memview +from sklearn.utils._cython_blas import _nrm2_memview +from sklearn.utils._cython_blas import _copy_memview +from sklearn.utils._cython_blas import _scal_memview +from sklearn.utils._cython_blas import _gemv_memview +from sklearn.utils._cython_blas import _ger_memview +from sklearn.utils._cython_blas import _gemm_memview from sklearn.utils._cython_blas import RowMajor, ColMajor from sklearn.utils._cython_blas import Trans, NoTrans @@ -29,7 +29,7 @@ def _no_op(x): @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_dot(dtype): - dot = _xdot_memview[NUMPY_TO_CYTHON[dtype]] + dot = _dot_memview[NUMPY_TO_CYTHON[dtype]] rng = np.random.RandomState(0) x = rng.random_sample(10).astype(dtype, copy=False) @@ -43,7 +43,7 @@ def test_dot(dtype): @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_asum(dtype): - asum = _xasum_memview[NUMPY_TO_CYTHON[dtype]] + asum = _asum_memview[NUMPY_TO_CYTHON[dtype]] rng = np.random.RandomState(0) x = rng.random_sample(10).astype(dtype, copy=False) @@ -56,7 +56,7 @@ def test_asum(dtype): @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_axpy(dtype): - axpy = _xaxpy_memview[NUMPY_TO_CYTHON[dtype]] + axpy = _axpy_memview[NUMPY_TO_CYTHON[dtype]] rng = np.random.RandomState(0) x = rng.random_sample(10).astype(dtype, copy=False) @@ -71,7 +71,7 @@ def test_axpy(dtype): @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_nrm2(dtype): - nrm2 = _xnrm2_memview[NUMPY_TO_CYTHON[dtype]] + nrm2 = _nrm2_memview[NUMPY_TO_CYTHON[dtype]] rng = np.random.RandomState(0) x = rng.random_sample(10).astype(dtype, copy=False) @@ -84,7 +84,7 @@ def test_nrm2(dtype): @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_copy(dtype): - copy = _xcopy_memview[NUMPY_TO_CYTHON[dtype]] + copy = _copy_memview[NUMPY_TO_CYTHON[dtype]] rng = np.random.RandomState(0) x = rng.random_sample(10).astype(dtype, copy=False) @@ -98,7 +98,7 @@ def test_copy(dtype): @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_scal(dtype): - scal = _xscal_memview[NUMPY_TO_CYTHON[dtype]] + scal = _scal_memview[NUMPY_TO_CYTHON[dtype]] rng = np.random.RandomState(0) x = rng.random_sample(10).astype(dtype, copy=False) @@ -117,7 +117,7 @@ def test_scal(dtype): @pytest.mark.parametrize("order", [RowMajor, ColMajor], ids=["RowMajor", "ColMajor"]) def test_gemv(dtype, opA, transA, order): - gemv = _xgemv_memview[NUMPY_TO_CYTHON[dtype]] + gemv = _gemv_memview[NUMPY_TO_CYTHON[dtype]] rng = np.random.RandomState(0) A = np.asarray(opA(rng.random_sample((20, 10)).astype(dtype, copy=False)), @@ -136,7 +136,7 @@ def test_gemv(dtype, opA, transA, order): @pytest.mark.parametrize("order", [RowMajor, ColMajor], ids=["RowMajor", "ColMajor"]) def test_ger(dtype, order): - ger = _xger_memview[NUMPY_TO_CYTHON[dtype]] + ger = _ger_memview[NUMPY_TO_CYTHON[dtype]] rng = np.random.RandomState(0) x = rng.random_sample(10).astype(dtype, copy=False) @@ -161,7 +161,7 @@ def test_ger(dtype, order): @pytest.mark.parametrize("order", [RowMajor, ColMajor], ids=["RowMajor", "ColMajor"]) def test_gemm(dtype, opA, transA, opB, transB, order): - gemm = _xgemm_memview[NUMPY_TO_CYTHON[dtype]] + gemm = _gemm_memview[NUMPY_TO_CYTHON[dtype]] rng = np.random.RandomState(0) A = np.asarray(opA(rng.random_sample((30, 10)).astype(dtype, copy=False)), From 18686487b133921d2da18f26e47ab4e93b1ad8b7 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 20 Dec 2018 11:31:35 +0100 Subject: [PATCH 10/14] flake8 --- sklearn/utils/_cython_blas.pxd | 8 ++++---- sklearn/utils/_cython_blas.pyx | 32 ++++++++++++++++---------------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/sklearn/utils/_cython_blas.pxd b/sklearn/utils/_cython_blas.pxd index 3968ae0f69c9e..dc0a4d89c3958 100644 --- a/sklearn/utils/_cython_blas.pxd +++ b/sklearn/utils/_cython_blas.pxd @@ -23,11 +23,11 @@ cdef void _scal(int, floating, floating*, int) nogil # BLAS Level 2 ################################################################ cdef void _gemv(BLAS_Order, BLAS_Trans, int, int, floating, floating*, int, - floating*, int, floating, floating*, int) nogil + floating*, int, floating, floating*, int) nogil cdef void _ger(BLAS_Order, int, int, floating, floating*, int, floating*, int, - floating*, int) nogil + floating*, int) nogil # BLASLevel 3 ################################################################ cdef void _gemm(BLAS_Order, BLAS_Trans, BLAS_Trans, int, int, int, floating, - floating*, int, floating*, int, floating, floating*, - int) nogil + floating*, int, floating*, int, floating, floating*, + int) nogil diff --git a/sklearn/utils/_cython_blas.pyx b/sklearn/utils/_cython_blas.pyx index d1294b219ead4..4f1594d7baeff 100644 --- a/sklearn/utils/_cython_blas.pyx +++ b/sklearn/utils/_cython_blas.pyx @@ -18,7 +18,7 @@ from scipy.linalg.cython_blas cimport sgemm, dgemm ################ cdef floating _dot(int n, floating *x, int incx, - floating *y, int incy) nogil: + floating *y, int incy) nogil: """x.T.y""" if floating is float: return sdot(&n, x, &incx, y, &incy) @@ -43,7 +43,7 @@ cpdef _asum_memview(floating[::1] x): cdef void _axpy(int n, floating alpha, floating *x, int incx, - floating *y, int incy) nogil: + floating *y, int incy) nogil: """y := alpha * x + y""" if floating is float: saxpy(&n, &alpha, x, &incx, y, &incy) @@ -96,12 +96,12 @@ cpdef _scal_memview(floating alpha, floating[::1] x): ################ cdef void _gemv(BLAS_Order order, BLAS_Trans ta, int m, int n, floating alpha, - floating *A, int lda, floating *x, int incx, - floating beta, floating *y, int incy) nogil: + floating *A, int lda, floating *x, int incx, + floating beta, floating *y, int incy) nogil: """y := alpha * op(A).x + beta * y""" cdef char ta_ = ta if order == RowMajor: - ta_ = NoTrans if ta == Trans else Trans + ta_ = NoTrans if ta == Trans else Trans if floating is float: sgemv(&ta_, &n, &m, &alpha, A, &lda, x, &incx, &beta, y, &incy) else: @@ -114,8 +114,8 @@ cdef void _gemv(BLAS_Order order, BLAS_Trans ta, int m, int n, floating alpha, cpdef _gemv_memview(BLAS_Order order, BLAS_Trans ta, floating alpha, - floating[:, :] A, floating[::1] x, floating beta, - floating[::1] y): + floating[:, :] A, floating[::1] x, floating beta, + floating[::1] y): cdef: int m = A.shape[0] int n = A.shape[1] @@ -125,7 +125,7 @@ cpdef _gemv_memview(BLAS_Order order, BLAS_Trans ta, floating alpha, cdef void _ger(BLAS_Order order, int m, int n, floating alpha, floating *x, - int incx, floating *y, int incy, floating *A, int lda) nogil: + int incx, floating *y, int incy, floating *A, int lda) nogil: """A := alpha * x.y.T + A""" if order == RowMajor: if floating is float: @@ -139,14 +139,14 @@ cdef void _ger(BLAS_Order order, int m, int n, floating alpha, floating *x, dger(&m, &n, &alpha, x, &incx, y, &incy, A, &lda) -cpdef _ger_memview(BLAS_Order order, floating alpha, floating[::1] x, floating[::] y, - floating[:, :] A): +cpdef _ger_memview(BLAS_Order order, floating alpha, floating[::1] x, + floating[::] y, floating[:, :] A): cdef: BLAS_Order order_ = ColMajor if order == ColMajor else RowMajor int m = A.shape[0] int n = A.shape[1] int lda = m if order == ColMajor else n - + _ger(order_, m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda) @@ -155,8 +155,8 @@ cpdef _ger_memview(BLAS_Order order, floating alpha, floating[::1] x, floating[: ################ cdef void _gemm(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, int m, int n, - int k, floating alpha, floating *A, int lda, floating *B, - int ldb, floating beta, floating *C, int ldc) nogil: + int k, floating alpha, floating *A, int lda, floating *B, + int ldb, floating beta, floating *C, int ldc) nogil: """C := alpha * op(A).op(B) + beta * C""" cdef: char ta_ = ta @@ -178,8 +178,8 @@ cdef void _gemm(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, int m, int n, cpdef _gemm_memview(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, - floating alpha, floating[:, :] A, floating[:, :] B, - floating beta, floating[:, :] C): + floating alpha, floating[:, :] A, floating[:, :] B, + floating beta, floating[:, :] C): cdef: int m = A.shape[0] if ta == NoTrans else A.shape[1] int n = B.shape[1] if tb == NoTrans else B.shape[0] @@ -196,4 +196,4 @@ cpdef _gemm_memview(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, ldc = n _gemm(order, ta, tb, m, n, k, alpha, &A[0, 0], - lda, &B[0, 0], ldb, beta, &C[0, 0], ldc) \ No newline at end of file + lda, &B[0, 0], ldb, beta, &C[0, 0], ldc) From 12b3660b2498fcba3be65176a0fb71b401736d22 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 20 Dec 2018 11:37:55 +0100 Subject: [PATCH 11/14] clean up --- sklearn/utils/_cython_blas.pxd | 14 ++++++++++---- sklearn/utils/_cython_blas.pyx | 2 -- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/sklearn/utils/_cython_blas.pxd b/sklearn/utils/_cython_blas.pxd index dc0a4d89c3958..4d82c7b1aaf13 100644 --- a/sklearn/utils/_cython_blas.pxd +++ b/sklearn/utils/_cython_blas.pxd @@ -4,26 +4,32 @@ from cython cimport floating cpdef enum BLAS_Order: - RowMajor - ColMajor + RowMajor # C contiguous + ColMajor # Fortran contiguous cpdef enum BLAS_Trans: - Trans = 116 # correspond to 'n' - NoTrans = 110 # correspond to 't' + NoTrans = 110 # correspond to 'n' + Trans = 116 # correspond to 't' # BLAS Level 1 ################################################################ cdef floating _dot(int, floating*, int, floating*, int) nogil + cdef floating _asum(int, floating*, int) nogil + cdef void _axpy(int, floating, floating*, int, floating*, int) nogil + cdef floating _nrm2(int, floating*, int) nogil + cdef void _copy(int, floating*, int, floating*, int) nogil + cdef void _scal(int, floating, floating*, int) nogil # BLAS Level 2 ################################################################ cdef void _gemv(BLAS_Order, BLAS_Trans, int, int, floating, floating*, int, floating*, int, floating, floating*, int) nogil + cdef void _ger(BLAS_Order, int, int, floating, floating*, int, floating*, int, floating*, int) nogil diff --git a/sklearn/utils/_cython_blas.pyx b/sklearn/utils/_cython_blas.pyx index 4f1594d7baeff..5e8790d8de1c4 100644 --- a/sklearn/utils/_cython_blas.pyx +++ b/sklearn/utils/_cython_blas.pyx @@ -1,5 +1,3 @@ -# cython: boundscheck=False, wraparound=False, cdivision=True - from cython cimport floating from scipy.linalg.cython_blas cimport sdot, ddot From 86556d9d694313ca912bb6c397759c8a0a3d0a07 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 14 Jan 2019 15:10:43 +0100 Subject: [PATCH 12/14] what's new --- doc/whats_new/v0.21.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index 2f359ca87463f..92ef22d1b1228 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -155,6 +155,10 @@ Support for Python 3.4 and below has been officially dropped. and now it returns NaN and raises :class:`exceptions.UndefinedMetricWarning`. :issue:`12855` by :user:`Pawel Sendyk .` +- |Efficiency| The pairwise manhattan distances with sparse input now uses the + BLAS shipped with scipy instead of the bundled BLAS. :issue:`12732` by + :user:`Jérémie du Boisberranger ` + :mod:`sklearn.model_selection` .............................. From 0019d250b4f0bf7cce6f010e9f45a6173bad1f95 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 31 Jan 2019 11:36:24 +0100 Subject: [PATCH 13/14] infer memory layout --- sklearn/utils/_cython_blas.pyx | 37 +++++++++++++------------ sklearn/utils/tests/test_cython_blas.py | 6 ++-- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/sklearn/utils/_cython_blas.pyx b/sklearn/utils/_cython_blas.pyx index 5e8790d8de1c4..7585105227f9f 100644 --- a/sklearn/utils/_cython_blas.pyx +++ b/sklearn/utils/_cython_blas.pyx @@ -104,19 +104,19 @@ cdef void _gemv(BLAS_Order order, BLAS_Trans ta, int m, int n, floating alpha, sgemv(&ta_, &n, &m, &alpha, A, &lda, x, &incx, &beta, y, &incy) else: dgemv(&ta_, &n, &m, &alpha, A, &lda, x, &incx, &beta, y, &incy) - elif order == ColMajor: + else: if floating is float: sgemv(&ta_, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy) else: dgemv(&ta_, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy) -cpdef _gemv_memview(BLAS_Order order, BLAS_Trans ta, floating alpha, - floating[:, :] A, floating[::1] x, floating beta, - floating[::1] y): +cpdef _gemv_memview(BLAS_Trans ta, floating alpha, floating[:, :] A, + floating[::1] x, floating beta, floating[::1] y): cdef: int m = A.shape[0] int n = A.shape[1] + BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor int lda = m if order == ColMajor else n _gemv(order, ta, m, n, alpha, &A[0, 0], lda, &x[0], 1, beta, &y[0], 1) @@ -130,22 +130,22 @@ cdef void _ger(BLAS_Order order, int m, int n, floating alpha, floating *x, sger(&n, &m, &alpha, y, &incy, x, &incx, A, &lda) else: dger(&n, &m, &alpha, y, &incy, x, &incx, A, &lda) - elif order == ColMajor: + else: if floating is float: sger(&m, &n, &alpha, x, &incx, y, &incy, A, &lda) else: dger(&m, &n, &alpha, x, &incx, y, &incy, A, &lda) -cpdef _ger_memview(BLAS_Order order, floating alpha, floating[::1] x, - floating[::] y, floating[:, :] A): +cpdef _ger_memview(floating alpha, floating[::1] x, floating[::] y, + floating[:, :] A): cdef: - BLAS_Order order_ = ColMajor if order == ColMajor else RowMajor int m = A.shape[0] int n = A.shape[1] + BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor int lda = m if order == ColMajor else n - _ger(order_, m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda) + _ger(order, m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda) ################ @@ -166,7 +166,7 @@ cdef void _gemm(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, int m, int n, else: dgemm(&tb_, &ta_, &n, &m, &k, &alpha, B, &ldb, A, &lda, &beta, C, &ldc) - elif order == ColMajor: + else: if floating is float: sgemm(&ta_, &tb_, &m, &n, &k, &alpha, A, &lda, B, &ldb, &beta, C, &ldc) @@ -175,23 +175,24 @@ cdef void _gemm(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, int m, int n, &lda, B, &ldb, &beta, C, &ldc) -cpdef _gemm_memview(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, - floating alpha, floating[:, :] A, floating[:, :] B, - floating beta, floating[:, :] C): +cpdef _gemm_memview(BLAS_Trans ta, BLAS_Trans tb, floating alpha, + floating[:, :] A, floating[:, :] B, floating beta, + floating[:, :] C): cdef: int m = A.shape[0] if ta == NoTrans else A.shape[1] int n = B.shape[1] if tb == NoTrans else B.shape[0] int k = A.shape[1] if ta == NoTrans else A.shape[0] int lda, ldb, ldc + BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor - if order == ColMajor: - lda = m if ta == NoTrans else k - ldb = k if tb == NoTrans else n - ldc = m - else: + if order == RowMajor: lda = k if ta == NoTrans else m ldb = n if tb == NoTrans else k ldc = n + else: + lda = m if ta == NoTrans else k + ldb = k if tb == NoTrans else n + ldc = m _gemm(order, ta, tb, m, n, k, alpha, &A[0, 0], lda, &B[0, 0], ldb, beta, &C[0, 0], ldc) diff --git a/sklearn/utils/tests/test_cython_blas.py b/sklearn/utils/tests/test_cython_blas.py index 1df345fd32e24..0305e5a5476dc 100644 --- a/sklearn/utils/tests/test_cython_blas.py +++ b/sklearn/utils/tests/test_cython_blas.py @@ -127,7 +127,7 @@ def test_gemv(dtype, opA, transA, order): alpha, beta = 2.5, -0.5 expected = alpha * opA(A).dot(x) + beta * y - gemv(order, transA, alpha, A, x, beta, y) + gemv(transA, alpha, A, x, beta, y) assert_allclose(y, expected, rtol=RTOL[dtype]) @@ -146,7 +146,7 @@ def test_ger(dtype, order): alpha = 2.5 expected = alpha * np.outer(x, y) + A - ger(order, alpha, x, y, A) + ger(alpha, x, y, A) assert_allclose(A, expected, rtol=RTOL[dtype]) @@ -173,6 +173,6 @@ def test_gemm(dtype, opA, transA, opB, transB, order): alpha, beta = 2.5, -0.5 expected = alpha * opA(A).dot(opB(B)) + beta * C - gemm(order, transA, transB, alpha, A, B, beta, C) + gemm(transA, transB, alpha, A, B, beta, C) assert_allclose(C, expected, rtol=RTOL[dtype]) From 3564ffa2157c8437a35684de4368bc992d01d57c Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 1 Feb 2019 12:05:18 +0100 Subject: [PATCH 14/14] remove blank line --- sklearn/metrics/pairwise_fast.pyx | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/metrics/pairwise_fast.pyx b/sklearn/metrics/pairwise_fast.pyx index d465ab88fa6e3..76ab64d9cd987 100644 --- a/sklearn/metrics/pairwise_fast.pyx +++ b/sklearn/metrics/pairwise_fast.pyx @@ -12,7 +12,6 @@ import numpy as np cimport numpy as np from cython cimport floating - from ..utils._cython_blas cimport _asum