8000 Fix #1811: Add support for secondary constructors in JS classes · scala-js/scala-js@3599292 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3599292

Browse files
committed
Fix #1811: Add support for secondary constructors in JS classes
1 parent 4e71cc2 commit 3599292

File tree

5 files changed

+525
-58
lines changed

5 files changed

+525
-58
lines changed

compiler/src/main/scala/org/scalajs/core/compiler/GenJSCode.scala

Lines changed: 323 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -649,29 +649,334 @@ abstract class GenJSCode extends plugins.PluginComponent
649649
constructorTrees: List[DefDef]): js.Tree = {
650650
implicit val pos = classSym.pos
651651

652-
val (primaryCtorTree :: Nil, secondaryCtorTrees) =
653-
constructorTrees.partition(_.symbol.isPrimaryConstructor)
654-
655652
// Implementation restriction
656-
val sym = primaryCtorTree.symbol
653+
val syms = constructorTrees.map(_.symbol)
657654
val hasBadParam = enteringPhase(currentRun.uncurryPhase) {
658-
sym.paramss.flatten.exists(p => p.hasDefault || isRepeated(p))
655+
syms.exists(_.paramss.flatten.exists(p => p.hasDefault))
659656
}
660657
if (hasBadParam) {
661658
reporter.error(pos,
662659
"Implementation restriction: the constructor of a " +
663-
"Scala.js-defined JS classes cannot have default parameters nor " +
664-
"repeated parameters.")
660+
"Scala.js-defined JS classes cannot have default parameters.")
665661
}
666662

667-
// Implementation restriction
668-
for (tree <- secondaryCtorTrees) {
669-
reporter.error(tree.pos,
670-
"Implementation restriction: Scala.js-defined JS classes cannot " +
671-
"have secondary constructors")
663+
withNewLocalNameScope {
664+
val ctors: List[js.MethodDef] = constructorTrees.flatMap { tree =>
665+
genMethodWithCurrentLocalNameScope(tree)
666+
}
667+
668+
val dispatch =
669+
genJSConstructorExport(constructorTrees.map(_.symbol))
670+
val js.MethodDef(_, dispatchName, dispatchArgs, dispatchResultType,
671+
dispatchResolution) = dispatch
672+
673+
val isConstructor = ctors.map(_.name.name).toSet
674+
675+
val constructorTree = mkConstructorTree(ctors, isConstructor)
676+
677+
val overloadIdent = freshLocalIdent("overload")
678+
679+
// Section containing the overload resolution and casts of parameters
680+
val overloadSelection = mkOverloadSelection(constructorTree,
681+
overloadIdent, dispatchResolution, isConstructor)
682+
683+
/* Section containing all the code executed before the call to `this`
684+
* for every secondary constructor.
685+
*/
686+
val prePrimaryCtorBody =
687+
constructorTree.mkPreSuperCall(overloadIdent, isConstructor)
688+
689+
val primaryCtorBody = constructorTree.rootBody
690+
691+
/* Section containing all the code executed after the call to this for
692+
* every secondary constructor.
693+
*/
694+
val postPrimaryCtorBody =
695+
constructorTree.mkPostSuperCall(overloadIdent, isConstructor)
696+
697+
val newBody = js.Block(overloadSelection ::: prePrimaryCtorBody ::
698+
primaryCtorBody :: postPrimaryCtorBody :: Nil)
699+
700+
js.MethodDef(static = false, dispatchName, dispatchArgs, jstpe.NoType,
701+
newBody)(dispatch.optimizerHints, dispatch.hash)
702+
}
703+
}
704+
705+
private class ConstructorTree(overrideNum: Int, method: js.MethodDef,
706+
subConstructors: List[ConstructorTree]) {
707+
708+
def methodName: String = method.name.name
709+
710+
def hasSubConstructors: Boolean = subConstructors.nonEmpty
711+
712+
lazy val overrideNumBounds: (Int, Int) =
713+
if (subConstructors.isEmpty) (overrideNum, overrideNum)
714+
else (subConstructors.head.overrideNumBounds._1, overrideNum)
715+
716+
def get(methodName: String): Option[ConstructorTree] = {
717+
if (methodName == this.methodName) {
718+
Some(this)
719+
} else {
720+
subConstructors.iterator.map(_.get(methodName)).collectFirst {
721+
case Some(node) => node
722+
}
723+
}
724+
}
725+
726+
def getOverrideNum(methodName: String): Int =
727+
get(methodName).fold(-1)(_.getOverrideNum)
728+
729+
def rootBody: js.Tree = method.body
730+
731+
private def getOverrideNum: Int = overrideNum
732+
733+
def getParamRefs(methodName: String)(implicit pos: Position): List[js.VarRef] =
734+
get(methodName).fold(List.empty[js.VarRef])(_.getParamRefs)
735+
736+
private def getParamRefs(implicit pos: Position): List[js.VarRef] =
737+
method.args.map(_.ref)
738+
739+
def mkPreSuperCall(overrideNumIdent: js.Ident, isConstructor: Set[String])(
740+
implicit pos: Position): js.Tree = {
741+
val overrideNumRef = js.VarRef(overrideNumIdent)(jstpe.IntType)
742+
mkSubPreCalls(overrideNumRef, isConstructor)
743+
}
744+
745+
def mkPostSuperCall(overrideNumIdent: js.Ident,
746+
isConstructor: Set[String])(implicit pos: Position): js.Tree = {
747+
val overrideNumRef = js.VarRef(overrideNumIdent)(jstpe.IntType)
748+
js.Block(mkSubPostCalls(overrideNumRef, isConstructor))
749+
}
750+
751+
def getAllParamDefsAsVars(implicit pos: Position): List[js.VarDef] = {
752+
val localDefs = method.args.map { pDef =>
753+
js.VarDef(pDef.name, pDef.ptpe, mutable = true, jstpe.zeroOf(pDef.ptpe))
754+
}
755+
localDefs ++ subConstructors.flatMap(_.getAllParamDefsAsVars)
756+
}
757+
758+
private def mkSubPreCalls(overrideNumRef: js.VarRef,
759+
isConstructor: Set[String])(implicit pos: Position) = {
760+
val overrideNumss = subConstructors.map(_.overrideNumBounds)
761+
val paramRefs = getParamRefs
762+
val bodies = subConstructors.map { cg =>
763+
cg.mkPreSuperCallOnSndCtr(overrideNumRef, paramRefs, isConstructor)
764+
}
765+
overrideNumss.zip(bodies).foldRight[js.Tree](js.Skip()) {
766+
case ((numBounds, body), acc) =>
767+
val cond = mkOverrideNumsCond(overrideNumRef, numBounds)
768+
js.If(cond, body, acc)(jstpe.BooleanType)
769+
}
770+
}
771+
772+
private def mkPreSuperCallOnSndCtr(
773+
overrideNumRef: js.VarRef, outputParams: List[js.VarRef],
774+
isConstructor: Set[String])(implicit pos: Position): js.Tree = {
775+
val subCalls =
776+
mkSubPreCalls(overrideNumRef, isConstructor)
777+
778+
val preSuperCall = {
779+
method.body match {
780+
case js.Block(stats) =>
781+
val beforeSuperCall = stats.takeWhile {
782+
case js.ApplyStatic(_, mtd, _)
783+
if ir.Definitions.isConstructorName(mtd.name) =>
784+
false
785+
786+
case _ => true
787+
}
788+
val superCallParams = stats.collectFirst {
789+
case js.ApplyStatic(_, mtd, js.This() :: args)
790+
if ir.Definitions.isConstructorName(mtd.name) =>
791+
zipMap(outputParams, args) { (ref, tree) =>
792+
js.Assign(ref, tree)
793+
}
794+
}.getOrElse(Nil)
795+
796+
beforeSuperCall ::: superCallParams
797+
798+
case js.ApplyStatic(_, mtd, js.This() :: args)
799+
if ir.Definitions.isConstructorName(mtd.name) =>
800+
zipMap(outputParams, args)(js.Assign(_, _))
801+
802+
case _ => Nil
803+
}
804+
}
805+
806+
js.Block(subCalls :: preSuperCall)
807+
}
808+
809+
private def mkSubPostCalls(overrideNumRef: js.VarRef,
810+
isConstructor: Set[String])(implicit pos: Position): js.Tree = {
811+
val overrideNumss = subConstructors.map(_.overrideNumBounds)
812+
val bodies = subConstructors.map { cg =>
813+
cg.mkPostSuperCallOnSndCtr(overrideNumRef, isConstructor)
814+
}
815+
overrideNumss.zip(bodies).foldRight[js.Tree](js.Skip()) {
816+
case ((numBounds, js.Skip()), acc) => acc
817+
818+
case ((numBounds, body), acc) =>
819+
val cond = mkOverrideNumsCond(overrideNumRef, numBounds)
820+
js.If(cond, body, acc)(jstpe.BooleanType)
821+
}
822+
}
823+
824+
private def mkPostSuperCallOnSndCtr(overrideNumRef: js.VarRef,
825+
isConstructor: Set[String])(implicit pos: Position): js.Tree = {
826+
val postSuperCall = {
827+
method.body match {
828+
case js.Block(stats) =>
829+
stats.dropWhile {
830+
case js.ApplyStatic(_, mtd, _) if isConstructor(mtd.name) => false
831+
case _ => true
832+
}.tail
833+
834+
case _ => Nil
835+
}
836+
}
837+
js.Block(postSuperCall :+ mkSubPostCalls(overrideNumRef, isConstructor))
838+
}
839+
840+
private def mkOverrideNumsCond(numRef: js.VarRef,
841+
numBounds: (Int, Int))(implicit pos: Position) = numBounds match {
842+
case (lo, hi) if lo == hi =>
843+
js.BinaryOp(js.BinaryOp.===, js.IntLiteral(lo), numRef)
844+
845+
case (lo, hi) if lo == hi - 1 =>
846+
val lhs = js.BinaryOp(js.BinaryOp.===, numRef, js.IntLiteral(lo))
847+
val rhs = js.BinaryOp(js.BinaryOp.===, numRef, js.IntLiteral(hi))
848+
js.If(lhs, js.BooleanLiteral(true), rhs)(jstpe.BooleanType)
849+
850+
case (lo, hi) =>
851+
val lhs = js.BinaryOp(js.BinaryOp.Num_<=, js.IntLiteral(lo), numRef)
852+
val rhs = js.BinaryOp(js.BinaryOp.Num_<=, numRef, js.IntLiteral(hi))
853+
js.BinaryOp(js.BinaryOp.Boolean_&, lhs, rhs)
854+
js.If(lhs, rhs, js.BooleanLiteral(false))(jstpe.BooleanType)
855+
}
856+
}
857+
858+
private def isCallToContructor(cls: jstpe.ClassType,
859+
methodName: String): Boolean = {
860+
ir.Definitions.isConstructorName(methodName) && cls == currentClassType
861+
}
862+
863+
private def zipMap[T, U, V](xs: List[T], ys: List[U])(
864+
f: (T, U) => V): List[V] = {
865+
for ((x, y) <- xs zip ys) yield f(x, y)
866+
}
867+
868+
/** mkOverloadSelection return a list of `stats` with that starts with:
869+
* 1) The definition for the local variable that will hold the overload
870+
* resolution number.
871+
* 2) The definitions of all local variables that are used as parameters
872+
* in all the constructors.
873+
* 3) The overload resolution match/if statements. For each overload the
874+
* overload number is assigned and the parameters are cast and assigned
875+
* to their corresponding variables.
876+
*/
877+
private def mkOverloadSelection(constructorTree: ConstructorTree,
878+
overloadIdent: js.Ident, dispatchResolution: js.Tree,
879+
isConstructor: Set[String])(implicit pos: Position): List[js.Tree]= {
880+
if (!constructorTree.hasSubConstructors) {
881< F438 span class="diff-text-marker">+
dispatchResolution match {
882+
case js.Block(stats) =>
883+
val js.ApplyStatic(_, method, _) = stats.last
884+
val refs = constructorTree.getParamRefs(method.name)
885+
val rhss = stats.collect {
886+
case js.VarDef(_, _, _, rhs) => rhs
887+
}
888+
zipMap(refs, rhss) { (ref, rhs) =>
889+
js.VarDef(ref.ident, ref.tpe, mutable = false, rhs)
890+
}
891+
892+
case js.ApplyStatic(_, mtd, js.This() :: Nil)
893+
if ir.Definitions.isConstructorName(mtd.name) =>
894+
Nil
895+
896+
case tree =>
897+
List(tree)
898+
}
899+
} else {
900+
val overloadRef = js.VarRef(overloadIdent)(jstpe.IntType)
901+
902+
def transformDispatch(tree: js.Tree): js.Tree = tree match {
903+
case js.Block(stats) =>
904+
val js.ApplyStatic(_, method, _) = stats.last
905+
val num = constructorTree.getOverrideNum(method.name)
906+
val refs = overloadRef :: constructorTree.getParamRefs(method.name)
907+
val rhss = js.IntLiteral(num) :: stats.collect {
908+
case js.VarDef(_, _, _, rhs) => rhs
909+
}
910+
val newStats = zipMap(refs, rhss)(js.Assign(_, _))
911+
js.Block(newStats)
912+
913+
case js.ApplyStatic(_, method, js.This() :: Nil)
914+
if ir.Definitions.isConstructorName(method.name) =>
915+
js.Assign(overloadRef,
916+
js.IntLiteral(constructorTree.getOverrideNum(method.name)))
917+
918+
case js.Match(selector, cases, default) =>
919+
val newCases = cases.map {
920+
case (literals, body) => (literals, transformDispatch(body))
921+
}
922+
val newDefault = transformDispatch(default)
923+
js.Match(selector, newCases, newDefault)(tree.tpe)
924+
925+
case js.If(cond, thenp, elsep) =>
926+
js.If(cond, transformDispatch(thenp),
927+
transformDispatch(elsep))(tree.tpe)
928+
929+
case _ =>
930+
tree
931+
}
932+
933+
val newDispatchResolution = transformDispatch(dispatchResolution)
934+
val allParamDefsAsVars = constructorTree.getAllParamDefsAsVars
935+
val overrideNumDef =
936+
js.VarDef(overloadIdent, jstpe.IntType, mutable = true, js.IntLiteral(0))
937+
938+
overrideNumDef :: allParamDefsAsVars ::: newDispatchResolution :: Nil
939+
}
940+
}
941+
942+
private def mkConstructorTree(ctors: List[js.MethodDef],
943+
isConstructor: Set[String])(implicit pos: Position): ConstructorTree = {
944+
def superCall(tree: js.Tree): Option[String] = tree match {
945+
case js.Block(stats) =>
946+
stats.map(superCall).collectFirst {
947+
case Some(name) => name
948+
}
949+
950+
case js.ApplyStatic(_, method, js.This() :: _)
951+
if ir.Definitions.isConstructorName(method.name) =>
952+
Some(method.name)
953+
954+
case _ => None
955+
}
956+
957+
val (primaryCtor :: Nil, secondaryCtors) = ctors.partition {
958+
_.body match {
959+
case js.Block(stats) =>
960+
stats.exists(_.isInstanceOf[js.JSSuperConstructorCall])
961+
962+
case _: js.JSSuperConstructorCall => true
963+
case _ => false
964+
}
965+
}
966+
967+
val ctorToChildren = secondaryCtors.flatMap { ctor =>
968+
superCall(ctor.body).map(_ -> ctor)
969+
}.groupBy(_._1).mapValues(_.map(_._2)).withDefaultValue(Nil)
970+
971+
var overrideNum = -1
972+
def mkConstructorTree(method: js.MethodDef): ConstructorTree = {
973+
val methodName = method.name.name
974+
val subCtrTrees = ctorToChildren(methodName).map(mkConstructorTree)
975+
overrideNum += 1
976+
new ConstructorTree(overrideNum, method, subCtrTrees)
672977
}
673978

674-
genMethod(primaryCtorTree).get
979+
mkConstructorTree(primaryCtor)
675980
}
676981

677982
// Generate a method -------------------------------------------------------
@@ -721,9 +1026,7 @@ abstract class GenJSCode extends plugins.PluginComponent
7211026
val isJSClassConstructor =
7221027
sym.isClassConstructor && isScalaJSDefinedJSClass(currentClassSym)
7231028

724-
val methodName: js.PropertyName =
725-
if (isJSClassConstructor) js.StringLiteral("constructor")
726-
else encodeMethodSym(sym)
1029+
val methodName: js.PropertyName = encodeMethodSym(sym)
7271030

7281031
def jsParams = for (param <- params) yield {
7291032
implicit val pos = param.pos
@@ -771,12 +1074,11 @@ abstract class GenJSCode extends plugins.PluginComponent
7711074
val methodDef = {
7721075
if (isJSClassConstructor) {
7731076
val body0 = genStat(rhs)
774-
val body1 = moveAllStatementsAfterSuperConstructorCall(body0)
775-
val (patchedParams, patchedBody) =
776-
patchFunBodyWithBoxes(sym, jsParams, body1)
1077+
val body1 =
1078+
if (!sym.isPrimaryConstructor) body0
1079+
else moveAllStatementsAfterSuperConstructorCall(body0)
7771080
js.MethodDef(static = false, methodName,
778-
patchedParams, jstpe.NoType, patchedBody)(
779-
optimizerHints, None)
1081+
jsParams, jstpe.NoType, body1)(optimizerHints, None)
7801082
} else if (sym.isClassConstructor) {
7811083
js.MethodDef(static = false, methodName,
7821084
jsParams, jstpe.NoType,

0 commit comments

Comments
 (0)
0