8000 Merge pull request #17681 from Qiyu8/sum_intrinsic · numpy/numpy@671e8a0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 671e8a0

Browse files
authored
Merge pull request #17681 from Qiyu8/sum_intrinsic
SIMD: Add sum intrinsics for float/double.
2 parents 1a12887 + 1f0298d commit 671e8a0

File tree

7 files changed

+138
-0
lines changed

7 files changed

+138
-0
lines changed

numpy/core/src/_simd/_simd.dispatch.c.src

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
* #mul_sup = 1, 1, 1, 1, 1, 1, 0, 0, 1, 1#
2121
* #div_sup = 0, 0, 0, 0, 0, 0, 0, 0, 1, 1#
2222
* #fused_sup = 0, 0, 0, 0, 0, 0, 0, 0, 1, 1#
23+
* #sum_sup = 0, 0, 0, 0, 0, 0, 0, 0, 1, 1#
2324
* #ncont_sup = 0, 0, 0, 0, 1, 1, 1, 1, 1, 1#
2425
* #shl_imm = 0, 0, 15, 15, 31, 31, 63, 63, 0, 0#
2526
* #shr_imm = 0, 0, 16, 16, 32, 32, 64, 64, 0, 0#
@@ -351,6 +352,10 @@ SIMD_IMPL_INTRIN_3(@intrin@_@sfx@, v@sfx@, v@sfx@, v@sfx@, v@sfx@)
351352
/**end repeat1**/
352353
#endif // fused_sup
353354

355+
#if @sum_sup@
356+
SIMD_IMPL_INTRIN_1(sum_@sfx@, @sfx@, v@sfx@)
357+
#endif // sum_sup
358+
354359
#endif // simd_sup
355360
/**end repeat**/
356361
/***************************
@@ -370,6 +375,7 @@ static PyMethodDef simd__intrinsics_methods[] = {
370375
* #mul_sup = 1, 1, 1, 1, 1, 1, 0, 0, 1, 1#
371376
* #div_sup = 0, 0, 0, 0, 0, 0, 0, 0, 1, 1#
372377
* #fused_sup = 0, 0, 0, 0, 0, 0, 0, 0, 1, 1#
378+
* #sum_sup = 0, 0, 0, 0, 0, 0, 0, 0, 1, 1#
373379
* #ncont_sup = 0, 0, 0, 0, 1, 1, 1, 1, 1, 1#
374380
* #shl_imm = 0, 0, 15, 15, 31, 31, 63, 63, 0, 0#
375381
* #shr_imm = 0, 0, 16, 16, 32, 32, 64, 64, 0, 0#
@@ -484,6 +490,10 @@ SIMD_INTRIN_DEF(@intrin@_@sfx@)
484490
/**end repeat1**/
485491
#endif // fused_sup
486492

493+
#if @sum_sup@
494+
SIMD_INTRIN_DEF(sum_@sfx@)
495+
#endif // sum_sup
496+
487497
#endif // simd_sup
488498
/**end repeat**/
489499

numpy/core/src/common/simd/avx2/arithmetic.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,27 @@
116116
return npyv_sub_f64(npyv_mul_f64(neg_a, b), c);
117117
}
118118
#endif // !NPY_HAVE_FMA3
119+
120+
// Horizontal add: Calculates the sum of all vector elements.
121+
NPY_FINLINE float npyv_sum_f32(__m256 a)
122+
{
123+
__m256 sum_halves = _mm256_hadd_ps(a, a);
124+
sum_halves = _mm256_hadd_ps(sum_halves, sum_halves);
125+
__m128 lo = _mm256_castps256_ps128(sum_halves);
126+
__m128 hi = _mm256_extractf128_ps(sum_halves, 1);
127+
__m128 sum = _mm_add_ps(lo, hi);
128+
return _mm_cvtss_f32(sum);
129+
}
130+
131+
NPY_FINLINE double npyv_sum_f64(__m256d a)
132+
{
133+
__m256d sum_halves = _mm256_hadd_pd(a, a);
134+
__m128d lo = _mm256_castpd256_pd128(sum_halves);
135+
__m128d hi = _mm256_extractf128_pd(sum_halves, 1);
136+
__m128d sum = _mm_add_pd(lo, hi);
137+
return _mm_cvtsd_f64(sum);
138+
}
139+
119140
#endif // _NPY_SIMD_AVX2_ARITHMETIC_H
141+
142+

numpy/core/src/common/simd/avx512/arithmetic.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,4 +129,47 @@ NPY_FINLINE __m512i npyv_mul_u8(__m512i a, __m512i b)
129129
#define npyv_nmulsub_f32 _mm512_fnmsub_ps
130130
#define npyv_nmulsub_f64 _mm512_fnmsub_pd
131131

132+
/***************************
133+
* Reduce Sum
134+
* there are three ways to implement reduce sum for AVX512:
135+
* 1- split(256) /add /split(128) /add /hadd /hadd /extract
136+
* 2- shuff(cross) /add /shuff(cross) /add /shuff /add /shuff /add /extract
137+
* 3- _mm512_reduce_add_ps/pd
138+
* The first one is been widely used by many projects
139+
*
140+
* the second one is used by Intel Compiler, maybe because the
141+
* latency of hadd increased by (2-3) starting from Skylake-X which makes two
142+
* extra shuffles(non-cross) cheaper. check https://godbolt.org/z/s3G9Er for more info.
143+
*
144+
* The third one is almost the same as the second one but only works for
145+
* intel compiler/GCC 7.1/Clang 4, we still need to support older GCC.
146+
***************************/
147+
#ifdef NPY_HAVE_AVX512F_REDUCE
148+
#define npyv_sum_f32 _mm512_reduce_add_ps
149+
#define npyv_sum_f64 _mm512_reduce_add_pd
150+
#else
151+
NPY_FINLINE float npyv_sum_f32(npyv_f32 a)
152+
{
153+
__m512 h64 = _mm512_shuffle_f32x4(a, a, _MM_SHUFFLE(3, 2, 3, 2));
154+
__m512 sum32 = _mm512_add_ps(a, h64);
155+
__m512 h32 = _mm512_shuffle_f32x4(sum32, sum32, _MM_SHUFFLE(1, 0, 3, 2));
156+
__m512 sum16 = _mm512_add_ps(sum32, h32);
157+
__m512 h16 = _mm512_permute_ps(sum16, _MM_SHUFFLE(1, 0, 3, 2));
158+
__m512 sum8 = _mm512_add_ps(sum16, h16);
159+
__m512 h4 = _mm512_permute_ps(sum8, _MM_SHUFFLE(2, 3, 0, 1));
160+
__m512 sum4 = _mm512_add_ps(sum8, h4);
161+
return _mm_cvtss_f32(_mm512_castps512_ps128(sum4));
162+
}
163+
NPY_FINLINE double npyv_sum_f64(npyv_f64 a)
164+
{
165+
__m512d h64 = _mm512_shuffle_f64x2(a, a, _MM_SHUFFLE(3, 2, 3, 2));
166+
__m512d sum32 = _mm512_add_pd(a, h64);
167+
__m512d h32 = _mm512_permutex_pd(sum32, _MM_SHUFFLE(1, 0, 3, 2));
168+
__m512d sum16 = _mm512_add_pd(sum32, h32);
169+
__m512d h16 = _mm512_permute_pd(sum16, _MM_SHUFFLE(2, 3, 0, 1));
170+
__m512d sum8 = _mm512_add_pd(sum16, h16);
171+
return _mm_cvtsd_f64(_mm512_castpd512_pd128(sum8));
172+
}
173+
#endif
174+
132175
#endif // _NPY_SIMD_AVX512_ARITHMETIC_H

numpy/core/src/common/simd/neon/arithmetic.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,17 @@
118118
NPY_FINLINE npyv_f64 npyv_nmulsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
119119
{ return vfmsq_f64(vnegq_f64(c), a, b); }
120120
#endif // NPY_SIMD_F64
121+
122+
// Horizontal add: Calculates the sum of all vector elements.
123+
#if NPY_SIMD_F64
124+
#define npyv_sum_f32 vaddvq_f32
125+
#define npyv_sum_f64 vaddvq_f64
126+
#else
127+
NPY_FINLINE float npyv_sum_f32(npyv_f32 a)
128+
{
129+
float32x2_t r = vadd_f32(vget_high_f32(a), vget_low_f32(a));
130+
return vget_lane_f32(vpadd_f32(r, r), 0);
131+
}
132+
#endif
133+
121134
#endif // _NPY_SIMD_NEON_ARITHMETIC_H

numpy/core/src/common/simd/sse/arithmetic.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,4 +147,31 @@ NPY_FINLINE __m128i npyv_mul_u8(__m128i a, __m128i b)
147147
return npyv_sub_f64(npyv_mul_f64(neg_a, b), c);
148148
}
149149
#endif // !NPY_HAVE_FMA3
150+
151+
// Horizontal add: Calculates the sum of all vector elements.
152+
NPY_FINLINE float npyv_sum_f32(__m128 a)
153+
{
154+
#ifdef NPY_HAVE_SSE3
155+
__m128 sum_halves = _mm_hadd_ps(a, a);
156+
return _mm_cvtss_f32(_mm_hadd_ps(sum_halves, sum_halves));
157+
#else
158+
__m128 t1 = _mm_movehl_ps(a, a);
159+
__m128 t2 = _mm_add_ps(a, t1);
160+
__m128 t3 = _mm_shuffle_ps(t2, t2, 1);
161+
__m128 t4 = _mm_add_ss(t2, t3);
162+
return _mm_cvtss_f32(t4);
163+
#endif
164+
}
165+
166+
NPY_FINLINE double npyv_sum_f64(__m128d a)
167+
{
168+
#ifdef NPY_HAVE_SSE3
169+
return _mm_cvtsd_f64(_mm_hadd_pd(a, a));
170+
#else
171+
return _mm_cvtsd_f64(_mm_add_pd(a, _mm_unpackhi_pd(a, a)));
172+
#endif
173+
}
174+
150175
#endif // _NPY_SIMD_SSE_ARITHMETIC_H
176+
177+

numpy/core/src/common/simd/vsx/arithmetic.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,16 @@
116116
#define npyv_nmulsub_f32 vec_nmadd // equivalent to -(a*b + c)
117117
#define npyv_nmulsub_f64 vec_nmadd
118118

119+
// Horizontal add: Calculates the sum of all vector elements.
120+
NPY_FINLINE float npyv_sum_f32(npyv_f32 a)
121+
{
122+
npyv_f32 sum = vec_add(a, npyv_combineh_f32(a, a));
123+
return vec_extract(sum, 0) + vec_extract(sum, 1);
124+
}
125+
126+
NPY_FINLINE double npyv_sum_f64(npyv_f64 a)
127+
{
128+
return vec_extract(a, 0) + vec_extract(a, 1);
129+
}
130+
119131
#endif // _NPY_SIMD_VSX_ARITHMETIC_H

numpy/core/tests/test_simd.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,16 @@ def test_arithmetic_div(self):
512512
div = self.div(vdata_a, vdata_b)
513513
assert div == data_div
514514

515+
def test_arithmetic_reduce_sum(self):
516+
if not self._is_fp():
517+
return
518+
# reduce sum
519+
data = self._data()
520+
vdata = self.load(data)
521+
522+
data_sum = sum(data)
523+
vsum = self.sum(vdata)
524+
assert vsum == data_sum
515525

516526
int_sfx = ("u8", "s8", "u16", "s16", "u32", "s32", "u64", "s64")
517527
fp_sfx = ("f32", "f64")

0 commit comments

Comments
 (0)
0