8000 Codegen: Inline the target of `Function` nodes in their `js.Closure`s. by sjrd · Pull Request #5081 · scala-js/scala-js · GitHub
[go: up one dir, main page]

Skip to content

Codegen: Inline the target of Function nodes in their js.Closures. #5081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Codegen: Inline the target of Function nodes in their js.Closures.
The body of `Function` nodes always has a simple shape that calls a
helper method. We previously generated that call in the body of the
`js.Closure`, and marked the target method `@inline` so that the
optimizer would always inline it.

Instead, we now directly "inline" it from the codegen, by
generating the `js.MethodDef` right inside the `js.Closure` scope.

As is, this does not change the generated code. However, it may
speed up (cold) linker runs, since it will have less work to do.
Notably, it performs two fewer knowledge queries to find and inline
the target method. It also reduces the total amount of methods to
manipulate in the incremental analysis.

More importantly, this will be necessary later if we want to add
support for `async/await` or `function*/yield`. Indeed, for those,
we will need `await`/`yield` expressions to be lexically scoped
in the body of their enclosing closure. That won't work if they are
in the body of a separate helper method.
  • Loading branch information
sjrd committed Dec 1, 2024
commit 0d16b42e54d823dbd79a06fb6247730a01206831
263 changes: 155 additions & 108 deletions compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala
8000
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
val currentClassSym = new ScopedVar[Symbol]
private val fieldsMutatedInCurrentClass = new ScopedVar[mutable.Set[Name]]
private val generatedSAMWrapperCount = new ScopedVar[VarBox[Int]]
private val delambdafyTargetDefDefs = new ScopedVar[mutable.Map[Symbol, DefDef]]

def currentThisTypeNullable: jstpe.Type =
encodeClassType(currentClassSym)
Expand All @@ -185,10 +186,12 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
private val mutableLocalVars = new ScopedVar[mutable.Set[Symbol]]
private val mutatedLocalVars = new ScopedVar[mutable.Set[Symbol]]

private def withPerMethodBodyState[A](methodSym: Symbol)(body: => A): A = {
private def withPerMethodBodyState[A](methodSym: Symbol,
initThisLocalVarIdent: Option[js.LocalIdent] = None)(body: => A): A = {

withScopedVars(
currentMethodSym := methodSym,
thisLocalVarIdent := None,
thisLocalVarIdent := initThisLocalVarIdent,
enclosingLabelDefInfos := Map.empty,
isModuleInitialized := new VarBox(false),
undefinedDefaultParams := mutable.Set.empty,
Expand Down Expand Up @@ -236,6 +239,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
currentClassSym := clsSym,
fieldsMutatedInCurrentClass := mutable.Set.empty,
generatedSAMWrapperCount := new VarBox(0),
delambdafyTargetDefDefs := mutable.Map.empty,
currentMethodSym := null,
thisLocalVarIdent := null,
enclosingLabelDefInfos := null,
Expand Down Expand Up @@ -481,7 +485,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
withScopedVars(
currentClassSym := sym,
fieldsMutatedInCurrentClass := mutable.Set.empty,
generatedSAMWrapperCount := new VarBox(0)
generatedSAMWrapperCount := new VarBox(0),
delambdafyTargetDefDefs := mutable.Map.empty
) {
val tree = if (isJSType(sym)) {
if (!sym.isTraitOrInterface && isNonNativeJSClass(sym) &&
Expand Down Expand Up @@ -590,6 +595,34 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
}
}

private def collectDefDefs(impl: Template): List[DefDef] = {
val b = List.newBuilder[DefDef]

for (stat <- impl.body) {
stat match {
case stat: DefDef =>
if (stat.symbol.isDelambdafyTarget)
delambdafyTargetDefDefs += stat.symbol -> stat
else
b += stat

case EmptyTree | _:ValDef =>
()

case _ =>
abort(s"Unexpected tree in template: $stat at ${stat.pos}")
}
}

b.result()
}

private def consumeDelambdafyTarget(sym: Symbol): DefDef = {
delambdafyTargetDefDefs.remove(sym).getOrElse {
abort(s"Cannot resolve delambdafy target $sym at ${sym.pos}")
}
}

// Generate a class --------------------------------------------------------

/** Gen the IR ClassDef for a class definition (maybe a module class).
Expand Down Expand Up @@ -640,26 +673,13 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
val methodsBuilder = List.newBuilder[js.MethodDef]
val jsNativeMembersBuilder = List.newBuilder[js.JSNativeMemberDef]

def gen(tree: Tree): Unit = {
tree match {
case EmptyTree => ()
case Template(_, _, body) => body foreach gen
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to double check: This was dead code?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not strictly dead code, but it would only happen for the top-level "recursive" call (we cannot have a Template inside a Template, and would always happen for the top-level call (the input is statically typed as Template, so we know that). The new code directly iterates over impl.body instead.


case ValDef(mods, name, tpt, rhs) =>
() // fields are added via genClassFields()

case dd: DefDef =>
if (dd.symbol.hasAnnotation(JSNativeAnnotation))
jsNativeMembersBuilder += genJSNativeMemberDef(dd)
else
methodsBuilder ++= genMethod(dd)

case _ => abort("Illegal tree in gen of genClass(): " + tree)
}
for (dd <- collectDefDefs(impl)) {
if (dd.symbol.hasAnnotation(JSNativeAnnotation))
jsNativeMembersBuilder += genJSNativeMemberDef(dd)
else
methodsBuilder ++= genMethod(dd)
}

gen(impl)

val fields = if (!isHijacked) genClassFields(cd) else Nil

val jsNativeMembers = jsNativeMembersBuilder.result()
Expand Down Expand Up @@ -797,44 +817,31 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
val generatedMethods = new ListBuffer[js.MethodDef]
val dispatchMethodNames = new ListBuffer[JSName]

def gen(tree: Tree): Unit = {
tree match {
case EmptyTree => ()
case Template(_, _, body) => body foreach gen
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here: This was dead-code?


case ValDef(mods, name, tpt, rhs) =>
() // fields are added via genClassFields()

case dd: DefDef =>
val sym = dd.symbol
val exposed = isExposed(sym)

if (sym.isClassConstructor) {
constructorTrees += dd
} else if (exposed && sym.isAccessor && !sym.isLazy) {
/* Exposed accessors must not be emitted, since the field they
* access is enough.
*/
} else if (sym.hasAnnotation(JSOptionalAnnotation)) {
// Optional methods must not be emitted
} else {
generatedMethods ++= genMethod(dd)
for (dd <- collectDefDefs(cd.impl)) {
val sym = dd.symbol
val exposed = isExposed(sym)

// Collect the names of the dispatchers we have to create
if (exposed && !sym.isDeferred) {
/* We add symbols that we have to expose here. This way we also
* get inherited stuff that is implemented in this class.
*/
dispatchMethodNames += jsNameOf(sym)
}
}
if (sym.isClassConstructor) {
constructorTrees += dd
} else if (exposed && sym.isAccessor && !sym.isLazy) {
/* Exposed accessors must not be emitted, since the field they
* access is enough.
*/
} else if (sym.hasAnnotation(JSOptionalAnnotation)) {
// Optional methods must not be emitted
} else {
generatedMethods ++= genMethod(dd)

case _ => abort("Illegal tree in gen of genClass(): " + tree)
// Collect the names of the dispatchers we have to create
if (exposed && !sym.isDeferred) {
/* We add symbols that we have to expose here. This way we also
* get inherited stuff that is implemented in this class.
*/
dispatchMethodNames += jsNameOf(sym)
}
}
}

gen(cd.impl)

// Static members (exported from the companion object)
val (staticFields, staticExports) = {
/* Phase travel is necessary for non-top-level classes, because flatten
Expand Down Expand Up @@ -1158,20 +1165,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)

val classIdent = encodeClassNameIdent(sym)

// fill in class info builder
def gen(tree: Tree): List[js.MethodDef] = {
tree match {
case EmptyTree => Nil
case Template(_, _, body) => body.flatMap(gen)

case dd: DefDef =>
genMethod(dd).toList

case _ =>
abort("Illegal tree in gen of genInterface(): " + tree)
}
}
val generatedMethods = gen(cd.impl)
val generatedMethods = collectDefDefs(cd.impl).flatMap(genMethod(_))
val interfaces = genClassInterfaces(sym, forJSClass = false)

val allMemberDefs =
Expand Down Expand Up @@ -2045,11 +2039,13 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
*
* Other (normal) methods are emitted with `genMethodDef()`.
*/
def genMethodWithCurrentLocalNameScope(dd: DefDef): js.MethodDef = {
def genMethodWithCurrentLocalNameScope(dd: DefDef,
initThisLocalVarIdent: Option[js.LocalIdent] = None): js.MethodDef = {

implicit val pos = dd.pos
val sym = dd.symbol

withPerMethodBodyState(sym) {
withPerMethodBodyState(sym, initThisLocalVarIdent) {
val methodName = encodeMethodSym(sym)
val originalName = originalNameOfMethod(sym)

Expand Down Expand Up @@ -6409,13 +6405,16 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
*
* To translate them, we first construct a JS closure for the body:
* {{{
* lambda<this, capture1, ..., captureM>(
* _this, capture1, ..., captureM, arg1, ..., argN) {
* _this.someMethod(arg1, ..., argN, capture1, ..., captureM)
* arrow-lambda<_this = this, capture1: U1 = capture1, ..., captureM: UM = captureM>(
* arg1: any, ..., argN: any): any = {
* val arg1Unboxed: T1 = arg1.asInstanceOf[T1];
* ...
* val argNUnboxed: TN = argN.asInstanceOf[TN];
* // inlined body of `someMethod`, boxed
* }
* }}}
* In the closure, input params are unboxed before use, and the result of
* `someMethod()` is boxed back.
* the body of `someMethod` is boxed back.
*
* Then, we wrap that closure in a class satisfying the expected type.
* For Scala function types, we use the existing
Expand All @@ -6440,61 +6439,109 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
targetTree @ Select(receiver, _), allArgs0)) = originalFunction

val captureSyms =
global.delambdafy.FreeVarTraverser.freeVarsOf(originalFunction)
global.delambdafy.FreeVarTraverser.freeVarsOf(originalFunction).toList
val target = targetTree.symbol
val params = paramTrees.map(_.symbol)

val allArgs = allArgs0 map genExpr
val isTargetStatic = compileAsStaticMethod(target)

val formalCaptures = captureSyms.toList.map(genParamDef(_, pos))
val actualCaptures = formalCaptures.map(_.ref)

val formalArgs = params.map(genParamDef(_))

val (allFormalCaptures, body, allActualCaptures) = if (!compileAsStaticMethod(target)) {
val thisActualCapture = genExpr(receiver)
val thisFormalCapture = js.ParamDef(
freshLocalIdent("this")(receiver.pos), thisOriginalName,
thisActualCapture.tpe, mutable = false)(receiver.pos)
val thisCaptureArg = thisFormalCapture.ref
// Gen actual captures in the local name scope of the enclosing method
val actualCaptures: List[js.Tree] = {
val base = captureSyms.map(genVarRef(_))
if (isTargetStatic)
base
else
genExpr(receiver) :: base
}

val body = if (isJSType(receiver.tpe) && target.owner != ObjectClass) {
assert(isNonNativeJSClass(target.owner) && !isExposed(target),
s"A Function lambda is trying to call an exposed JS method ${target.fullName}")
genApplyJSClassMethod(thisCaptureArg, target, allArgs)
val closure: js.Closure = withNewLocalNameScope {
// Gen the formal capture params for the closure
val thisFormalCapture: Option[js.ParamDef] = if (isTargetStatic) {
None
} else {
genApplyMethodMaybeStatically(thisCaptureArg, target, allArgs)
Some(js.ParamDef(
freshLocalIdent("this")(receiver.pos), thisOriginalName,
toIRType(receiver.tpe), mutable = false)(receiver.pos))
}
val formalCaptures: List[js.ParamDef] =
thisFormalCapture.toList ::: captureSyms.map(genParamDef(_, pos))

(thisFormalCapture :: formalCaptures,
body, thisActualCapture :: actualCaptures)
} else {
val body = genApplyStatic(target, allArgs)
// Gen the inlined target method body
val genMethodDef = {
genMethodWithCurrentLocalNameScope(consumeDelambdafyTarget(target),
initThisLocalVarIdent = thisFormalCapture.map(_.name))
}
val js.MethodDef(methodFlags, _, _, methodParams, _, methodBody) = genMethodDef

(formalCaptures, body, actualCaptures)
}
/* If the target method was not supposed to be static, but genMethodDef
* turns out to be static, it means it is a non-exposed method of a JS
* class. The `this` param was turned into a regular param, for which
* we need a `js.VarDef`.
*/
val (maybeThisParamAsVarDef, remainingMethodParams) = {
if (methodFlags.namespace.isStatic && !isTargetStatic) {
val thisParamDef :: remainingMethodParams = methodParams: @unchecked
val thisParamAsVarDef = js.VarDef(thisParamDef.name, thisParamDef.originalName,
thisParamDef.ptpe, thisParamDef.mutable, thisFormalCapture.get.ref)
(thisParamAsVarDef :: Nil, remainingMethodParams)
} else {
(Nil, methodParams)
}
}

val (patchedFormalArgs, paramsLocals) =
patchFunParamsWithBoxes(target, formalArgs, useParamsBeforeLambdaLift = true)
// After that, the args found in the `Function` node had better match the remaining method params
assert(remainingMethodParams.size == allArgs0.size,
s"Arity mismatch: $remainingMethodParams <-> $allArgs0 at $pos")

val patchedBody =
js.Block(paramsLocals :+ ensureResultBoxed(body, target))
/* Declare each method param as a VarDef, initialized to the corresponding arg.
* In practice, all the args are `This` nodes or `VarRef` nodes, so the
* optimizer will alias those VarDefs away.
* We do this because we have different Symbols, hence different
* encoded LocalIdents.
*/
val methodParamsAsVarDefs = for ((methodParam, arg) <- remainingMethodParams.zip(allArgs0)) yield {
js.VarDef(methodParam.name, methodParam.originalName, methodParam.ptpe,
methodParam.mutable, genExpr(arg))
}

val closure = js.Closure(
/* Adapt the params and result so that they are boxed from the outside.
* We need this because a `js.Closure` is always from `any`s to `any`.
*
* TODO In total we generate *3* locals for each original param: the
* patched ParamDef, the VarDef for the unboxed value, and a VarDef for
* the original parameter of the delambdafy target. In theory we only
* need 2: can we make it so?
*/
val formalArgs = paramTrees.map(p => genParamDef(p.symbol))
val (patchedFormalArgs, paramsLocals) =
patchFunParamsWithBoxes(target, formalArgs, useParamsBeforeLambdaLift = true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a bit ugly IMO that we introduce another set of VarDefs here for unboxing.

Copy link
Member Author
@sjrd sjrd Dec 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will definitely change again with #5003. That said per se the additional VarDefs will still be there, since types don't always align.

The changes in this PR are neutral in terms of number of assignments that need to be processed by the optimizer: we traded Bindings for ParamDefs in a method call for Bindings to local VarDefs.

We can look into improving this a bit further after #5003 if we think it's worth it.

val patchedBodyWithBox =
ensureResultBoxed(methodBody.get, target)

// Finally, assemble all the pieces
val fullClosureBody = js.Block(
paramsLocals :::
maybeThisParamAsVarDef :::
methodParamsAsVarDefs :::
patchedBodyWithBox ::
Nil
)
js.Closure(
arrow = true,
allFormalCaptures,
formalCaptures,
patchedFormalArgs,
restParam = None,
patchedBody,
allActualCaptures)
fullClosureBody,
actualCaptures
)
}

// Wrap the closure in the appropriate box for the SAM type
val funSym = originalFunction.tpe.typeSymbolDirect
if (isFunctionSymbol(funSym)) {
/* This is a scala.FunctionN. We use the existing AnonFunctionN
* wrapper.
*/
genJSFunctionToScala(closure, params.size)
genJSFunctionToScala(closure, paramTrees.size)
} else {
/* This is an arbitrary SAM type (can only happen in 2.12).
* We have to synthesize a class like LambdaMetaFactory would do on
Expand Down
Loading
0