8000 Make the label of `Return` nodes mandatory. · scala-js/scala-js@c61d78d · GitHub
[go: up one dir, main page]

Skip to content

Commit c61d78d

Browse files
committed
Make the label of Return nodes mandatory.
A `Return` with a `None` label implicitly meant returning from the enclosing `Closure` or `MethodDef`. In the spirit of solving all problems with `Labeled` blocks, we can get rid of this special-case by introducing an explicit function-wide `Labeled` block when necessary. Instead of emitting: def foo__I(): int = { ... if (c) return 5 ... } we would now use def foo__I(): int = { _return[int]: { ... if (c) return@_return 5 ... } } which is equivalent. This complicates somewhat the compiler back-end, since it now needs to take care of adding that label if necessary (it could *always* add one, but that would not be nice, and defeat most of the purpose of this change), and it does not really simplify anything later in the pipeline either. Moreover, we observe a slight regression of quality of generated .js code in some rare cases: when there is a `return` in a `void`-returning method, and that method is not inlined, the `FunctionEmitter` cannot replace the `break _return` by a `return` in JS. This is due to a technicality in `FunctionEmitter`, and this restriction could be lifted in the future. The main benefits are more regular IR, as well as a potential performance improvement in the optimizer. Indeed, previously, every time we inlined a method, we needed to preemptively summon a lot of complicated machinery to introduce a label in case the body contains a `return`. In most cases, it doesn't, and we detect after the fact that the label is useless and can be removed. Instead, we can now straightforwardly process the body without any extra action. In the rare cases where the body did contain a `return`, the optimizer will encounter the `Labeled` block, and spawn the machinery then. In all the other cases, the machinery is completely by-passed.
1 parent 72a6809 commit c61d78d

File tree

10 files changed

+103
-82
lines changed

10 files changed

+103
-82
lines changed

compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1716,7 +1716,7 @@ abstract class GenJSCode extends plugins.PluginComponent
17161716

17171717
val bodyIsStat = resultIRType == jstpe.NoType
17181718

1719-
def genBody() = tree match {
1719+
def genBodyWithinReturnableScope(): js.Tree = tree match {
17201720
case Block(
17211721
(thisDef @ ValDef(_, nme.THIS, _, initialThis)) :: otherStats,
17221722
rhs) =>
@@ -1793,6 +1793,12 @@ abstract class GenJSCode extends plugins.PluginComponent
17931793
else genExpr(tree)
17941794
}
17951795

1796+
def genBody(): js.Tree = {
1797+
withNewReturnableScope(resultIRType) {
1798+
genBodyWithinReturnableScope()
1799+
}
1800+
}
1801+
17961802
if (!isNonNativeJSClass(currentClassSym) ||
17971803
isRawJSFunctionDef(currentClassSym)) {
17981804
val body = {
@@ -1958,7 +1964,7 @@ abstract class GenJSCode extends plugins.PluginComponent
19581964
js.Return(toIRType(expr.tpe) match {
19591965
case jstpe.NoType => js.Block(genStat(expr), js.Undefined())
19601966
case _ => genExpr(expr)
1961-
}, None)
1967+
}, getEnclosingReturnLabel())
19621968

19631969
case t: Try =>
19641970
genTry(t, isStat)
@@ -2267,9 +2273,9 @@ abstract class GenJSCode extends plugins.PluginComponent
22672273
js.While(js.BooleanLiteral(true), {
22682274
js.Labeled(labelIdent, jstpe.NoType, {
22692275
if (bodyType == jstpe.NoType)
2270-
js.Block(genStat(rhs), js.Return(js.Undefined(), Some(blockLabelIdent)))
2276+
js.Block(genStat(rhs), js.Return(js.Undefined(), blockLabelIdent))
22712277
else
2272-
js.Return(genExpr(rhs), Some(blockLabelIdent))
2278+
js.Return(genExpr(rhs), blockLabelIdent)
22732279
})
22742280
})
22752281
})
@@ -2634,7 +2640,7 @@ abstract class GenJSCode extends plugins.PluginComponent
26342640
* labeled block surrounding the match.
26352641
*/
26362642
countsOfReturnsToMatchEnd(sym) += 1
2637-
js.Return(genExpr(args.head), Some(encodeLabelSym(sym)))
2643+
js.Return(genExpr(args.head), encodeLabelSym(sym))
26382644
} else {
26392645
/* No other label apply should ever happen. If it does, then we
26402646
* have missed a pattern of LabelDef/LabelApply and some new
@@ -2715,7 +2721,7 @@ abstract class GenJSCode extends plugins.PluginComponent
27152721
}
27162722

27172723
// The actual jump (return(labelDefIdent) undefined;)
2718-
val jump = js.Return(js.Undefined(), Some(encodeLabelSym(sym)))
2724+
val jump = js.Return(js.Undefined(), encodeLabelSym(sym))
27192725

27202726
quadruplets match {
27212727
case Nil => jump
@@ -3078,7 +3084,7 @@ abstract class GenJSCode extends plugins.PluginComponent
30783084
def genJumpToElseClause(implicit pos: ir.Position): js.Tree = {
30793085
if (optElseClauseLabel.isEmpty)
30803086
optElseClauseLabel = Some(freshLocalIdent("default"))
3081-
js.Return(js.Undefined(), optElseClauseLabel)
3087+
js.Return(js.Undefined(), optElseClauseLabel.get)
30823088
}
30833089

30843090
for (caze @ CaseDef(pat, guard, body) <- cases) {
@@ -3160,10 +3166,9 @@ abstract class GenJSCode extends plugins.PluginComponent
31603166
val matchResultLabel = freshLocalIdent("matchResult")
31613167
val patchedClauses = for ((alts, body) <- clauses) yield {
31623168
implicit val pos = body.pos
3163-
val lab = Some(matchResultLabel)
31643169
val newBody =
3165-
if (isStat) js.Block(body, js.Return(js.Undefined(), lab))
3166-
else js.Return(body, lab)
3170+
if (isStat) js.Block(body, js.Return(js.Undefined(), matchResultLabel))
3171+
else js.Return(body, matchResultLabel)
31673172
(alts, newBody)
31683173
}
31693174
js.Labeled(matchResultLabel, resultType, js.Block(List(
@@ -3220,11 +3225,12 @@ abstract class GenJSCode extends plugins.PluginComponent
32203225
// Peculiar shape generated by `return x match {...}` - #2928
32213226
case Return(retExpr: LabelDef) if isCaseLabelDef(retExpr) =>
32223227
val result = translateMatch(retExpr)
3228+
val label = getEnclosingReturnLabel()
32233229
if (result.tpe == jstpe.NoType) {
32243230
// Could not actually reproduce this, but better be safe than sorry
3225-
js.Block(result, js.Return(js.Undefined(), None))
3231+
js.Block(result, js.Return(js.Undefined(), label))
32263232
} else {
3227-
js.Return(result, None)
3233+
js.Return(result, label)
32283234
}
32293235

32303236
case _ =>
@@ -3354,7 +3360,7 @@ abstract class GenJSCode extends plugins.PluginComponent
33543360

33553361
if (revAlts.size == returnCount - 1) {
33563362
def tryDropReturn(body: js.Tree): Option[js.Tree] = body match {
3357-
case jse.BlockOrAlone(prep, js.Return(result, Some(`label`))) =>
3363+
case jse.BlockOrAlone(prep, js.Return(result, `label`)) =>
33583364
Some(js.Block(prep :+ result)(body.pos))
33593365

33603366
case _ =>

compiler/src/main/scala/org/scalajs/nscplugin/JSEncoding.scala

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import scala.tools.nsc._
1212
import org.scalajs.ir
1313
import ir.{Trees => js, Types => jstpe}
1414

15-
import util.ScopedVar
15+
import util.{ScopedVar, VarBox}
1616
import ScopedVar.withScopedVars
1717

1818
/** Encoding of symbol names for JavaScript
@@ -59,13 +59,15 @@ trait JSEncoding extends SubComponent { self: GenJSCode =>
5959
// Fresh local name generator ----------------------------------------------
6060

6161
private val usedLocalNames = new ScopedVar[mutable.Set[String]]
62+
private val returnLabelName = new ScopedVar[VarBox[Option[String]]]
6263
private val localSymbolNames = new ScopedVar[mutable.Map[Symbol, String]]
6364
private val isReserved =
6465
Set("arguments", "eval", ScalaJSEnvironmentName)
6566

6667
def withNewLocalNameScope[A](body: => A): A = {
6768
withScopedVars(
6869
usedLocalNames := mutable.Set.empty,
70+
returnLabelName := null,
6971
localSymbolNames := mutable.Map.empty
7072
)(body)
7173
}
@@ -77,6 +79,21 @@ trait JSEncoding extends SubComponent { self: GenJSCode =>
7779
usedLocalNames += name
7880
}
7981

82+
def withNewReturnableScope(tpe: jstpe.Type)(body: => js.Tree)(
83+
implicit pos: ir.Position): js.Tree = {
84+
withScopedVars(
85+
returnLabelName := new VarBox(None)
86+
) {
87+
val inner = body
88+
returnLabelName.get.value match {
89+
case None =>
90+
inner
91+
case Some(labelName) =>
92+
js.Labeled(js.Ident(labelName), tpe, inner)
93+
}
94+
}
95+
}
96+
8097
private def freshName(base: String = "x"): String = {
8198
var suffix = 1
8299
var longName = base
@@ -97,6 +114,15 @@ trait JSEncoding extends SubComponent { self: GenJSCode =>
97114
private def localSymbolName(sym: Symbol): String =
98115
localSymbolNames.getOrElseUpdate(sym, freshName(sym.name.toString))
99116

117+
def getEnclosingReturnLabel()(implicit pos: ir.Position): js.Ident = {
118+
val box = returnLabelName.get
119+
if (box == null)
120+
throw new IllegalStateException(s"No enclosing returnable scope at $pos")
121+
if (box.value.isEmpty)
122+
box.value = Some(freshName("_return"))
123+
js.Ident(box.value.get)
124+
}
125+
100126
// Encoding methods ----------------------------------------------------------
101127

102128
def encodeLabelSym(sym: Symbol)(implicit pos: Position): js.Ident = {

ir/src/main/scala/org/scalajs/ir/Hashers.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ object Hashers {
120120
case Return(expr, label) =>
121121
mixTag(TagReturn)
122122
mixTree(expr)
123-
mixOptIdent(label)
123+
mixIdent(label)
124124

125125
case If(cond, thenp, elsep) =>
126126
mixTag(TagIf)

ir/src/main/scala/org/scalajs/ir/Printers.scala

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -164,15 +164,10 @@ object Printers {
164164
print(rhs)
165165

166166
case Return(expr, label) =>
167-
if (label.isEmpty) {
168-
print("return ")
169-
print(expr)
170-
} else {
171-
print("return(")
172-
print(label.get)
173-
print(") ")
174-
print(expr)
175-
}
167+
print("return@")
168+
print(label)
169+
print(" ")
170+
print(expr)
176171

177172
case If(cond, BooleanLiteral(true), elsep) =>
178173
print(cond)

ir/src/main/scala/org/scalajs/ir/Serializers.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ object Serializers {
155155

156156
case Return(expr, label) =>
157157
writeByte(TagReturn)
158-
writeTree(expr); writeOptIdent(label)
158+
writeTree(expr); writeIdent(label)
159159

160160
case If(cond, thenp, elsep) =>
161161
writeByte(TagIf)
@@ -839,7 +839,7 @@ object Serializers {
839839
case TagBlock => Block(readTrees())
840840
case TagLabeled => Labeled(readIdent(), readType(), readTree())
841841
case TagAssign => Assign(readTree(), readTree())
842-
case TagReturn => Return(readTree(), readOptIdent())
842+
case TagReturn => Return(readTree(), readIdent())
843843
case TagIf => If(readTree(), readTree(), readTree())(readType())
844844
case TagWhile => While(readTree(), readTree())
845845
case TagDoWhile => DoWhile(readTree(), readTree())

ir/src/main/scala/org/scalajs/ir/Trees.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ object Trees {
177177
val tpe = NoType // cannot be in expression position
178178
}
179179

180-
case class Return(expr: Tree, label: Option[Ident])(
180+
case class Return(expr: Tree, label: Ident)(
181181
implicit val pos: Position) extends Tree {
182182
val tpe = NothingType
183183
}

ir/src/test/scala/org/scalajs/ir/PrintersTest.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,7 @@ class PrintersTest {
147147
}
148148

149149
@Test def printReturn(): Unit = {
150-
assertPrintEquals("return 5", Return(i(5), None))
151-
assertPrintEquals("return(lab) 5", Return(i(5), Some("lab")))
150+
assertPrintEquals("return@lab 5", Return(i(5), "lab"))
152151
}
153152

154153
@Test def printIf(): Unit = {

linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -462,10 +462,6 @@ private[emitter] class FunctionEmitter(jsGen: JSGen) {
462462

463463
val env = env0.withParams(params)
464464

465-
val withReturn =
466-
if (isStat) body
467-
else Return(body, None)
468-
469465
val translateRestParam =
470466
if (esFeatures.useECMAScript2015) false
471467
else params.nonEmpty && params.last.rest
@@ -477,13 +473,17 @@ private[emitter] class FunctionEmitter(jsGen: JSGen) {
477473
val newParams =
478474
(if (translateRestParam) params.init else params).map(transformParamDef)
479475

480-
val newBody = transformStat(withReturn, Set.empty)(env) match {
476+
val newBody =
477+
if (isStat) transformStat(body, Set.empty)(env)
478+
else pushLhsInto(Lhs.ReturnFromFunction, body, Set.empty)(env)
479+
480+
val cleanedNewBody = newBody match {
481481
case js.Block(stats :+ js.Return(js.Undefined())) => js.Block(stats)
482482
case other => other
483483
}
484484

485485
js.Function(arrow && useArrowFunctions, newParams,
486-
js.Block(extractRestParam, newBody))
486+
js.Block(extractRestParam, cleanedNewBody))
487487
}
488488

489489
private def makeExtractRestParam(params: List[ParamDef])(
@@ -1455,9 +1455,9 @@ private[emitter] class FunctionEmitter(jsGen: JSGen) {
14551455
doVarDef(name, tpe, mutable, rhs)
14561456
case Lhs.Assign(lhs) =>
14571457
doAssign(lhs, rhs)
1458-
case Lhs.Return(None) =>
1458+
case Lhs.ReturnFromFunction =>
14591459
js.Return(transformExpr(rhs, env.expectedReturnType))
1460-
case Lhs.Return(Some(l)) =>
1460+
case Lhs.Return(l) =>
14611461
doReturnToLabel(l)
14621462
}
14631463

@@ -1485,10 +1485,10 @@ private[emitter] class FunctionEmitter(jsGen: JSGen) {
14851485
js.Block(varDef, assign)
14861486
}
14871487

1488-
case Lhs.Return(None) =>
1488+
case Lhs.ReturnFromFunction =>
14891489
throw new AssertionError("Cannot return a record value.")
14901490

1491-
case Lhs.Return(Some(l)) =>
1491+
case Lhs.Return(l) =>
14921492
doReturnToLabel(l)
14931493
}
14941494

@@ -2792,7 +2792,11 @@ private object FunctionEmitter {
27922792
case class Assign(lhs: Tree) extends Lhs
27932793
case class VarDef(name: Ident, tpe: Type, mutable: Boolean) extends Lhs
27942794

2795-
case class Return(label: Option[Ident]) extends Lhs {
2795+
case object ReturnFromFunction extends Lhs {
2796+
override def hasNothingType: Boolean = true
2797+
}
2798+
2799+
case class Return(label: Ident) extends Lhs {
27962800
override def hasNothingType: Boolean = true
27972801
}
27982802

linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ private final class IRChecker(unit: LinkingUnit,
710710
typecheckExpect(body, env.withLabeledReturnType(label.name, tpe), tpe)
711711

712712
case Return(expr, label) =>
713-
env.returnTypes.get(label.map(_.name)).fold[Unit] {
713+
env.returnTypes.get(label.name).fold[Unit] {
714714
reportError(s"Cannot return to label $label.")
715715
typecheckExpr(expr, env)
716716
} { returnType =>
@@ -1277,7 +1277,7 @@ private final class IRChecker(unit: LinkingUnit,
12771277
/** Local variables in scope (including through closures). */
12781278
val locals: Map[String, LocalDef],
12791279
/** Return types by label. */
1280-
val returnTypes: Map[Option[String], Type],
1280+
val returnTypes: Map[String, Type],
12811281
/** Whether we're in a constructor of the class */
12821282
val inConstructor: Boolean
12831283
) {
@@ -1291,13 +1291,9 @@ private final class IRChecker(unit: LinkingUnit,
12911291
this.inConstructor)
12921292
}
12931293

1294-
def withReturnType(returnType: Type): Env =
1295-
new Env(this.thisTpe, this.locals,
1296-
returnTypes + (None -> returnType), this.inConstructor)
1297-
12981294
def withLabeledReturnType(label: String, returnType: Type): Env =
12991295
new Env(this.thisTpe, this.locals,
1300-
returnTypes + (Some(label) -> returnType), this.inConstructor)
1296+
returnTypes + (label -> returnType), this.inConstructor)
13011297

13021298
def withInConstructor(inConstructor: Boolean): Env =
13031299
new Env(this.thisTpe, this.locals, this.returnTypes, inConstructor)
@@ -1313,9 +1309,7 @@ private final class IRChecker(unit: LinkingUnit,
13131309
val paramLocalDefs =
13141310
for (p @ ParamDef(ident, tpe, mutable, _) <- allParams)
13151311
yield ident.name -> LocalDef(ident.name, tpe, mutable)(p.pos)
1316-
new Env(thisType, paramLocalDefs.toMap,
1317-
Map(None -> (if (resultType == NoType) AnyType else resultType)),
1318-
isConstructor)
1312+
new Env(thisType, paramLocalDefs.toMap, Map.empty, isConstructor)
13191313
}
13201314
}
13211315

0 commit comments

Comments
 (0)
0