8000 NPYV: add fused multiply subtract/add intrinics for all supported pla… · numpy/numpy@5a642d2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5a642d2

Browse files
committed
NPYV: add fused multiply subtract/add intrinics for all supported platforms
1 parent c970c04 commit 5a642d2

File tree

5 files changed

+175
-1
lines changed

5 files changed

+175
-1
lines changed

numpy/core/src/common/simd/avx2/arithmetic.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,48 @@
7272
#define npyv_div_f32 _mm256_div_ps
7373
#define npyv_div_f64 _mm256_div_pd
7474

75+
/***************************
76+
* FUSED
77+
***************************/
78+
#ifdef NPY_HAVE_FMA3
79+
// multiply and add, a*b + c
80+
#define npyv_muladd_f32 _mm256_fmadd_ps
81+
#define npyv_muladd_f64 _mm256_fmadd_pd
82+
// multiply and subtract, a*b - c
83+
#define npyv_mulsub_f32 _mm256_fmsub_ps
84+
#define npyv_mulsub_f64 _mm256_fmsub_pd
85+
// negate multiply and add, -(a*b) + c
86+
#define npyv_nmuladd_f32 _mm256_fnmadd_ps
87+
#define npyv_nmuladd_f64 _mm256_fnmadd_pd
88+
// negate multiply and subtract, -(a*b) - c
89+
#define npyv_nmulsub_f32 _mm256_fnmsub_ps
90+
#define npyv_nmulsub_f64 _mm256_fnmsub_pd
91+
#else
92+
// multiply and add, a*b + c
93+
NPY_FINLINE npyv_f32 npyv_muladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
94+
{ return npyv_add_f32(npyv_mul_f32(a, b), c); }
95+
NPY_FINLINE npyv_f64 npyv_muladd_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
96+
{ return npyv_add_f64(npyv_mul_f64(a, b), c); }
97+
// multiply and subtract, a*b - c
98+
NPY_FINLINE npyv_f32 npyv_mulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
99+
{ return npyv_sub_f32(npyv_mul_f32(a, b), c); }
100+
NPY_FINLINE npyv_f64 npyv_mulsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
101+
{ return npyv_sub_f64(npyv_mul_f64(a, b), c); }
102+
// negate multiply and add, -(a*b) + c
103+
NPY_FINLINE npyv_f32 npyv_nmuladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
104+
{ return npyv_sub_f32(c, npyv_mul_f32(a, b)); }
105+
NPY_FINLINE npyv_f64 npyv_nmuladd_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
106+
{ return npyv_sub_f64(c, npyv_mul_f64(a, b)); }
107+
// negate multiply and subtract, -(a*b) - c
108+
NPY_FINLINE npyv_f32 npyv_nmulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
109+
{
110+
npyv_f32 neg_a = npyv_xor_f32(a, npyv_setall_f32(-0.0f));
111+
return npyv_sub_f32(npyv_mul_f32(neg_a, b), c);
112+
}
113+
NPY_FINLINE npyv_f64 npyv_nmulsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
114+
{
115+
npyv_f64 neg_a = npyv_xor_f64(a, npyv_setall_f64(-0.0));
116+
return npyv_sub_f64(npyv_mul_f64(neg_a, b), c);
117+
}
118+
#endif // !NPY_HAVE_FMA3
75119
< 8000 span class=pl-k>#endif // _NPY_SIMD_AVX2_ARITHMETIC_H

numpy/core/src/common/simd/avx512/arithmetic.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,4 +113,20 @@ NPY_FINLINE __m512i npyv_mul_u8(__m512i a, __m512i b)
113113
#define npyv_div_f32 _mm512_div_ps
114114
#define npyv_div_f64 _mm512_div_pd
115115

116+
/***************************
117+
* FUSED
118+
***************************/
119+
// multiply and add, a*b + c
120+
#define npyv_muladd_f32 _mm512_fmadd_ps
121+
#define npyv_muladd_f64 _mm512_fmadd_pd
122+
// multiply and subtract, a*b - c
123+
#define npyv_mulsub_f32 _mm512_fmsub_ps
124+
#define npyv_mulsub_f64 _mm512_fmsub_pd
125+
// negate multiply and add, -(a*b) + c
126+
#define npyv_nmuladd_f32 _mm512_fnmadd_ps
127+
#define npyv_nmuladd_f64 _mm512_fnmadd_pd
128+
// negate multiply and subtract, -(a*b) - c
129+
#define npyv_nmulsub_f32 _mm512_fnmsub_ps
130+
#define npyv_nmulsub_f64 _mm512_fnmsub_pd
131+
116132
#endif // _NPY_SIMD_AVX512_ARITHMETIC_H

numpy/core/src/common/simd/neon/arithmetic.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,47 @@
7575
#endif
7676
#define npyv_div_f64 vdivq_f64
7777

78+
/***************************
79+
* FUSED F32
80+
***************************/
81+
#ifdef NPY_HAVE_NEON_VFPV4 // FMA
82+
// multiply and add, a*b + c
83+
NPY_FINLINE npyv_f32 npyv_muladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
84+
{ return vfmaq_f32(c, a, b); }
85+
// multiply and subtract, a*b - c
86+
NPY_FINLINE npyv_f32 npyv_mulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
87+
{ return vfmaq_f32(vnegq_f32(c), a, b); }
88+
// negate multiply and add, -(a*b) + c
89+
NPY_FINLINE npyv_f32 npyv_nmuladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
90+
{ return vfmsq_f32(c, a, b); }
91+
// negate multiply and subtract, -(a*b) - c
92+
NPY_FINLINE npyv_f32 npyv_nmulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
93+
{ return vfmsq_f32(vnegq_f32(c), a, b); }
94+
#else
95+
// multiply and add, a*b + c
96+
NPY_FINLINE npyv_f32 npyv_muladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
97+
{ return vmlaq_f32(c, a, b); }
98+
// multiply and subtract, a*b - c
99+
NPY_FINLINE npyv_f32 npyv_mulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
100+
{ return vmlaq_f32(vnegq_f32(c), a, b); }
101+
// negate multiply and add, -(a*b) + c
102+
NPY_FINLINE npyv_f32 npyv_nmuladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
103+
{ return vmlsq_f32(c, a, b); }
104+
// negate multiply and subtract, -(a*b) - c
105+
NPY_FINLINE npyv_f32 npyv_nmulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
106+
{ return vmlsq_f32(vnegq_f32(c), a, b); }
107+
#endif
108+
/***************************
109+
* FUSED F64
110+
***************************/
111+
#if NPY_SIMD_F64
112+
NPY_FINLINE npyv_f64 npyv_muladd_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
113+
{ return vfmaq_f64(c, a, b); }
114+
NPY_FINLINE npyv_f64 npyv_mulsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
115+
{ return vfmaq_f64(vnegq_f64(c), a, b); }
116+
NPY_FINLINE npyv_f64 npyv_nmuladd_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
117+
{ return vfmsq_f64(c, a, b); }
118+
NPY_FINLINE npyv_f64 npyv_nmulsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
119+
{ return vfmsq_f64(vnegq_f64(c), a, b); }
120+
#endif // NPY_SIMD_F64
78121
#endif // _NPY_SIMD_NEON_ARITHMETIC_H

numpy/core/src/common/simd/sse/arithmetic.h

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,5 +91,60 @@ NPY_FINLINE __m128i npyv_mul_u8(__m128i a, __m128i b)
9191
// TODO: emulate integer division
9292
#define npyv_div_f32 _mm_div_ps
9393
#define npyv_div_f64 _mm_div_pd
94-
94+
/***************************
95+
* FUSED
96+
***************************/
97+
#ifdef NPY_HAVE_FMA3
98+
// multiply and add, a*b + c
99+
#define npyv_muladd_f32 _mm_fmadd_ps
100+
#define npyv_muladd_f64 _mm_fmadd_pd
101+
// multiply and subtract, a*b - c
102+
#define npyv_mulsub_f32 _mm_fmsub_ps
103+
#define npyv_mulsub_f64 _mm_fmsub_pd
104+
// negate multiply and add, -(a*b) + c
105+
#define npyv_nmuladd_f32 _mm_fnmadd_ps
106+
#define npyv_nmuladd_f64 _mm_fnmadd_pd
107+
// negate multiply and subtract, -(a*b) - c
108+
#define npyv_nmulsub_f32 _mm_fnmsub_ps
109+
#define npyv_nmulsub_f64 _mm_fnmsub_pd
110+
#elif defined(NPY_HAVE_FMA4)
111+
// multiply and add, a*b + c
112+
#define npyv_muladd_f32 _mm_macc_ps
113+
#define npyv_muladd_f64 _mm_macc_pd
114+
// multiply and subtract, a*b - c
115+
#define npyv_mulsub_f32 _mm_msub_ps
116+
#define npyv_mulsub_f64 _mm_msub_pd
117+
// negate multiply and add, -(a*b) + c
118+
#define npyv_nmuladd_f32 _mm_nmacc_ps
119+
#define npyv_nmuladd_f64 _mm_nmacc_pd
120+
#else
121+
// multiply and add, a*b + c
122+
NPY_FINLINE npyv_f32 npyv_muladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
123+
{ return npyv_add_f32(npyv_mul_f32(a, b), c); }
124+
NPY_FINLINE npyv_f64 npyv_muladd_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
125+
{ return npyv_add_f64(npyv_mul_f64(a, b), c); }
126+
// multiply and subtract, a*b - c
127+
NPY_FINLINE npyv_f32 npyv_mulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
128+
{ return npyv_sub_f32(npyv_mul_f32(a, b), c); }
129+
NPY_FINLINE npyv_f64 npyv_mulsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
130+
{ return npyv_sub_f64(npyv_mul_f64(a, b), c); }
131+
// negate multiply and add, -(a*b) + c
132+
NPY_FINLINE npyv_f32 npyv_nmuladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
133+
{ return npyv_sub_f32(c, npyv_mul_f32(a, b)); }
134+
NPY_FINLINE npyv_f64 npyv_nmuladd_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
135+
{ return npyv_sub_f64(c, npyv_mul_f64(a, b)); }
136+
#endif // NPY_HAVE_FMA3
137+
#ifndef NPY_HAVE_FMA3 // for FMA4 and NON-FMA3
138+
// negate multiply and subtract, -(a*b) - c
139+
NPY_FINLINE npyv_f32 npyv_nmulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
140+
{
141+
npyv_f32 neg_a = npyv_xor_f32(a, npyv_setall_f32(-0.0f));
142+
return npyv_sub_f32(npyv_mul_f32(neg_a, b), c);
143+
}
144+
NPY_FINLINE npyv_f64 npyv_nmulsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
145+
{
146+
npyv_f64 neg_a = npyv_xor_f64(a, npyv_setall_f64(-0.0));
147+
return npyv_sub_f64(npyv_mul_f64(neg_a, b), c);
148+
}
149+
#endif // !NPY_HAVE_FMA3
95150
#endif // _NPY_SIMD_SSE_ARITHMETIC_H

numpy/core/src/common/simd/vsx/arithmetic.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,20 @@
100100
#define npyv_div_f32 vec_div
101101
#define npyv_div_f64 vec_div
102102

103+
/***************************
104+
* FUSED
105+
***************************/
106+
// multiply and add, a*b + c
107+
#define npyv_muladd_f32 vec_madd
108+
#define npyv_muladd_f64 vec_madd
109+
// multiply and subtract, a*b - c
110+
#define npyv_mulsub_f32 vec_msub
111+
#define npyv_mulsub_f64 vec_msub
112+
// negate multiply and add, -(a*b) + c
113+
#define npyv_nmuladd_f32 vec_nmsub // equivalent to -(a*b - c)
114+
#define npyv_nmuladd_f64 vec_nmsub
115+
// negate multiply and subtract, -(a*b) - c
116+
#define npyv_nmulsub_f32 vec_nmadd // equivalent to -(a*b + c)
117+
#define npyv_nmulsub_f64 vec_nmadd
118+
103119
#endif // _NPY_SIMD_VSX_ARITHMETIC_H

0 commit comments

Comments
 (0)
0