@@ -19,7 +19,7 @@ from scipy.linalg.cython_blas cimport sgemm, dgemm
19
19
20
20
cdef floating _xdot(int n, floating * x, int incx,
21
21
floating * y, int incy) nogil:
22
- """ """
22
+ """ x.T.y """
23
23
if floating is float :
24
24
return sdot(& n, x, & incx, y, & incy)
25
25
else :
@@ -31,7 +31,7 @@ cpdef _xdot_memview(floating[::1] x, floating[::1] y):
31
31
32
32
33
33
cdef floating _xasum(int n, floating * x, int incx) nogil:
34
- """ """
34
+ """ sum(|x_i|) """
35
35
if floating is float :
36
36
return sasum(& n, x, & incx)
37
37
else :
@@ -44,7 +44,7 @@ cpdef _xasum_memview(floating[::1] x):
44
44
45
45
cdef void _xaxpy(int n, floating alpha, floating * x, int incx,
46
46
floating * y, int incy) nogil:
47
- """ """
47
+ """ y := alpha * x + y """
48
48
if floating is float :
49
49
saxpy(& n, & alpha, x, & incx, y, & incy)
50
50
else :
@@ -56,7 +56,7 @@ cpdef _xaxpy_memview(floating alpha, floating[::1] x, floating[::1] y):
56
56
57
57
58
58
cdef floating _xnrm2(int n, floating * x, int incx) nogil:
59
- """ """
59
+ """ sqrt(sum((x_i)^2)) """
60
60
if floating is float :
61
61
return snrm2(& n, x, & incx)
62
62
else :
@@ -68,7 +68,7 @@ cpdef _xnrm2_memview(floating[::1] x):
68
68
69
69
70
70
cdef void _xcopy(int n, floating * x, int incx, floating * y, int incy) nogil:
71
- """ """
71
+ """ y := x """
72
72
if floating is float :
73
73
scopy(& n, x, & incx, y, & incy)
74
74
else :
@@ -80,7 +80,7 @@ cpdef _xcopy_memview(floating[::1] x, floating[::1] y):
80
80
81
81
82
82
cdef void _xscal(int n, floating alpha, floating * x, int incx) nogil:
83
- """ """
83
+ """ x := alpha * x """
84
84
if floating is float :
85
85
sscal(& n, & alpha, x, & incx)
86
86
else :
@@ -98,7 +98,7 @@ cpdef _xscal_memview(floating alpha, floating[::1] x):
98
98
cdef void _xgemv(char layout, char ta, int m, int n, floating alpha,
99
99
floating * A, int lda, floating * x, int incx,
100
100
floating beta, floating * y, int incy) nogil:
101
- """ """
101
+ """ y := alpha * op(A).x + beta * y """
102
102
if layout == ' C' :
103
103
ta = ' n' if ta == ' t' else ' t'
104
104
if floating is float :
@@ -112,21 +112,22 @@ cdef void _xgemv(char layout, char ta, int m, int n, floating alpha,
112
112
dgemv(& ta, & m, & n, & alpha, A, & lda, x, & incx, & beta, y, & incy)
113
113
114
114
115
- cpdef _xgemv_memview(const unsigned char [:] layout, const unsigned char [:] ta,
116
- floating alpha, floating[:, :] A, floating[::1 ] x,
117
- floating beta, floating[::1 ] y):
115
+ cpdef _xgemv_memview(layout, ta, floating alpha, floating[:, :] A,
116
+ floating[::1 ] x, floating beta, floating[::1 ] y):
118
117
cdef:
118
+ char layout_ = ' F' if layout == ' F' else ' C'
119
+ char ta_ = ' n' if ta == ' n' else ' t'
119
120
int m = A.shape[0 ]
120
121
int n = A.shape[1 ]
121
- int lda = m if layout[ 0 ] == ' F' else n
122
+ int lda = m if layout == ' F' else n
122
123
123
- _xgemv(layout[ 0 ], ta[ 0 ] , m, n, alpha, & A[0 , 0 ], lda,
124
+ _xgemv(layout_, ta_ , m, n, alpha, & A[0 , 0 ], lda,
124
125
& x[0 ], 1 , beta, & y[0 ], 1 )
125
126
126
127
127
128
cdef void _xger(char layout, int m, int n, floating alpha, floating * x,
128
129
int incx, floating * y, int incy, floating * A, int lda) nogil:
129
- """ """
130
+ """ A := alpha * x.y.T + A """
130
131
if layout == ' C' :
131
132
if floating is float :
132
133
sger(& n, & m, & alpha, y, & incy, x, & incx, A, & lda)
@@ -139,14 +140,15 @@ cdef void _xger(char layout, int m, int n, floating alpha, floating *x,
139
140
dger(& m, & n, & alpha, x, & incx, y, & incy, A, & lda)
140
141
141
142
142
- cpdef _xger_memview(const unsigned char [:] layout , floating alpha ,
143
- floating[:: 1 ] x, floating[::] y, floating[: , :] A):
143
+ cpdef _xger_memview(layout, floating alpha, floating[:: 1 ] x , floating[::] y ,
144
+ floating[:, :] A):
144
145
cdef:
146
+ char layout_ = ' F' if layout == ' F' else ' C'
145
147
int m = A.shape[0 ]
146
148
int n = A.shape[1 ]
147
149
int lda = m if layout[0 ] == ' F' else n
148
150
149
- _xger(layout[ 0 ] , m, n, alpha, & x[0 ], 1 , & y[0 ], 1 , & A[0 , 0 ], lda)
151
+ _xger(layout_ , m, n, alpha, & x[0 ], 1 , & y[0 ], 1 , & A[0 , 0 ], lda)
150
152
151
153
152
154
# ###############
@@ -156,7 +158,7 @@ cpdef _xger_memview(const unsigned char[:] layout, floating alpha,
156
158
cdef void _xgemm(char layout, char ta, char tb, int m, int n, int k,
157
159
floating alpha, floating * A, int lda, floating * B, int ldb,
158
160
floating beta, floating * C, int ldc) nogil:
159
- """ """
161
+ """ C := alpha * op(A).op(B) + beta * C """
160
162
if layout == ' C' :
161
163
if floating is float :
162
164
sgemm(& tb, & ta, & n, & m, & k, & alpha, B, & ldb, A, & lda, & beta, C, & ldc)
@@ -169,24 +171,25 @@ cdef void _xgemm(char layout, char ta, char tb, int m, int n, int k,
169
171
dgemm(& ta, & tb, & m, & n, & k, & alpha, A, & lda, B, & ldb, & beta, C, & ldc)
170
172
171
173
172
- cpdef _xgemm_memview(const unsigned char [:] layout, const unsigned char [:] ta,
173
- const unsigned char [:] tb, floating alpha,
174
- floating[:, :] A, floating[:, :] B, floating beta,
175
- floating[:, :] C):
174
+ cpdef _xgemm_memview(layout, ta, tb, floating alpha, floating[:, :] A,
175
+ floating[:, :] B, floating beta, floating[:, :] C):
176
176
cdef:
177
+ char layout_ = ' F' if layout == ' F' else ' C'
178
+ char ta_ = ' n' if ta == ' n' else ' t'
179
+ char tb_ = ' n' if tb == ' n' else ' t'
177
180
int m = A.shape[0 ] if ta[0 ] == ' n' else A.shape[1 ]
178
181
int n = B.shape[1 ] if tb[0 ] == ' n' else B.shape[0 ]
179
182
int k = A.shape[1 ] if ta[0 ] == ' n' else A.shape[0 ]
180
183
int lda, ldb, ldc
181
184
182
- if layout[ 0 ] == ' F' :
183
- lda = m if ta[ 0 ] == ' n' else k
184
- ldb = k if tb[ 0 ] == ' n' else n
185
+ if layout == ' F' :
186
+ lda = m if ta == ' n' else k
187
+ ldb = k if tb == ' n' else n
185
188
ldc = m
186
189
else :
187
- lda = k if ta[ 0 ] == ' n' else m
188
- ldb = n if tb[ 0 ] == ' n' else k
190
+ lda = k if ta == ' n' else m
191
+ ldb = n if tb == ' n' else k
189
192
ldc = n
190
193
191
- _xgemm(layout[ 0 ], ta[ 0 ], tb[ 0 ] , m, n, k, alpha,
194
+ _xgemm(layout_, ta_, tb_ , m, n, k, alpha,
192
195
& A[0 , 0 ], lda, & B[0 , 0 ], ldb, beta, & C[0 , 0 ], ldc)
0 commit comments