8000 Merge pull request #4532 from sjrd/ir-match-more-literals · renowncoder/scala-js@e37a082 · GitHub
[go: up one dir, main page]

Skip to content

Commit e37a082

Browse files
authored
Merge pull request scala-js#4532 from sjrd/ir-match-more-literals
Fix scala-js#3843: Allow Match nodes to be used for string switches.
2 parents 26beb3f + 808b054 commit e37a082

File tree

7 files changed

+88
-54
lines changed

7 files changed

+88
-54
lines changed

compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3494,7 +3494,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
34943494
body.symbol
34953495
}.getOrElse(NoSymbol)
34963496

3497-
var clauses: List[(List[js.Tree], js.Tree)] = Nil
3497+
var clauses: List[(List[js.MatchableLiteral], js.Tree)] = Nil
34983498
var optElseClause: Option[js.Tree] = None
34993499
var optElseClauseLabel: Option[js.LabelIdent] = None
35003500

@@ -3545,9 +3545,19 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
35 8000 453545
genStatOrExpr(body, isStat)
35463546
}
35473547

3548+
def invalidCase(tree: Tree): Nothing =
3549+
abort(s"Invalid case in alternative in switch-like pattern match: $tree at: ${tree.pos}")
3550+
3551+
def genMatchableLiteral(tree: Literal): js.MatchableLiteral = {
3552+
genExpr(tree) match {
3553+
case matchableLiteral: js.MatchableLiteral => matchableLiteral
3554+
case otherExpr => invalidCase(tree)
3555+
}
3556+
}
3557+
35483558
pat match {
35493559
case lit: Literal =>
3550-
clauses = (List(genExpr(lit)), genBody(body)) :: clauses
3560+
clauses = (List(genMatchableLiteral(lit)), genBody(body)) :: clauses
35513561
case Ident(nme.WILDCARD) =>
35523562
optElseClause = Some(body match {
35533563
case LabelDef(_, Nil, rhs) if hasSynthCaseSymbol(body) =>
@@ -3558,16 +3568,13 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
35583568
case Alternative(alts) =>
35593569
val genAlts = {
35603570
alts map {
3561-
case lit: Literal => genExpr(lit)
3562-
case _ =>
3563-
abort("Invalid case in alternative in switch-like pattern match: " +
3564-
tree + " at: " + tree.pos)
3571+
case lit: Literal => genMatchableLiteral(lit)
3572+
case _ => invalidCase(tree)
35653573
}
35663574
}
35673575
clauses = (genAlts, genBody(body)) :: clauses
35683576
case _ =>
3569-
abort("Invalid case statement in switch-like pattern match: " +
3570-
tree + " at: " + (tree.pos))
3577+
invalidCase(tree)
35713578
}
35723579
}
35733580

@@ -3580,12 +3587,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
35803587
* case is a typical product of `match`es that are full of
35813588
* `case n if ... =>`, which are used instead of `if` chains for
35823589
* convenience and/or readability.
3583-
*
3584-
* When no optimization applies, and any of the case values is not a
3585-
* literal int, we emit a series of `if..else` instead of a `js.Match`.
3586-
* This became necessary in 2.13.2 with strings and nulls.
35873590
*/
3588-
def buildMatch(cases: List[(List[js.Tree], js.Tree)],
3591+
def buildMatch(cases: List[(List[js.MatchableLiteral], js.Tree)],
35893592
default: js.Tree, tpe: jstpe.Type): js.Tree = {
35903593

35913594
def isInt(tree: js.Tree): Boolean = tree.tpe == jstpe.IntType
@@ -3609,32 +3612,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
36093612
js.If(js.BinaryOp(op, genSelector, uniqueAlt), caseRhs, default)(tpe)
36103613

36113614
case _ =>
3612-
if (isInt(genSelector) &&
3613-
cases.forall(_._1.forall(_.isInstanceOf[js.IntLiteral]))) {
3614-
// We have int literals only: use a js.Match
3615-
val intCases = cases.asInstanceOf[List[(List[js.IntLiteral], js.Tree)]]
3616-
js.Match(genSelector, intCases, default)(tpe)
3617-
} else {
3618-
// We have other stuff: generate an if..else chain
3619-
val (tempSelectorDef, tempSelectorRef) = genSelector match {
3620-
case varRef: js.VarRef =>
3621-
(js.Skip(), varRef)
3622-
case _ =>
3623-
val varDef = js.VarDef(freshLocalIdent(), NoOriginalName,
3624-
genSelector.tpe, mutable = false, genSelector)
3625-
(varDef, varDef.ref)
3626-
}
3627-
val ifElseChain = cases.foldRight(default) { (caze, elsep) =>
3628-
val conds = caze._1.map { caseValue =>
3629-
js.BinaryOp(js.BinaryOp.===, tempSelectorRef, caseValue)
3630-
}
3631-
val cond = conds.reduceRight[js.Tree] { (left, right) =>
3632-
js.If(left, js.BooleanLiteral(true), right)(jstpe.BooleanType)
3633-
}
3634-
js.If(cond, caze._2, elsep)(tpe)
3635-
}
3636-
js.Block(tempSelectorDef, ifElseChain)
3637-
}
3615+
// We have more than one case: use a js.Match
3616+
js.Match(genSelector, cases, default)(tpe)
36383617
}
36393618
}
36403619

ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import scala.util.matching.Regex
1818

1919
object ScalaJSVersions extends VersionChecks(
2020
current = "1.7.0-SNAPSHOT",
21-
binaryEmitted = "1.6"
21+
binaryEmitted = "1.7-SNAPSHOT"
2222
)
2323

2424
/** Helper class to allow for testing of logic. */

ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1087,7 +1087,7 @@ object Serializers {
10871087
case TagThrow => Throw(readTree())
10881088
case TagMatch =>
10891089
Match(readTree(), List.fill(readInt()) {
1090-
(readTrees().map(_.asInstanceOf[IntLiteral]), readTree())
1090+
(readTrees().map(_.asInstanceOf[MatchableLiteral]), readTree())
10911091
}, readTree())(readType())
10921092
case TagDebugger => Debugger()
10931093

ir/shared/src/main/scala/org/scalajs/ir/Trees.scala

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,30 @@ object Trees {
199199
}
200200

201201
/** A break-free switch (without fallthrough behavior).
202+
*
202203
* Unlike a JavaScript switch, it can be used in expression position.
203-
* It supports alternatives explicitly (hence the `List[IntLiteral]` in
204-
* cases), whereas in a switch one would use the fallthrough behavior to
204+
* It supports alternatives explicitly (hence the `List[MatchableLiteral]`
205+
* in cases), whereas in a switch one would use the fallthrough behavior to
205206
* implement alternatives.
206207
* (This is not a pattern matching construct like in Scala.)
208+
*
209+
* The selector must be either an `int` (`IntType`) or a `java.lang.String`.
210+
* The cases can be any `MatchableLiteral`, even if they do not make sense
211+
* for the type of the selecter (they simply will never match).
212+
*
213+
* Because `+0.0 === -0.0` in JavaScript, and because those semantics are
214+
* used in a JS `switch`, we have to prevent the selector from ever being
215+
* `-0.0`. Otherwise, it would be matched by a `case IntLiteral(0)`. At the
216+
* same time, we must allow at least `int` and `java.lang.String` to support
217+
* all switchable `match`es from Scala. Since the latter two have no common
218+
* super type that does not allow `-0.0`, we really have to special-case
219+
* those two types.
220+
*
221+
* This is also why we restrict `MatchableLiteral`s to `IntLiteral`,
222+
* `StringLiteral` and `Null`. Allowing more cases would only make IR
223+
* checking more complicated, without bringing any added value.
207224
*/
208-
sealed case class Match(selector: Tree, cases: List[(List[IntLiteral], Tree)],
225+
sealed case class Match(selector: Tree, cases: List[(List[MatchableLiteral], Tree)],
209226
default: Tree)(val tpe: Type)(implicit val pos: Position) extends Tree
210227

211228
sealed case class Debugger()(implicit val pos: Position) extends Tree {
@@ -866,11 +883,23 @@ object Trees {
866883
/** Marker for literals. Literals are always pure. */
867884
sealed trait Literal extends Tree
868885

886+
/** Marker for literals that can be used in a [[Match]] case.
887+
*
888+
* Matchable literals are:
889+
*
890+
* - `IntLiteral`
891+
* - `StringLiteral`
892+
* - `Null`
893+
*
894+
* See [[Match]] for the rationale about that specific set.
895+
*/
896+
sealed trait MatchableLiteral extends Literal
897+
869898
sealed case class Undefined()(implicit val pos: Position) extends Literal {
870899
val tpe = UndefType
871900
}
872901

873-
sealed case class Null()(implicit val pos: Position) extends Literal {
902+
sealed case class Null()(implicit val pos: Position) extends MatchableLiteral {
874903
val tpe = NullType
875904
}
876905

@@ -895,7 +924,7 @@ object Trees {
895924
}
896925

897926
sealed case class IntLiteral(value: Int)(
898-
implicit val pos: Position) extends Literal {
927+
implicit val pos: Position) extends MatchableLiteral {
899928
val tpe = IntType
900929
}
901930

@@ -915,7 +944,7 @@ object Trees {
915944
}
916945

917946
sealed case class StringLiteral(value: String)(
918-
implicit val pos: Position) extends Literal {
947+
implicit val pos: Position) extends MatchableLiteral {
919948
val tpe = StringType
920949
}
921950

linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1663,7 +1663,7 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) {
16631663
val newCases = {
16641664
for {
16651665
(values, body) <- cases
1666-
newValues = values.map(v => js.IntLiteral(v.value)(v.pos))
1666+
newValues = values.map(transformExprNoChar(_))
16671667
// add the break statement
16681668
newBody = js.Block(
16691669
pushLhsInto(newLhs, body, tailPosLabels),

linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -744,9 +744,16 @@ private final class IRChecker(unit: LinkingUnit, logger: Logger) {
744744
typecheckExpr(expr, env)
745745

746746
case Match(selector, cases, default) =>
747+
// Ty 179B pecheck the selector as an int or a java.lang.String
748+
typecheck(selector, env)
749+
if (!isSubtype(selector.tpe, IntType) && !isSubtype(selector.tpe, BoxedStringType)) {
750+
reportError(
751+
i"int or java.lang.String exp F438 ected but ${selector.tpe} found" +
752+
i"for tree of type ${selector.getClass.getName}")
753+
}
754+
755+
// The alternatives are MatchableLiterals, no point typechecking them
747756
val tpe = tree.tpe
748-
typecheckExpect(selector, env, IntType)
749-
// The alternatives are IntLiterals, no point typechecking them
750757
for ((_, body) <- cases)
751758
typecheckExpect(body, env, tpe)
752759
typecheckExpect(default, env, tpe)
@@ -1396,6 +1403,8 @@ private final class IRChecker(unit: LinkingUnit, logger: Logger) {
13961403
}
13971404

13981405
object IRChecker {
1406+
private val BoxedStringType = ClassType(BoxedStringClass)
1407+
13991408
/** Checks that the IR in a [[frontend.LinkingUnit LinkingUnit]] is correct.
14001409
*
14011410
* @return Count of IR checking errors (0 in case of success)

linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -461,9 +461,9 @@ private[optimizer] abstract class OptimizerCore(config: CommonPhaseConfig) {
461461
case Match(selector, cases, default) =>
462462
val newSelector = transformExpr(selector)
463463
newSelector match {
464-
case IntLiteral(selectorValue) =>
464+
case selectorValue: MatchableLiteral =>
465465
val body = cases.collectFirst {
466-
case (alts, body) if alts.exists(_.value == selectorValue) => body
466+
case (alts, body) if alts.exists(matchableLiteral_===(_, selectorValue)) => body
467467
}.getOrElse(default)
468468
transform(body, isStat)
469469
case _ =>
@@ -798,9 +798,9 @@ private[optimizer] abstract class OptimizerCore(config: CommonPhaseConfig) {
798798
case Match(selector, cases, default) =>
799799
val newSelector = transformExpr(selector)
800800
newSelector match {
801-
case IntLiteral(selectorValue) =>
801+
case selectorValue: MatchableLiteral =>
802802
val body = cases.collectFirst {
803-
case (alts, body) if alts.exists(_.value == selectorValue) => body
803+
case (alts, body) if alts.exists(matchableLiteral_===(_, selectorValue)) => body
804804
}.getOrElse(default)
805805
pretransformExpr(body)(cont)
806806
case _ =>
@@ -2952,6 +2952,23 @@ private[optimizer] abstract class OptimizerCore(config: CommonPhaseConfig) {
29522952
}
29532953
}
29542954

2955+
/** Performs `===` for two matchable literals.
2956+
*
2957+
* This corresponds to the test used by a `Match` at run-time, to decide
2958+
* which case is selected.
2959+
*
2960+
* The result is always known statically.
2961+
*/
2962+
private def matchableLiteral_===(lhs: MatchableLiteral,
2963+
rhs: MatchableLiteral): Boolean = {
2964+
(lhs, rhs) match {
2965+
case (IntLiteral(l), IntLiteral(r)) => l == r
2966+
case (StringLiteral(l), StringLiteral(r)) => l == r
2967+
case (Null(), Null()) => true
2968+
case _ => false
2969+
}
2970+
}
2971+
29552972
private def constantFoldBinaryOp_except_String_+(op: BinaryOp.Code,
29562973
lhs: Literal, rhs: Literal)(implicit pos: Position): Literal = {
29572974
import BinaryOp._

0 commit comments

Comments
 (0)
0