8000 Introduce TypedClosures in the IR, which are closures without JS inte… · scala-js/scala-js@fd2b383 · GitHub
[go: up one dir, main page]

Skip to content

Commit fd2b383

Browse files
committed
Introduce TypedClosures in the IR, which are closures without JS interop.
A `TypedClosure` is like a `Closure` but without any semantics for JS interop. This is stronger than `Char`, which is "merely" opaque to JS. A `Char` can still be passed to JS and has a meaningful `toString()`. A `TypedClosure` *cannot* be passed to JS in any way. That is enforced by making their type *not* a subtype of `any`. Since a `TypedClosure` has no JS interop semantics, it is free to strongly, statically type its parameters and result type. More importantly, we can freely choose its representation in the best possible way for the given target. On JS, that remains an arrow function. On Wasm, however, we represent is as a pair of `(capture data pointer, function pointer)`. This allows to compile them in an efficient way that does not require going through a JS bridge closure. The latter has been shown to have a devastating impact on performance when a Scala function is used in a tight loop. The type of a `TypedClosure` is a `ClosureType`. It records its parameter types and its result type. Closure types are non-variant: they are only subtypes of themselves. As mentioned, they are not subtypes of `any`. They are however subtypes of `void` and supertypes of `nothing`. To call a typed closure, we introduce a dedicated application node `ApplyTypedClosure`. IR checking ensures that actual arguments match the expected parameter types. The result type is directly used as the type of the application. There are no changes to the source language. In particular, there is no way to express typed closures or their types at the user level. However, we change the compilation of SAM anonymous functions, both for Scala functions and JVM-like SAMs, to use a `TypedClosure` instead of a `Closure`. We wrap the `TypedClosure` inside an instance of `scala.scalajs.runtime.TypedFunctionN`. These classes are generated "by hand" in the build, since there is no user-level way to define them. They have valid IR, though. These changes have no real impact on the JS output (only marginal naming differences). On Wasm, however, they make Scala functions much, much faster. Before, a Scala function in a typed loop would cause a Wasm implementation to be, in the worst measured case, 20x slower than on JS. After these changes, similar benchmarks become significantly faster on Wasm than on JS.
1 parent b63ca64 commit fd2b383

File tree

28 files changed

+1181
-167
lines changed

28 files changed

+1181
-167
lines changed

compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,9 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
962962
case closure: js.Closure =>
963963
val newCaptureValues = closure.captureValues.map(transformExpr)
964964
closure.copy(captureValues = newCaptureValues)(closure.pos)
965+
case closure: js.TypedClosure =>
966+
val newCaptureValues = closure.captureValues.map(transformExpr)
967+
closure.copy(captureValues = newCaptureValues)(closure.pos)
965968

966969
case tree =>
967970
super.transform(tree, isStat)
@@ -2030,15 +2033,12 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
20302033
} yield {
20312034
js.ParamDef(name, originalName, ptpe, newMutable(name.name, mutable))(p.pos)
20322035
}
2033-
val transformer = new ir.Transformers.Transformer {
2036+
val transformer = new ir.Transformers.LocalScopeTransformer {
20342037
override def transform(tree: js.Tree, isStat: Boolean): js.Tree = tree match {
20352038
case js.VarDef(name, originalName, vtpe, mutab 9E88 le, rhs) =>
20362039
assert(isStat, s"found a VarDef in expression position at ${tree.pos}")
20372040
super.transform(js.VarDef(name, originalName, vtpe,
20382041
newMutable(name.name, mutable), rhs)(tree.pos), isStat)
2039-
case js.Closure(arrow, captureParams, params, restParam, body, captureValues) =>
2040-
js.Closure(arrow, captureParams, params, restParam, body,
2041-
captureValues.map(transformExpr))(tree.pos)
20422042
case _ =>
20432043
super.transform(tree, isStat)
20442044
}
@@ -2068,13 +2068,10 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
20682068
} yield {
20692069
js.ParamDef(name, originalName, newType(name, ptpe), mutable)(p.pos)
20702070
}
2071-
val transformer = new ir.Transformers.Transformer {
2071+
val transformer = new ir.Transformers.LocalScopeTransformer {
20722072
override def transform(tree: js.Tree, isStat: Boolean): js.Tree = tree match {
20732073
case tree @ js.VarRef(name) =>
20742074
js.VarRef(name)(newType(name, tree.tpe))(tree.pos)
2075-
case js.Closure(arrow, captureParams, params, restParam, body, captureValues) =>
2076-
js.Closure(arrow, captureParams, params, restParam, body,
2077-
captureValues.map(transformExpr))(tree.pos)
20782075
case _ =>
20792076
super.transform(tree, isStat)
20802077
}
@@ -3239,6 +3236,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
32393236
genNewArray(arr, args.map(genExpr))
32403237
case prim: jstpe.PrimRef =>
32413238
abort(s"unexpected primitive type $prim in New at $pos")
3239+
case typeRef: jstpe.ClosureTypeRef =>
3240+
abort(s"unexpected closure type $typeRef in New at $pos")
32423241
}
32433242
}
32443243
}
@@ -6220,10 +6219,10 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
62206219
* We identify the captures using the same method as the `delambdafy`
62216220
* phase. We have an additional hack for `this`.
62226221
*
6223-
* To translate them, we first construct a JS closure for the body:
6222+
* To translate them, we first construct a typed closure for the body:
62246223
* {{{
6225-
* lambda<this, capture1, ..., captureM>(
6226-
* _this, capture1, ..., captureM, arg1, ..., argN) {
6224+
* typed-lambda<_this = this, capture1: U1 = capture1, ..., captureM: UM = captureM>(
6225+
* arg1: T1, ..., argN: TN): TR = {
62276226
* _this.someMethod(arg1, ..., argN, capture1, ..., captureM)
62286227
* }
62296228
* }}}
@@ -6237,13 +6236,13 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
62376236
* this:
62386237
* {{{
62396238
* class AnonFun extends Object with FunctionalInterface {
6240-
* val f: any
6241-
* def <init>(f: any) {
6239+
* val f: (Ti...) => TR
6240+
* def <init>(f: (Ti...) => TR) {
62426241
* super();
62436242
* this.f = f
62446243
* }
6245-
* def theSAMMethod(params: Types...): Type =
6246-
* unbox((this.f)(boxParams...))
6244+
* def theSAMMethod(params: Ti...): TR =
6245+
* (this.f)(params...)
62476246
* }
62486247
* }}}
62496248
*/
@@ -6293,21 +6292,29 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
62936292
val patchedBody =
62946293
js.Block(paramsLocals :+ ensureResultBoxed(body, target))
62956294

6296-
val closure = js.Closure(
6297-
arrow = true,
6295+
val closure = js.TypedClosure(
62986296
allFormalCaptures,
62996297
patchedFormalArgs,
6300-
restParam = None,
6298+
resultType = jstpe.AnyType,
63016299
patchedBody,
63026300
allActualCaptures)
63036301

6302+
val arity = params.size
6303+
val ctorName = {
6304+
val objectClassRef = jstpe.ClassRef(ir.Names.ObjectClass)
6305+
val closureTypeRef =
6306+
jstpe.ClosureTypeRef(List.fill(arity)(objectClassRef), objectClassRef)
6307+
ir.Names.MethodName.constructor(closureTypeRef :: Nil)
6308+
}
6309+
63046310
// Wrap the closure in the appropriate box for the SAM type
63056311
val funSym = originalFunction.tpe.typeSymbolDirect
63066312
if (isFunctionSymbol(funSym)) {
63076313
/* This is a scala.FunctionN. We use the existing AnonFunctionN
63086314
* wrapper.
63096315
*/
6310-
genJSFunctionToScala(closure, params.size)
6316+
js.New(ir.Names.ClassName("scala.scalajs.runtime.TypedFunction" + arity),
6317+
js.MethodIdent(ctorName), List(closure))
63116318
} else {
63126319
/* This is an arbitrary SAM type (can only happen in 2.12).
63136320
* We have to synthesize a class like LambdaMetaFactory would do on
@@ -6317,13 +6324,13 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
63176324
abort(s"Cannot find the SAMFunction attachment on $originalFunction at $pos")
63186325
}
63196326

6320-
val samWrapperClassName = synthesizeSAMWrapper(funSym, sam)
6321-
js.New(samWrapperClassName, js.MethodIdent(ObjectArgConstructorName),
6322-
List(closure))
6327+
val samWrapperClassName = synthesizeSAMWrapper(funSym, sam, ctorName)
6328+
js.New(samWrapperClassName, js.MethodIdent(ctorName), List(closure))
63236329
}
63246330
}
63256331

6326-
private def synthesizeSAMWrapper(funSym: Symbol, samInfo: SAMFunction)(
6332+
private def synthesizeSAMWrapper(funSym: Symbol, samInfo: SAMFunction,
6333+
ctorName: ir.Names.MethodName)(
63276334
implicit pos: Position): ClassName = {
63286335
val intfName = encodeClassName(funSym)
63296336

@@ -6336,24 +6343,27 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
63366343

63376344
val classType = jstpe.ClassType(className)
63386345

6339-
// val f: Any
6346+
val arity = samInfo.sam.tpe.params.size
6347+
val closureType = jstpe.ClosureType(List.fill(arity)(jstpe.AnyType), jstpe.AnyType)
6348+
6349+
// val f: ((any, ..., any) => any)
63406350
val fFieldIdent = js.FieldIdent(FieldName(className, SimpleFieldName("f")))
63416351
val fFieldDef = js.FieldDef(js.MemberFlags.empty, fFieldIdent,
6342-
NoOriginalName, jstpe.AnyType)
6352+
NoOriginalName, closureType)
63436353

63446354
// def this(f: Any) = { this.f = f; super() }
63456355
val ctorDef = {
63466356
val fParamDef = js.ParamDef(js.LocalIdent(LocalName("f")),
6347-
NoOriginalName, jstpe.AnyType, mutable = false)
6357+
NoOriginalName, closureType, mutable = false)
63486358
js.MethodDef(
63496359
js.MemberFlags.empty.withNamespace(js.MemberNamespace.Constructor),
6350-
js.MethodIdent(ObjectArgConstructorName),
6360+
js.MethodIdent(ctorName),
63516361
NoOriginalName,
63526362
List(fParamDef),
63536363
jstpe.NoType,
63546364
Some(js.Block(List(
63556365
js.Assign(
6356-
js.Select(js.This()(classType), fFieldIdent)(jstpe.AnyType),
6366+
js.Select(js.This()(classType), fFieldIdent)(closureType),
63576367
fParamDef.ref),
63586368
js.ApplyStatically(js.ApplyFlags.empty.withConstructor(true),
63596369
js.This()(classType),
@@ -6403,8 +6413,9 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
64036413
yield (formal.ref, param.tpe)
64046414
}.map((ensureBoxed _).tupled)
64056415

6406-
val call = js.JSFunctionApply(
6407-
js.Select(js.This()(classType), fFieldIdent)(jstpe.AnyType),
6416+
val call = js.ApplyTypedClosure(
6417+
js.ApplyFlags.empty,
6418+
js.Select(js.This()(classType), fFieldIdent)(closureType),
64086419
actualParams)
64096420

64106421
val body = fromAny(call, enteringPhase(currentRun.posterasurePhase) {

ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,12 @@ object Hashers {
310310
mixMethodIdent(method)
311311
mixTrees(args)
312312

313+
case ApplyTypedClosure(flags, fun, args) =>
314+
mixTag(TagApplyTypedClosure)
315+
mixInt(ApplyFlags.toBits(flags))
316+
mixTree(fun)
317+
mixTrees(args)
318+
313319
case UnaryOp(op, lhs) =>
314320
mixTag(TagUnaryOp)
315321
mixInt(op)
@@ -527,6 +533,10 @@ object Hashers {
527533
mixTag(TagClassOf)
528534
mixTypeRef(typeRef)
529535

536+
case NullTypedClosure(tpe) =>
537+
mixTag(TagNullTypedClosure)
538+
mixType(tpe)
539+
530540
case VarRef(ident) =>
531541
mixTag(TagVarRef)
532542
mixLocalIdent(ident)
@@ -545,6 +555,14 @@ object Hashers {
545555
mixTree(body)
546556
mixTrees(captureValues)
547557

558+
case TypedClosure(captureParams, params, resultType, body, captureValues) =>
559+
mixTag(TagTypedClosure)
560+
mixParamDefs(captureParams)
561+
mixParamDefs(params)
562+
mixType(resultType)
563+
mixTree(body)
564+
mixTrees(captureValues)
565+
548566
case CreateJSClass(className, captureValues) =>
549567
mixTag(TagCreateJSClass)
550568
mixName(className)
@@ -597,13 +615,21 @@ object Hashers {
597615
case typeRef: ArrayTypeRef =>
598616
mixTag(TagArrayTypeRef)
599617
mixArrayTypeRef(typeRef)
618+
case typeRef: ClosureTypeRef =>
619+
mixTag(TagClosureTypeRef)
620+
mixClosureTypeRef(typeRef)
600621
}
601622

602623
def mixArrayTypeRef(arrayTypeRef: ArrayTypeRef): Unit = {
603624
mixTypeRef(arrayTypeRef.base)
604625
mixInt(arrayTypeRef.dimensions)
605626
}
606627

628+
def mixClosureTypeRef(closureTypeRef: ClosureTypeRef): Unit = {
629+
closureTypeRef.paramTypeRefs.foreach(mixTypeRef(_))
630+
mixTypeRef(closureTypeRef.resultTypeRef)
631+
}
632+
607633
def mixType(tpe: Type): Unit = tpe match {
608634
case AnyType => mixTag(TagAnyType)
609635
case NothingType => mixTag(TagNothingType)
@@ -628,6 +654,11 @@ object Hashers {
628654
mixTag(TagArrayType)
629655
mixArrayTypeRef(arrayTypeRef)
630656

657+
case ClosureType(paramTypes, resultType) =>
658+
mixTag(TagClosureType)
659+
mixTypes(paramTypes)
660+
mixType(resultType)
661+
631662
case RecordType(fields) =>
632663
mixTag(TagRecordType)
633664
for (RecordType.Field(name, originalName, tpe, mutable) <- fields) {
@@ -638,6 +669,9 @@ object Hashers {
638669
}
639670
}
640671

672+
def mixTypes(tpes: List[Type]): Unit =
673+
tpes.foreach(mixType)
674+
641675
def mixLocalIdent(ident: LocalIdent): Unit = {
642676
mixPos(ident.pos)
643677
mixName(ident.name)

ir/shared/src/main/scala/org/scalajs/ir/Names.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,18 @@ object Names {
414414
i += 1
415415
}
416416
appendTypeRef(base)
417+
case ClosureTypeRef(paramTypeRefs, resultTypeRef) =>
418+
builder.append('(')
419+
var first = true
420+
for (paramTypeRef <- paramTypeRefs) {
421+
if (first)
422+
first = false
423+
else
424+
builder.append(',')
425+
appendTypeRef(paramTypeRef)
426+
}
427+
builder.append(')')
428+
appendTypeRef(resultTypeRef)
417429
}
418430

419431
builder.append(simpleName.nameString)

ir/shared/src/main/scala/org/scalajs/ir/Printers.scala

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,10 @@ object Printers {
343343
print(method)
344344
printArgs(args)
345345

346+
case ApplyTypedClosure(flags, fun, args) =>
347+
print(fun)
348+
printArgs(args)
349+
346350
case UnaryOp(UnaryOp.String_length, lhs) =>
347351
print(lhs)
348352
print(".length")
@@ -846,6 +850,11 @@ object Printers {
846850
print(typeRef)
847851
print(']')
848852

853+
case NullTypedClosure(tpe) =>
854+
print("null<")
855+
print(tpe)
856+
print('>')
857+
849858
// Atomic expressions
850859

851860
case VarRef(ident) =>
@@ -874,6 +883,23 @@ object Printers {
874883
printBlock(body)
875884
print(')')
876885

886+
case TypedClosure(captureParams, params, resultType, body, captureValues) =>
887+
print("(typed-lambda<")
888+
var first = true
889+
for ((param, value) <- captureParams.zip(captureValues)) {
890+
if (first)
891+
first = false
892+
else
893+
print(", ")
894+
print(param)
895+
print(" = ")
896+
print(value)
897+
}
898+
print(">")
899+
printSig(params, restParam = None, resultType)
900+
printBlock(body)
901+
print(')')
902+
877903
case CreateJSClass(className, captureValues) =>
878904
print("createjsclass[")
879905
print(className)
@@ -1063,6 +1089,18 @@ object Printers {
10631089
print(base)
10641090
for (i <- 1 to dims)
10651091
print("[]")
1092+
case ClosureTypeRef(paramTypeRefs, resultTypeRef) =>
1093+
print('(')
1094+
var first = true
1095+
for (paramTypeRef <- paramTypeRefs) {
1096+
if (first)
1097+
first = false
1098+
else
1099+
print(", ")
1100+
print(paramTypeRef)
1101+
}
1102+
print(") => ")
1103+
print(resultTypeRef)
10661104
}
10671105

10681106
def print(tpe: Type): Unit = tpe match {
@@ -1085,6 +1123,20 @@ object Printers {
10851123
case ArrayType(arrayTypeRef) =>
10861124
print(arrayTypeRef)
10871125

1126+
case ClosureType(paramTypes, resultType) =>
1127+
print("((")
1128+
var first = true
1129+
for (paramType <- paramTypes) {
1130+
if (first)
1131+
first = false
1132+
else
1133+
print(", ")
1134+
print(paramType)
1135+
}
1136+
print(") => ")
1137+
print(resultType)
1138+
print(')')
1139+
10881140
case RecordType(fields) =>
10891141
print('(')
10901142
var first = true

0 commit comments

Comments
 (0)
0