@@ -91,5 +91,60 @@ NPY_FINLINE __m128i npyv_mul_u8(__m128i a, __m128i b)
9191// TODO: emulate integer division
9292#define npyv_div_f32 _mm_div_ps
9393#define npyv_div_f64 _mm_div_pd
94-
94+ /***************************
95+ * FUSED
96+ ***************************/
97+ #ifdef NPY_HAVE_FMA3
98+ // multiply and add, a*b + c
99+ #define npyv_muladd_f32 _mm_fmadd_ps
100+ #define npyv_muladd_f64 _mm_fmadd_pd
101+ // multiply and subtract, a*b - c
102+ #define npyv_mulsub_f32 _mm_fmsub_ps
103+ #define npyv_mulsub_f64 _mm_fmsub_pd
104+ // negate multiply and add, -(a*b) + c
105+ #define npyv_nmuladd_f32 _mm_fnmadd_ps
106+ #define npyv_nmuladd_f64 _mm_fnmadd_pd
107+ // negate multiply and subtract, -(a*b) - c
108+ #define npyv_nmulsub_f32 _mm_fnmsub_ps
109+ #define npyv_nmulsub_f64 _mm_fnmsub_pd
110+ #elif defined(NPY_HAVE_FMA4 )
111+ // multiply and add, a*b + c
112+ #define npyv_muladd_f32 _mm_macc_ps
113+ #define npyv_muladd_f64 _mm_macc_pd
114+ // multiply and subtract, a*b - c
115+ #define npyv_mulsub_f32 _mm_msub_ps
116+ #define npyv_mulsub_f64 _mm_msub_pd
117+ // negate multiply and add, -(a*b) + c
118+ #define npyv_nmuladd_f32 _mm_nmacc_ps
119+ #define npyv_nmuladd_f64 _mm_nmacc_pd
120+ #else
121+ // multiply and add, a*b + c
122+ NPY_FINLINE npyv_f32 npyv_muladd_f32 (npyv_f32 a , npyv_f32 b , npyv_f32 c )
123+ { return npyv_add_f32 (npyv_mul_f32 (a , b ), c ); }
124+ NPY_FINLINE npyv_f64 npyv_muladd_f64 (npyv_f64 a , npyv_f64 b , npyv_f64 c )
125+ { return npyv_add_f64 (npyv_mul_f64 (a , b ), c ); }
126+ // multiply and subtract, a*b - c
127+ NPY_FINLINE npyv_f32 npyv_mulsub_f32 (npyv_f32 a , npyv_f32 b , npyv_f32 c )
128+ { return npyv_sub_f32 (npyv_mul_f32 (a , b ), c ); }
129+ NPY_FINLINE npyv_f64 npyv_mulsub_f64 (npyv_f64 a , npyv_f64 b , npyv_f64 c )
130+ { return npyv_sub_f64 (npyv_mul_f64 (a , b ), c ); }
131+ // negate multiply and add, -(a*b) + c
132+ NPY_FINLINE npyv_f32 npyv_nmuladd_f32 (npyv_f32 a , npyv_f32 b , npyv_f32 c )
133+ { return npyv_sub_f32 (c , npyv_mul_f32 (a , b )); }
134+ NPY_FINLINE npyv_f64 npyv_nmuladd_f64 (npyv_f64 a , npyv_f64 b , npyv_f64 c )
135+ { return npyv_sub_f64 (c , npyv_mul_f64 (a , b )); }
136+ #endif // NPY_HAVE_FMA3
137+ #ifndef NPY_HAVE_FMA3 // for FMA4 and NON-FMA3
138+ // negate multiply and subtract, -(a*b) - c
139+ NPY_FINLINE npyv_f32 npyv_nmulsub_f32 (npyv_f32 a , npyv_f32 b , npyv_f32 c )
140+ {
141+ npyv_f32 neg_a = npyv_xor_f32 (a , npyv_setall_f32 (-0.0f ));
142+ return npyv_sub_f32 (npyv_mul_f32 (neg_a , b ), c );
143+ }
144+ NPY_FINLINE npyv_f64 npyv_nmulsub_f64 (npyv_f64 a , npyv_f64 b , npyv_f64 c )
145+ {
146+ npyv_f64 neg_a = npyv_xor_f64 (a , npyv_setall_f64 (-0.0 ));
147+ return npyv_sub_f64 (npyv_mul_f64 (neg_a , b ), c );
148+ }
149+ #endif // !NPY_HAVE_FMA3
95150#endif // _NPY_SIMD_SSE_ARITHMETIC_H
0 commit comments