10000 fix cython string -> char · scikit-learn/scikit-learn@1660d7e · GitHub
[go: up one dir, main page]

Skip to content

Commit 1660d7e

Browse files
committed
fix cython string -> char
1 parent 9f6d45c commit 1660d7e

File tree

2 files changed

+33
-31
lines changed

2 files changed

+33
-31
lines changed

sklearn/utils/_cython_blas.pyx

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ from scipy.linalg.cython_blas cimport sgemm, dgemm
1919

2020
cdef floating _xdot(int n, floating *x, int incx,
2121
floating *y, int incy) nogil:
22-
""""""
22+
"""x.T.y"""
2323
if floating is float:
2424
return sdot(&n, x, &incx, y, &incy)
2525
else:
@@ -31,7 +31,7 @@ cpdef _xdot_memview(floating[::1] x, floating[::1] y):
3131

3232

3333
cdef floating _xasum(int n, floating *x, int incx) nogil:
34-
""""""
34+
"""sum(|x_i|)"""
3535
if floating is float:
3636
return sasum(&n, x, &incx)
3737
else:
@@ -44,7 +44,7 @@ cpdef _xasum_memview(floating[::1] x):
4444

4545
cdef void _xaxpy(int n, floating alpha, floating *x, int incx,
4646
floating *y, int incy) nogil:
47-
""""""
47+
"""y := alpha * x + y"""
4848
if floating is float:
4949
saxpy(&n, &alpha, x, &incx, y, &incy)
5050
else:
@@ -56,7 +56,7 @@ cpdef _xaxpy_memview(floating alpha, floating[::1] x, floating[::1] y):
5656

5757

5858
cdef floating _xnrm2(int n, floating *x, int incx) nogil:
59-
""""""
59+
"""sqrt(sum((x_i)^2))"""
6060
if floating is float:
6161
return snrm2(&n, x, &incx)
6262
else:
@@ -68,7 +68,7 @@ cpdef _xnrm2_memview(floating[::1] x):
6868

6969

7070
cdef void _xcopy(int n, floating *x, int incx, floating *y, int incy) nogil:
71-
""""""
71+
"""y := x"""
7272
if floating is float:
7373
scopy(&n, x, &incx, y, &incy)
7474
else:
@@ -80,7 +80,7 @@ cpdef _xcopy_memview(floating[::1] x, floating[::1] y):
8080

8181

8282
cdef void _xscal(int n, floating alpha, floating *x, int incx) nogil:
83-
""""""
83+
"""x := alpha * x"""
8484
if floating is float:
8585
sscal(&n, &alpha, x, &incx)
8686
else:
@@ -98,7 +98,7 @@ cpdef _xscal_memview(floating alpha, floating[::1] x):
9898
cdef void _xgemv(char layout, char ta, int m, int n, floating alpha,
9999
floating *A, int lda, floating *x, int incx,
100100
floating beta, floating *y, int incy) nogil:
101-
""""""
101+
"""y := alpha * op(A).x + beta * y"""
102102
if layout == 'C':
103103
ta = 'n' if ta == 't' else 't'
104104
if floating is float:
@@ -112,21 +112,22 @@ cdef void _xgemv(char layout, char ta, int m, int n, floating alpha,
112112
dgemv(&ta, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy)
113113

114114

115-
cpdef _xgemv_memview(const unsigned char[:] layout, const unsigned char[:] ta,
116-
floating alpha, floating[:, :] A, floating[::1] x,
117-
floating beta, floating[::1] y):
115+
cpdef _xgemv_memview(layout, ta, floating alpha, floating[:, :] A,
116+
floating[::1] x, floating beta, floating[::1] y):
118117
cdef:
118+
char layout_ = 'F' if layout == 'F' else 'C'
119+
char ta_ = 'n' if ta == 'n' else 't'
119120
int m = A.shape[0]
120121
int n = A.shape[1]
121-
int lda = m if layout[0] == 'F' else n
122+
int lda = m if layout == 'F' else n
122123

123-
_xgemv(layout[0], ta[0], m, n, alpha, &A[0, 0], lda,
124+
_xgemv(layout_, ta_, m, n, alpha, &A[0, 0], lda,
124125
&x[0], 1, beta, &y[0], 1)
125126

126127

127128
cdef void _xger(char layout, int m, int n, floating alpha, floating *x,
128129
int incx, floating *y, int incy, floating *A, int lda) nogil:
129-
""""""
130+
"""A := alpha * x.y.T + A"""
130131
if layout == 'C':
131132
if floating is float:
132133
sger(&n, &m, &alpha, y, &incy, x, &incx, A, &lda)
@@ -139,14 +140,15 @@ cdef void _xger(char layout, int m, int n, floating alpha, floating *x,
139140
dger(&m, &n, &alpha, x, &incx, y, &incy, A, &lda)
140141

141142

142-
cpdef _xger_memview(const unsigned char[:] layout, floating alpha,
143-
floating[::1] x, floating[::] y, floating[:, :] A):
143+
cpdef _xger_memview(layout, floating alpha, floating[::1] x, floating[::] y,
144+
floating[:, :] A):
144145
cdef:
146+
char layout_ = 'F' if layout == 'F' else 'C'
145147
int m = A.shape[0]
146148
int n = A.shape[1]
147149
int lda = m if layout[0] == 'F' else n
148150

149-
_xger(layout[0], m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda)
151+
_xger(layout_, m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda)
150152

151153

152154
################
@@ -156,7 +158,7 @@ cpdef _xger_memview(const unsigned char[:] layout, floating alpha,
156158
cdef void _xgemm(char layout, char ta, char tb, int m, int n, int k,
157159
floating alpha, floating *A, int lda, floating *B, int ldb,
158160
floating beta, floating *C, int ldc) nogil:
159-
""""""
161+
"""C := alpha * op(A).op(B) + beta * C"""
160162
if layout == 'C':
161163
if floating is float:
162164
sgemm(&tb, &ta, &n, &m, &k, &alpha, B, &ldb, A, &lda, &beta, C, &ldc)
@@ -169,24 +171,25 @@ cdef void _xgemm(char layout, char ta, char tb, int m, int n, int k,
169171
dgemm(&ta, &tb, &m, &n, &k, &alpha, A, &lda, B, &ldb, &beta, C, &ldc)
170172

171173

172-
cpdef _xgemm_memview(const unsigned char[:] layout, const unsigned char[:] ta,
173-
const unsigned char[:] tb, floating alpha,
174-
floating[:, :] A, floating[:, :] B, floating beta,
175-
floating[:, :] C):
174+
cpdef _xgemm_memview(layout, ta, tb, floating alpha, floating[:, :] A,
175+
floating[:, :] B, floating beta, floating[:, :] C):
176176
cdef:
177+
char layout_ = 'F' if layout == 'F' else 'C'
178+
char ta_ = 'n' if ta == 'n' else 't'
179+
char tb_ = 'n' if tb == 'n' else 't'
177180
int m = A.shape[0] if ta[0] == 'n' else A.shape[1]
178181
int n = B.shape[1] if tb[0] == 'n' else B.shape[0]
179182
int k = A.shape[1] if ta[0] == 'n' else A.shape[0]
180183
int lda, ldb, ldc
181184

182-
if layout[0] == 'F':
183-
lda = m if ta[0] == 'n' else k
184-
ldb = k if tb[0] == 'n' else n
185+
if layout == 'F':
186+
lda = m if ta == 'n' else k
187+
ldb = k if tb == 'n' else n
185188
ldc = m
186189
else:
187-
lda = k if ta[0] == 'n' else m
188-
ldb = n if tb[0] == 'n' else k
190+
lda = k if ta == 'n' else m
191+
ldb = n if tb == 'n' else k
189192
ldc = n
190193

191-
_xgemm(layout[0], ta[0], tb[0], m, n, k, alpha,
194+
_xgemm(layout_, ta_, tb_, m, n, k, alpha,
192195
&A[0, 0], lda, &B[0, 0], ldb, beta, &C[0, 0], ldc)

sklearn/utils/tests/test_cython_blas.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def test_gemv(dtype, opA, transA, layout):
122122
alpha, beta = 1.23, -3.21
123123

124124
expected = alpha * opA(A).dot(x) + beta * y
125-
gemv(layout.encode(), transA.encode(), alpha, A, x, beta, y)
125+
gemv(layout, transA, alpha, A, x, beta, y)
126126

127127
assert_allclose(y, expected, rtol=1e-4)
128128

@@ -140,7 +140,7 @@ def test_ger(dtype, layout):
140140
alpha = 1.23
141141

142142
expected = alpha * np.outer(x, y) + A
143-
ger(layout.encode(), alpha, x, y, A)
143+
ger(layout, alpha, x, y, A)
144144

145145
assert_allclose(A, expected, rtol=1e-4)
146146

@@ -166,7 +166,6 @@ def test_gemm(dtype, opA, transA, opB, transB, layout):
166166
alpha, beta = 1.23, -3.21
167167

168168
expected = alpha * opA(A).dot(opB(B)) + beta * C
169-
gemm(layout.encode(), transA.encode(), transB.encode(),
170-
alpha, A, B, beta, C)
169+
gemm(layout, transA, transB, alpha, A, B, beta, C)
171170

172171
assert_allclose(C, expected, rtol=1e-4)

0 commit comments

Comments
 (0)
0