8000 Merge pull request #5081 from sjrd/inline-delambdafy-targets · scala-js/scala-js@030eff0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 030eff0

Browse files
authored
Merge pull request #5081 from sjrd/inline-delambdafy-targets
Codegen: Inline the target of `Function` nodes in their `js.Closure`s.
2 parents 4150da9 + 0d16b42 commit 030eff0

File tree

2 files changed

+165
-118
lines changed

2 files changed

+165
-118
lines changed

compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala

Lines changed: 155 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
159159
val currentClassSym = new ScopedVar[Symbol]
160160
private val fieldsMutatedInCurrentClass = new ScopedVar[mutable.Set[Name]]
161161
private val generatedSAMWrapperCount = new ScopedVar[VarBox[Int]]
162+
private val delambdafyTargetDefDefs = new ScopedVar[mutable.Map[Symbol, DefDef]]
162163

163164
def currentThisTypeNullable: jstpe.Type =
164165
encodeClassType(currentClassSym)
@@ -185,10 +186,12 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
185186
private val mutableLocalVars = new ScopedVar[mutable.Set[Symbol]]
186187
private val mutatedLocalVars = new ScopedVar[mutable.Set[Symbol]]
187188

188-
private def withPerMethodBodyState[A](methodSym: Symbol)(body: => A): A = {
189+
private def withPerMethodBodyState[A](methodSym: Symbol,
190+
initThisLocalVarIdent: Option[js.LocalIdent] = None)(body: => A): A = {
191+
189192
withScopedVars(
190193
currentMethodSym := methodSym,
191-
thisLocalVarIdent := None,
194+
thisLocalVarIdent := initThisLocalVarIdent,
192195
enclosingLabelDefInfos := Map.empty,
193196
isModuleInitialized := new VarBox(false),
194197
undefinedDefaultParams := mutable.Set.empty,
@@ -236,6 +239,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
236239
currentClassSym := clsSym,
237240
fieldsMutatedInCurrentClass := mutable.Set.empty,
238241
generatedSAMWrapperCount := new VarBox(0),
242+
delambdafyTargetDefDefs := mutable.Map.empty,
239243
currentMethodSym := null,
240244
thisLocalVarIdent := null,
241245
enclosingLabelDefInfos := null,
@@ -481,7 +485,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
481485
withScopedVars(
482486
currentClassSym := sym,
483487
fieldsMutatedInCurrentClass := mutable.Set.empty,
484-
generatedSAMWrapperCount := new VarBox(0)
488+
generatedSAMWrapperCount := new VarBox(0),
489+
delambdafyTargetDefDefs := mutable.Map.empty
485490
) {
486491
val tree = if (isJSType(sym)) {
487492
if (!sym.isTraitOrInterface && isNonNativeJSClass(sym) &&
@@ -590,6 +595,34 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
590595
}
591596
}
592597

598+
private def collectDefDefs(impl: Template): List[DefDef] = {
599+
val b = List.newBuilder[DefDef]
600+
601+
for (stat <- impl.body) {
602+
stat match {
603+
case stat: DefDef =>
604+
if (stat.symbol.isDelambdafyTarget)
605+
delambdafyTargetDefDefs += stat.symbol -> stat
606+
else
607+
b += stat
608+
609+
case EmptyTree | _:ValDef =>
610+
()
611+
612+
case _ =>
613+
abort(s"Unexpected tree in template: $stat at ${stat.pos}")
614+
}
615+
}
616+
617+
b.result()
618+
}
619+
620+
private def consumeDelambdafyTarget(sym: Symbol): DefDef = {
621+
delambdafyTargetDefDefs.remove(sym).getOrElse {
622+
abort(s"Cannot resolve delambdafy target $sym at ${sym.pos}")
623+
}
624+
}
625+
593626
// Generate a class --------------------------------------------------------
594627

595628
/** Gen the IR ClassDef for a class definition (maybe a module class).
@@ -640,26 +673,13 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
640673
val methodsBuilder = List.newBuilder[js.MethodDef]
641674
val jsNativeMembersBuilder = List.newBuilder[js.JSNativeMemberDef]
642675

643-
def gen(tree: Tree): Unit = {
644-
tree match {
645-
case EmptyTree => ()
646-
case Template(_, _, body) => body foreach gen
647-
648-
case ValDef(mods, name, tpt, rhs) =>
649-
() // fields are added via genClassFields()
650-
651-
case dd: DefDef =>
652-
if (dd.symbol.hasAnnotation(JSNativeAnnotation))
653-
jsNativeMembersBuilder += genJSNativeMemberDef(dd)
654-
else
655-
methodsBuilder ++= genMethod(dd)
656-
657-
case _ => abort("Illegal tree in gen of genClass(): " + tree)
658-
}
676+
for (dd <- collectDefDefs(impl)) {
677+
if (dd.symbol.hasAnnotation(JSNativeAnnotation))
678+
jsNativeMembersBuilder += genJSNativeMemberDef(dd)
679+
else
680+
methodsBuilder ++= genMethod(dd)
659681
}
660682

661-
gen(impl)
662-
663683
val fields = if (!isHijacked) genClassFields(cd) else Nil
664684

665685
val jsNativeMembers = jsNativeMembersBuilder.result()
@@ -797,44 +817,31 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
797817
val generatedMethods = new ListBuffer[js.MethodDef]
798818
val dispatchMethodNames = new ListBuffer[JSName]
799819

800-
def gen(tree: Tree): Unit = {
801-
tree match {
802-
case EmptyTree => ()
803-
case Template(_, _, body) => body foreach gen
804-
805-
case ValDef(mods, name, tpt, rhs) =>
806-
() // fields are added via genClassFields()
807-
808-
case dd: DefDef =>
809-
val sym = dd.symbol
810-
val exposed = isExposed(sym)
811-
812-
if (sym.isClassConstructor) {
813-
constructorTrees += dd
814-
} else if (exposed && sym.isAccessor && !sym.isLazy) {
815-
/* Exposed accessors must not be emitted, since the field they
816-
* access is enough.
817-
*/
818-
} else if (sym.hasAnnotation(JSOptionalAnnotation)) {
819-
// Optional methods must not be emitted
820-
} else {
821-
generatedMethods ++= genMethod(dd)
820+
for (dd <- collectDefDefs(cd.impl)) {
821+
val sym = dd.symbol
822+
val exposed = isExposed(sym)
822823

823-
// Collect the names of the dispatchers we have to create
824-
if (exposed && !sym.isDeferred) {
825-
/* We add symbols that we have to expose here. This way we also
826-
* get inherited stuff that is implemented in this class.
827-
*/
828-
dispatchMethodNames += jsNameOf(sym)
829-
}
830-
}
824+
if (sym.isClassConstructor) {
825+
constructorTrees += dd
826+
} else if (exposed && sym.isAccessor && !sym.isLazy) {
827+
/* Exposed accessors must not be emitted, since the field they
828+
* access is enough.
829+
*/
830+
} else if (sym.hasAnnotation(JSOptionalAnnotation)) {
831+
// Optional methods must not be emitted
832+
} else {
833+
generatedMethods ++= genMethod(dd)
831834

832-
case _ => abort("Illegal tree in gen of genClass(): " + tree)
835+
// Collect the names of the dispatchers we have to create
836+
if (exposed && !sym.isDeferred) {
837+
/* We add symbols that we have to expose here. This way we also
838+
* get inherited stuff that is implemented in this class.
839+
*/
840+
dispatchMethodNames += jsNameOf(sym)
841+
}
833842
}
834843
}
835844

836-
gen(cd.impl)
837-
838845
// Static members (exported from the companion object)
839846
val (staticFields, staticExports) = {
840847
/* Phase travel is necessary for non-top-level classes, because flatten
@@ -1158,20 +1165,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
11581165

11591166
val classIdent = encodeClassNameIdent(sym)
11601167

1161-
// fill in class info builder
1162-
def gen(tree: Tree): List[js.MethodDef] = {
1163-
tree match {
1164-
case EmptyTree => Nil
1165-
case Template(_, _, body) => body.flatMap(gen)
1166-
1167-
case dd: DefDef =>
1168-
genMethod(dd).toList
1169-
1170-
case _ =>
1171-
abort("Illegal tree in gen of genInterface(): " + tree)
1172-
}
1173-
}
1174-
val generatedMethods = gen(cd.impl)
1168+
val generatedMethods = collectDefDefs(cd.impl).flatMap(genMethod(_))
11751169
val interfaces = genClassInterfaces(sym, forJSClass = false)
11761170

11771171
val allMemberDefs =
@@ -2045,11 +2039,13 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
20452039
*
20462040
* Other (normal) methods are emitted with `genMethodDef()`.
20472041
*/
2048-
def genMethodWithCurrentLocalNameScope(dd: DefDef): js.MethodDef = {
2042+
def genMethodWithCurrentLocalNameScope(dd: DefDef,
2043+
initThisLocalVarIdent: Option[js.LocalIdent] = None): js.MethodDef = {
2044+
20492045
implicit val pos = dd.pos
20502046
val sym = dd.symbol
20512047

2052-
withPerMethodBodyState(sym) {
2048+
withPerMethodBodyState(sym, initThisLocalVarIdent) {
20532049
val methodName = encodeMethodSym(sym)
20542050
val originalName = originalNameOfMethod(sym)
20552051

@@ -6409,13 +6405,16 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
64096405
*
64106406
* To translate them, we first construct a JS closure for the body:
64116407
* {{{
6412-
* lambda<this, capture1, ..., captureM>(
6413-
* _this, capture1, ..., captureM, arg1, ..., argN) {
6414-
* _this.someMethod(arg1, ..., argN, capture1, ..., captureM)
6408+
* arrow-lambda<_this = this, capture1: U1 = capture1, ..., captureM: UM = captureM>(
6409+
* arg1: any, ..., argN: any): any = {
6410+
* val arg1Unboxed: T1 = arg1.asInstanceOf[T1];
6411+
* ...
6412+
* val argNUnboxed: TN = argN.asInstanceOf[TN];
6413+
* // inlined body of `someMethod`, boxed
64156414
* }
64166415
* }}}
64176416
* In the closure, input params are unboxed before use, and the result of
6418-
* `someMethod()` is boxed back.
6417+
* the body of `someMethod` is boxed back.
64196418
*
64206419
* Then, we wrap that closure in a class satisfying the expected type.
64216420
* For Scala function types, we use the existing
@@ -6440,61 +6439,109 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
64406439
targetTree @ Select(receiver, _), allArgs0)) = originalFunction
64416440

64426441
val captureSyms =
6443-
global.delambdafy.FreeVarTraverser.freeVarsOf(originalFunction)
6442+
global.delambdafy.FreeVarTraverser.freeVarsOf(originalFunction).toList
64446443
val target = targetTree.symbol
6445-
val params = paramTrees.map(_.symbol)
64466444

6447-
val allArgs = allArgs0 map genExpr
6445+
val isTargetStatic = compileAsStaticMethod(target)
64486446

6449-
val formalCaptures = captureSyms.toList.map(genParamDef(_, pos))
6450-
val actualCaptures = formalCaptures.map(_.ref)
6451-
6452-
val formalArgs = params.map(genParamDef(_))
6453-
6454-
val (allFormalCaptures, body, allActualCaptures) = if (!compileAsStaticMethod(target)) {
6455-
val thisActualCapture = genExpr(receiver)
6456-
val thisFormalCapture = js.ParamDef(
6457-
freshLocalIdent("this")(receiver.pos), thisOriginalName,
6458-
thisActualCapture.tpe, mutable = false)(receiver.pos)
6459-
val thisCaptureArg = thisFormalCapture.ref
6447+
// Gen actual captures in the local name scope of the enclosing method
6448+
val actualCaptures: List[js.Tree] = {
6449+
val base = captureSyms.map(genVarRef(_))
6450+
if (isTargetStatic)
6451+
base
6452+
else
6453+
genExpr(receiver) :: base
6454+
}
64606455

6461-
val body = if (isJSType(receiver.tpe) && target.owner != ObjectClass) {
6462-
assert(isNonNativeJSClass(target.owner) && !isExposed(target),
6463-
s"A Function lambda is trying to call an exposed JS method ${target.fullName}")
6464-
genApplyJSClassMethod(thisCaptureArg, target, allArgs)
6456+
val closure: js.Closure = withNewLocalNameScope {
6457+
// Gen the formal capture params for the closure
6458+
val thisFormalCapture: Option[js.ParamDef] = if (isTargetStatic) {
6459+
None
64656460
} else {
6466-
genApplyMethodMaybeStatically(thisCaptureArg, target, allArgs)
6461+
Some(js.ParamDef(
6462+
freshLocalIdent("this")(receiver.pos), thisOriginalName,
6463+
toIRType(receiver.tpe), mutable = false)(receiver.pos))
64676464
}
6465+
val formalCaptures: List[js.ParamDef] =
6466+
thisFormalCapture.toList ::: captureSyms.map(genParamDef(_, pos))
64686467

6469-
(thisFormalCapture :: formalCaptures,
6470-
body, thisActualCapture :: actualCaptures)
6471-
} else {
6472-
val body = genApplyStatic(target, allArgs)
6468+
// Gen the inlined target method body
6469+
val genMethodDef = {
6470+
genMethodWithCurrentLocalNameScope(consumeDelambdafyTarget(target),
6471+
initThisLocalVarIdent = thisFormalCapture.map(_.name))
6472+
}
6473+
val js.MethodDef(methodFlags, _, _, methodParams, _, methodBody) = genMethodDef
64736474

6474-
(formalCaptures, body, actualCaptures)
6475-
}
6475+
/* If the target method was not supposed to be static, but genMethodDef
6476+
* turns out to be static, it means it is a non-exposed method of a JS
6477+
* class. The `this` param was turned into a regular param, for which
6478+
* we need a `js.VarDef`.
6479+
*/
6480+
val (maybeThisParamAsVarDef, remainingMethodParams) = {
6481+
if (methodFlags.namespace.isStatic && !isTargetStatic) {
6482+
val thisParamDef :: remainingMethodParams = methodParams: @unchecked
6483+
val thisParamAsVarDef = js.VarDef(thisParamDef.name, thisParamDef.originalName,
6484+
thisParamDef.ptpe, thisParamDef.mutable, thisFormalCapture.get.ref)
6485+
(thisParamAsVarDef :: Nil, remainingMethodParams)
6486+
} else {
6487+
(Nil, methodParams)
6488+
}
6489+
}
64766490

6477-
val (patchedFormalArgs, paramsLocals) =
6478-
patchFunParamsWithBoxes(target, formalArgs, useParamsBeforeLambdaLift = true)
6491+
// After that, the args found in the `Function` node had better match the remaining method params
6492+
assert(remainingMethodParams.size == allArgs0.size,
6493+
s"Arity mismatch: $remainingMethodParams <-> $allArgs0 at $pos")
64796494

6480-
val patchedBody =
6481-
js.Block(paramsLocals :+ ensureResultBoxed(body, target))
6495+
/* Declare each method param as a VarDef, initialized to the corresponding arg.
6496+
* In practice, all the args are `This` nodes or `VarRef` nodes, so the
6497+
* optimizer will alias those VarDefs away.
6498+
* We do this because we have different Symbols, hence different
6499+
* encoded LocalIdents.
6500+
*/
6501+
val methodParamsAsVarDefs = for ((methodParam, arg) <- remainingMethodParams.zip(allArgs0)) yield {
6502+
js.VarDef(methodParam.name, methodParam.originalName, methodParam.ptpe,
6503+
methodParam.mutable, genExpr(arg))
6504+
}
64826505

6483-
val closure = js.Closure(
6506+
/* Adapt the params and result so that they are boxed from the outside.
6507+
* We need this because a `js.Closure` is always from `any`s to `any`.
6508+
*
6509+
* TODO In total we generate *3* locals for each original param: the
6510+
* patched ParamDef, the VarDef for the unboxed value, and a VarDef for
6511+
* the original parameter of the delambdafy target. In theory we only
6512+
* need 2: can we make it so?
6513+
*/
6514+
val formalArgs = paramTrees.map(p => genParamDef(p.symbol))
6515+
val (patchedFormalArgs, paramsLocals) =
6516+
patchFunParamsWithBoxes(target, formalArgs, useParamsBeforeLambdaLift = true)
6517+
val patchedBodyWithBox =
6518+
ensureResultBoxed(methodBody.get, target)
6519+
6520+
// Finally, assemble all the pieces
6521+
val fullClosureBody = js.Block(
6522+
paramsLocals :::
6523+
maybeThisParamAsVarDef :::
6524+
methodParamsAsVarDefs :::
6525+
patchedBodyWithBox ::
6526+
Nil
6527+
)
6528+
js.Closure(
64846529
arrow = true,
6485-
allFormalCaptures,
6530+
formalCaptures,
64866531
patchedFormalArgs,
64876532
restParam = None,
6488-
patchedBody,
6489-
allActualCaptures)
6533+
fullClosureBody,
6534+
actualCaptures
6535+
)
6536+
}
64906537

64916538
// Wrap the closure in the appropriate box for the SAM type
64926539
val funSym = originalFunction.tpe.typeSymbolDirect
64936540
if (isFunctionSymbol(funSym)) {
64946541
/* This is a scala.FunctionN. We use the existing AnonFunctionN
64956542
* wrapper.
64966543
*/
6497-
genJSFunctionToScala(closure, params.size)
6544+
genJSFunctionToScala(closure, paramTrees.size)
64986545
} else {
64996546
/* This is an arbitrary SAM type (can only happen in 2.12).
65006547
* We have to synthesize a class like LambdaMetaFactory would do on

0 commit comments

Comments
 (0)
0