8000 Scipy cython_blas fused helpers · scikit-learn/scikit-learn@486cf22 · GitHub
[go: up one dir, main page]

Skip to content

Commit 486cf22

Browse files
committed
Scipy cython_blas fused helpers
1 parent 1cb56ba commit 486cf22

File tree

9 files changed

+398
-69
lines changed

9 files changed

+398
-69
lines changed

.circleci/config.yml

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -29,41 +29,7 @@ jobs:
2929
- persist_to_workspace:
3030
root: doc/_build/html
3131
paths: .
32-
33-
34-
python2:
35-
docker:
36-
# We use the python 3 docker image for simplicity. Python is installed
37-
# through conda and the python version actually used is set via the
38-
# PYTHON_VERSION environment variable.
39-
- image: circleci/python:3.6.1
40-
environment:
41-
# Test examples run with minimal dependencies
42-
- MINICONDA_PATH: ~/miniconda
43-
- CONDA_ENV_NAME: testenv
44-
- PYTHON_VERSION: "2"
45-
- NUMPY_VERSION: "1.10"
46-
- SCIPY_VERSION: "0.16"
47-
- MATPLOTLIB_VERSION: "1.4"
48-
- SCIKIT_IMAGE_VERSION: "0.11"
49-
- PANDAS_VERSION: "0.17.1"
50-
steps:
51-
- checkout
52-
- run: ./build_tools/circle/checkout_merge_commit.sh
53-
- restore_cache:
54-
key: v1-datasets-{{ .Branch }}-python2
55-
- run: ./build_tools/circle/build_doc.sh
56-
- save_cache:
57-
key: v1-datasets-{{ .Branch }}-python2
58-
paths:
59-
- ~/scikit_learn_data
60-
- store_artifacts:
61-
path: doc/_build/html/stable
62-
destination: doc
63-
- store_artifacts:
64-
path: ~/log.txt
65-
destination: log.txt
66-
32+
6733
lint:
6834
docker:
6935
- image: circleci/python:3.6.1
@@ -115,7 +81,6 @@ workflows:
11581
build-doc-and-deploy:
11682
jobs:
11783
- python3
118-
- python2
11984
- lint
12085
- pypy3:
12186
filters:

.travis.yml

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,6 @@ env:
2020

2121
matrix:
2222
include:
23-
# This environment tests that scikit-learn can be built against
24-
# versions of numpy, scipy with ATLAS that comes with Ubuntu Trusty 14.04
25-
# i.e. numpy 1.8.2 and scipy 0.13.3
26-
- env: DISTRIB="ubuntu" PYTHON_VERSION="2.7" CYTHON_VERSION="0.23.5"
27-
COVERAGE=true
28-
if: type != cron
29-
addons:
30-
apt:
31-
packages:
32-
# these only required by the DISTRIB="ubuntu" builds:
33-
- python-scipy
34-
- libatlas3-base
35-
- libatlas-dev
3623
# Python 3.4 build
3724
- env: DISTRIB="conda" PYTHON_VERSION="3.4" INSTALL_MKL="false"
3825
NUMPY_VERSION="1.10.4" SCIPY_VERSION="0.16.1" CYTHON_VERSION="0.25.2"

appveyor.yml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@ environment:
2222
PYTHON_ARCH: "64"
2323
CHECK_WARNINGS: "true"
2424

25-
- PYTHON: "C:\\Python27"
26-
PYTHON_VERSION: "2.7.8"
27-
PYTHON_ARCH: "32"
28-
2925

3026
# Because we only have a single worker, we don't want to waste precious
3127
# appveyor CI time and make other PRs wait for repeated failures in a failing

sklearn/metrics/pairwise_fast.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ from libc.string cimport memset
1111
import numpy as np
1212
cimport numpy as np
1313

14-
cdef extern from "cblas.h":
15-
double cblas_dasum(int, const double *, int) nogil
14+
from ..utils._cython_blas cimport _xasum
15+
1616

1717
ctypedef float [:, :] float_array_2d_t
1818
ctypedef double [:, :] double_array_2d_t
@@ -76,4 +76,4 @@ def _sparse_manhattan(floating1d X_data, int[:] X_indices, int[:] X_indptr,
7676
for j in range(Y_indptr[iy], Y_indptr[iy + 1]):
7777
row[Y_indices[j]] -= Y_data[j]
7878

79-
D[ix, iy] = cblas_dasum(n_features, &row[0], 1)
79+
D[ix, iy] = _xasum(n_features, &row[0], 1)

sklearn/metrics/setup.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,26 @@
11
import os
2-
import os.path
32

4-
import numpy
53
from numpy.distutils.misc_util import Configuration
64

7-
from sklearn._build_utils import get_blas_info
8-
95

106
def configuration(parent_package="", top_path=None):
117
config = Configuration("metrics", parent_package, top_path)
128

13-
cblas_libs, blas_info = get_blas_info()
9+
libraries = []
1410
if os.name == 'posix':
15-
cblas_libs.append('m')
11+
libraries.append('m')
1612

1713
config.add_subpackage('cluster')
14+
1815
config.add_extension("pairwise_fast",
1916
sources=["pairwise_fast.pyx"],
20-
include_dirs=[os.path.join('..', 'src', 'cblas'),
21-
numpy.get_include(),
22-
blas_info.pop('include_dirs', [])],
23-
libraries=cblas_libs,
24-
extra_compile_args=blas_info.pop('extra_compile_args',
25-
[]),
26-
**blas_info)
17+
libraries=libraries)
18+
2719
config.add_subpackage('tests')
2820

2921
return config
3022

23+
3124
if __name__ == "__main__":
3225
from numpy.distutils.core import setup
3326
setup(**configuration().todict())

sklearn/utils/_cython_blas.pxd

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from cython cimport floating
2+
3+
4+
# BLAS Level 1 ################################################################
5+
cdef floating _xdot(int, floating*, int, floating*, int) nogil
6+
cdef floating _xasum(int, floating*, int) nogil
7+
cdef void _xaxpy(int, floating, floating*, int, floating*, int) nogil
8+
cdef floating _xnrm2(int, floating*, int) nogil
9+
cdef void _xcopy(int, floating*, int, floating*, int) nogil
10+
cdef void _xscal(int, floating, floating*, int) nogil
11+
12+
# BLAS Level 2 ################################################################
13+
cdef void _xgemv(char, char, int, int, floating, floating*, int, floating*,
14+
int, floating, floating*, int) nogil
15+
cdef void _xger(char, int, int, floating, floating*, int, floating*, int,
16+
floating*, int) nogil
17+
18+
# BLASLevel 3 ################################################################
19+
cdef void _xgemm(char, char, char, int, int, int, floating, floating*, int,
20+
floating*, int, floating, floating*, int) nogil

sklearn/utils/_cython_blas.pyx

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# cython: boundscheck=False, wraparound=False, cdivision=True
2+
3+
from cython cimport floating
4+
5+
from scipy.linalg.cython_blas cimport sdot, ddot
6+
from scipy.linalg.cython_blas cimport sasum, dasum
7+
from scipy.linalg.cython_blas cimport saxpy, daxpy
8+
from scipy.linalg.cython_blas cimport snrm2, dnrm2
9+
from scipy.linalg.cython_blas cimport scopy, dcopy
10+
from scipy.linalg.cython_blas cimport sscal, dscal
11+
from scipy.linalg.cython_blas cimport sgemv, dgemv
12+
from scipy.linalg.cython_blas cimport sger, dger
13+
from scipy.linalg.cython_blas cimport sgemm, dgemm
14+
15+
16+
################
17+
# BLAS Level 1 #
18+
################< 741A /span>
19+
20+
cdef floating _xdot(int n, floating *x, int incx,
21+
floating *y, int incy) nogil:
22+
""""""
23+
if floating is float:
24+
return sdot(&n, x, &incx, y, &incy)
25+
else:
26+
return ddot(&n, x, &incx, y, &incy)
27+
28+
29+
cpdef _xdot_memview(floating[::1] x, floating[::1] y):
30+
return _xdot(x.shape[0], &x[0], 1, &y[0], 1)
31+
32+
33+
cdef floating _xasum(int n, floating *x, int incx) nogil:
34+
""""""
35+
if floating is float:
36+
return sasum(&n, x, &incx)
37+
else:
38+
return dasum(&n, x, &incx)
39+
40+
41+
cpdef _xasum_memview(floating[::1] x):
42+
return _xasum(x.shape[0], &x[0], 1)
43+
44+
45+
cdef void _xaxpy(int n, floating alpha, floating *x, int incx,
46+
floating *y, int incy) nogil:
47+
""""""
48+
if floating is float:
49+
saxpy(&n, &alpha, x, &incx, y, &incy)
50+
else:
51+
daxpy(&n, &alpha, x, &incx, y, &incy)
52+
53+
54+
cpdef _xaxpy_memview(floating alpha, floating[::1] x, floating[::1] y):
55+
_xaxpy(x.shape[0], alpha, &x[0], 1, &y[0], 1)
56+
57+
58+
cdef floating _xnrm2(int n, floating *x, int incx) nogil:
59+
""""""
60+
if floating is float:
61+
return snrm2(&n, x, &incx)
62+
else:
63+
return dnrm2(&n, x, &incx)
64+
65+
66+
cpdef _xnrm2_memview(floating[::1] x):
67+
return _xnrm2(x.shape[0], &x[0], 1)
68+
69+
70+
cdef void _xcopy(int n, floating *x, int incx, floating *y, int incy) nogil:
71+
""""""
72+
if floating is float:
73+
scopy(&n, x, &incx, y, &incy)
74+
else:
75+
dcopy(&n, x, &incx, y, &incy)
76+
77+
78+
cpdef _xcopy_memview(floating[::1] x, floating[::1] y):
79+
_xcopy(x.shape[0], &x[0], 1, &y[0], 1)
80+
81+
82+
cdef void _xscal(int n, floating alpha, floating *x, int incx) nogil:
83+
""""""
84+
if floating is float:
85+
sscal(&n, &alpha, x, &incx)
86+
else:
87+
dscal(&n, &alpha, x, &incx)
88+
89+
90+
cpdef _xscal_memview(floating alpha, floating[::1] x):
91+
_xscal(x.shape[0], alpha, &x[0], 1)
92+
93+
94+
################
95+
# BLAS Level 2 #
96+
################
97+
98+
cdef void _xgemv(char layout, char ta, int m, int n, floating alpha,
99+
floating *A, int lda, floating *x, int incx,
100+
floating beta, floating *y, int incy) nogil:
101+
""""""
102+
if layout == 'C':
103+
ta = 'n' if ta == 't' else 't'
104+
if floating is float:
105+
sgemv(&ta, &n, &m, &alpha, A, &lda, x, &incx, &beta, y, &incy)
106+
else:
107+
dgemv(&ta, &n, &m, &alpha, A, &lda, x, &incx, &beta, y, &incy)
108+
elif layout == 'F':
109+
if floating is float:
110+
sgemv(&ta, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy)
111+
else:
112+
dgemv(&ta, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy)
113+
114+
115+
cpdef _xgemv_memview(const unsigned char[:] layout, const unsigned char[:] ta,
116+
floating alpha, floating[:, :] A, floating[::1] x,
117+
floating beta, floating[::1] y):
118+
cdef:
119+
int m = A.shape[0]
120+
int n = A.shape[1]
121+
int lda = m if layout[0] == 'F' else n
122+
123+
_xgemv(layout[0], ta[0], m, n, alpha, &A[0, 0], lda,
124+
&x[0], 1, beta, &y[0], 1)
125+
126+
127+
cdef void _xger(char layout, int m, int n, floating alpha, floating *x,
128+
int incx, floating *y, int incy, floating *A, int lda) nogil:
129+
""""""
130+
if layout == 'C':
131+
if floating is float:
132+
sger(&n, &m, &alpha, y, &incy, x, &incx, A, &lda)
133+
else:
134+
dger(&n, &m, &alpha, y, &incy, x, &incx, A, &lda)
135+
elif layout == 'F':
136+
if floating is float:
137+
sger(&m, &n, &alpha, x, &incx, y, &incy, A, &lda)
138+
else:
139+
dger(&m, &n, &alpha, x, &incx, y, &incy, A, &lda)
140+
141+
142+
cpdef _xger_memview(const unsigned char[:] layout, floating alpha,
143+
floating[::1] x, floating[::] y, floating[:, :] A):
144+
cdef:
145+
int m = A.shape[0]
146+
int n = A.shape[1]
147+
int lda = m if layout[0] == 'F' else n
148+
149+
_xger(layout[0], m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda)
150+
151+
152+
################
153+
# BLAS Level 3 #
154+
################
155+
156+
cdef void _xgemm(char layout, char ta, char tb, int m, int n, int k,
157+
floating alpha, floating *A, int lda, floating *B, int ldb,
158+
floating beta, floating *C, int ldc) nogil:
159+
""""""
160+
if layout == 'C':
161+
if floating is float:
162+
sgemm(&tb, &ta, &n, &m, &k, &alpha, B, &ldb, A, &lda, &beta, C, &ldc)
163+
else:
164+
dgemm(&tb, &ta, &n, &m, &k, &alpha, B, &ldb, A, &lda, &beta, C, &ldc)
165+
elif layout == 'F':
166+
if floating is float:
167+
sgemm(&ta, &tb, &m, &n, &k, &alpha, A, &lda, B, &ldb, &beta, C, &ldc)
168+
else:
169+
dgemm(&ta, &tb, &m, &n, &k, &alpha, A, &lda, B, &ldb, &beta, C, &ldc)
170+
171+
172+
cpdef _xgemm_memview(const unsigned char[:] layout, const unsigned char[:] ta,
173+
const unsigned char[:] tb, floating alpha,
174+
floating[:, :] A, floating[:, :] B, floating beta,
175+
floating[:, :] C):
176+
cdef:
177+
int m = A.shape[0] if ta[0] == 'n' else A.shape[1]
178+
int n = B.shape[1] if tb[0] == 'n' else B.shape[0]
179+
int k = A.shape[1] if ta[0] == 'n' else A.shape[0]
180+
int lda, ldb, ldc
181+
182+
if layout[0] == 'F':
183+
lda = m if ta[0] == 'n' else k
184+
ldb = k if tb[0] == 'n' else n
185+
ldc = m
186+
else:
187+
lda = k if ta[0] == 'n' else m
188+
ldb = n if tb[0] == 'n' else k
189+
ldc = n
190+
191+
_xgemm(layout[0], ta[0], tb[0], m, n, k, alpha,
192+
&A[0, 0], lda, &B[0, 0], ldb, beta, &C[0, 0], ldc)

sklearn/utils/setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ def configuration(parent_package='', top_path=None):
2424
config.add_extension('sparsefuncs_fast', sources=['sparsefuncs_fast.pyx'],
2525
libraries=libraries)
2626

27+
config.add_extension('_cython_blas',
28+
sources=['_cython_blas.pyx'],
29+
libraries=libraries)
30+
2731
config.add_extension('arrayfuncs',
2832
sources=['arrayfuncs.pyx'],
2933
depends=[join('src', 'cholesky_delete.h')],

0 commit comments

Comments
 (0)
0