@@ -578,11 +578,31 @@ INLINE void ff_fma(flexfloat_t *dest, const flexfloat_t *a, const flexfloat_t *b
578
578
assert ((dest -> desc .exp_bits == a -> desc .exp_bits ) && (dest -> desc .frac_bits == a -> desc .frac_bits ) &&
579
579
(a -> desc .exp_bits == b -> desc .exp_bits ) && (a -> desc .frac_bits == b -> desc .frac_bits ) &&
580
580
(b -> desc .exp_bits == c -> desc .exp_bits ) && (b -> desc .frac_bits == c -> desc .frac_bits ));
581
- dest -> value = fma (a -> value , b -> value , c -> value );
581
+ #ifdef FLEXFLOAT_ROUNDING
582
+ // Change the rounding mode according to the error direction if we need to do manual rounding for RNE
583
+ int mode = fegetround ();
584
+ bool eff_sub = flexfloat_sign (a ) ^ flexfloat_sign (b ) ^ flexfloat_sign (c );
585
+ if (a -> desc .frac_bits < NUM_BITS_FRAC && mode == FE_TONEAREST ) {
586
+ if (!eff_sub ) { // in this case, we need to round away from zero
587
+ fexcept_t flags ;
588
+ fegetexceptflag (& flags , FE_ALL_EXCEPT ); // get accrued flags to not tarnish them here
589
+ double try = fma (a -> value , b -> value , c -> value );
590
+ (try >= 0 ) ? fesetround (FE_UPWARD ) : fesetround (FE_DOWNWARD );
591
+ fesetexceptflag (& flags , FE_ALL_EXCEPT ); // restore flags here
592
+ } else {
593
+ fesetround (FE_TOWARDZERO ); // just truncate
594
+ }
595
+ }
596
+ #endif
597
+ dest -> value = fma (a -> value , b -> value , c -> value ); // finally the actual operation
582
598
#ifdef FLEXFLOAT_TRACKING
583
599
dest -> exact_value = fma (a -> exact_value , b -> exact_value , c -> exact_value );
584
600
if (dest -> tracking_fn ) (dest -> tracking_fn )(dest , dest -> tracking_arg );
585
601
#endif
602
+ #ifdef FLEXFLOAT_ROUNDING
603
+ if (a -> desc .frac_bits < NUM_BITS_FRAC && mode == FE_TONEAREST )
604
+ fesetround (FE_TONEAREST ); // restore rounding
605
+ #endif
586
606
flexfloat_sanitize (dest );
587
607
#ifdef FLEXFLOAT_STATS
588
608
if (StatsEnabled ) getOpStats (dest -> desc )-> fma += 1 ;
0 commit comments