8000 Make the label of `Return` nodes mandatory. by sjrd · Pull Request #3323 · scala-js/scala-js · GitHub
[go: up one dir, main page]

Skip to content

Make the label of Return nodes mandatory. #3323

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 4, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1716,7 +1716,7 @@ abstract class GenJSCode extends plugins.PluginComponent

val bodyIsStat = resultIRType == jstpe.NoType

def genBody() = tree match {
def genBodyWithinReturnableScope(): js.Tree = tree match {
case Block(
(thisDef @ ValDef(_, nme.THIS, _, initialThis)) :: otherStats,
rhs) =>
Expand Down Expand Up @@ -1793,6 +1793,12 @@ abstract class GenJSCode extends plugins.PluginComponent
else genExpr(tree)
}

def genBody(): js.Tree = {
withNewReturnableScope(resultIRType) {
genBodyWithinReturnableScope()
}
}

if (!isNonNativeJSClass(currentClassSym) ||
isRawJSFunctionDef(currentClassSym)) {
val body = {
Expand Down Expand Up @@ -1958,7 +1964,7 @@ abstract class GenJSCode extends plugins.PluginComponent
js.Return(toIRType(expr.tpe) match {
case jstpe.NoType => js.Block(genStat(expr), js.Undefined())
case _ => genExpr(expr)
}, None)
}, getEnclosingReturnLabel())

case t: Try =>
genTry(t, isStat)
Expand Down Expand Up @@ -2267,9 +2273,9 @@ abstract class GenJSCode extends plugins.PluginComponent
js.While(js.BooleanLiteral(true), {
js.Labeled(labelIdent, jstpe.NoType, {
if (bodyType == jstpe.NoType)
js.Block(genStat(rhs), js.Return(js.Undefined(), Some(blockLabelIdent)))
js.Block(genStat(rhs), js.Return(js.Undefined(), blockLabelIdent))
else
js.Return(genExpr(rhs), Some(blockLabelIdent))
js.Return(genExpr(rhs), blockLabelIdent)
})
})
})
Expand Down Expand Up @@ -2634,7 +2640,7 @@ abstract class GenJSCode extends plugins.PluginComponent
* labeled block surrounding the match.
*/
countsOfReturnsToMatchEnd(sym) += 1
js.Return(genExpr(args.head), Some(encodeLabelSym(sym)))
js.Return(genExpr(args.head), encodeLabelSym(sym))
} else {
/* No other label apply should ever happen. If it does, then we
* have missed a pattern of LabelDef/LabelApply and some new
Expand Down Expand Up @@ -2715,7 +2721,7 @@ abstract class GenJSCode extends plugins.PluginComponent
}

// The actual jump (return(labelDefIdent) undefined;)
val jump = js.Return(js.Undefined(), Some(encodeLabelSym(sym)))
val jump = js.Return(js.Undefined(), encodeLabelSym(sym))

quadruplets match {
case Nil => jump
Expand Down Expand Up @@ -3078,7 +3084,7 @@ abstract class GenJSCode extends plugins.PluginComponent
def genJumpToElseClause(implicit pos: ir.Position): js.Tree = {
if (optElseClauseLabel.isEmpty)
optElseClauseLabel = Some(freshLocalIdent("default"))
js.Return(js.Undefined(), optElseClauseLabel)
js.Return(js.Undefined() 8000 , optElseClauseLabel.get)
}

for (caze @ CaseDef(pat, guard, body) <- cases) {
Expand Down Expand Up @@ -3160,10 +3166,9 @@ abstract class GenJSCode extends plugins.PluginComponent
val matchResultLabel = freshLocalIdent("matchResult")
val patchedClauses = for ((alts, body) <- clauses) yield {
implicit val pos = body.pos
val lab = Some(matchResultLabel)
val newBody =
if (isStat) js.Block(body, js.Return(js.Undefined(), lab))
else js.Return(body, lab)
if (isStat) js.Block(body, js.Return(js.Undefined(), matchResultLabel))
else js.Return(body, matchResultLabel)
(alts, newBody)
}
js.Labeled(matchResultLabel, resultType, js.Block(List(
Expand Down Expand Up @@ -3220,11 +3225,12 @@ abstract class GenJSCode extends plugins.PluginComponent
// Peculiar shape generated by `return x match {...}` - #2928
case Return(retExpr: LabelDef) if isCaseLabelDef(retExpr) =>
val result = translateMatch(retExpr)
val label = getEnclosingReturnLabel()
if (result.tpe == jstpe.NoType) {
// Could not actually reproduce this, but better be safe than sorry
js.Block(result, js.Return(js.Undefined(), None))
js.Block(result, js.Return(js.Undefined(), label))
} else {
js.Return(result, None)
js.Return(result, label)
}

case _ =>
Expand Down Expand Up @@ -3354,7 +3360,7 @@ abstract class GenJSCode extends plugins.PluginComponent

if (revAlts.size == returnCount - 1) {
def tryDropReturn(body: js.Tree): Option[js.Tree] = body match {
case jse.BlockOrAlone(prep, js.Return(result, Some(`label`))) =>
case jse.BlockOrAlone(prep, js.Return(result, `label`)) =>
Some(js.Block(prep :+ result)(body.pos))

case _ =>
Expand Down
28 changes: 27 additions & 1 deletion compiler/src/main/scala/org/scalajs/nscplugin/JSEncoding.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import scala.tools.nsc._
import org.scalajs.ir
import ir.{Trees => js, Types => jstpe}

import util.ScopedVar
import util.{ScopedVar, VarBox}
import ScopedVar.withScopedVars

/** Encoding of symbol names for JavaScript
Expand Down Expand Up @@ -59,13 +59,15 @@ trait JSEncoding extends SubComponent { self: GenJSCode =>
// Fresh local name generator ----------------------------------------------

private val usedLocalNames = new ScopedVar[mutable.Set[String]]
private val returnLabelName = new ScopedVar[VarBox[Option[String]]]
private val localSymbolNames = new ScopedVar[mutable.Map[Symbol, String]]
private val isReserved =
Set("arguments", "eval", ScalaJSEnvironmentName)

def withNewLocalNameScope[A](body: => A): A = {
withScopedVars(
usedLocalNames := mutable.Set.empty,
returnLabelName := null,
localSymbolNames := mutable.Map.empty
)(body)
}
Expand All @@ -77,6 +79,21 @@ trait JSEncoding extends SubComponent { self: GenJSCode =>
usedLocalNames += name
}

def withNewReturnableScope(tpe: jstpe.Type)(body: => js.Tree)(
implicit pos: ir.Position): js.Tree = {
withScopedVars(
returnLabelName := new VarBox(None)
) {
val inner = body
returnLabelName.get.value match {
case None =>
inner
case Some(labelName) =>
js.Labeled(js.Ident(labelName), tpe, inner)
}
}
}

private def freshName(base: String = "x"): String = {
var suffix = 1
var longName = base
Expand All @@ -97,6 +114,15 @@ trait JSEncoding extends SubComponent { self: GenJSCode =>
private def localSymbolName(sym: Symbol): String =
localSymbolNames.getOrElseUpdate(sym, freshName(sym.name.toString))

def getEnclosingReturnLabel()(implicit pos: ir.Position): js.Ident = {
val box = returnLabelName.get
if (box == null)
throw new IllegalStateException(s"No enclosing returnable scope at $pos")
if (box.value.isEmpty)
box.value = Some(freshName("_return"))
js.Ident(box.value.get)
}

// Encoding methods ----------------------------------------------------------

def encodeLabelSym(sym: Symbol)(implicit pos: Position): js.Ident = {
Expand Down
2 changes: 1 addition & 1 deletion ir/src/main/scala/org/scalajs/ir/Hashers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ object Hashers {
case Return(expr, label) =>
mixTag(TagReturn)
mixTree(expr)
mixOptIdent(label)
mixIdent(label)

case If(cond, thenp, elsep) =>
mixTag(TagIf)
Expand Down
13 changes: 4 additions & 9 deletions ir/src/main/scala/org/scalajs/ir/Printers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,10 @@ object Printers {
print(rhs)

case Return(expr, label) =>
if (label.isEmpty) {
print("return ")
print(expr)
} else {
print("return(")
print(label.get)
print(") ")
print(expr)
}
print("return@")
print(label)
print(" ")
print(expr)

case If(cond, BooleanLiteral(true), elsep) =>
print(cond)
Expand Down
4 changes: 2 additions & 2 deletions ir/src/main/scala/org/scalajs/ir/Serializers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ object Serializers {

case Return(expr, label) =>
writeByte(TagReturn)
writeTree(expr); writeOptIdent(label)
writeTree(expr); writeIdent(label)

case If(cond, thenp, elsep) =>
writeByte(TagIf)
Expand Down Expand Up @@ -839,7 +839,7 @@ object Serializers {
case TagBlock => Block(readTrees())
case TagLabeled => Labeled(readIdent(), readType(), readTree())
case TagAssign => Assign(readTree(), readTree())
case TagReturn => Return(readTree(), readOptIdent())
case TagReturn => Return(readTree(), readIdent())
case TagIf => If(readTree(), readTree(), readTree())(readType())
case TagWhile => While(readTree(), readTree())
case TagDoWhile => DoWhile(readTree(), readTree())
Expand Down
2 changes: 1 addition & 1 deletion ir/src/main/scala/org/scalajs/ir/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ object Trees {
val tpe = NoType // cannot be in expression position
}

case class Return(expr: Tree, label: Option[Ident])(
case class Return(expr: Tree, label: Ident)(
implicit val pos: Position) extends Tree {
val tpe = NothingType
}
Expand Down
3 changes: 1 addition & 2 deletions ir/src/test/scala/org/scalajs/ir/PrintersTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,7 @@ class PrintersTest {
}

@Test def printReturn(): Unit = {
assertPrintEquals("return 5", Return(i(5), None))
assertPrintEquals("return(lab) 5", Return(i(5), Some("lab")))
assertPrintEquals("return@lab 5", Return(i(5), "lab"))
}

@Test def printIf(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,10 +462,6 @@ private[emitter] class FunctionEmitter(jsGen: JSGen) {

val env = env0.withParams(params)

val withReturn =
if (isStat) body
else Return(body, None)

val translateRestParam =
if (esFeatures.useECMAScript2015) false
else params.nonEmpty && params.last.rest
Expand All @@ -477,13 +473,17 @@ private[emitter] class FunctionEmitter(jsGen: JSGen) {
val newParams =
(if (translateRestParam) params.init else params).map(transformParamDef)

val newBody = transformStat(withReturn, Set.empty)(env) match {
val newBody =
if (isStat) transformStat(body, Set.empty)(env)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So basically this line would have to hijack the labeled block. Fair enough.

else pushLhsInto(Lhs.ReturnFromFunction, body, Set.empty)(env)

val cleanedNewBody = newBody match {
case js.Block(stats :+ js.Return(js.Undefined())) => js.Block(stats)
case other => other
}

js.Function(arrow && useArrowFunctions, newParams,
js.Block(extractRestParam, newBody))
js.Block(extractRestParam, cleanedNewBody))
}

private def makeExtractRestParam(params: List[ParamDef])(
Expand Down Expand Up @@ -1455,9 +1455,9 @@ private[emitter] class FunctionEmitter(jsGen: JSGen) {
doVarDef(name, tpe, mutable, rhs)
case Lhs.Assign(lhs) =>
doAssign(lhs, rhs)
case Lhs.Return(None) =>
case Lhs.ReturnFromFunction =>
js.Return(transformExpr(rhs, env.expectedReturnType))
case Lhs.Return(Some(l)) =>
case Lhs.Return(l) =>
doReturnToLabel(l)
}

Expand Down Expand Up @@ -1485,10 +1485,10 @@ private[emitter] class FunctionEmitter(jsGen: JSGen) {
js.Block(varDef, assign)
}

case Lhs.Return(None) =>
case Lhs.ReturnFromFunction =>
throw new AssertionError("Cannot return a record value.")

case Lhs.Return(Some(l)) =>
case Lhs.Return(l) =>
doReturnToLabel(l)
}

Expand Down Expand Up @@ -2792,7 +2792,11 @@ private object FunctionEmitter {
case class Assign(lhs: Tree) extends Lhs
case class VarDef(name: Ident, tpe: Type, mutable: Boolean) extends Lhs

case class Return(label: Option[Ident]) extends Lhs {
case object ReturnFromFunction extends Lhs {
override def hasNothingType: Boolean = true
}

case class Return(label: Ident) extends Lhs {
override def hasNothingType: Boolean = true
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ private final class IRChecker(unit: LinkingUnit,
typecheckExpect(body, env.withLabeledReturnType(label.name, tpe), tpe)

case Return(expr, label) =>
env.returnTypes.get(label.map(_.name)).fold[Unit] {
env.returnTypes.get(label.name).fold[Unit] {
reportError(s"Cannot return to label $label.")
typecheckExpr(expr, env)
} { returnType =>
Expand Down Expand Up @@ -1277,7 +1277,7 @@ private final class IRChecker(unit: LinkingUnit,
/** Local variables in scope (including through closures). */
val locals: Map[String, LocalDef],
/** Return types by label. */
val returnTypes: Map[Option[String], Type],
val returnTypes: Map[String, Type],
/** Whether we're in a constructor of the class */
val inConstructor: Boolean
) {
Expand All @@ -1291,13 +1291,9 @@ private final class IRChecker(unit: LinkingUnit,
this.inConstructor)
}

def withReturnType(returnType: Type): Env =
new Env(this.thisTpe, this.locals,
returnTypes + (None -> returnType), this.inConstructor)

def withLabeledReturnType(label: String, returnType: Type): Env =
new Env(this.thisTpe, this.locals,
returnTypes + (Some(label) -> returnType), this.inConstructor)
returnTypes + (label -> returnType), this.inConstructor)

def withInConstructor(inConstructor: Boolean): Env =
new Env(this.thisTpe, this.locals, this.returnTypes, inConstructor)
Expand All @@ -1313,9 +1309,7 @@ private final class IRChecker(unit: LinkingUnit,
val paramLocalDefs =
for (p @ ParamDef(ident, tpe, mutable, _) <- allParams)
yield ident.name -> LocalDef(ident.name, tpe, mutable)(p.pos)
new Env(thisType, paramLocalDefs.toMap,
Map(None -> (if (resultType == NoType) AnyType else resultType)),
isConstructor)
new Env(thisType, paramLocalDefs.toMap, Map.empty, isConstructor)
}
}

Expand Down
Loading
0