From a2c04415079d7342de3f642c6f0ceae4aa25ed35 Mon Sep 17 00:00:00 2001 From: Rikito Taniguchi Date: Wed, 26 Mar 2025 15:24:36 +0900 Subject: [PATCH] Experiment: Avoiding `JSArrayConstr` for Varargs to Optimize the Wasm Backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently in Scala.js, varargs call like `List(1, 2, 3)`, it is translated into the IR form `js.WrappedArray(JSArrayConstr(...))`. That requires JS interop for constructing the array and accessing its elements. Since Wasm-to-JS calls are expensive, this is undesirable for performance. This commit experiments avoiding `JSArrayConstr` for varargs. Instead, varargs are transformed into something like `new WrappedArray$ofInt(ArrayValue(1, 2, 3))` (or `new ArraySeq$ofInt(...)` on 2.13) to explore potential Wasm-specific optimizations. Note1: While reducing JS interop can improve performance on the Wasm backend, the same does not apply the JS backend. We'd need to re-optimize back to `JSArrayConstr`-based code during the Optimizer for the JS backend. Note 2: How about `WrappedArray.make` instead of directly instantiating specialized `WrappedArray`? I found that runtime type checks in `make` appear to be very slow, and in some micro-benchmarks, using `make` performed slightly worse than the original `JSArrayConstr`-based code. --- Unfortunately, the performance improvements were negligible. For example, in the following code: ```scala def main(args: Array[String]): Unit = { val startTime = System.nanoTime() val xs = Seq(1, 2, ..., 20) xs.foreach(x => assert(x > 0)) val endTime = System.nanoTime() println(s"elapsed: ${endTime - startTime} ns") } ``` Both versions (with and without `JSArrayConstr`) reported similar timings of ~540000–580000 ns. Benchmarks run using [`sjrd/scalajs-benchmarks/wasm`](https://github.com/sjrd/scalajs-benchmarks/tree/wasm) also did not show any significant performance differences. | Benchmark | before | after | Ratio (after / before) | |-------------|--------------|--------------|----------------------| | sha512 | 12403.95816 | 12737.42497 | 1.0269 | | sha512int | 12259.02363 | 13313.69655 | 1.0860 | | queens | 2.954778067 | 2.920396237 | 0.9873 | | list | 60.56316878 | 60.52163829 | 0.9993 | | richards | 87.49714448 | 87.77807725 | 1.0032 | | cd | 32866.35461 | 31672.2486 | 0.9637 | | gcbench | 104672.8678 | 121351.2553 | 1.1588 | | tracerFloat | 870.5680015 | 876.4962162 | 1.0068 | | tracer | 784.3968099 | 783.2297365 | 0.9985 | | sudoku | 3634.165046 | 3609.813857 | 0.9933 | | nbody | 23722.03084 | 24192.39211 | 1.0198 | | permute | 262.9023071 | 265.949228 | 1.0116 | | deltaBlue | 525.4683864 | 511.8722724 | 0.9742 | | kmeans | 206339.5187 | 202615.0022 | 0.9820 | | brainfuck | 2352.518883 | 2357.064959 | 1.0019 | | json | 288.1723513 | 280.1001847 | 0.9720 | | bounce | 33.12443674 | 33.01821262 | 0.9978 | There may still be room for further optimization in the non-`JSArrayConstr` implementation. --- .../org/scalajs/nscplugin/GenJSCode.scala | 30 +++++++++++++---- .../org/scalajs/nscplugin/JSDefinitions.scala | 14 ++++++++ .../scala/scalajs/runtime/Compat.scala | 31 ++++++++++++++++++ .../scala/scalajs/runtime/Compat.scala | 32 +++++++++++++++++++ .../scala/scala/scalajs/runtime/package.scala | 30 +++++++++++++++++ 5 files changed, 130 insertions(+), 7 deletions(-) diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala b/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala index 4f8387fe57..ed5cdd907b 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala @@ -5775,8 +5775,22 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) if (wasRepeated) { tryGenRepeatedParamAsJSArray(arg, handleNil = false).fold { genExpr(arg) - } { genArgs => - genJSArrayToVarArgs(js.JSArrayConstr(genArgs)) + } { case (elemType, genArgs) => + val arrayTypeRef = jstpe.ArrayTypeRef.of(toTypeRef(elemType)) + val arrayRef = js.ArrayValue(arrayTypeRef, genArgs.map(genExpr(_))) + val methodSym = toIRType(elemType) match { + case jstpe.IntType => Runtime_toScalaVarArgsFromScalaArrayInt + case jstpe.DoubleType => Runtime_toScalaVarArgsFromScalaArrayDouble + case jstpe.LongType => Runtime_toScalaVarArgsFromScalaArrayLong + case jstpe.FloatType => Runtime_toScalaVarArgsFromScalaArrayFloat + case jstpe.CharType => Runtime_toScalaVarArgsFromScalaArrayChar + case jstpe.ByteType => Runtime_toScalaVarArgsFromScalaArrayByte + case jstpe.ShortType => Runtime_toScalaVarArgsFromScalaArrayShort + case jstpe.BooleanType => Runtime_toScalaVarArgsFromScalaArrayBoolean + case jstpe.VoidType => Runtime_toScalaVarArgsFromScalaArrayUnit + case _ => Runtime_toScalaVarArgsFromScalaArrayAnyRef + } + genApplyMethod(genLoadModule(RuntimePackageModule), methodSym, List(arrayRef)) } } else { genExpr(arg) @@ -5928,7 +5942,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) * Otherwise, it returns a JSSpread with the Seq converted to a js.Array. */ private def genPrimitiveJSRepeatedParam(arg: Tree): List[js.TreeOrJSSpread] = { - tryGenRepeatedParamAsJSArray(arg, handleNil = true) getOrElse { + tryGenRepeatedParamAsJSArray(arg, handleNil = true).fold { /* Fall back to calling runtime.toJSVarArgs to perform the conversion * to js.Array, then wrap in a Spread operator. */ @@ -5937,7 +5951,9 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) genLoadModule(RuntimePackageModule), Runtime_toJSVarArgs, List(genExpr(arg))) - List(js.JSSpread(jsArrayArg)) + List(js.JSSpread(jsArrayArg).asInstanceOf[js.TreeOrJSSpread]) + } { case (elemTpe, genArgs) => + genArgs.map(e => ensureBoxed(genExpr(e), elemTpe)(arg.pos).asInstanceOf[js.TreeOrJSSpread]) } } @@ -5948,7 +5964,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) * method returns `None`. */ private def tryGenRepeatedParamAsJSArray(arg: Tree, - handleNil: Boolean): Option[List[js.Tree]] = { + handleNil: Boolean): Option[(Type, List[Tree])] = { implicit val pos = arg.pos // Given a method `def foo(args: T*)` @@ -5960,11 +5976,11 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) * the type before erasure. */ val elemTpe = tpt.tpe - Some(elems.map(e => ensureBoxed(genExpr(e), elemTpe))) + Some((tpt.tpe, elems)) // foo() case Select(_, _) if handleNil && arg.symbol == NilModule => - Some(Nil) + Some((NoType, Nil)) // foo(argSeq:_*) - cannot be optimized case _ => diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala b/compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala index 5a46388543..c33e623950 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala @@ -127,6 +127,17 @@ trait JSDefinitions { lazy val Runtime_identityHashCode = getMemberMethod(RuntimePackageModule, newTermName("identityHashCode")) lazy val Runtime_dynamicImport = getMemberMethod(RuntimePackageModule, newTermName("dynamicImport")) + lazy val Runtime_toScalaVarArgsFromScalaArrayAnyRef = getMemberMethod(RuntimePackageModule, newTermName("toScalaVarArgsFromScalaArrayAnyRef")) + lazy val Runtime_toScalaVarArgsFromScalaArrayInt = getMemberMethod(RuntimePackageModule, newTermName("toScalaVarArgsFromScalaArrayInt")) + lazy val Runtime_toScalaVarArgsFromScalaArrayDouble = getMemberMethod(RuntimePackageModule, newTermName("toScalaVarArgsFromScalaArrayDouble")) + lazy val Runtime_toScalaVarArgsFromScalaArrayLong = getMemberMethod(RuntimePackageModule, newTermName("toScalaVarArgsFromScalaArrayLong")) + lazy val Runtime_toScalaVarArgsFromScalaArrayFloat = getMemberMethod(RuntimePackageModule, newTermName("toScalaVarArgsFromScalaArrayFloat")) + lazy val Runtime_toScalaVarArgsFromScalaArrayChar = getMemberMethod(RuntimePackageModule, newTermName("toScalaVarArgsFromScalaArrayChar")) + lazy val Runtime_toScalaVarArgsFromScalaArrayByte = getMemberMethod(RuntimePackageModule, newTermName("toScalaVarArgsFromScalaArrayByte")) + lazy val Runtime_toScalaVarArgsFromScalaArrayShort = getMemberMethod(RuntimePackageModule, newTermName("toScalaVarArgsFromScalaArrayShort")) + lazy val Runtime_toScalaVarArgsFromScalaArrayBoolean = getMemberMethod(RuntimePackageModule, newTermName("toScalaVarArgsFromScalaArrayBoolean")) + lazy val Runtime_toScalaVarArgsFromScalaArrayUnit = getMemberMethod(RuntimePackageModule, newTermName("toScalaVarArgsFromScalaArrayUnit")) + lazy val LinkingInfoModule = getRequiredModule("scala.scalajs.LinkingInfo") lazy val LinkingInfo_linkTimePropertyBoolean = getMemberMethod(LinkingInfoModule, newTermName("linkTimePropertyBoolean")) lazy val LinkingInfo_linkTimePropertyInt = getMemberMethod(LinkingInfoModule, newTermName("linkTimePropertyInt")) @@ -152,6 +163,9 @@ trait JSDefinitions { lazy val ExecutionContextImplicitsModule = getRequiredModule("scala.concurrent.ExecutionContext.Implicits") lazy val ExecutionContextImplicits_global = getMemberMethod(ExecutionContextImplicitsModule, newTermName("global")) + + lazy val WrappedArrayOfIntClass = getRequiredClass("scala.collection.mutable.WrappedArray$ofInt") + lazy val WrappedArrayOfDoubleClass = getRequiredClass("scala.collection.mutable.WrappedArray$ofDouble") } // scalastyle:on line.size.limit diff --git a/library/src/main/scala-new-collections/scala/scalajs/runtime/Compat.scala b/library/src/main/scala-new-collections/scala/scalajs/runtime/Compat.scala index c95be9a685..ae655273ed 100644 --- a/library/src/main/scala-new-collections/scala/scalajs/runtime/Compat.scala +++ b/library/src/main/scala-new-collections/scala/scalajs/runtime/Compat.scala @@ -12,6 +12,7 @@ package scala.scalajs.runtime +import scala.collection.immutable.ArraySeq import scala.collection.IterableOnce import scala.scalajs.js @@ -21,6 +22,36 @@ private[runtime] object Compat { @inline def toScalaVarArgsImpl[A](array: js.Array[A]): Seq[A] = WrappedVarArgs.wrap(array) + @inline def toScalaVarArgsFromScalaArrayAnyRefImpl(array: Array[AnyRef]): Seq[AnyRef] = + new ArraySeq.ofRef(array) + + @inline def toScalaVarArgsFromScalaArrayIntImpl(array: Array[Int]): Seq[Int] = + new ArraySeq.ofInt(array) + + @inline def toScalaVarArgsFromScalaArrayDoubleImpl(array: Array[Double]): Seq[Double] = + new ArraySeq.ofDouble(array) + + @inline def toScalaVarArgsFromScalaArrayLongImpl(array: Array[Long]): Seq[Long] = + new ArraySeq.ofLong(array) + + @inline def toScalaVarArgsFromScalaArrayFloatImpl(array: Array[Float]): Seq[Float] = + new ArraySeq.ofFloat(array) + + @inline def toScalaVarArgsFromScalaArrayCharImpl(array: Array[Char]): Seq[Char] = + new ArraySeq.ofChar(array) + + @inline def toScalaVarArgsFromScalaArrayByteImpl(array: Array[Byte]): Seq[Byte] = + new ArraySeq.ofByte(array) + + @inline def toScalaVarArgsFromScalaArrayShortImpl(array: Array[Short]): Seq[Short] = + new ArraySeq.ofShort(array) + + @inline def toScalaVarArgsFromScalaArrayBooleanImpl(array: Array[Boolean]): Seq[Boolean] = + new ArraySeq.ofBoolean(array) + + @inline def toScalaVarArgsFromScalaArrayUnitImpl(array: Array[Unit]): Seq[Unit] = + new ArraySeq.ofUnit(array) + def toJSVarArgsImpl[A](seq: Seq[A]): js.Array[A] = { seq match { case seq: WrappedVarArgs[A] => diff --git a/library/src/main/scala-old-collections/scala/scalajs/runtime/Compat.scala b/library/src/main/scala-old-collections/scala/scalajs/runtime/Compat.scala index c58ec71b7e..4ea04b77b5 100644 --- a/library/src/main/scala-old-collections/scala/scalajs/runtime/Compat.scala +++ b/library/src/main/scala-old-collections/scala/scalajs/runtime/Compat.scala @@ -13,6 +13,7 @@ package scala.scalajs.runtime import scala.collection.GenTraversableOnce +import scala.collection.mutable.WrappedArray import scala.scalajs.js private[runtime] object Compat { @@ -20,6 +21,37 @@ private[runtime] object Compat { @inline def toScalaVarArgsImpl[A](array: js.Array[A]): Seq[A] = new js.WrappedArray(array) + @inline def toScalaVarArgsFromScalaArrayAnyRefImpl(array: Array[AnyRef]): Seq[AnyRef] = + new WrappedArray.ofRef(array) + + @inline def toScalaVarArgsFromScalaArrayIntImpl(array: Array[Int]): Seq[Int] = + new WrappedArray.ofInt(array) + + @inline def toScalaVarArgsFromScalaArrayDoubleImpl(array: Array[Double]): Seq[Double] = + new WrappedArray.ofDouble(array) + + @inline def toScalaVarArgsFromScalaArrayLongImpl(array: Array[Long]): Seq[Long] = + new WrappedArray.ofLong(array) + + @inline def toScalaVarArgsFromScalaArrayFloatImpl(array: Array[Float]): Seq[Float] = + new WrappedArray.ofFloat(array) + + @inline def toScalaVarArgsFromScalaArrayCharImpl(array: Array[Char]): Seq[Char] = + new WrappedArray.ofChar(array) + + @inline def toScalaVarArgsFromScalaArrayByteImpl(array: Array[Byte]): Seq[Byte] = + new WrappedArray.ofByte(array) + + @inline def toScalaVarArgsFromScalaArrayShortImpl(array: Array[Short]): Seq[Short] = + new WrappedArray.ofShort(array) + + @inline def toScalaVarArgsFromScalaArrayBooleanImpl(array: Array[Boolean]): Seq[Boolean] = + new WrappedArray.ofBoolean(array) + + @inline def toScalaVarArgsFromScalaArrayUnitImpl(array: Array[Unit]): Seq[Unit] = + new WrappedArray.ofUnit(array) + + def toJSVarArgsImpl[A](seq: Seq[A]): js.Array[A] = { seq match { case seq: js.WrappedArray[A] => diff --git a/library/src/main/scala/scala/scalajs/runtime/package.scala b/library/src/main/scala/scala/scalajs/runtime/package.scala index d3ba4f766f..0c832f3be7 100644 --- a/library/src/main/scala/scala/scalajs/runtime/package.scala +++ b/library/src/main/scala/scala/scalajs/runtime/package.scala @@ -29,6 +29,36 @@ package object runtime { @inline def toScalaVarArgs[A](array: js.Array[A]): Seq[A] = toScalaVarArgsImpl(array) + @inline def toScalaVarArgsFromScalaArrayAnyRef(array: Array[AnyRef]): Seq[AnyRef] = + toScalaVarArgsFromScalaArrayAnyRefImpl(array) + + @inline def toScalaVarArgsFromScalaArrayInt(array: Array[Int]): Seq[Int] = + toScalaVarArgsFromScalaArrayIntImpl(array) + + @inline def toScalaVarArgsFromScalaArrayDouble(array: Array[Double]): Seq[Double] = + toScalaVarArgsFromScalaArrayDoubleImpl(array) + + @inline def toScalaVarArgsFromScalaArrayLong(array: Array[Long]): Seq[Long] = + toScalaVarArgsFromScalaArrayLongImpl(array) + + @inline def toScalaVarArgsFromScalaArrayFloat(array: Array[Float]): Seq[Float] = + toScalaVarArgsFromScalaArrayFloatImpl(array) + + @inline def toScalaVarArgsFromScalaArrayChar(array: Array[Char]): Seq[Char] = + toScalaVarArgsFromScalaArrayCharImpl(array) + + @inline def toScalaVarArgsFromScalaArrayByte(array: Array[Byte]): Seq[Byte] = + toScalaVarArgsFromScalaArrayByteImpl(array) + + @inline def toScalaVarArgsFromScalaArrayShort(array: Array[Short]): Seq[Short] = + toScalaVarArgsFromScalaArrayShortImpl(array) + + @inline def toScalaVarArgsFromScalaArrayBoolean(array: Array[Boolean]): Seq[Boolean] = + toScalaVarArgsFromScalaArrayBooleanImpl(array) + + @inline def toScalaVarArgsFromScalaArrayUnit(array: Array[Unit]): Seq[Unit] = + toScalaVarArgsFromScalaArrayUnitImpl(array) + @inline def toJSVarArgs[A](seq: Seq[A]): js.Array[A] = toJSVarArgsImpl(seq)