1
+ # cython: boundscheck=False, wraparound=False, cdivision=True
2
+
3
+ from cython cimport floating
4
+
5
+ from scipy.linalg.cython_blas cimport sdot, ddot
6
+ from scipy.linalg.cython_blas cimport sasum, dasum
7
+ from scipy.linalg.cython_blas cimport saxpy, daxpy
8
+ from scipy.linalg.cython_blas cimport snrm2, dnrm2
9
+ from scipy.linalg.cython_blas cimport scopy, dcopy
10
+ from scipy.linalg.cython_blas cimport sscal, dscal
11
+ from scipy.linalg.cython_blas cimport sgemv, dgemv
12
+ from scipy.linalg.cython_blas cimport sger, dger
13
+ from scipy.linalg.cython_blas cimport sgemm, dgemm
14
+
15
+
16
+ # ###############
17
+ # BLAS Level 1 #
18
+ # ###############<
741A
/span>
19
+
20
+ cdef floating _xdot(int n, floating * x, int incx,
21
+ floating * y, int incy) nogil:
22
+ """ """
23
+ if floating is float :
24
+ return sdot(& n, x, & incx, y, & incy)
25
+ else :
26
+ return ddot(& n, x, & incx, y, & incy)
27
+
28
+
29
+ cpdef _xdot_memview(floating[::1 ] x, floating[::1 ] y):
30
+ return _xdot(x.shape[0 ], & x[0 ], 1 , & y[0 ], 1 )
31
+
32
+
33
+ cdef floating _xasum(int n, floating * x, int incx) nogil:
34
+ """ """
35
+ if floating is float :
36
+ return sasum(& n, x, & incx)
37
+ else :
38
+ return dasum(& n, x, & incx)
39
+
40
+
41
+ cpdef _xasum_memview(floating[::1 ] x):
42
+ return _xasum(x.shape[0 ], & x[0 ], 1 )
43
+
44
+
45
+ cdef void _xaxpy(int n, floating alpha, floating * x, int incx,
46
+ floating * y, int incy) nogil:
47
+ """ """
48
+ if floating is float :
49
+ saxpy(& n, & alpha, x, & incx, y, & incy)
50
+ else :
51
+ daxpy(& n, & alpha, x, & incx, y, & incy)
52
+
53
+
54
+ cpdef _xaxpy_memview(floating alpha, floating[::1 ] x, floating[::1 ] y):
55
+ _xaxpy(x.shape[0 ], alpha, & x[0 ], 1 , & y[0 ], 1 )
56
+
57
+
58
+ cdef floating _xnrm2(int n, floating * x, int incx) nogil:
59
+ """ """
60
+ if floating is float :
61
+ return snrm2(& n, x, & incx)
62
+ else :
63
+ return dnrm2(& n, x, & incx)
64
+
65
+
66
+ cpdef _xnrm2_memview(floating[::1 ] x):
67
+ return _xnrm2(x.shape[0 ], & x[0 ], 1 )
68
+
69
+
70
+ cdef void _xcopy(int n, floating * x, int incx, floating * y, int incy) nogil:
71
+ """ """
72
+ if floating is float :
73
+ scopy(& n, x, & incx, y, & incy)
74
+ else :
75
+ dcopy(& n, x, & incx, y, & incy)
76
+
77
+
78
+ cpdef _xcopy_memview(floating[::1 ] x, floating[::1 ] y):
79
+ _xcopy(x.shape[0 ], & x[0 ], 1 , & y[0 ], 1 )
80
+
81
+
82
+ cdef void _xscal(int n, floating alpha, floating * x, int incx) nogil:
83
+ """ """
84
+ if floating is float :
85
+ sscal(& n, & alpha, x, & incx)
86
+ else :
87
+ dscal(& n, & alpha, x, & incx)
88
+
89
+
90
+ cpdef _xscal_memview(floating alpha, floating[::1 ] x):
91
+ _xscal(x.shape[0 ], alpha, & x[0 ], 1 )
92
+
93
+
94
+ # ###############
95
+ # BLAS Level 2 #
96
+ # ###############
97
+
98
+ cdef void _xgemv(char layout, char ta, int m, int n, floating alpha,
99
+ floating * A, int lda, floating * x, int incx,
100
+ floating beta, floating * y, int incy) nogil:
101
+ """ """
102
+ if layout == ' C' :
103
+ ta = ' n' if ta == ' t' else ' t'
104
+ if floating is float :
105
+ sgemv(& ta, & n, & m, & alpha, A, & lda, x, & incx, & beta, y, & incy)
106
+ else :
107
+ dgemv(& ta, & n, & m, & alpha, A, & lda, x, & incx, & beta, y, & incy)
108
+ elif layout == ' F' :
109
+ if floating is float :
110
+ sgemv(& ta, & m, & n, & alpha, A, & lda, x, & incx, & beta, y, & incy)
111
+ else :
112
+ dgemv(& ta, & m, & n, & alpha, A, & lda, x, & incx, & beta, y, & incy)
113
+
114
+
115
+ cpdef _xgemv_memview(const unsigned char [:] layout, const unsigned char [:] ta,
116
+ floating alpha, floating[:, :] A, floating[::1 ] x,
117
+ floating beta, floating[::1 ] y):
118
+ cdef:
119
+ int m = A.shape[0 ]
120
+ int n = A.shape[1 ]
121
+ int lda = m if layout[0 ] == ' F' else n
122
+
123
+ _xgemv(layout[0 ], ta[0 ], m, n, alpha, & A[0 , 0 ], lda,
124
+ & x[0 ], 1 , beta, & y[0 ], 1 )
125
+
126
+
127
+ cdef void _xger(char layout, int m, int n, floating alpha, floating * x,
128
+ int incx, floating * y, int incy, floating * A, int lda) nogil:
129
+ """ """
130
+ if layout == ' C' :
131
+ if floating is float :
132
+ sger(& n, & m, & alpha, y, & incy, x, & incx, A, & lda)
133
+ else :
134
+ dger(& n, & m, & alpha, y, & incy, x, & incx, A, & lda)
135
+ elif layout == ' F' :
136
+ if floating is float :
137
+ sger(& m, & n, & alpha, x, & incx, y, & incy, A, & lda)
138
+ else :
139
+ dger(& m, & n, & alpha, x, & incx, y, & incy, A, & lda)
140
+
141
+
142
+ cpdef _xger_memview(const unsigned char [:] layout, floating alpha,
143
+ floating[::1 ] x, floating[::] y, floating[:, :] A):
144
+ cdef:
145
+ int m = A.shape[0 ]
146
+ int n = A.shape[1 ]
147
+ int lda = m if layout[0 ] == ' F' else n
148
+
149
+ _xger(layout[0 ], m, n, alpha, & x[0 ], 1 , & y[0 ], 1 , & A[0 , 0 ], lda)
150
+
151
+
152
+ # ###############
153
+ # BLAS Level 3 #
154
+ # ###############
155
+
156
+ cdef void _xgemm(char layout, char ta, char tb, int m, int n, int k,
157
+ floating alpha, floating * A, int lda, floating * B, int ldb,
158
+ floating beta, floating * C, int ldc) nogil:
159
+ """ """
160
+ if layout == ' C' :
161
+ if floating is float :
162
+ sgemm(& tb, & ta, & n, & m, & k, & alpha, B, & ldb, A, & lda, & beta, C, & ldc)
163
+ else :
164
+ dgemm(& tb, & ta, & n, & m, & k, & alpha, B, & ldb, A, & lda, & beta, C, & ldc)
165
+ elif layout == ' F' :
166
+ if floating is float :
167
+ sgemm(& ta, & tb, & m, & n, & k, & alpha, A, & lda, B, & ldb, & beta, C, & ldc)
168
+ else :
169
+ dgemm(& ta, & tb, & m, & n, & k, & alpha, A, & lda, B, & ldb, & beta, C, & ldc)
170
+
171
+
172
+ cpdef _xgemm_memview(const unsigned char [:] layout, const unsigned char [:] ta,
173
+ const unsigned char [:] tb, floating alpha,
174
+ floating[:, :] A, floating[:, :] B, floating beta,
175
+ floating[:, :] C):
176
+ cdef:
177
+ int m = A.shape[0 ] if ta[0 ] == ' n' else A.shape[1 ]
178
+ int n = B.shape[1 ] if tb[0 ] == ' n' else B.shape[0 ]
179
+ int k = A.shape[1 ] if ta[0 ] == ' n' else A.shape[0 ]
180
+ int lda, ldb, ldc
181
+
182
+ if layout[0 ] == ' F' :
183
+ lda = m if ta[0 ] == ' n' else k
184
+ ldb = k if tb[0 ] == ' n' else n
185
+ ldc = m
186
+ else :
187
+ lda = k if ta[0 ] == ' n' else m
188
+ ldb = n if tb[0 ] == ' n' else k
189
+ ldc = n
190
+
191
+ _xgemm(layout[0 ], ta[0 ], tb[0 ], m, n, k, alpha,
192
+ & A[0 , 0 ], lda, & B[0 , 0 ], ldb, beta, & C[0 , 0 ], ldc)
0 commit comments