8000 Introduce TypedClosures in the IR, which are closures without JS inte… · scala-js/scala-js@f3e5d03 · GitHub
[go: up one dir, main page]

Skip to content

Commit f3e5d03

Browse files
committed
Introduce TypedClosures in the IR, which are closures without JS interop.
A `TypedClosure` is like a `Closure` but without any semantics for JS interop. This is stronger than `Char`, which is "merely" opaque to JS. A `Char` can still be passed to JS and has a meaningful `toString()`. A `TypedClosure` *cannot* be passed to JS in any way. That is enforced by making their type *not* a subtype of `any`. Since a `TypedClosure` has no JS interop semantics, it is free to strongly, statically type its parameters and result type. More importantly, we can freely choose its representation in the best possible way for the given target. On JS, that remains an arrow function. On Wasm, however, we represent is as a pair of `(capture data pointer, function pointer)`. This allows to compile them in an efficient way that does not require going through a JS bridge closure. The latter has been shown to have a devastating impact on performance when a Scala function is used in a tight loop. The type of a `TypedClosure` is a `ClosureType`. It records its parameter types and its result type. Closure types are non-variant: they are only subtypes of themselves. As mentioned, they are not subtypes of `any`. They are however subtypes of `void` and supertypes of `nothing`. Unfortunately, they must also be nullable to have a default value, so they are also supertypes of `null`. To call a typed closure, we introduce a dedicated application node `ApplyTypedClosure`. IR checking ensures that actual arguments match the expected parameter types. The result type is directly used as the type of the application. There are no changes to the source language. In particular, there is no way to express typed closures or their types at the user level. However, we change the compilation of SAM anonymous functions, both for Scala functions and JVM-like SAMs, to use a `TypedClosure` instead of a `Closure`. We wrap the `TypedClosure` inside an instance of `scala.scalajs.runtime.TypedFunctionN`. These classes are generated "by hand" in the build, since there is no user-level way to define them. They have valid IR, though. These changes have no real impact on the JS output (only marginal naming differences). On Wasm, however, they make Scala functions much, much faster. Before, a Scala function in a typed loop would cause a Wasm implementation to be, in the worst measured case, 20x slower than on JS. After these changes, similar benchmarks become significantly faster on Wasm than on JS.
1 parent 48b179d commit f3e5d03

File tree

28 files changed

+1164
-191
lines changed

28 files changed

+1164
-191
lines changed

compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -960,16 +960,10 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
960960
}
961961

962962
// After the super call, substitute `selfRef` for `This()`
963-
val afterSuper = new ir.Transformers.Transformer {
963+
val afterSuper = new ir.Transformers.LocalScopeTransformer {
964964
override def transform(tree: js.Tree, isStat: Boolean): js.Tree = tree match {
965965
case js.This() =>
966966
selfRef(tree.pos)
967-
968-
// Don't traverse closure boundaries
969-
case closure: js.Closure =>
970-
val newCaptureValues = closure.captureValues.map(transformExpr)
971-
closure.copy(captureValues = newCaptureValues)(closure.pos)
972-
973967
case tree =>
974968
super.transform(tree, isStat)
975969
}
@@ -2037,15 +2031,12 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
20372031
} yield {
20382032
js.ParamDef(name, originalName, ptpe, newMutable(name.name, mutable))(p.pos)
20392033
}
2040-
val transformer = new ir.Transformers.Transformer {
2034+
val transformer = new ir.Transformers.LocalScopeTransformer {
20412035
override def transform(tree: js.Tree, isStat: Boolean): js.Tree = tree match {
20422036
case js.VarDef(name, originalName, vtpe, mutable, rhs) =>
20432037
assert(isStat, s"found a VarDef in expression position at ${tree.pos}")
20442038
super.transform(js.VarDef(name, originalName, vtpe,
20452039
newMutable(name.name, mutable), rhs)(tree.pos), isStat)
2046-
case js.Closure(arrow, captureParams, params, restParam, body, captureValues) =>
2047-
js.Closure(arrow, captureParams, params, restParam, body,
2048-
captureValues.map(transformExpr))(tree.pos)
20492040
case _ =>
20502041
super.transform(tree, isStat)
20512042
}
@@ -2075,13 +2066,10 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
20752066
} yield {
20762067
js.ParamDef(name, originalName, newType(name, ptpe), mutable)(p.pos)
20772068
}
2078-
val transformer = new ir.Transformers.Transformer {
2069+
val transformer = new ir.Transformers.LocalScopeTransformer {
20792070
override def transform(tree: js.Tree, isStat: Boolean): js.Tree = tree match {
20802071
case tree @ js.VarRef(name) =>
20812072
js.VarRef(name)(newType(name, tree.tpe))(tree.pos)
2082-
case js.Closure(arrow, captureParams, params, restParam, body, captureValues) =>
2083-
js.Closure(arrow, captureParams, params, restParam, body,
2084-
captureValues.map(transformExpr))(tree.pos)
20852073
case _ =>
20862074
super.transform(tree, isStat)
20872075
}
@@ -3281,6 +3269,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
32813269
genNewArray(arr, args.map(genExpr))
32823270
case prim: jstpe.PrimRef =>
32833271
abort(s"unexpected primitive type $prim in New at $pos")
3272+
case typeRef: jstpe.ClosureTypeRef =>
3273+
abort(s"unexpected closure type $typeRef in New at $pos")
32843274
}
32853275
}
32863276
}
@@ -6261,10 +6251,10 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
62616251
* We identify the captures using the same method as the `delambdafy`
62626252
* phase. We have an additional hack for `this`.
62636253
*
6264-
* To translate them, we first construct a JS closure for the body:
6254+
* To translate them, we first construct a typed closure for the body:
62656255
* {{{
6266-
* lambda<this, capture1, ..., captureM>(
6267-
* _this, capture1, ..., captureM, arg1, ..., argN) {
6256+
* typed-lambda<_this = this, capture1: U1 = capture1, ..., captureM: UM = captureM>(
6257+
* arg1: T1, ..., argN: TN): TR = {
62686258
* _this.someMethod(arg1, ..., argN, capture1, ..., captureM)
62696259
* }
62706260
* }}}
@@ -6278,13 +6268,13 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
62786268
* this:
62796269
* {{{
62806270
* class AnonFun extends Object with FunctionalInterface {
6281-
* val f: any
6282-
* def <init>(f: any) {
6271+
* val f: (Ti...) => TR
6272+
* def <init>(f: (Ti...) => TR) {
62836273
* super();
62846274
* this.f = f
62856275
* }
6286-
* def theSAMMethod(params: Types...): Type =
6287-
* unbox((this.f)(boxParams...))
6276+
* def theSAMMethod(params: Ti...): TR =
6277+
* (this.f)(params...)
62886278
* }
62896279
* }}}
62906280
*/
@@ -6334,21 +6324,29 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
63346324
val patchedBody =
63356325
js.Block(paramsLocals :+ ensureResultBoxed(body, target))
63366326

6337-
val closure = js.Closure(
6338-
arrow = true,
6327+
val closure = js.TypedClosure(
63396328
allFormalCaptures,
63406329
patchedFormalArgs,
6341-
restParam = None,
6330+
resultType = jstpe.AnyType,
63426331
patchedBody,
63436332
allActualCaptures)
63446333

6334+
val arity = params.size
6335+
val ctorName = {
6336+
val objectClassRef = jstpe.ClassRef(ir.Names.ObjectClass)
6337+
val closureTypeRef =
6338+
jstpe.ClosureTypeRef(List.fill(arity)(objectClassRef), objectClassRef)
6339+
ir.Names.MethodName.constructor(closureTypeRef :: Nil)
6340+
}
6341+
63456342
// Wrap the closure in the appropriate box for the SAM type
63466343
val funSym = originalFunction.tpe.typeSymbolDirect
63476344
if (isFunctionSymbol(funSym)) {
63486345
/* This is a scala.FunctionN. We use the existing AnonFunctionN
63496346
* wrapper.
63506347
*/
6351-
genJSFunctionToScala(closure, params.size)
6348+
js.New(ir.Names.ClassName("scala.scalajs.runtime.TypedFunction" + arity),
6349+
js.MethodIdent(ctorName), List(closure))
63526350
} else {
63536351
/* This is an arbitrary SAM type (can only happen in 2.12).
63546352
* We have to synthesize a class like LambdaMetaFactory would do on
@@ -6358,13 +6356,13 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
63586356
abort(s"Cannot find the SAMFunction attachment on $originalFunction at $pos")
63596357
}
63606358

6361-
val samWrapperClassName = synthesizeSAMWrapper(funSym, sam)
6362-
js.New(samWrapperClassName, js.MethodIdent(ObjectArgConstructorName),
6363-
List(closure))
6359+
val samWrapperClassName = synthesizeSAMWrapper(funSym, sam, ctorName)
6360+
js.New(samWrapperCl F42D assName, js.MethodIdent(ctorName), List(closure))
63646361
}
63656362
}
63666363

6367-
private def synthesizeSAMWrapper(funSym: Symbol, samInfo: SAMFunction)(
6364+
private def synthesizeSAMWrapper(funSym: Symbol, samInfo: SAMFunction,
6365+
ctorName: ir.Names.MethodName)(
63686366
implicit pos: Position): ClassName = {
63696367
val intfName = encodeClassName(funSym)
63706368

@@ -6377,24 +6375,28 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
63776375

63786376
val thisType = jstpe.ClassType(className, nullable = false)
63796377

6380-
// val f: Any
6378+
val arity = samInfo.sam.tpe.params.size
6379+
val closureType = jstpe.ClosureType(List.fill(arity)(jstpe.AnyType),
6380+
jstpe.AnyType, nullable = true)
6381+
6382+
// val f: ((any, ..., any) => any)
63816383
val fFieldIdent = js.FieldIdent(FieldName(className, SimpleFieldName("f")))
63826384
val fFieldDef = js.FieldDef(js.MemberFlags.empty, fFieldIdent,
6383-
NoOriginalName, jstpe.AnyType)
6385+
NoOriginalName, closureType)
63846386

63856387
// def this(f: Any) = { this.f = f; super() }
63866388
val ctorDef = {
63876389
val fParamDef = js.ParamDef(js.LocalIdent(LocalName("f")),
6388-
NoOriginalName, jstpe.AnyType, mutable = false)
6390+
NoOriginalName, closureType, mutable = false)
63896391
js.MethodDef(
63906392
js.MemberFlags.empty.withNamespace(js.MemberNamespace.Constructor),
6391-
js.MethodIdent(ObjectArgConstructorName),
6393+
js.MethodIdent(ctorName),
63926394
NoOriginalName,
63936395
List(fParamDef),
63946396
jstpe.NoType,
63956397
Some(js.Block(List(
63966398
js.Assign(
6397-
js.Select(js.This()(thisType), fFieldIdent)(jstpe.AnyType),
6399+
js.Select(js.This()(thisType), fFieldIdent)(closureType),
63986400
fParamDef.ref),
63996401
js.ApplyStatically(js.ApplyFlags.empty.withConstructor(true),
64006402
js.This()(thisType),
@@ -6444,8 +6446,9 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
64446446
yield (formal.ref, param.tpe)
64456447
}.map((ensureBoxed _).tupled)
64466448

6447-
val call = js.JSFunctionApply(
6448-
js.Select(js.This()(thisType), fFieldIdent)(jstpe.AnyType),
6449+
val call = js.ApplyTypedClosure(
6450+
js.ApplyFlags.empty,
6451+
js.Select(js.This()(thisType), fFieldIdent)(closureType),
64496452
actualParams)
64506453

64516454
val body = fromAny(call, enteringPhase(currentRun.posterasurePhase) {

ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,12 @@ object Hashers {
310310
mixMethodIdent(method)
311311
mixTrees(args)
312312

313+
case ApplyTypedClosure(flags, fun, args) =>
314+
mixTag(TagApplyTypedClosure)
315+
mixInt(ApplyFlags.toBits(flags))
316+
mixTree(fun)
317+
mixTrees(args)
318+
313319
case UnaryOp(op, lhs) =>
314320
mixTag(TagUnaryOp)
315321
mixInt(op)
@@ -545,6 +551,14 @@ object Hashers {
545551
mixTree(body)
546552
mixTrees(captureValues)
547553

554+
case TypedClosure(captureParams, params, resultType, body, captureValues) =>
555+
mixTag(TagTypedClosure)
556+
mixParamDefs(captureParams)
557+
mixParamDefs(params)
558+
mixType(resultType)
559+
mixTree(body)
560+
mixTrees(captureValues)
561+
548562
case CreateJSClass(className, captureValues) =>
549563
mixTag(TagCreateJSClass)
550564
mixName(className)
@@ -597,13 +611,21 @@ object Hashers {
597611
case typeRef: ArrayTypeRef =>
598612
mixTag(TagArrayTypeRef)
599613
mixArrayTypeRef(typeRef)
614+
case typeRef: ClosureTypeRef =>
615+
mixTag(TagClosureTypeRef)
616+
mixClosureTypeRef(typeRef)
600617
}
601618

602619
def mixArrayTypeRef(arrayTypeRef: ArrayTypeRef): Unit = {
603620
mixTypeRef(arrayTypeRef.base)
604621
mixInt(arrayTypeRef.dimensions)
605622
}
606623

624+
def mixClosureTypeRef(closureTypeRef: ClosureTypeRef): Unit = {
625+
closureTypeRef.paramTypeRefs.foreach(mixTypeRef(_))
626+
mixTypeRef(closureTypeRef.resultTypeRef)
627+
}
628+
607629
def mixType(tpe: Type): Unit = tpe match {
608630
case AnyType => mixTag(TagAnyType)
609631
case AnyNotNullType => mixTag(TagAnyNotNullType)
@@ -629,6 +651,11 @@ object Hashers {
629651
mixTag(if (nullable) TagArrayType else TagNonNullArrayType)
630652
mixArrayTypeRef(arrayTypeRef)
631653

654+
case ClosureType(paramTypes, resultType, nullable) =>
655+
mixTag(if (nullable) TagClosureType else TagNonNullClosureType)
656+
mixTypes(paramTypes)
657+
mixType(resultType)
658+
632659
case RecordType(fields) =>
633660
mixTag(TagRecordType)
634661
for (RecordType.Field(name, originalName, tpe, mutable) <- fields) {
@@ -639,6 +666,9 @@ object Hashers {
639666
}
640667
}
641668

669+
def mixTypes(tpes: List[Type]): Unit =
670+
tpes.foreach(mixType)
671+
642672
def mixLocalIdent(ident: LocalIdent): Unit = {
643673
mixPos(ident.pos)
644674
mixName(ident.name)

ir/shared/src/main/scala/org/scalajs/ir/Names.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,18 @@ object Names {
414414
i += 1
415415
}
416416
appendTypeRef(base)
417+
case ClosureTypeRef(paramTypeRefs, resultTypeRef) =>
418+
builder.append('(')
419+
var first = true
420+
for (paramTypeRef <- paramTypeRefs) {
421+
if (first)
422+
first = false
423+
else
424+
builder.append(',')
425+
appendTypeRef(paramTypeRef)
426+
}
427+
builder.append(')')
428+
appendTypeRef(resultTypeRef)
417429
}
418430

419431
builder.append(simpleName.nameString)

ir/shared/src/main/scala/org/scalajs/ir/Printers.scala

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,10 @@ object Printers {
343343
print(method)
344344
printArgs(args)
345345

346+
case ApplyTypedClosure(flags, fun, args) =>
347+
print(fun)
348+
printArgs(args)
349+
346350
case UnaryOp(op, lhs) =>
347351
import UnaryOp._
348352

@@ -896,6 +900,23 @@ object Printers {
896900
printBlock(body)
897901
print(')')
898902

903+
case TypedClosure(captureParams, params, resultType, body, captureValues) =>
904+
print("(typed-lambda<")
905+
var first = true
906+
for ((param, value) <- captureParams.zip(captureValues)) {
907+
if (first)
908+
first = false
909+
else
910+
print(", ")
911+
print(param)
912+
print(" = ")
913+
print(value)
914+
}
915+
print(">")
916+
printSig(params, restParam = None, resultType)
917+
printBlock(body)
918+
print(')')
919+
899920
case CreateJSClass(className, captureValues) =>
900921
print("createjsclass[")
901922
print(className)
@@ -1085,6 +1106,18 @@ object Printers {
10851106
print(base)
10861107
for (i <- 1 to dims)
10871108
print("[]")
1109+
case ClosureTypeRef(paramTypeRefs, resultTypeRef) =>
1110+
print('(')
1111+
var first = true
1112+
for (paramTypeRef <- paramTypeRefs) {
1113+
if (first)
1114+
first = false
1115+
else
1116+
print(", ")
1117+
print(paramTypeRef)
1118+
}
1119+
print(") => ")
1120+
print(resultTypeRef)
10881121
}
10891122

10901123
def print(tpe: Type): Unit = tpe match {
@@ -1114,6 +1147,22 @@ object Printers {
11141147
if (!nullable)
11151148
print("!")
11161149

1150+
case ClosureType(paramTypes, resultType, nullable) =>
1151+
print("((")
1152+
var first = true
1153+
for (paramType <- paramTypes) {
1154+
if (first)
1155+
first = false
1156+
else
1157+
print(", ")
1158+
print(paramType)
1159+
}
1160+
print(") => ")
1161+
print(resultType)
1162+
print(')')
1163+
if (!nullable)
1164+
print('!')
1165+
11171166
case RecordType(fields) =>
11181167
print('(')
11191168
var first = true

0 commit comments

Comments
 (0)
0