|
2 | 2 | #ifndef __CYGWIN__
|
3 | 3 |
|
4 | 4 | #if defined(NPY_HAVE_AVX512_SKX)
|
5 |
| - #include "x86-simd-sort/src/avx512-64bit-argsort.hpp" |
| 5 | +#include "x86-simd-sort/src/avx512-64bit-argsort.hpp" |
| 6 | +#elif defined(NPY_HAVE_AVX2) |
| 7 | +#include "x86-simd-sort/src/avx2-32bit-half.hpp" |
| 8 | +#include "x86-simd-sort/src/avx2-32bit-qsort.hpp" |
| 9 | +#include "x86-simd-sort/src/avx2-64bit-qsort.hpp" |
| 10 | +#include "x86-simd-sort/src/xss-common-argsort.h" |
6 | 11 | #endif
|
7 | 12 |
|
8 |
| -namespace np { namespace qsort_simd { |
| 13 | +namespace { |
| 14 | +template<typename T> |
| 15 | +void x86_argsort(T* arr, size_t* arg, npy_intp num) |
| 16 | +{ |
| 17 | +#if defined(NPY_HAVE_AVX512_SKX) |
| 18 | + avx512_argsort(arr, arg, num, true); |
| 19 | +#elif defined(NPY_HAVE_AVX2) |
| 20 | + avx2_argsort(arr, arg, num, true); |
| 21 | +#endif |
| 22 | +} |
9 | 23 |
|
10 |
| -/* arg methods currently only have AVX-512 versions */ |
| 24 | +template<typename T> |
| 25 | +void x86_argselect(T* arr, size_t* arg, npy_intp kth, npy_intp num) |
| 26 | +{ |
11 | 27 | #if defined(NPY_HAVE_AVX512_SKX)
|
| 28 | + avx512_argselect(arr, arg, kth, num, true); |
| 29 | +#elif defined(NPY_HAVE_AVX2) |
| 30 | + avx2_argselect(arr, arg, kth, num, true); |
| 31 | +#endif |
| 32 | +} |
| 33 | +} // anonymous |
| 34 | + |
| 35 | +namespace np { namespace qsort_simd { |
| 36 | + |
12 | 37 | template<> void NPY_CPU_DISPATCH_CURFX(ArgQSelect)(int32_t *arr, npy_intp* arg, npy_intp num, npy_intp kth)
|
13 | 38 | {
|
14 |
| - avx512_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num); |
| 39 | + x86_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num); |
15 | 40 | }
|
16 | 41 | template<> void NPY_CPU_DISPATCH_CURFX(ArgQSelect)(uint32_t *arr, npy_intp* arg, npy_intp num, npy_intp kth)
|
17 | 42 | {
|
18 |
| - avx512_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num); |
| 43 | + x86_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num);
|
19 | 44 | }
|
20 | 45 | template<> void NPY_CPU_DISPATCH_CURFX(ArgQSelect)(int64_t*arr, npy_intp* arg, npy_intp num, npy_intp kth)
|
21 | 46 | {
|
22 |
| - avx512_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num); |
| 47 | + x86_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num); |
23 | 48 | }
|
24 | 49 | template<> void NPY_CPU_DISPATCH_CURFX(ArgQSelect)(uint64_t*arr, npy_intp* arg, npy_intp num, npy_intp kth)
|
25 | 50 | {
|
26 |
| - avx512_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num); |
| 51 | + x86_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num); |
27 | 52 | }
|
28 | 53 | template<> void NPY_CPU_DISPATCH_CURFX(ArgQSelect)(float *arr, npy_intp* arg, npy_intp num, npy_intp kth)
|
29 | 54 | {
|
30 |
| - avx512_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num, true); |
| 55 | + x86_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num); |
31 | 56 | }
|
32 | 57 | template<> void NPY_CPU_DISPATCH_CURFX(ArgQSelect)(double *arr, npy_intp* arg, npy_intp num, npy_intp kth)
|
33 | 58 | {
|
34 |
| - avx512_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num, true); |
| 59 | + x86_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num); |
35 | 60 | }
|
36 | 61 | template<> void NPY_CPU_DISPATCH_CURFX(ArgQSort)(int32_t *arr, npy_intp *arg, npy_intp size)
|
37 | 62 | {
|
38 |
| - avx512_argsort(arr, reinterpret_cast<size_t*>(arg), size); |
| 63 | + x86_argsort(arr, reinterpret_cast<size_t*>(arg), size); |
39 | 64 | }
|
40 | 65 | template<> void NPY_CPU_DISPATCH_CURFX(ArgQSort)(uint32_t *arr, npy_intp *arg, npy_intp size)
|
41 | 66 | {
|
42 |
| - avx512_argsort(arr, reinterpret_cast<size_t*>(arg), size); |
| 67 | + x86_argsort(arr, reinterpret_cast<size_t*>(arg), size); |
43 | 68 | }
|
44 | 69 | template<> void NPY_CPU_DISPATCH_CURFX(ArgQSort)(int64_t *arr, npy_intp *arg, npy_intp size)
|
45 | 70 | {
|
46 |
| - avx512_argsort(arr, reinterpret_cast<size_t*>(arg), size); |
| 71 | + x86_argsort(arr, reinterpret_cast<size_t*>(arg), size); |
47 | 72 | }
|
48 | 73 | template<> void NPY_CPU_DISPATCH_CURFX(ArgQSort)(uint64_t *arr, npy_intp *arg, npy_intp size)
|
49 | 74 | {
|
50 |
| - avx512_argsort(arr, reinterpret_cast<size_t*>(arg), size); |
| 75 | + x86_argsort(arr, reinterpret_cast<size_t*>(arg), size); |
51 | 76 | }
|
52 | 77 | template<> void NPY_CPU_DISPATCH_CURFX(ArgQSort)(float *arr, npy_intp *arg, npy_intp size)
|
53 | 78 | {
|
54 |
| - avx512_argsort(arr, reinterpret_cast<size_t*>(arg), size, true); |
| 79 | + x86_argsort(arr, reinterpret_cast<size_t*>(arg), size); |
55 | 80 | }
|
56 | 81 | template<> void NPY_CPU_DISPATCH_CURFX(ArgQSort)(double *arr, npy_intp *arg, npy_intp size)
|
57 | 82 | {
|
58 |
| - avx512_argsort(arr, reinterpret_cast<size_t*>(arg), size, true); |
| 83 | + x86_argsort(arr, reinterpret_cast<size_t*>(arg), size); |
59 | 84 | }
|
60 |
| -#endif // NPY_HAVE_AVX512_SKX |
61 | 85 |
|
62 | 86 | }} // namespace np::simd
|
63 | 87 |
|
|
0 commit comments