8000 Merge pull request #25610 from r-devulap/avx2_arg · mattip/numpy@221427b · GitHub
[go: up one dir, main page]

Skip to content

Commit 221427b

Browse files
authored
Merge pull request numpy#25610 from r-devulap/avx2_arg
ENH: Vectorize argsort and argselect with AVX2
2 parents 0a4b2b8 + 680b682 commit 221427b

File tree

3 files changed

+42
-18
lines changed

3 files changed

+42
-18
lines changed

numpy/_core/meson.build

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,7 @@ foreach gen_mtargets : [
787787
[
788788
'x86_simd_argsort.dispatch.h',
789789
'src/npysort/x86_simd_argsort.dispatch.cpp',
790-
use_intel_sort ? [AVX512_SKX] : []
790+
use_intel_sort ? [AVX512_SKX, AVX2] : []
791791
],
792792
[
793793
'x86_simd_qsort.dispatch.h',

numpy/_core/src/npysort/x86_simd_argsort.dispatch.cpp

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,62 +2,86 @@
22
#ifndef __CYGWIN__
33

44
#if defined(NPY_HAVE_AVX512_SKX)
5-
#include "x86-simd-sort/src/avx512-64bit-argsort.hpp"
5+
#include "x86-simd-sort/src/avx512-64bit-argsort.hpp"
6+
#elif defined(NPY_HAVE_AVX2)
7+
#include "x86-simd-sort/src/avx2-32bit-half.hpp"
8+
#include "x86-simd-sort/src/avx2-32bit-qsort.hpp"
9+
#include "x86-simd-sort/src/avx2-64bit-qsort.hpp"
10+
#include "x86-simd-sort/src/xss-common-argsort.h"
611
#endif
712

8-
namespace np { namespace qsort_simd {
13+
namespace {
14+
template<typename T>
15+
void x86_argsort(T* arr, size_t* arg, npy_intp num)
16+
{
17+
#if defined(NPY_HAVE_AVX512_SKX)
18+
avx512_argsort(arr, arg, num, true);
19+
#elif defined(NPY_HAVE_AVX2)
20+
avx2_argsort(arr, arg, num, true);
21+
#endif
22+
}
923

10-
/* arg methods currently only have AVX-512 versions */
24+
template<typename T>
25+
void x86_argselect(T* arr, size_t* arg, npy_intp kth, npy_intp num)
26+
{
1127
#if defined(NPY_HAVE_AVX512_SKX)
28+
avx512_argselect(arr, arg, kth, num, true);
29+
#elif defined(NPY_HAVE_AVX2)
30+
avx2_argselect(arr, arg, kth, num, true);
31+
#endif
32+
}
33+
} // anonymous
34+
35+
namespace np { namespace qsort_simd {
36+
1237
template<> void NPY_CPU_DISPATCH_CURFX(ArgQSelect)(int32_t *arr, npy_intp* arg, npy_intp num, npy_intp kth)
1338
{
14-
avx512_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num);
39+
x86_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num);
1540
}
1641
template<> void NPY_CPU_DISPATCH_CURFX(ArgQSelect)(uint32_t *arr, npy_intp* arg, npy_intp num, npy_intp kth)
1742
{
18-
avx512_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num);
43+
x86_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num);
1944
}
2045
template<> void NPY_CPU_DISPATCH_CURFX(ArgQSelect)(int64_t*arr, npy_intp* arg, npy_intp num, npy_intp kth)
2146
{
22-
avx512_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num);
47+
x86_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num);
2348
}
2449
template<> void NPY_CPU_DISPATCH_CURFX(ArgQSelect)(uint64_t*arr, npy_intp* arg, npy_intp num, npy_intp kth)
2550
{
26-
avx512_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num);
51+
x86_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num);
2752
}
2853
template<> void NPY_CPU_DISPATCH_CURFX(ArgQSelect)(float *arr, npy_intp* arg, npy_intp num, npy_intp kth)
2954
{
30-
avx512_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num, true);
55+
x86_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num);
3156
}
3257
template<> void NPY_CPU_DISPATCH_CURFX(ArgQSelect)(double *arr, npy_intp* arg, npy_intp num, npy_intp kth)
3358
{
34-
avx512_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num, true);
59+
x86_argselect(arr, reinterpret_cast<size_t*>(arg), kth, num);
3560
}
3661
template<> void NPY_CPU_DISPATCH_CURFX(ArgQSort)(int32_t *arr, npy_intp *arg, npy_intp size)
3762
{
38-
avx512_argsort(arr, reinterpret_cast<size_t*>(arg), size);
63+
x86_argsort(arr, reinterpret_cast<size_t*>(arg), size);
3964
}
4065
template<> void NPY_CPU_DISPATCH_CURFX(ArgQSort)(uint32_t *arr, npy_intp *arg, npy_intp size)
4166
{
42-
avx512_argsort(arr, reinterpret_cast<size_t*>(arg), size);
67+
x86_argsort(arr, reinterpret_cast<size_t*>(arg), size);
4368
}
4469
template<> void NPY_CPU_DISPATCH_CURFX(ArgQSort)(int64_t *arr, npy_intp *arg, npy_intp size)
4570
{
46-
avx512_argsort(arr, reinterpret_cast<size_t*>(arg), size);
71+
x86_argsort(arr, reinterpret_cast<size_t*>(arg), size);
4772
}
4873
template<> void NPY_CPU_DISPATCH_CURFX(ArgQSort)(uint64_t *arr, npy_intp *arg, npy_intp size)
4974
{
50-
avx512_argsort(arr, reinterpret_cast<size_t*>(arg), size);
75+
x86_argsort(arr, reinterpret_cast<size_t*>(arg), size);
5176
}
5277
template<> void NPY_CPU_DISPATCH_CURFX(ArgQSort)(float *arr, npy_intp *arg, npy_intp size)
5378
{
54-
avx512_argsort(arr, reinterpret_cast<size_t*>(arg), size, true);
79+
x86_argsort(arr, reinterpret_cast<size_t*>(arg), size);
5580
}
5681
template<> void NPY_CPU_DISPATCH_CURFX(ArgQSort)(double *arr, npy_intp *arg, npy_intp size)
5782
{
58-
avx512_argsort(arr, reinterpret_cast<size_t*>(arg), size, true);
83+
x86_argsort(arr, reinterpret_cast<size_t*>(arg), size);
5984
}
60-
#endif // NPY_HAVE_AVX512_SKX
6185

6286
}} // namespace np::simd
6387

0 commit comments

Comments
 (0)
0