8000 Reimplement `Math.rint` without floating-point remainder. · sjrd/scala-js@31666da · GitHub
[go: up one dir, main page]

Skip to content

Commit 31666da

Browse files
committed
Reimplement Math.rint without floating-point remainder.
Now that I know what a `%` on floating-point numbers really looks like inside, it's worth avoiding it in other low-level operations. A cursory `git grep` highlighted `rint` as the only method in the javalib that used that operation, so we rewrite it in a different way.
1 parent 70a4164 commit 31666da

File tree

2 files changed

+162
-42
lines changed
  • javalib/src/main/scala/java/lang
  • test-suite/shared/src/test/scala/org/scalajs/testsuite/javalib/lang

2 files changed

+162
-42
lines changed

javalib/src/main/scala/java/lang/Math.scala

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,71 @@ object Math {
5353

5454
// Wasm intrinsic
5555
def rint(a: scala.Double): scala.Double = {
56-
val rounded = js.Math.round(a)
57-
val mod = a % 1.0
58-
// The following test is also false for specials (0's, Infinities and NaN)
59-
if (mod == 0.5 || mod == -0.5) {
60-
// js.Math.round(a) rounds up but we have to round to even
61-
if (rounded % 2.0 == 0.0) rounded
62-
else rounded - 1.0
63-
} else {
64-
rounded
65-
}
56+
/* Is the integer-valued `x` odd? Fused by hand of `(x.toLong & 1L) != 0L`.
57+
* Corner cases: returns false for Infinities and NaN.
58+
*/
59+
@inline def isOdd(x: scala.Double): scala.Boolean =
60+
(x.asInstanceOf[js.Dynamic] & 1.asInstanceOf[js.Dynamic]).asInstanceOf[Int] != 0
61+
62+
/* js.Math.round(a) does *almost* what we want. It rounds to nearest,
63+
* breaking ties *up*. We need to break ties to *even*. So we need to
64+
* detect ties, and for them, detect if we rounded to odd instead of even.
65+
*
66+
* The reasons why the apparently simple algorithm below works are subtle,
67+
* and vary a lot depending on the range of `a`:
68+
*
69+
* - a is NaN
70+
* r is NaN, then the == is false
71+
* -> return r
72+
*
73+
* - a is +-Infinity
74+
* r == a, then == is true! but isOdd(r) is false
75+
* -> return r
76+
*
77+
* - 2**53 <= abs(a) < Infinity
78+
* r == a, r - 0.5 rounds back to a so == is true!
79+
* fortunately, isOdd(r) is false because all a >= 2**53 are even
80+
* -> return r
81+
*
82+
* - 2**52 <= abs(a) < 2**53
83+
* r == a (because all a's are integers in that range)
84+
* - a is odd
85+
* r - 0.5 rounds down (towards even) to r - 1.0
86+
* so a == r - 0.5 is false
87+
* -> return r
88+
* - a is even
89+
* r - 0.5 rounds back up! (towards even) to r
90+
* so a == r - 0.5 is true!
91+
* but, isOdd(r) is false
92+
* -> return r
93+
*
94+
* - 0.5 < abs(a) < 2**52
95+
* then -2**52 + 0.5 <= a <= 2**52 - 0.5 (because values in-between are not representable)
96+
* since Math.round rounds *up* on ties, r is an integer in the range (-2**52, 2**52]
97+
* r - 0.5 is therefore lossless
98+
* so a == r - 0.5 accurately detects ties, and isOdd(r) breaks ties
99+
* -> return `r`` or `r - 1.0`
100+
*
101+
* - a == +0.5
102+
* r == 1.0
103+
* a == r - 0.5 is true and isOdd(r) is true
104+
* -> return `r - 1.0`, which is +0.0
105+
*
106+
* - a == -0.5
107+
* r == -0.0
108+
* a == r - 0.5 is true and isOdd(r) is false
109+
* -> return `r`, which is -0.0
110+
*
111+
* - 0.0 <= abs(a) < 0.5
112+
* r == 0.0 with the same sign as a
113+
* a == r - 0.5 is false
114+
* -> return r
115+
*/
116+
val r = js.Math.round(a)
117+
if ((a == r - 0.5) && isOdd(r))
118+
r - 1.0
119+
else
120+
r
66121
}
67122

68123
@inline def round(a: scala.Float): scala.Int = js.Math.round(a).toInt

test-suite/shared/src/test/scala/org/scalajs/testsuite/javalib/lang/MathTest.scala

Lines changed: 97 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -415,44 +415,109 @@ class MathTest {
415415
@Test def rintForDouble(): Unit = {
416416
import Math.rint
417417

418-
def isPosZero(x: Double): Boolean =
419-
x == 0.0 && (1.0 / x) == Double.PositiveInfinity
420-
421-
def isNegZero(x: Double): Boolean =
422-
x == 0.0 && (1.0 / x) == Double.NegativeInfinity
418+
val intLimit = (1L << 52).toDouble
419+
val halfIntLimit = (1L << 51).toDouble
420+
val doubleIntLimit = (1L << 53).toDouble
423421

424422
// Specials
425-
assertTrue(isPosZero(rint(+0.0)))
426-
assertTrue(isNegZero(rint(-0.0)))
427-
assertEquals(Double.PositiveInfinity, rint(Double.PositiveInfinity), 0.0)
428-
assertEquals(Double.NegativeInfinity, rint(Double.NegativeInfinity), 0.0)
429-
assertTrue(rint(Double.NaN).isNaN)
423+
assertSameDouble(+0.0, rint(+0.0))
424+
assertSameDouble(-0.0, rint(-0.0))
425+
assertSameDouble(Double.PositiveInfinity, rint(Double.PositiveInfinity))
426+
assertSameDouble(Double.NegativeInfinity, rint(Double.NegativeInfinity))
427+
assertSameDouble(Double.NaN, rint(Double.NaN))
430428

431429
// Positive values
432-
assertTrue(isPosZero(rint(0.1)))
433-
assertTrue(isPosZero(rint(0.5)))
434-
assertEquals(1.0, rint(0.5000000000000001), 0.0)
435-
assertEquals(1.0, rint(0.999), 0.0)
436-
assertEquals(1.0, rint(1.4999999999999998), 0.0)
437-
assertEquals(2.0, rint(1.5), 0.0)
438-
assertEquals(2.0, rint(2.0), 0.0)
439-
assertEquals(2.0, rint(2.1), 0.0)
440-
assertEquals(2.0, rint(2.5), 0.0)
441-
assertEquals(Double.MaxValue, rint(Double.MaxValue), 0.0)
442-
assertEquals(4503599627370496.0, rint(4503599627370495.5), 0.0) // MaxSafeInt / 2
430+
assertSameDouble(+0.0, rint(Double.MinPositiveValue))
431+
assertSameDouble(+0.0, rint(java.lang.Double.MIN_NORMAL))
432+
assertSameDouble(+0.0, rint(0.1))
433+
assertSameDouble(+0.0, rint(0.5))
434+
assertSameDouble(1.0, rint(0.5000000000000001))
435+
assertSameDouble(1.0, rint(0.999))
436+
assertSameDouble(1.0, rint(1.4999999999999998))
437+
assertSameDouble(2.0, rint(1.5))
438+
assertSameDouble(2.0, rint(2.0))
439+
assertSameDouble(2.0, rint(2.1))
440+
assertSameDouble(2.0, rint(2.5))
441+
assertSameDouble(3.0, rint(2.75))
442+
assertSameDouble(3.0, rint(3.25))
443+
assertSameDouble(4.0, rint(3.5))
444+
assertSameDouble(4.0, rint(3.75))
445+
assertSameDouble(halfIntLimit - 2.0, rint(halfIntLimit - 1.5))
446+
assertSameDouble(halfIntLimit - 1.0, rint(halfIntLimit - 1.25))
447+
assertSameDouble(halfIntLimit - 1.0, rint(halfIntLimit - 1.0))
448+
assertSameDouble(halfIntLimit - 1.0, rint(halfIntLimit - 0.75))
449+
assertSameDouble(halfIntLimit, rint(halfIntLimit - 0.5))
450+
assertSameDouble(halfIntLimit, rint(halfIntLimit - 0.25))
451+
assertSameDouble(halfIntLimit, rint(halfIntLimit))
452+
assertSameDouble(halfIntLimit, rint(halfIntLimit + 0.25))
453+
assertSameDouble(halfIntLimit, rint(halfIntLimit + 0.5))
454+
assertSameDouble(halfIntLimit + 1.0, rint(halfIntLimit + 0.75))
455+
assertSameDouble(halfIntLimit + 1.0, rint(halfIntLimit + 1.0))
456+
assertSameDouble(halfIntLimit + 1.0, rint(halfIntLimit + 1.25))
457+
assertSameDouble(halfIntLimit + 2.0, rint(halfIntLimit + 1.5))
458+
assertSameDouble(intLimit - 2.0, rint(intLimit - 1.5))
459+
assertSameDouble(intLimit - 1.0, rint(intLimit - 1.0))
460+
assertSameDouble(intLimit, rint(intLimit - 0.5))
461+
assertSameDouble(intLimit, rint(intLimit))
462+
463+
val largeIntegers = List(
464+
// corner cases just above intLimit
465+
intLimit + 1.0,
466+
intLimit + 2.0,
467+
intLimit + 3.0,
468+
intLimit + 4.0,
469+
// corner cases around doubleIntLimit
470+
doubleIntLimit - 4.0,
471+
doubleIntLimit - 3.0,
472+
doubleIntLimit - 2.0,
473+
doubleIntLimit - 1.0,
474+
doubleIntLimit,
475+
doubleIntLimit + 2.0,
476+
doubleIntLimit + 4.0,
477+
doubleIntLimit + 6.0,
478+
doubleIntLimit + 8.0,
479+
doubleIntLimit + 16.0,
480+
Double.MaxValue
481+
)
482+
for (x <- largeIntegers)
483+
assertSameDouble(x, rint(x))
443484

444485
// Negative values
445-
assertTrue(isNegZero(rint(-0.1)))
446-
assertTrue(isNegZero(rint(-0.5)))
447-
assertEquals(-1.0, rint(-0.5000000000000001), 0.0)
448-
assertEquals(-1.0, rint(-0.999), 0.0)
449-
assertEquals(-1.0, rint(-1.4999999999999998), 0.0)
450-
assertEquals(-2.0, rint(-1.5), 0.0)
451-
assertEquals(-2.0, rint(-2.0), 0.0)
452-
assertEquals(-2.0, rint(-2.1), 0.0)
453-
assertEquals(-2.0, rint(-2.5), 0.0)
454-
assertEquals(Double.MinValue, rint(Double.MinValue), 0.0)
455-
assertEquals(-4503599627370496.0, rint(-4503599627370495.5), 0.0) // -MaxSafeInt / 2
486+
assertSameDouble(-0.0, rint(-Double.MinPositiveValue))
487+
assertSameDouble(-0.0, rint(-java.lang.Double.MIN_NORMAL))
488+
assertSameDouble(-0.0, rint(-0.1))
489+
assertSameDouble(-0.0, rint(-0.5))
490+
assertSameDouble(-1.0, rint(-0.5000000000000001))
491+
assertSameDouble(-1.0, rint(-0.999))
492+
assertSameDouble(-1.0, rint(-1.4999999999999998))
493+
assertSameDouble(-2.0, rint(-1.5))
494+
assertSameDouble(-2.0, rint(-2.0))
495+
assertSameDouble(-2.0, rint(-2.1))
496+
assertSameDouble(-2.0, rint(-2.5))
497+
assertSameDouble(-3.0, rint(-2.75))
498+
assertSameDouble(-3.0, rint(-3.25))
499+
assertSameDouble(-4.0, rint(-3.5))
500+
assertSameDouble(-4.0, rint(-3.75))
501+
assertSameDouble(-(halfIntLimit - 2.0), rint(-(halfIntLimit - 1.5)))
502+
assertSameDouble(-(halfIntLimit - 1.0), rint(-(halfIntLimit - 1.25)))
503+
assertSameDouble(-(halfIntLimit - 1.0), rint(-(halfIntLimit - 1.0)))
504+
assertSameDouble(-(halfIntLimit - 1.0), rint(-(halfIntLimit - 0.75)))
505+
assertSameDouble(-halfIntLimit, rint(-(halfIntLimit - 0.5)))
506+
assertSameDouble(-halfIntLimit, rint(-(halfIntLimit - 0.25)))
507+
assertSameDouble(-halfIntLimit, rint(-halfIntLimit))
508+
assertSameDouble(-halfIntLimit, rint(-(halfIntLimit + 0.25)))
509+
assertSameDouble(-halfIntLimit, rint(-(halfIntLimit + 0.5)))
510+
assertSameDouble(-(halfIntLimit + 1.0), rint(-(halfIntLimit + 0.75)))
511+
assertSameDouble(-(halfIntLimit + 1.0), rint(-(halfIntLimit + 1.0)))
512+
assertSameDouble(-(halfIntLimit + 1.0), rint(-(halfIntLimit + 1.25)))
513+
assertSameDouble(-(halfIntLimit + 2.0), rint(-(halfIntLimit + 1.5)))
514+
assertSameDouble(-(intLimit - 2.0), rint(-(intLimit - 1.5)))
515+
assertSameDouble(-(intLimit - 1.0), rint(-(intLimit - 1.0)))
516+
assertSameDouble(-intLimit, rint(-(intLimit - 0.5)))
517+
assertSameDouble(-intLimit, rint(-intLimit))
518+
519+
for (x <- largeIntegers)
520+
assertSameDouble(-x, rint(-x))
456521
}
457522

458523
@Test def addExact(): Unit = {

0 commit comments

Comments
 (0)
0