8000 SIMD: add fused multiply subtract/add intrinics for all supported platforms by seiko2plus · Pull Request #17258 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

SIMD: add fused multiply subtract/add intrinics for all supported platforms #17258

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions numpy/core/src/common/simd/avx2/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,48 @@
#define npyv_div_f32 _mm256_div_ps
#define npyv_div_f64 _mm256_div_pd

/***************************
* FUSED
***************************/
#ifdef NPY_HAVE_FMA3
// multiply and add, a*b + c
#define npyv_muladd_f32 _mm256_fmadd_ps
#define npyv_muladd_f64 _mm256_fmadd_pd
// multiply and subtract, a*b - c
#define npyv_mulsub_f32 _mm256_fmsub_ps
#define npyv_mulsub_f64 _mm256_fmsub_pd
// negate multiply and add, -(a*b) + c
#define npyv_nmuladd_f32 _mm256_fnmadd_ps
#define npyv_nmuladd_f64 _mm256_fnmadd_pd
// negate multiply and subtract, -(a*b) - c
#define npyv_nmulsub_f32 _mm256_fnmsub_ps
#define npyv_nmulsub_f64 _mm256_fnmsub_pd
#else
// multiply and add, a*b + c
NPY_FINLINE npyv_f32 npyv_muladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
{ return npyv_add_f32(npyv_mul_f32(a, b), c); }
NPY_FINLINE npyv_f64 npyv_muladd_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
{ return npyv_add_f64(npyv_mul_f64(a, b), c); }
// multiply and subtract, a*b - c
NPY_FINLINE npyv_f32 npyv_mulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
{ return npyv_sub_f32(npyv_mul_f32(a, b), c); }
NPY_FINLINE npyv_f64 npyv_mulsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
{ return npyv_sub_f64(npyv_mul_f64(a, b), c); }
// negate multiply and add, -(a*b) + c
NPY_FINLINE npyv_f32 npyv_nmuladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
{ return npyv_sub_f32(c, npyv_mul_f32(a, b)); }
NPY_FINLINE npyv_f64 npyv_nmuladd_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
{ return npyv_sub_f64(c, npyv_mul_f64(a, b)); }
// negate multiply and subtract, -(a*b) - c
NPY_FINLINE npyv_f32 npyv_nmulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
{
npyv_f32 neg_a = npyv_xor_f32(a, npyv_setall_f32(-0.0f));
return npyv_sub_f32(npyv_mul_f32(neg_a, b), c);
}
NPY_FINLINE npyv_f64 npyv_nmulsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
{
npyv_f64 neg_a = npyv_xor_f64(a, npyv_setall_f64(-0.0));
return npyv_sub_f64(npyv_mul_f64(neg_a, b), c);
}
#endif // !NPY_HAVE_FMA3
#endif // _NPY_SIMD_AVX2_ARITHMETIC_H
16 changes: 16 additions & 0 deletions numpy/core/src/common/simd/avx512/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,20 @@ NPY_FINLINE __m512i npyv_mul_u8(__m512i a, __m512i b)
#define npyv_div_f32 _mm512_div_ps
#define npyv_div_f64 _mm512_div_pd

/***************************
* FUSED
***************************/
// multiply and add, a*b + c
#define npyv_muladd_f32 _mm512_fmadd_ps
#define npyv_muladd_f64 _mm512_fmadd_pd
// multiply and subtract, a*b - c
#define npyv_mulsub_f32 _mm512_fmsub_ps
#define npyv_mulsub_f64 _mm512_fmsub_pd
// negate multiply and add, -(a*b) + c
#define npyv_nmuladd_f32 _mm512_fnmadd_ps
#define npyv_nmuladd_f64 _mm512_fnmadd_pd
// negate multiply and subtract, -(a*b) - c
#define npyv_nmulsub_f32 _mm512_fnmsub_ps
#define npyv_nmulsub_f64 _mm512_fnmsub_pd

#endif // _NPY_SIMD_AVX512_ARITHMETIC_H
43 changes: 43 additions & 0 deletions numpy/core/src/common/simd/neon/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,47 @@
#endif
#define npyv_div_f64 vdivq_f64

/***************************
* FUSED F32
***************************/
#ifdef NPY_HAVE_NEON_VFPV4 // FMA
// multiply and add, a*b + c
NPY_FINLINE npyv_f32 npyv_muladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
{ return vfmaq_f32(c, a, b); }
// multiply and subtract, a*b - c
NPY_FINLINE npyv_f32 npyv_mulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
{ return vfmaq_f32(vnegq_f32(c), a, b); }
// negate multiply and add, -(a*b) + c
NPY_FINLINE npyv_f32 npyv_nmuladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
{ return vfmsq_f32(c, a, b); }
// negate multiply and subtract, -(a*b) - c
NPY_FINLINE npyv_f32 npyv_nmulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
{ return vfmsq_f32(vnegq_f32(c), a, b); }
#else
// multiply and add, a*b + c
NPY_FINLINE npyv_f32 npyv_muladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
{ return vmlaq_f32(c, a, b); }
// multiply and subtract, a*b - c
NPY_FINLINE npyv_f32 npyv_mulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
{ return vmlaq_f32(vnegq_f32(c), a, b); }
// negate multiply and add, -(a*b) + c
NPY_FINLINE npyv_f32 npyv_nmuladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
{ return vmlsq_f32(c, a, b); }
// negate multiply and subtract, -(a*b) - c
NPY_FINLINE npyv_f32 npyv_nmulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
{ return vmlsq_f32(vnegq_f32(c), a, b); }
#endif
/***************************
* FUSED F64
***************************/
#if NPY_SIMD_F64
NPY_FINLINE npyv_f64 npyv_muladd_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
{ return vfmaq_f64(c, a, b); }
NPY_FINLINE npyv_f64 npyv_mulsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
{ return vfmaq_f64(vnegq_f64(c), a, b); }
NPY_FINLINE npyv_f64 npyv_nmuladd_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
{ return vfmsq_f64(c, a, b); }
NPY_FINLINE npyv_f64 npyv_nmulsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
{ return vfmsq_f64(vnegq_f64(c), a, b); }
#endif // NPY_SIMD_F64
#endif // _NPY_SIMD_NEON_ARITHMETIC_H
57 changes: 56 additions & 1 deletion numpy/core/src/common/simd/sse/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,60 @@ NPY_FINLINE __m128i npyv_mul_u8(__m128i a, __m128i b)
// TODO: emulate integer division
#define npyv_div_f32 _mm_div_ps
#define npyv_div_f64 _mm_div_pd

/***************************
* FUSED
***************************/
#ifdef NPY_HAVE_FMA3
// multiply and add, a*b + c
#define npyv_muladd_f32 _mm_fmadd_ps
#define npyv_muladd_f64 _mm_fmadd_pd
// multiply and subtract, a*b - c
#define npyv_mulsub_f32 _mm_fmsub_ps
#define npyv_mulsub_f64 _mm_fmsub_pd
// negate multiply and add, -(a*b) + c
#define npyv_nmuladd_f32 _mm_fnmadd_ps
#define npyv_nmuladd_f64 _mm_fnmadd_pd
// negate multiply and subtract, -(a*b) - c
#define npyv_nmulsub_f32 _mm_fnmsub_ps
#define npyv_nmulsub_f64 _mm_fnmsub_pd
#elif defined(NPY_HAVE_FMA4)
// multiply and add, a*b + c
#define npyv_muladd_f32 _mm_macc_ps
#define npyv_muladd_f64 _mm_macc_pd
// multiply and subtract, a*b - c
#define npyv_mulsub_f32 _mm_msub_ps
#define npyv_mulsub_f64 _mm_msub_pd
// negate multiply and add, -(a*b) + c
#define npyv_nmuladd_f32 _mm_nmacc_ps
#define npyv_nmuladd_f64 _mm_nmacc_pd
#else
// multiply and add, a*b + c
NPY_FINLINE npyv_f32 npyv_muladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
{ return npyv_add_f32(npyv_mul_f32(a, b), c); }
NPY_FINLINE npyv_f64 npyv_muladd_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
{ return npyv_add_f64(npyv_mul_f64(a, b), c); }
// multiply and subtract, a*b - c
NPY_FINLINE npyv_f32 npyv_mulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
{ return npyv_sub_f32(npyv_mul_f32(a, b), c); }
NPY_FINLINE npyv_f64 npyv_mulsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
{ return npyv_sub_f64(npyv_mul_f64(a, b), c); }
// negate multiply and add, -(a*b) + c
NPY_FINLINE npyv_f32 npyv_nmuladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
{ return npyv_sub_f32(c, npyv_mul_f32(a, b)); }
NPY_FINLINE npyv_f64 npyv_nmuladd_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
{ return npyv_sub_f64(c, npyv_mul_f64(a, b)); }
#endif // NPY_HAVE_FMA3
#ifndef NPY_HAVE_FMA3 // for FMA4 and NON-FMA3
// negate multiply and subtract, -(a*b) - c
NPY_FINLINE npyv_f32 npyv_nmulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
{
npyv_f32 neg_a = npyv_xor_f32(a, npyv_setall_f32(-0.0f));
return npyv_sub_f32(npyv_mul_f32(neg_a, b), c);
}
NPY_FINLINE npyv_f64 npyv_nmulsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
{
npyv_f64 neg_a = npyv_xor_f64(a, npyv_setall_f64(-0.0));
return npyv_sub_f64(npyv_mul_f64(neg_a, b), c);
}
#endif // !NPY_HAVE_FMA3
#endif // _NPY_SIMD_SSE_ARITHMETIC_H
16 changes: 16 additions & 0 deletions numpy/core/src/common/simd/vsx/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,20 @@
#define npyv_div_f32 vec_div
#define npyv_div_f64 vec_div

/***************************
* FUSED
***************************/
// multiply and add, a*b + c
#define npyv_muladd_f32 vec_madd
#define npyv_muladd_f64 vec_madd
// multiply and subtract, a*b - c
#define npyv_mulsub_f32 vec_msub
#define npyv_mulsub_f64 vec_msub
// negate multiply and add, -(a*b) + c
#define npyv_nmuladd_f32 vec_nmsub // equivalent to -(a*b - c)
#define npyv_nmuladd_f64 vec_nmsub
// negate multiply and subtract, -(a*b) - c
#define npyv_nmulsub_f32 vec_nmadd // equivalent to -(a*b + c)
#define npyv_nmulsub_f64 vec_nmadd

#endif // _NPY_SIMD_VSX_ARITHMETIC_H
0