|
| 1 | +from cython cimport floating |
| 2 | + |
| 3 | +from scipy.linalg.cython_blas cimport sdot, ddot |
| 4 | +from scipy.linalg.cython_blas cimport sasum, dasum |
| 5 | +from scipy.linalg.cython_blas cimport saxpy, daxpy |
| 6 | +from scipy.linalg.cython_blas cimport snrm2, dnrm2 |
| 7 | +from scipy.linalg.cython_blas cimport scopy, dcopy |
| 8 | +from scipy.linalg.cython_blas cimport sscal, dscal |
| 9 | +from scipy.linalg.cython_blas cimport sgemv, dgemv |
| 10 | +from scipy.linalg.cython_blas cimport sger, dger |
| 11 | +from scipy.linalg.cython_blas cimport sgemm, dgemm |
| 12 | + |
| 13 | + |
| 14 | +################ |
| 15 | +# BLAS Level 1 # |
| 16 | +################ |
| 17 | + |
| 18 | +cdef floating _dot(int n, floating *x, int incx, |
| 19 | + floating *y, int incy) nogil: |
| 20 | + """x.T.y""" |
| 21 | + if floating is float: |
| 22 | + return sdot(&n, x, &incx, y, &incy) |
| 23 | + else: |
| 24 | + return ddot(&n, x, &incx, y, &incy) |
| 25 | + |
| 26 | + |
| 27 | +cpdef _dot_memview(floating[::1] x, floating[::1] y): |
| 28 | + return _dot(x.shape[0], &x[0], 1, &y[0], 1) |
| 29 | + |
| 30 | + |
| 31 | +cdef floating _asum(int n, floating *x, int incx) nogil: |
| 32 | + """sum(|x_i|)""" |
| 33 | + if floating is float: |
| 34 | + return sasum(&n, x, &incx) |
| 35 | + else: |
| 36 | + return dasum(&n, x, &incx) |
| 37 | + |
| 38 | + |
| 39 | +cpdef _asum_memview(floating[::1] x): |
| 40 | + return _asum(x.shape[0], &x[0], 1) |
| 41 | + |
| 42 | + |
| 43 | +cdef void _axpy(int n, floating alpha, floating *x, int incx, |
| 44 | + floating *y, int incy) nogil: |
| 45 | + """y := alpha * x + y""" |
| 46 | + if floating is float: |
| 47 | + saxpy(&n, &alpha, x, &incx, y, &incy) |
| 48 | + else: |
| 49 | + daxpy(&n, &alpha, x, &incx, y, &incy) |
| 50 | + |
| 51 | + |
| 52 | +cpdef _axpy_memview(floating alpha, floating[::1] x, floating[::1] y): |
| 53 | + _axpy(x.shape[0], alpha, &x[0], 1, &y[0], 1) |
| 54 | + |
| 55 | + |
| 56 | +cdef floating _nrm2(int n, floating *x, int incx) nogil: |
| 57 | + """sqrt(sum((x_i)^2))""" |
| 58 | + if floating is float: |
| 59 | + return snrm2(&n, x, &incx) |
| 60 | + else: |
| 61 | + return dnrm2(&n, x, &incx) |
| 62 | + |
| 63 | + |
| 64 | +cpdef _nrm2_memview(floating[::1] x): |
| 65 | + return _nrm2(x.shape[0], &x[0], 1) |
| 66 | + |
| 67 | + |
| 68 | +cdef void _copy(int n, floating *x, int incx, floating *y, int incy) nogil: |
| 69 | + """y := x""" |
| 70 | + if floating is float: |
| 71 | + scopy(&n, x, &incx, y, &incy) |
| 72 | + else: |
| 73 | + dcopy(&n, x, &incx, y, &incy) |
| 74 | + |
| 75 | + |
| 76 | +cpdef _copy_memview(floating[::1] x, floating[::1] y): |
| 77 | + _copy(x.shape[0], &x[0], 1, &y[0], 1) |
| 78 | + |
| 79 | + |
| 80 | +cdef void _scal(int n, floating alpha, floating *x, int incx) nogil: |
| 81 | + """x := alpha * x""" |
| 82 | + if floating is float: |
| 83 | + sscal(&n, &alpha, x, &incx) |
| 84 | + else: |
| 85 | + dscal(&n, &alpha, x, &incx) |
| 86 | + |
| 87 | + |
| 88 | +cpdef _scal_memview(floating alpha, floating[::1] x): |
| 89 | + _scal(x.shape[0], alpha, &x[0], 1) |
| 90 | + |
| 91 | + |
| 92 | +################ |
| 93 | +# BLAS Level 2 # |
| 94 | +################ |
| 95 | + |
| 96 | +cdef void _gemv(BLAS_Order order, BLAS_Trans ta, int m, int n, floating alpha, |
| 97 | + floating *A, int lda, floating *x, int incx, |
| 98 | + floating beta, floating *y, int incy) nogil: |
| 99 | + """y := alpha * op(A).x + beta * y""" |
| 100 | + cdef char ta_ = ta |
| 101 | + if order == RowMajor: |
| 102 | + ta_ = NoTrans if ta == Trans else Trans |
| 103 | + if floating is float: |
| 104 | + sgemv(&ta_, &n, &m, &alpha, A, &lda, x, &incx, &beta, y, &incy) |
| 105 | + else: |
| 106 | + dgemv(&ta_, &n, &m, &alpha, A, &lda, x, &incx, &beta, y, &incy) |
| 107 | + else: |
| 108 | + if floating is float: |
| 109 | + sgemv(&ta_, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy) |
| 110 | + else: |
| 111 | + dgemv(&ta_, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy) |
| 112 | + |
| 113 | + |
| 114 | +cpdef _gemv_memview(BLAS_Trans ta, floating alpha, floating[:, :] A, |
| 115 | + floating[::1] x, floating beta, floating[::1] y): |
| 116 | + cdef: |
| 117 | + int m = A.shape[0] |
| 118 | + int n = A.shape[1] |
| 119 | + BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor |
| 120 | + int lda = m if order == ColMajor else n |
| 121 | + |
| 122 | + _gemv(order, ta, m, n, alpha, &A[0, 0], lda, &x[0], 1, beta, &y[0], 1) |
| 123 | + |
| 124 | + |
| 125 | +cdef void _ger(BLAS_Order order, int m, int n, floating alpha, floating *x, |
| 126 | + int incx, floating *y, int incy, floating *A, int lda) nogil: |
| 127 | + """A := alpha * x.y.T + A""" |
| 128 | + if order == RowMajor: |
| 129 | + if floating is float: |
| 130 | + sger(&n, &m, &alpha, y, &incy, x, &incx, A, &lda) |
| 131 | + else: |
| 132 | + dger(&n, &m, &alpha, y, &incy, x, &incx, A, &lda) |
| 133 | + else: |
| 134 | + if floating is float: |
| 135 | + sger(&m, &n, &alpha, x, &incx, y, &incy, A, &lda) |
| 136 | + else: |
| 137 | + dger(&m, &n, &alpha, x, &incx, y, &incy, A, &lda) |
| 138 | + |
| 139 | + |
| 140 | +cpdef _ger_memview(floating alpha, floating[::1] x, floating[::] y, |
| 141 | + floating[:, :] A): |
| 142 | + cdef: |
| 143 | + int m = A.shape[0] |
| 144 | + int n = A.shape[1] |
| 145 | + BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor |
| 146 | + int lda = m if order == ColMajor else n |
| 147 | + |
| 148 | + _ger(order, m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda) |
| 149 | + |
| 150 | + |
| 151 | +################ |
| 152 | +# BLAS Level 3 # |
| 153 | +################ |
| 154 | + |
| 155 | +cdef void _gemm(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, int m, int n, |
| 156 | + int k, floating alpha, floating *A, int lda, floating *B, |
| 157 | + int ldb, floating beta, floating *C, int ldc) nogil: |
| 158 | + """C := alpha * op(A).op(B) + beta * C""" |
| 159 | + cdef: |
| 160 | + char ta_ = ta |
| 161 | + char tb_ = tb |
| 162 | + if order == RowMajor: |
| 163 | + if floating is float: |
| 164 | + sgemm(&tb_, &ta_, &n, &m, &k, &alpha, B, |
| 165 | + &ldb, A, &lda, &beta, C, &ldc) |
| 166 | + else: |
| 167 | + dgemm(&tb_, &ta_, &n, &m, &k, &alpha, B, |
| 168 | + &ldb, A, &lda, &beta, C, &ldc) |
| 169 | + else: |
| 170 | + if floating is float: |
| 171 | + sgemm(&ta_, &tb_, &m, &n, &k, &alpha, A, |
| 172 | + &lda, B, &ldb, &beta, C, &ldc) |
| 173 | + else: |
| 174 | + dgemm(&ta_, &tb_, &m, &n, &k, &alpha, A, |
| 175 | + &lda, B, &ldb, &beta, C, &ldc) |
| 176 | + |
| 177 | + |
| 178 | +cpdef _gemm_memview(BLAS_Trans ta, BLAS_Trans tb, floating alpha, |
| 179 | + floating[:, :] A, floating[:, :] B, floating beta, |
| 180 | + floating[:, :] C): |
| 181 | + cdef: |
| 182 | + int m = A.shape[0] if ta == NoTrans else A.shape[1] |
| 183 | + int n = B.shape[1] if tb == NoTrans else B.shape[0] |
| 184 | + int k = A.shape[1] if ta == NoTrans else A.shape[0] |
| 185 | + int lda, ldb, ldc |
| 186 | + BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor |
| 187 | + |
| 188 | + if order == RowMajor: |
| 189 | + lda = k if ta == NoTrans else m |
| 190 | + ldb = n if tb == NoTrans else k |
| 191 | + ldc = n |
| 192 | + else: |
| 193 | + lda = m if ta == NoTrans else k |
| 194 | + ldb = k if tb == NoTrans else n |
| 195 | + ldc = m |
| 196 | + |
| 197 | + _gemm(order, ta, tb, m, n, k, alpha, &A[0, 0], |
| 198 | + lda, &B[0, 0], ldb, beta, &C[0, 0], ldc) |
0 commit comments