8000 ENH, SIMD: Add sqrt, abs, recip and square intrinsics for f32/64 · numpy/numpy@ade6638 · GitHub
[go: up one dir, main page]

Skip to content

Commit ade6638

Browse files
committed
ENH, SIMD: Add sqrt, abs, recip and square intrinsics for f32/64
this patch also improves division precision for NEON/A32
1 parent 671e8a0 commit ade6638

File tree

11 files changed

+273
-5
lines changed

11 files changed

+273
-5
lines changed

numpy/core/src/common/simd/avx2/avx2.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,4 @@ typedef struct { __m256d val[3]; } npyv_f64x3;
6767
#include "operators.h"
6868
#include "conversion.h"
6969
#include "arithmetic.h"
70+
#include "math.h"
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#ifndef NPY_SIMD
2+
#error "Not a standalone header"
3+
#endif
4+
5+
#ifndef _NPY_SIMD_AVX2_MATH_H
6+
#define _NPY_SIMD_AVX2_MATH_H
7+
/***************************
8+
* Elementary
9+
***************************/
10+
// Square root
11+
#define npyv_sqrt_f32 _mm256_sqrt_ps
12+
#define npyv_sqrt_f64 _mm256_sqrt_pd
13+
14+
// Reciprocal
15+
NPY_FINLINE npyv_f32 npyv_recip_f32(npyv_f32 a)
16+
{ return _mm256_div_ps(_mm256_set1_ps(1.0f), a); }
17+
NPY_FINLINE npyv_f64 npyv_recip_f64(npyv_f64 a)
18+
{ return _mm256_div_pd(_mm256_set1_pd(1.0), a); }
19+
20+
// Absolute
21+
NPY_FINLINE npyv_f32 npyv_abs_f32(npyv_f32 a)
22+
{
23+
return _mm256_and_ps(
24+
a, _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffff))
25+
);
26+
}
27+
NPY_FINLINE npyv_f64 npyv_abs_f64(npyv_f64 a)
28+
{
29+
return _mm256_and_pd(
30+
a, _mm256_castsi256_pd(npyv_setall_s64(0x7fffffffffffffffLL))
31+
);
32+
}
33+
34+
// Square
35+
NPY_FINLINE npyv_f32 npyv_square_f32(npyv_f32 a)
36+
{ return _mm256_mul_ps(a, a); }
37+
NPY_FINLINE npyv_f64 npyv_square_f64(npyv_f64 a)
38+
{ return _mm256_mul_pd(a, a); }
39+
40+
#endif

numpy/core/src/common/simd/avx512/avx512.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,4 @@ typedef struct { __m512d val[3]; } npyv_f64x3;
7272
#include "operators.h"
7373
#include "conversion.h"
7474
#include "arithmetic.h"
75+
#include "math.h"
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#ifndef NPY_SIMD
2+
#error "Not a standalone header"
3+
#endif
4+
5+
#ifndef _NPY_SIMD_AVX512_MATH_H
6+
#define _NPY_SIMD_AVX512_MATH_H
7+
8+
/***************************
9+
* Elementary
10+
***************************/
11+
// Square root
12+
#define npyv_sqrt_f32 _mm512_sqrt_ps
13+
#define npyv_sqrt_f64 _mm512_sqrt_pd
14+
15+
// Reciprocal
16+
NPY_FINLINE npyv_f32 npyv_recip_f32(npyv_f32 a)
17+
{ return _mm512_div_ps(_mm512_set1_ps(1.0f), a); }
18+
NPY_FINLINE npyv_f64 npyv_recip_f64(npyv_f64 a)
19+
{ return _mm512_div_pd(_mm512_set1_pd(1.0), a); }
20+
21+
// Absolute
22+
NPY_FINLINE npyv_f32 npyv_abs_f32(npyv_f32 a)
23+
{
24+
#if 0 // def NPY_HAVE_AVX512DQ
25+
return _mm512_range_ps(a, a, 8);
26+
#else
27+
return npyv_and_f32(
28+
a, _mm512_castsi512_ps(_mm512_set1_epi32(0x7fffffff))
29+
);
30+
#endif
31+
}
32+
NPY_FINLINE npyv_f64 npyv_abs_f64(npyv_f64 a)
33+
{
34+
#if 0 // def NPY_HAVE_AVX512DQ
35+
return _mm512_range_pd(a, a, 8);
36+
#else
37+
return npyv_and_f64(
38+
a, _mm512_castsi512_pd(_mm512_set1_epi64(0x7fffffffffffffffLL))
39+
);
40+
#endif
41+
}
42+
43+
// Square
44+
NPY_FINLINE npyv_f32 npyv_square_f32(npyv_f32 a)
45+
{ return _mm512_mul_ps(a, a); }
46+
NPY_FINLINE npyv_f64 npyv_square_f64(npyv_f64 a)
47+
{ return _mm512_mul_pd(a, a); }
48+
49+
#endif

numpy/core/src/common/simd/neon/arithmetic.h

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,26 @@
6363
/***************************
6464
* Division
6565
***************************/
66-
#ifdef __aarch64__
66+
#if NPY_SIMD_F64
6767
#define npyv_div_f32 vdivq_f32
6868
#else
69-
NPY_FINLINE float32x4_t npyv_div_f32(float32x4_t a, float32x4_t b)
69+
NPY_FINLINE npyv_f32 npyv_div_f32(npyv_f32 a, npyv_f32 b)
7070
{
71-
float32x4_t recip = vrecpeq_f32(b);
72-
recip = vmulq_f32(vrecpsq_f32(b, recip), recip);
73-
return vmulq_f32(a, recip);
71+
// Based on ARM doc, see https://developer.arm.com/documentation/dui0204/j/CIHDIACI
72+
// estimate to 1/b
73+
npyv_f32 recipe = vrecpeq_f32(b);
74+
/**
75+
* Newton-Raphson iteration:
76+
* x[n+1] = x[n] * (2-d * x[n])
77+
* converges to (1/d) if x0 is the result of VRECPE applied to d.
78+
*
79+
* NOTE: at least 3 iterations is needed to improve precision
80+
*/
81+
recipe = vmulq_f32(vrecpsq_f32(b, recipe), recipe);
82+
recipe = vmulq_f32(vrecpsq_f32(b, recipe), recipe);
83+
recipe = vmulq_f32(vrecpsq_f32(b, recipe), recipe);
84+
// a/b = a*recip(b)
85+
return vmulq_f32(a, recipe);
7486
}
7587
#endif
7688
#define npyv_div_f64 vdivq_f64
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#ifndef NPY_SIMD
2+
#error "Not a standalone header"
3+
#endif
4+
5+
#ifndef _NPY_SIMD_NEON_MATH_H
6+
#define _NPY_SIMD_NEON_MATH_H
7+
8+
/***************************
9+
* Elementary
10+
***************************/
11+
// Absolute
12+
#define npyv_abs_f32 vabsq_f32
13+
#define npyv_abs_f64 vabsq_f64
14+
15+
// Square
16+
NPY_FINLINE npyv_f32 npyv_square_f32(npyv_f32 a)
17+
{ return vmulq_f32(a, a); }
18+
#if NPY_SIMD_F64
19+
NPY_FINLINE npyv_f64 npyv_square_f64(npyv_f64 a)
20+
{ return vmulq_f64(a, a); }
21+
#endif
22+
23+
// Square root
24+
#if NPY_SIMD_F64
25+
#define npyv_sqrt_f32 vsqrtq_f32
26+
#define npyv_sqrt_f64 vsqrtq_f64
27+
#else
28+
// Based on ARM doc, see https://developer.arm.com/documentation/dui0204/j/CIHDIACI
29+
NPY_FINLINE npyv_f32 npyv_sqrt_f32(npyv_f32 a)
30+
{
31+
const npyv_f32 zero = vdupq_n_f32(0.0f);
32+
const npyv_u32 pinf = vdupq_n_u32(0x7f800000);
33+
npyv_u32 is_zero = vceqq_f32(a, zero), is_inf = vceqq_u32(vreinterpretq_u32_f32(a), pinf);
34+
// guard agianst floating-point division-by-zero error
35+
npyv_f32 guard_byz = vbslq_f32(is_zero, vreinterpretq_f32_u32(pinf), a);
36+
// estimate to (1/√a)
37+
npyv_f32 rsqrte = vrsqrteq_f32(guard_byz);
38+
/**
39+
* Newton-Raphson iteration:
40+
* x[n+1] = x[n] * (3-d * (x[n]*x[n]) )/2)
41+
* converges to (1/√d)if x0 is the result of VRSQRTE applied to d.
42+
*
43+
* NOTE: at least 3 iterations is needed to improve precision
44+
*/
45+
rsqrte = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, rsqrte), rsqrte), rsqrte);
46+
rsqrte = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, rsqrte), rsqrte), rsqrte);
47+
rsqrte = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, rsqrte), rsqrte), rsqrte);
48+
// a * (1/√a)
49+
npyv_f32 sqrt = vmulq_f32(a, rsqrte);
50+
// return zero if the a is zero
51+
// - return zero if a is zero.
52+
// - return positive infinity if a is positive infinity
53+
return vbslq_f32(vorrq_u32(is_zero, is_inf), a, sqrt);
54+
}
55+
#endif // NPY_SIMD_F64
56+
57+
// Reciprocal
58+
NPY_FINLINE npyv_f32 npyv_recip_f32(npyv_f32 a)
59+
{
60+
#if NPY_SIMD_F64
61+
const npyv_f32 one = vdupq_n_f32(1.0f);
62+
return npyv_div_f32(one, a);
63+
#else
64+
npyv_f32 recipe = vrecpeq_f32(a);
65+
/**
66+
* Newton-Raphson iteration:
67+
* x[n+1] = x[n] * (2-d * x[n])
68+
* converges to (1/d) if x0 is the result of VRECPE applied to d.
69+
*
70+
* NOTE: at least 3 iterations is needed to improve precision
71+
*/
72+
recipe = vmulq_f32(vrecpsq_f32(a, recipe), recipe);
73+
recipe = vmulq_f32(vrecpsq_f32(a, recipe), recipe);
74+
recipe = vmulq_f32(vrecpsq_f32(a, recipe), recipe);
75+
return recipe;
76+
#endif
77+
}
78+
#if NPY_SIMD_F64
79+
NPY_FINLINE npyv_f64 npyv_recip_f64(npyv_f64 a)
80+
{
81+
const npyv_f64 one = vdupq_n_f64(1.0);
82+
return npyv_div_f64(one, a);
83+
}
84+
#endif // NPY_SIMD_F64
85+
86+
#endif // _NPY_SIMD_SSE_MATH_H

numpy/core/src/common/simd/neon/neon.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,4 @@ typedef float64x2x3_t npyv_f64x3;
7272
#include "operators.h"
7373
#include "conversion.h"
7474
#include "arithmetic.h"
75+
#include "math.h"

numpy/core/src/common/simd/sse/math.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#ifndef NPY_SIMD
2+
#error "Not a standalone header"
3+
#endif
4+
5+
#ifndef _NPY_SIMD_SSE_MATH_H
6+
#define _NPY_SIMD_SSE_MATH_H
7+
/***************************
8+
* Elementary
9+
***************************/
10+
// Square root
11+
#define npyv_sqrt_f32 _mm_sqrt_ps
12+
#define npyv_sqrt_f64 _mm_sqrt_pd
13+
14+
// Reciprocal
15+
NPY_FINLINE npyv_f32 npyv_recip_f32(npyv_f32 a)
16+
{ return _mm_div_ps(_mm_set1_ps(1.0f), a); }
17+
NPY_FINLINE npyv_f64 npyv_recip_f64(npyv_f64 a)
18+
{ return _mm_div_pd(_mm_set1_pd(1.0), a); }
19+
20+
// Absolute
21+
NPY_FINLINE npyv_f32 npyv_abs_f32(npyv_f32 a)
22+
{
23+
return _mm_and_ps(
24+
a, _mm_castsi128_ps(_mm_set1_epi32(0x7fffffff))
25+
);
26+
}
27+
NPY_FINLINE npyv_f64 npyv_abs_f64(npyv_f64 a)
28+
{
29+
return _mm_and_pd(
30+
a, _mm_castsi128_pd(npyv_setall_s64(0x7fffffffffffffffLL))
31+
);
32+
}
33+
34+
// Square
35+
NPY_FINLINE npyv_f32 npyv_square_f32(npyv_f32 a)
36+
{ return _mm_mul_ps(a, a); }
37+
NPY_FINLINE npyv_f64 npyv_square_f64(npyv_f64 a)
38+
{ return _mm_mul_pd(a, a); }
39+
40+
#endif

numpy/core/src/common/simd/sse/sse.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,4 @@ typedef struct { __m128d val[3]; } npyv_f64x3;
6464
#include "operators.h"
6565
#include "conversion.h"
6666
#include "arithmetic.h"
67+
#include "math.h"

numpy/core/src/common/simd/vsx/math.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#ifndef NPY_SIMD
2+
#error "Not a standalone header"
3+
#endif
4+
5+
#ifndef _NPY_SIMD_VSX_MATH_H
6+
#define _NPY_SIMD_VSX_MATH_H
7+
/***************************
8+
* Elementary
9+
***************************/
10+
// Square root
11+
#define npyv_sqrt_f32 vec_sqrt
12+
#define npyv_sqrt_f64 vec_sqrt
13+
14+
// Reciprocal
15+
NPY_FINLINE npyv_f32 npyv_recip_f32(npyv_f32 a)
16+
{
17+
const npyv_f32 one = npyv_setall_f32(1.0f);
18+
return vec_div(one, a);
19+
}
20+
NPY_FINLINE npyv_f64 npyv_recip_f64(npyv_f64 a)
21+
{
22+
const npyv_f64 one = npyv_setall_f64(1.0);
23+
return vec_div(one, a);
24+
}
25+
26+
// Absolute
27+
#define npyv_abs_f32 vec_abs
28+
#define npyv_abs_f64 vec_abs
29+
30+
// Square
31+
NPY_FINLINE npyv_f32 npyv_square_f32(npyv_f32 a)
32+
{ return vec_mul(a, a); }
33+
NPY_FINLINE npyv_f64 npyv_square_f64(npyv_f64 a)
34+
{ return vec_mul(a, a); }
35+
36+
#endif // _NPY_SIMD_VSX_MATH_H

numpy/core/src/common/simd/vsx/vsx.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,4 @@ typedef struct { npyv_f64 val[3]; } npyv_f64x3;
6262
#include "operators.h"
6363
#include "conversion.h"
6464
#include "arithmetic.h"
65+
#include "math.h"

0 commit comments

Comments
 (0)
0