8000 Implement jl.Math.{multiplyFull, multiplyHigh, unsignedMultiplyHigh}. by sjrd · Pull Request #5175 · scala-js/scala-js · GitHub
[go: up one dir, main page]

Skip to content

Implement jl.Math.{multiplyFull, multiplyHigh, unsignedMultiplyHigh}. #5175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions javalib/src/main/scala/java/lang/Math.scala
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,43 @@ object Math {
if (a >= Integer.MIN_VALUE && a <= Integer.MAX_VALUE) a.toInt
else throw new ArithmeticException("Integer overflow")

// RuntimeLong intrinsic
@inline
def multiplyFull(x: scala.Int, y: scala.Int): scala.Long =
x.toLong * y.toLong

@inline
def multiplyHigh(x: scala.Long, y: scala.Long): scala.Long = {
/* Hacker's Delight, Section 8-2, Figure 8-2,
* where we have "inlined" all the variables used only once to help our
* optimizer perform simplifications.
*/

val x0 = x & 0xffffffffL
val x1 = x >> 32
val y0 = y & 0xffffffffL
val y1 = y >> 32

val t = x1 * y0 + ((x0 * y0) >>> 32)
x1 * y1 + (t >> 32) + (((t & 0xffffffffL) + x0 * y1) >> 32)
}

@inline
def unsignedMultiplyHigh(x: scala.Long, y: scala.Long): scala.Long = {
/* Hacker's Delight, Section 8-2:
* > For an unsigned version, simply change all the int declarations to unsigned.
* In Scala, that means changing all the >> into >>>.
*/

val x0 = x & 0xffffffffL
val x1 = x >>> 32
val y0 = y & 0xffffffffL
val y1 = y >>> 32

val t = x1 * y0 + ((x0 * y0) >>> 32)
x1 * y1 + (t >>> 32) + (((t & 0xffffffffL) + x0 * y1) >>> 32)
}

def floorDiv(a: scala.Int, b: scala.Int): scala.Int = {
val quot = a / b
if ((a < 0) == (b < 0) || quot * b == a) quot
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,47 @@ object RuntimeLong {
}
}

/** Intrinsic for Math.multiplyFull.
*
* Compared to the regular expansion of `x.toLong * y.toLong`, this
* intrinsic avoids 2 int multiplications.
*/
@inline
def multiplyFull(a: Int, b: Int): RuntimeLong = {
/* We use Hacker's Delight, Section 8-2, Figure 8-2, to compute the hi
* word of the result. We reuse intermediate products to compute the lo
* word, like we do in `RuntimeLong.*`.
*
* We swap the role of a1b0 and a0b1 compared to Hacker's Delight, to
* optimize for the case where a1b0 collapses to 0, like we do in
* `RuntimeLong.*`. The optimizer normalizes constants in multiplyFull to
* be on the left-hand-side (when it cannot do constant-folding to begin
* with). Therefore, `b` is never constant in practice.
*/

val a0 = a & 0xffff
val a1 = a >> 16
val b0 = b & 0xffff
val b1 = b >> 16

val a0b0 = a0 * b0
val a1b0 = a1 * b0 // collapses to 0 when a is constant and 0 <= a <= 0xffff
val a0b1 = a0 * b1

/* lo = a * b, but we compute the above 3 subproducts for hi anyway,
* so we reuse them to compute lo too, trading a * for 2 +'s and 1 <<.
*/
val lo = a0b0 + ((a1b0 + a0b1) << 16)

val t = a0b1 + (a0b0 >>> 16)
val hi = {
a1 * b1 + (t >> 16) +
(((t & 0xffff) + a1b0) >> 16) // collapses to 0 when a1b0 = 0
}

new RuntimeLong(lo, hi)
}

@inline
def divide(a: RuntimeLong, b: RuntimeLong): RuntimeLong = {
val lo = divideImpl(a.lo, a.hi, b.lo, b.hi)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ private[linker] object LongImpl {
val AllModuleMethods = Set(
fromInt, fromDouble)

// Methods on the companion used for intrinsics

final val multiplyFull = MethodName("multiplyFull", List(IntRef, IntRef), RTLongRef)

val AllIntrinsicModuleMethods = Set(
multiplyFull)

// Extract the parts to give to the initFromParts constructor

def extractParts(value: Long): (Int, Int) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ final class IncOptimizer private[optimizer] (config: CommonPhaseConfig, collOps:
multiple(
cond(!targetIsWebAssembly && !esFeatures.allowBigIntsForLongs) {
// Required by the intrinsics manipulating Longs
callMethods(LongImpl.RuntimeLongClass, LongImpl.AllIntrinsicMethods.toList)
multiple(
callMethods(LongImpl.RuntimeLongClass, LongImpl.AllIntrinsicMethods.toList),
callMethods(LongImpl.RuntimeLongModuleClass, LongImpl.AllIntrinsicModuleMethods.toList)
)
},
cond(targetIsWebAssembly) {
// Required by the intrinsic CharacterCodePointToString
Expand Down
F438
Original file line number Diff line number Diff line change
Expand Up @@ -3074,6 +3074,32 @@ private[optimizer] abstract class OptimizerCore(
case MathMaxDouble =>
contTree(wasmBinaryOp(WasmBinaryOp.F64Max, targs.head, targs.tail.head))

case MathMultiplyFull =>
def expand(targs: List[PreTransform]): TailRec[Tree] = {
import LongImpl.{RuntimeLongModuleClass => modCls}
val receiver =
makeCast(LoadModule(modCls), ClassType(modCls, nullable = false)).toPreTransform

pretransformApply(ApplyFlags.empty,
receiver,
MethodIdent(LongImpl.multiplyFull),
targs,
ClassType(LongImpl.RuntimeLongClass, nullable = true),
isStat, usePreTransform)(
cont)
}

targs match {
case List(PreTransLit(IntLiteral(x)), PreTransLit(IntLiteral(y))) =>
// cannot actually call multiplyHigh to constant-fold because it is JDK9+
contTree(LongLiteral(x.toLong * y.toLong))
case List(tlhs, trhs @ PreTransLit(_)) =>
// normalize a single constant on the left; the implementation is optimized for that case
expand(trhs :: tlhs :: Nil)
case _ =>
expand(targs)
}

// scala.collection.mutable.ArrayBuilder

case GenericArrayBuilderResult =>
Expand Down Expand Up @@ -4389,6 +4415,14 @@ private[optimizer] abstract class OptimizerCore(
PreTransLit(IntLiteral(_))) if (y & 31) != 0 =>
foldBinaryOp(Int_>>>, lhs, rhs)

case (PreTransBinaryOp(op @ (Int_| | Int_& | Int_^),
PreTransLit(IntLiteral(x)), y),
z @ PreTransLit(IntLiteral(zValue))) =>
foldBinaryOp(
op,
PreTransLit(IntLiteral(x >> zValue)),
foldBinaryOp(Int_>>, y, z))

case (_, PreTransLit(IntLiteral(y))) =>
val dist = y & 31
if (dist == 0)
Expand Down Expand Up @@ -6514,8 +6548,9 @@ private[optimizer] object OptimizerCore {
final val MathMinDouble = MathMinFloat + 1
final val MathMaxFloat = MathMinDouble + 1
final val MathMaxDouble = MathMaxFloat + 1
final val MathMultiplyFull = MathMaxDouble + 1

final val ArrayBuilderZeroOf = MathMaxDouble + 1
final val ArrayBuilderZeroOf = MathMultiplyFull + 1
final val GenericArrayBuilderResult = ArrayBuilderZeroOf + 1

final val ClassGetName = GenericArrayBuilderResult + 1
Expand Down Expand Up @@ -6609,6 +6644,9 @@ private[optimizer] object OptimizerCore {
m("compare", List(J, J), I) -> LongCompare,
m("divideUnsigned", List(J, J), J) -> LongDivideUnsigned,
m("remainderUnsigned", List(J, J), J) -> LongRemainderUnsigned
),
ClassName("java.lang.Math$") -> List(
m("multiplyFull", List(I, I), J) -> MathMultiplyFull
)
)

Expand Down
1 change: 0 additions & 1 deletion project/Build.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2126,7 +2126,6 @@ object Build {
List(sharedTestDir / "scala", sharedTestDir / "require-scala2") :::
collectionsEraDependentDirectory(scalaV, sharedTestDir) ::
includeIf(sharedTestDir / "require-jdk11", javaV >= 11) :::
includeIf(sharedTestDir / "require-jdk15", javaV >= 15) :::
includeIf(sharedTestDir / "require-jdk17", javaV >= 17) :::
includeIf(sharedTestDir / "require-jdk21", javaV >= 21) :::
includeIf(testDir / "require-scala2", isJSTest)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/*
* Scala.js (https://www.scala-js.org/)
*
* Copyright EPFL.
*
* Licensed under Apache License 2.0
* (https://www.apache.org/licenses/LICENSE-2.0).
*
* See the NOTICE file distributed with this work for
* additional information regarding copyright ownership.
*/

package org.scalajs.testsuite.javalib.lang

import java.math.BigInteger
import java.util.SplittableRandom

import org.junit.Test
import org.junit.Assert._

class MathTestOnJDK11 {

@noinline
private def hideFromOptimizer(x: Int): Int = x

@Test def testMultiplyFull(): Unit = {
@inline def test(expected: Long, x: Int, y: Int): Unit = {
assertEquals(expected, Math.multiplyFull(x, y))
assertEquals(expected, Math.multiplyFull(x, hideFromOptimizer(y)))
assertEquals(expected, Math.multiplyFull(hideFromOptimizer(x), y))
assertEquals(expected, Math.multiplyFull(hideFromOptimizer(x), hideFromOptimizer(y)))
}

test(2641928036408725662L, 1942041231, 1360387202)
test(54843908448922272L, 1565939409, 35023008)
test(510471553407128558L, 1283300489, 397780222)
test(-1211162085735907941L, -1990140693, 608581137)
test(-1197265696701533712L, -584098468, 2049766884)
test(203152587796496856L, -1809591416, -112264341)
test(-1869763755321108598L, 1235591906, -1513253483)
test(-737954189546644064L, 675415792, -1092592442)
test(-2570904460570261986L, 1639253754, -1568338309)
test(1106623967126000400L, 2088029790, 529984760)
test(1407516248272451352L, -869881054, -1618055988)
test(-2120367337662071940L, -1558894530, 1360173698)
test(-1464086284066637244L, -1417313902, 1033000722)
test(36729253163312334L, -1673852034, -21942951)
test(-3197007331876781046L, 1876799847, -1703435418)
test(461794994386945009L, -246001091, -1877207099)
test(-1206231192496917804L, 867896526, -1389832954)
test(-1739671893103255929L, -1083992841, 1604873969)
test(-409626127116780624L, 240101424, -1706054551)
test(-3083566560548370936L, -1568530113, 1965895672)
test(-1205028798380605000L, -1201743532, 1002733750)
test(-1328689065035027168L, 929349664, -1429697687)
test(-124212693522020684L, 80893862, -1535502082)
test(-82341860111074830L, -243230690, 338534007)
test(-846837059701860202L, 1959770926, -432110227)
test(335728245390354432L, 506816728, 662425344)
test(745294755971022170L, 1521993302, 489683335)
test(-2370525755201631608L, 2023520366, -1171485988)
test(-1039854583047715776L, 593162592, -1753068378)
test(-152985384388127808L, -635946432, 240563319)
test(-678107568956539050L, 649113254, -1044667575)
test(-3064094283703186444L, -1890896836, 1620444979)
test(1240687269228318870L, -1080325230, -1148438669)
test(-46551523496333580L, 27167878, -1713476610)
test(-2500430606368427103L, 2023288183, -1235825241)
test(92963399778762084L, 896198732, 103730787)
test(2469065794894324667L, 2105111101, 1172890967)
test(172558569988357136L, -142945148, -1207166332)
test(335684786634110970L, -1647598405, -203741874)
test(2406859843746696240L, 2049365815, 1174441296)
test(3100973294006114952L, 1991928152, 1556769651)
test(-335912134649077352L, 866240524, -387781598)
test(84303320581066207L, 75666091, 1114149277)
test(-2623126349572207976L, 1426933667, -1838295928)
test(59139945163750590L, 149344270, 395997417)
test(-105764175098643999L, 68726447, -1538915217)
test(8595303129864000L, 726092025, 11837760)
test(-2958527843471399088L, 1536412078, -1925608296)
test(1532625839159904477L, 867021537, 1767690621)
test(384402376484481316L, 1207235521, 318415396)
test(-219376614576542698L, 1816299166, -120782203)
test(-672138807810988440L, 531516745, -1264567512)
test(-193351903065245331L, 170858169, -1131651499)
test(71263251057597648L, 51058196, 1395725988)
test(-774312974742971385L, 1958551603, -395349795)
test(-1846593638370672048L, 1190143097, -1551572784)
test(240083094242536384L, 1404614968, 170924488)
test(-130950827889833280L, -115480554, 1133964320)
test(128954457719585228L, 735993884, 175211317)
test(364779990580792000L, -668489125, -545678272)
test(107252402494512045L, 759517757, 141211185)
test(3038084150893069044L, -1924640913, -1578519988)
test(760804294233336624L, -728394552, -1044494762)
test(1171051779605774913L, 848233701, 1380576813)
test(-1805862307837393080L, -1385644986, 1303264780)
test(172227703288618734L, -104999826, -1640266559)
test(150448013961014407L, 163398103, 920745169)
test(-671469201380991232L, 650262784, -1032612073)
test(-1325861126942924945L, -1773644581, 747534845)
test(987406376890116568L, -1626507773, -607071416)
test(2918138947401192144L, 1695881208, 1720721318)
test(-2590993826910153940L, -1397240042, 1854365570)
test(954644624447419276L, -1516139806, -629654746)
test(407510452326678620L, -384747652, -1059162935)
test(149866317537821404L, 1530355444, 97929091)
test(922044716091910632L, 968149268, 952378674)
test(-3508732521573808284L, 1825364562, -1922209182)
test(1701723136959404304L, 894776752, 1901841027)
test(-2435876799625512705L, -1276062909, 1908900245)
test(-516933170985379201L, 657063047, -786732983)
test(123334479976750576L, 313765817, 393078128)
test(-1072624004420456775L, -894199299, 1199535725)
test(301682711612188737L, 330918981, 911651277)
test(1790992996470651507L, -1115945231, -1604911197)
test(-2750453268538140155L, 1878389719, -1464261245)
test(758285757353272504L, 1259684942, 601964612)
test(-218581674312137400L, -161533394, 1353167100)
test(-1824007072461951836L, -1244277844, 1465916219)
test(-92753167730460334L, -65368843, 1418920138)
test(-2326636630979491248L, 1124395877, -2069232624)
test(-7380586257943446L, 29715454, -248375349)
test(31319707234597638L, 491995506, 63658523)
test(-1196559502630778250L, -1752963990, 682592175)
test(166065559841839548L, -911521074, -182185102)
test(-1222260378510810100L, 1071539812, -1140657925)
test(57800571165871464L, -257569032, -224408077)
test(332444627169725608L, 1247224172, 266547614)
test(217903869180130650L, 1069161915, 203808110)
test(920425054266935850L, -901689546, -1020778225)
test(-507632632656614388L, 864632142, -587108214)
}

@Test def testMultiplyHigh(): Unit = {
/* We fuzz-test by comparing to the "obvious" implementations based on
* BigIntegers. We use a SplittableRandom generator, because Random cannot
* generate all Long values.
*/

val Seed = 909209754851418882L
val Rounds = 1024

val gen = new SplittableRandom(Seed)

for (round <- 1 to Rounds) {
val x = gen.nextLong()
val y = gen.nextLong()

val expected = {
BigInteger.valueOf(x)
.multiply(BigInteger.valueOf(y))
.shiftRight(64)
.longValue()
}

assertEquals(s"round $round, x = $x, y = $y", expected, Math.multiplyHigh(x, y))
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.junit.Assume._
import org.scalajs.testsuite.utils.AssertThrows.assertThrows
import org.scalajs.testsuite.utils.Platform

class InputStreamTestOnJDK15 {
class InputStreamTestOnJDK17 {
/** InputStream that only ever skips max bytes at once */
def lowSkipStream(max: Int, seq: Seq[Int]): InputStream = new SeqInputStreamForTest(seq) {
require(max > 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import org.junit.Assert._

import org.scalajs.testsuite.utils.AssertThrows.assertThrows

class StringTestOnJDK15 {
class StringTestOnJDK17 {

// indent and transform are available since JDK 12 but we're not testing them separately

Expand Down
Loading
0