8000 Introduce TypedClosures in the IR, which are closures without JS inte… · scala-js/scala-js@acc04e4 · GitHub
[go: up one dir, main page]

Skip to content

Commit acc04e4

Browse files
committed
Introduce TypedClosures in the IR, which are closures without JS interop.
A `TypedClosure` is like a `Closure` but without any semantics for JS interop. This is stronger than `Char`, which is "merely" opaque to JS. A `Char` can still be passed to JS and has a meaningful `toString()`. A `TypedClosure` *cannot* be passed to JS in any way. That is enforced by making their type *not* a subtype of `any`. Since a `TypedClosure` has no JS interop semantics, it is free to strongly, statically type its parameters and result type. More importantly, we can freely choose its representation in the best possible way for the given target. On JS, that remains an arrow function. On Wasm, however, we represent is as a pair of `(capture data pointer, function pointer)`. This allows to compile them in an efficient way that does not require going through a JS bridge closure. The latter has been shown to have a devastating impact on performance when a Scala function is used in a tight loop. The type of a `TypedClosure` is a `ClosureType`. It records its parameter types and its result type. Closure types are non-variant: they are only subtypes of themselves. As mentioned, they are not subtypes of `any`. They are however subtypes of `void` and supertypes of `nothing`. Unfortunately, they must also be nullable to have a default value, so they are also supertypes of `null`. To call a typed closure, we introduce a dedicated application node `ApplyTypedClosure`. IR checking ensures that actual arguments match the expected parameter types. The result type is directly used as the type of the application. There are no changes to the source language. In particular, there is no way to express typed closures or their types at the user level. However, we change the compilation of SAM anonymous functions, both for Scala functions and JVM-like SAMs, to use a `TypedClosure` instead of a `Closure`. We wrap the `TypedClosure` inside an instance of `scala.scalajs.runtime.TypedFunctionN`. These classes are generated "by hand" in the build, since there is no user-level way to define them. They have valid IR, though. These changes have no real impact on the JS output (only marginal naming differences). On Wasm, however, they make Scala functions much, much faster. Before, a Scala function in a typed loop would cause a Wasm implementation to be, in the worst measured case, 20x slower than on JS. After these changes, similar benchmarks become significantly faster on Wasm than on JS.
1 parent ddebbe1 commit acc04e4

File tree

28 files changed

+1135
-191
lines changed

28 files changed

+1135
-191
lines changed

compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala

Lines changed: 39 additions & 37 deletions
E377 < 2851 td data-grid-cell-id="diff-827fa2ae909eff6a714b058ec14d861994e8c9200b17948aafb30ae312f7cee9-6359-6360-2" data-line-anchor="diff-827fa2ae909eff6a714b058ec14d861994e8c9200b17948aafb30ae312f7cee9R6360" data-selected="false" role="gridcell" style="background-color:var(--bgColor-default);padding-right:24px" tabindex="-1" valign="top" class="focusable-grid-cell diff-text-cell right-side-diff-cell left-side">
js.This()(classType),
Original file line numberDiff line numberDiff line change
@@ -953,16 +953,10 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
953953
}
954954

955955
// After the super call, substitute `selfRef` for `This()`
956-
val afterSuper = new ir.Transformers.Transformer {
956+
val afterSuper = new ir.Transformers.LocalScopeTransformer {
957957
override def transform(tree: js.Tree, isStat: Boolean): js.Tree = tree match {
958958
case js.This() =>
959959
selfRef(tree.pos)
960-
961-
// Don't traverse closure boundaries
962-
case closure: js.Closure =>
963-
val newCaptureValues = closure.captureValues.map(transformExpr)
964-
closure.copy(captureValues = newCaptureValues)(closure.pos)
965-
966960
case tree =>
967961
super.transform(tree, isStat)
968962
}
@@ -2030,15 +2024,12 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
20302024
} yield {
20312025
js.ParamDef(name, originalName, ptpe, newMutable(name.name, mutable))(p.pos)
20322026
}
2033-
val transformer = new ir.Transformers.Transformer {
2027+
val transformer = new ir.Transformers.LocalScopeTransformer {
20342028
override def transform(tree: js.Tree, isStat: Boolean): js.Tree = tree match {
20352029
case js.VarDef(name, originalName, vtpe, mutable, rhs) =>
20362030
assert(isStat, s"found a VarDef in expression position at ${tree.pos}")
20372031
super.transform(js.VarDef(name, originalName, vtpe,
20382032
newMutable(name.name, mutable), rhs)(tree.pos), isStat)
2039-
case js.Closure(arrow, captureParams, params, restParam, body, captureValues) =>
2040-
js.Closure(arrow, captureParams, params, restParam, body,
2041-
captureValues.map(transformExpr))(tree.pos)
20422033
case _ =>
20432034
super.transform(tree, isStat)
20442035
}
@@ -2068,13 +2059,10 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
20682059
} yield {
20692060
js.ParamDef(name, originalName, newType(name, ptpe), mutable)(p.pos)
20702061
}
2071-
val transformer = new ir.Transformers.Transformer {
2062+
val transformer = new ir.Transformers.LocalScopeTransformer {
20722063
override def transform(tree: js.Tree, isStat: Boolean): js.Tree = tree match {
20732064
case tree @ js.VarRef(name) =>
20742065
js.VarRef(name)(newType(name, tree.tpe))(tree.pos)
2075-
case js.Closure(arrow, captureParams, params, restParam, body, captureValues) =>
2076-
js.Closure(arrow, captureParams, params, restParam, body,
2077-
captureValues.map(transformExpr))(tree.pos)
20782066
case _ =>
20792067
super.transform(tree, isStat)
20802068
}
@@ -3239,6 +3227,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
32393227
genNewArray(arr, args.map(genExpr))
32403228
case prim: jstpe.PrimRef =>
32413229
abort(s"unexpected primitive type $prim in New at $pos")
3230+
case typeRef: jstpe.ClosureTypeRef =>
3231+
abort(s"unexpected closure type $typeRef in New at $pos")
32423232
}
32433233
}
32443234
}
@@ -6220,10 +6210,10 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
62206210
* We identify the captures using the same method as the `delambdafy`
62216211
* phase. We have an additional hack for `this`.
62226212
*
6223-
* To translate them, we first construct a JS closure for the body:
6213+
* To translate them, we first construct a typed closure for the body:
62246214
* {{{
6225-
* lambda<this, capture1, ..., captureM>(
6226-
* _this, capture1, ..., captureM, arg1, ..., argN) {
6215+
* typed-lambda<_this = this, capture1: U1 = capture1, ..., captureM: UM = captureM>(
6216+
* arg1: T1, ..., argN: TN): TR = {
62276217
* _this.someMethod(arg1, ..., argN, capture1, ..., captureM)
62286218
* }
62296219
* }}}
@@ -6237,13 +6227,13 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
62376227
* this:
62386228
* {{{
62396229
* class AnonFun extends Object with FunctionalInterface {
6240-
* val f: any
6241-
* def <init>(f: any) {
6230+
* val f: (Ti...) => TR
6231+
* def <init>(f: (Ti...) => TR) {
62426232
* super();
62436233
* this.f = f
62446234
* }
6245-
* def theSAMMethod(params: Types...): Type =
6246-
* unbox((this.f)(boxParams...))
6235+
* def theSAMMethod(params: Ti...): TR =
6236+
* (this.f)(params...)
62476237
* }
62486238
* }}}
62496239
*/
@@ -6293,21 +6283,29 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
62936283
val patchedBody =
62946284
js.Block(paramsLocals :+ ensureResultBoxed(body, target))
62956285

6296-
val closure = js.Closure(
6297-
arrow = true,
6286+
val closure = js.TypedClosure(
62986287
allFormalCaptures,
62996288
patchedFormalArgs,
6300-
restParam = None,
6289+
resultType = jstpe.AnyType,
63016290
patchedBody,
63026291
allActualCaptures)
63036292

6293+
val arity = params.size
6294+
val ctorName = {
6295+
val objectClassRef = jstpe.ClassRef(ir.Names.ObjectClass)
6296+
val closureTypeRef =
6297+
jstpe.ClosureTypeRef(List.fill(arity)(objectClassRef), objectClassRef)
6298+
ir.Names.MethodName.constructor(closureTypeRef :: Nil)
6299+
}
6300+
63046301
// Wrap the closure in the appropriate box for the SAM type
63056302
val funSym = originalFunction.tpe.typeSymbolDirect
63066303
if (isFunctionSymbol(funSym)) {
63076304
/* This is a scala.FunctionN. We use the existing AnonFunctionN
63086305
* wrapper.
63096306
*/
6310-
genJSFunctionToScala(closure, params.size)
6307+
js.New(ir.Names.ClassName("scala.scalajs.runtime.TypedFunction" + arity),
6308+
js.MethodIdent(ctorName), List(closure))
63116309
} else {
63126310
/* This is an arbitrary SAM type (can only happen in 2.12).
63136311
* We have to synthesize a class like LambdaMetaFactory would do on
@@ -6317,13 +6315,13 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
63176315
abort(s"Cannot find the SAMFunction attachment on $originalFunction at $pos")
63186316
}
63196317

6320-
val samWrapperClassName = synthesizeSAMWrapper(funSym, sam)
6321-
js.New(samWrapperClassName, js.MethodIdent(ObjectArgConstructorName),
6322-
List(closure))
6318+
val samWrapperClassName = synthesizeSAMWrapper(funSym, sam, ctorName)
6319+
js.New(samWrapperClassName, js.MethodIdent(ctorName), List(closure))
63236320
}
63246321
}
63256322

6326-
private def synthesizeSAMWrapper(funSym: Symbol, samInfo: SAMFunction)(
6323+
private def synthesizeSAMWrapper(funSym: Symbol, samInfo: SAMFunction,
6324+
ctorName: ir.Names.MethodName)(
63276325
implicit pos: Position): ClassName = {
63286326
val intfName = encodeClassName(funSym)
63296327

@@ -6336,24 +6334,27 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
63366334

63376335
val classType = jstpe.ClassType(className)
63386336

6339-
// val f: Any
6337+
val arity = samInfo.sam.tpe.params.size
6338+
val closureType = jstpe.ClosureType(List.fill(arity)(jstpe.AnyType), jstpe.AnyType)
6339+
6340+
// val f: ((any, ..., any) => any)
63406341
val fFieldIdent = js.FieldIdent(FieldName(className, SimpleFieldName("f")))
63416342
val fFieldDef = js.FieldDef(js.MemberFlags.empty, fFieldIdent,
6342-
NoOriginalName, jstpe.AnyType)
6343+
NoOriginalName, closureType)
63436344

63446345
// def this(f: Any) = { this.f = f; super() }
63456346
val ctorDef = {
63466347
val fParamDef = js.ParamDef(js.LocalIdent(LocalName("f")),
6347-
NoOriginalName, jstpe.AnyType, mutable = false)
6348+
NoOriginalName, closureType, mutable = false)
63486349
js.MethodDef(
63496350
js.MemberFlags.empty.withNamespace(js.MemberNamespace.Constructor),
6350-
js.MethodIdent(ObjectArgConstructorName),
6351+
js.MethodIdent(ctorName),
63516352
NoOriginalName,
63526353
List(fParamDef),
63536354
jstpe.NoType,
63546355
Some(js.Block(List(
63556356
js.Assign(
6356-
js.Select(js.This()(classType), fFieldIdent)(jstpe.AnyType),
6357+
js.Select(js.This()(classType), fFieldIdent)(closureType),
63576358
fParamDef.ref),
63586359
js.ApplyStatically(js.ApplyFlags.empty.withConstructor(true),
63596360
@@ -6403,8 +6404,9 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
64036404
yield (formal.ref, param.tpe)
64046405
}.map((ensureBoxed _).tupled)
64056406

6406-
val call = js.JSFunctionApply(
6407-
js.Select(js.This()(classType), fFieldIdent)(jstpe.AnyType),
6407+
val call = js.ApplyTypedClosure(
6408+
js.ApplyFlags.empty,
6409+
js.Select(js.This()(classType), fFieldIdent)(closureType),
64086410
actualParams)
64096411

64106412
val body = fromAny(call, enteringPhase(currentRun.posterasurePhase) {

ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,12 @@ object Hashers {
310310
mixMethodIdent(method)
311311
mixTrees(args)
312312

313+
case ApplyTypedClosure(flags, fun, args) =>
314+
mixTag(TagApplyTypedClosure)
315+
mixInt(ApplyFlags.toBits(flags))
316+
mixTree(fun)
317+
mixTrees(args)
318+
313319
case UnaryOp(op, lhs) =>
314320
mixTag(TagUnaryOp)
315321
mixInt(op)
@@ -545,6 +551,14 @@ object Hashers {
545551
mixTree(body)
546552
mixTrees(captureValues)
547553

554+
case TypedClosure(captureParams, params, resultType, body, captureValues) =>
555+
mixTag(TagTypedClosure)
556+
mixParamDefs(captureParams)
557+
mixParamDefs(params)
558+
mixType(resultType)
559+
mixTree(body)
560+
mixTrees(captureValues)
561+
548562
case CreateJSClass(className, captureValues) =>
549563
mixTag(TagCreateJSClass)
550564
mixName(className)
@@ -597,13 +611,21 @@ object Hashers {
597611
case typeRef: ArrayTypeRef =>
598612
mixTag(TagArrayTypeRef)
599613
mixArrayTypeRef(typeRef)
614+
case typeRef: ClosureTypeRef =>
615+
mixTag(TagClosureTypeRef)
616+
mixClosureTypeRef(typeRef)
600617
}
601618

602619
def mixArrayTypeRef(arrayTypeRef: ArrayTypeRef): Unit = {
603620
mixTypeRef(arrayTypeRef.base)
604621
mixInt(arrayTypeRef.dimensions)
605622
}
606623

624+
def mixClosureTypeRef(closureTypeRef: ClosureTypeRef): Unit = {
625+
closureTypeRef.paramTypeRefs.foreach(mixTypeRef(_))
626+
mixTypeRef(closureTypeRef.resultTypeRef)
627+
}
628+
607629
def mixType(tpe: Type): Unit = tpe match {
608630
case AnyType => mixTag(TagAnyType)
609631
case NothingType => mixTag(TagNothingType)
@@ -628,6 +650,11 @@ object Hashers {
628650
mixTag(TagArrayType)
629651
mixArrayTypeRef(arrayTypeRef)
630652

653+
case ClosureType(paramTypes, resultType) =>
654+
mixTag(TagClosureType)
655+
mixTypes(paramTypes)
656+
mixType(resultType)
657+
631658
case RecordType(fields) =>
632659
mixTag(TagRecordType)
633660
for (RecordType.Field(name, originalName, tpe, mutable) <- fields) {
@@ -638,6 +665,9 @@ object Hashers {
638665
}
639666
}
640667

668+
def mixTypes(tpes: List[Type]): Unit =
669+
tpes.foreach(mixType)
670+
641671
def mixLocalIdent(ident: LocalIdent): Unit = {
642672
mixPos(ident.pos)
643673
mixName(ident.name)

ir/shared/src/main/scala/org/scalajs/ir/Names.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,18 @@ object Names {
414414
i += 1
415415
}
416416
appendTypeRef(base)
417+
case ClosureTypeRef(paramTypeRefs, resultTypeRef) =>
418+
builder.append('(')
419+
var first = true
420+
for (paramTypeRef <- paramTypeRefs) {
421+
if (first)
422+
first = false
423+
else
424+
builder.append(',')
425+
appendTypeRef(paramTypeRef)
426+
}
427+
builder.append(')')
428+
appendTypeRef(resultTypeRef)
417429
}
418430

419431
builder.append(simpleName.nameString)

ir/shared/src/main/scala/org/scalajs/ir/Printers.scala

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,10 @@ object Printers {
343343
print(method)
344344
printArgs(args)
345345

346+
case ApplyTypedClosure(flags, fun, args) =>
347+
print(fun)
348+
printArgs(args)
349+
346350
case UnaryOp(UnaryOp.String_length, lhs) =>
347351
print(lhs)
348352
print(".length")
@@ -874,6 +878,23 @@ object Printers {
874878
printBlock(body)
875879
print(')')
876880

881+
case TypedClosure(captureParams, params, resultType, body, captureValues) =>
882+
print("(typed-lambda<")
883+
var first = true
884+
for ((param, value) <- captureParams.zip(captureValues)) {
885+
if (first)
886+
first = false
887+
else
888+
print(", ")
889+
print(param)
890+
print(" = ")
891+
print(value)
892+
}
893+
print(">")
894+
printSig(params, restParam = None, resultType)
895+
printBlock(body)
896+
print(')')
897+
877898
case CreateJSClass(className, captureValues) =>
878899
print("createjsclass[")
879900
print(className)
@@ -1063,6 +1084,18 @@ object Printers {
10631084
print(base)
10641085
for (i <- 1 to dims)
10651086
print("[]")
1087+
case ClosureTypeRef(paramTypeRefs, resultTypeRef) =>
1088+
print('(')
1089+
var first = true
1090+
for (paramTypeRef <- paramTypeRefs) {
1091+
if (first)
1092+
first = false
1093+
else
1094+
print(", ")
1095+
print(paramTypeRef)
1096+
}
1097+
print(") => ")
1098+
print(resultTypeRef)
10661099
}
10671100

10681101
def print(tpe: Type): Unit = tpe match {
@@ -1085,6 +1118,20 @@ object Printers {
10851118
case ArrayType(arrayTypeRef) =>
10861119
print(arrayTypeRef)
10871120

1121+
case ClosureType(paramTypes, resultType) =>
1122+
print("((")
1123+
var first = true
1124+
for (paramType <- paramTypes) {
1125+
if (first)
1126+
first = false
1127+
else
1128+
print(", ")
1129+
print(paramType)
1130+
}
1131+
print(") => ")
1132+
print(resultType)
1133+
print(')')
1134+
10881135
case RecordType(fields) =>
10891136
print('(')
10901137
var first = true

0 commit comments

Comments
 (0)
0