8000 Fix #1811: Add support for secondary constructors in JS classes · scala-js/scala-js@16e8dc5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 16e8dc5

Browse files
committed
Fix #1811: Add support for secondary constructors in JS classes
1 parent c146755 commit 16e8dc5

File tree

5 files changed

+542
-60
lines changed

5 files changed

+542
-60
lines changed

compiler/src/main/scala/org/scalajs/core/compiler/GenJSCode.scala

Lines changed: 334 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -649,29 +649,345 @@ abstract class GenJSCode extends plugins.PluginComponent
649649
constructorTrees: List[DefDef]): js.Tree = {
650650
implicit val pos = classSym.pos
651651

652-
val (primaryCtorTree :: Nil, secondaryCtorTrees) =
653-
constructorTrees.partition(_.symbol.isPrimaryConstructor)
654-
655652
// Implementation restriction
656-
val sym = primaryCtorTree.symbol
653+
val syms = constructorTrees.map(_.symbol)
657654
val hasBadParam = enteringPhase(currentRun.uncurryPhase) {
658-
sym.paramss.flatten.exists(p => p.hasDefault || isRepeated(p))
655+
syms.exists(_.paramss.flatten.exists(p => p.hasDefault))
659656
}
660657
if (hasBadParam) {
661658
reporter.error(pos,
662659
"Implementation restriction: the constructor of a " +
663-
"Scala.js-defined JS classes cannot have default parameters nor " +
664-
"repeated parameters.")
660+
"Scala.js-defined JS classes cannot have default parameters.")
665661
}
666662

667-
// Implementation restriction
668-
for (tree <- secondaryCtorTrees) {
669-
reporter.error(tree.pos,
670-
"Implementation restriction: Scala.js-defined JS classes cannot " +
671-
"have secondary constructors")
663+
withNewLocalNameScope {
664+
val ctors: List[js.MethodDef] = constructorTrees.flatMap { tree =>
665+
genMethodWithCurrentLocalNameScope(tree)
666+
}
667+
668+
val dispatch =
669+
genJSConstructorExport(constructorTrees.map(_.symbol))
670+
val js.MethodDef(_, dispatchName, dispatchArgs, dispatchResultType,
671+
dispatchResolution) = dispatch
672+
673+
val jsConstructorBuilder = mkJSConstructorBuilder(ctors)
674+
675+
val overloadIdent = freshLocalIdent("overload")
676+
677+
// Section containing the overload resolution and casts of parameters
678+
val overloadSelection = mkOverloadSelection(jsConstructorBuilder,
679+
overloadIdent, dispatchResolution)
680+
681+
/* Section containing all the code executed before the call to `this`
682+
* for every secondary constructor.
683+
*/
684+
val prePrimaryCtorBody =
685+
jsConstructorBuilder.mkPrePrimaryCtorBody(overloadIdent)
686+
687+
val primaryCtorBody = jsConstructorBuilder.primaryCtorBody
688+
689+
/* Section containing all the code executed after the call to this for
690+
* every secondary constructor.
691+
*/
692+
val postPrimaryCtorBody =
693+
jsConstructorBuilder.mkPostPrimaryCtorBody(overloadIdent)
694+
695+
val newBody = js.Block(overloadSelection ::: prePrimaryCtorBody ::
696+
primaryCtorBody :: postPrimaryCtorBody :: Nil)
697+
698+
js.MethodDef(static = false, dispatchName, dispatchArgs, jstpe.NoType,
699+
newBody)(dispatch.optimizerHints, None)
700+
}
701+
}
702+
703+
private class ConstructorTree(val overrideNum: Int, val method: js.MethodDef,
704+
val subConstructors: List[ConstructorTree]) {
705+
706+
lazy val overrideNumBounds: (Int, Int) =
707+
if (subConstructors.isEmpty) (overrideNum, overrideNum)
708+
else (subConstructors.head.overrideNumBounds._1, overrideNum)
709+
710+
def get(methodName: String): Option[ConstructorTree] = {
711+
if (methodName == this.method.name.name) {
712+
Some(this)
713+
} else {
714+
subConstructors.iterator.map(_.get(methodName)).collectFirst {
715+
case Some(node) => node
716+
}
717+
}
718+
}
719+
720+
def getParamRefs(implicit pos: Position): List[js.VarRef] =
721+
method.args.map(_.ref)
722+
723+
def getAllParamDefsAsVars(implicit pos: Position): List[js.VarDef] = {
724+
val localDefs = method.args.map { pDef =>
725+
js.VarDef(pDef.name, pDef.ptpe, mutable = true, jstpe.zeroOf(pDef.ptpe))
726+
}
727+
localDefs ++ subConstructors.flatMap(_.getAllParamDefsAsVars)
728+
}
729+
}
730+
731+
private class JSConstructorBuilder(root: ConstructorTree) {
732+
733+
def primaryCtorBody: js.Tree = root.method.body
734+
735+
def hasSubConstructors: Boolean = root.subConstructors.nonEmpty
736+
737+
def getOverrideNum(methodName: String): Int =
738+
root.get(methodName).fold(-1)(_.overrideNum)
739+
740+
def getParamRefsFor(methodName: String)(implicit pos: Position): List[js.VarRef] =
741+
root.get(methodName).fold(List.empty[js.VarRef])(_.getParamRefs)
742+
743+
def getAllParamDefsAsVars(implicit pos: Position): List[js.VarDef] =
744+
root.getAllParamDefsAsVars
745+
746+
def mkPrePrimaryCtorBody(overrideNumIdent: js.Ident)(
747+
implicit pos: Position): js.Tree = {
748+
val overrideNumRef = js.VarRef(overrideNumIdent)(jstpe.IntType)
749+
mkSubPreCalls(root, overrideNumRef)
750+
}
751+
752+
def mkPostPrimaryCtorBody(overrideNumIdent: js.Ident)(
753+
implicit pos: Position): js.Tree = {
754+
val overrideNumRef = js.VarRef(overrideNumIdent)(jstpe.IntType)
755+
js.Block(mkSubPostCalls(root, overrideNumRef))
756+
}
757+
758+
private def mkSubPreCalls(constructorTree: ConstructorTree,
759+
overrideNumRef: js.VarRef)(implicit pos: Position): js.Tree = {
760+
val overrideNumss = constructorTree.subConstructors.map(_.overrideNumBounds)
761+
val paramRefs = constructorTree.getParamRefs
762+
val bodies = constructorTree.subConstructors.map { constructorTree =>
763+
mkPrePrimaryCtorBodyOnSndCtr(constructorTree, overrideNumRef, paramRefs)
764+
}
765+
overrideNumss.zip(bodies).foldRight[js.Tree](js.Skip()) {
766+
case ((numBounds, body), acc) =>
767+
val cond = mkOverrideNumsCond(overrideNumRef, numBounds)
768+
js.If(cond, body, acc)(jstpe. F438 BooleanType)
769+
}
770+
}
771+
772+
private def mkPrePrimaryCtorBodyOnSndCtr(constructorTree: ConstructorTree,
773+
overrideNumRef: js.VarRef, outputParams: List[js.VarRef])(
774+
implicit pos: Position): js.Tree = {
775+
val subCalls =
776+
mkSubPreCalls(constructorTree, overrideNumRef)
777+
778+
val preSuperCall = {
779+
constructorTree.method.body match {
780+
case js.Block(stats) =>
781+
val beforeSuperCall = stats.takeWhile {
782+
case js.ApplyStatic(_, mtd, _) => !ir.Definitions.isConstructorName(mtd.name)
783+
case _ => true
784+
}
785+
val superCallParams = stats.collectFirst {
786+
case js.ApplyStatic(_, mtd, js.This() :: args)
787+
if ir.Definitions.isConstructorName(mtd.name) =>
788+
zipMap(outputParams, args) { (ref, tree) =>
789+
js.Assign(ref, tree)
790+
}
791+
}.getOrElse(Nil)
792+
793+
beforeSuperCall ::: superCallParams
794+
795+
case js.ApplyStatic(_, mtd, js.This() :: args)
796+
if ir.Definitions.isConstructorName(mtd.name) =>
797+
zipMap(outputParams, args)(js.Assign(_, _))
798+
799+
case _ => Nil
800+
}
801+
}
802+
803+
js.Block(subCalls :: preSuperCall)
804+
}
805+
806+
private def mkSubPostCalls(constructorTree: ConstructorTree,
807+
overrideNumRef: js.VarRef)(implicit pos: Position): js.Tree = {
808+
val overrideNumss = constructorTree.subConstructors.map(_.overrideNumBounds)
809+
val bodies = constructorTree.subConstructors.map { ct =>
810+
mkPostPrimaryCtorBodyOnSndCtr(ct, overrideNumRef)
811+
}
812+
overrideNumss.zip(bodies).foldRight[js.Tree](js.Skip()) {
813+
case ((numBounds, js.Skip()), acc) => acc
814+
815+
case ((numBounds, body), acc) =>
816+
val cond = mkOverrideNumsCond(overrideNumRef, numBounds)
817+
js.If(cond, body, acc)(jstpe.BooleanType)
818+
}
819+
}
820+
821+
private def mkPostPrimaryCtorBodyOnSndCtr(constructorTree: ConstructorTree,
822+
overrideNumRef: js.VarRef)(implicit pos: Position): js.Tree = {
823+
val postSuperCall = {
824+
constructorTree.method.body match {
825+
case js.Block(stats) =>
826+
stats.dropWhile {
827+
case js.ApplyStatic(_, mtd, _) => !ir.Definitions.isConstructorName(mtd.name)
828+
case _ => true
829+
}.tail
830+
831+
case _ => Nil
832+
}
833+
}
834+
js.Block(postSuperCall :+ mkSubPostCalls(constructorTree, overrideNumRef))
835+
}
836+
837+
private def mkOverrideNumsCond(numRef: js.VarRef,
838+
numBounds: (Int, Int))(implicit pos: Position) = numBounds match {
839+
case (lo, hi) if lo == hi =>
840+
js.BinaryOp(js.BinaryOp.===, js.IntLiteral(lo), numRef)
841+
842+
case (lo, hi) if lo == hi - 1 =>
843+
val lhs = js.BinaryOp(js.BinaryOp.===, numRef, js.IntLiteral(lo))
844+
val rhs = js.BinaryOp(js.BinaryOp.===, numRef, js.IntLiteral(hi))
845+
js.If(lhs, js.BooleanLiteral(true), rhs)(jstpe.BooleanType)
846+
847+
case (lo, hi) =>
848+
val lhs = js.BinaryOp(js.BinaryOp.Num_<=, js.IntLiteral(lo), numRef)
849+
val rhs = js.BinaryOp(js.BinaryOp.Num_<=, numRef, js.IntLiteral(hi))
850+
js.BinaryOp(js.BinaryOp.Boolean_&, lhs, rhs)
851+
js.If(lhs, rhs, js.BooleanLiteral(false))(jstpe.BooleanType)
852+
}
853+
}
854+
855+
private def zipMap[T, U, V](xs: List[T], ys: List[U])(
856+
f: (T, U) => V): List[V] = {
857+
for ((x, y) <- xs zip ys) yield f(x, y)
858+
}
859+
860+
/** mkOverloadSelection return a list of `stats` with that starts with:
861+
* 1) The definition for the local variable that will hold the overload
862+
* resolution number.
863+
* 2) The definitions of all local variables that are used as parameters
864+
* in all the constructors.
865+
* 3) The overload resolution match/if statements. For each overload the
866+
* overload number is assigned and the parameters are cast and assigned
867+
* to their corresponding variables.
868+
*/
869+
private def mkOverloadSelection(jsConstructorBuilder: JSConstructorBuilder,
870+
overloadIdent: js.Ident, dispatchResolution: js.Tree)(
871+
implicit pos: Position): List[js.Tree]= {
872+
if (!jsConstructorBuilder.hasSubConstructors) {
873+
dispatchResolution match {
874+
/* Dispatch to constructor with no arguments.
875+
* Contains trivial parameterless call to the constructor.
876+
*/
877+
case js.ApplyStatic(_, mtd, js.This() :: Nil)
878+
if ir.Definitions.isConstructorName(mtd.name) =>
879+
Nil
880+
881+
/* Dispatch to constructor with no arguments
882+
* Where js.Block's stats.init corresponds to the parameter casts and
883+
* js.Block's stats.last contains the call to the constructor.
884+
*/
885+
case js.Block(stats) =>
886+
val js.ApplyStatic(_, method, _) = stats.last
887+
val refs = jsConstructorBuilder.getParamRefsFor(method.name)
888+
val paramCasts = stats.init.map(_.asInstanceOf[js.VarDef])
889+
zipMap(refs, paramCasts) { (ref, paramCast) =>
890+
js.VarDef(ref.ident, ref.tpe, mutable = false, paramCast.rhs)
891+
}
892+
}
893+
} else {
894+
val overloadRef = js.VarRef(overloadIdent)(jstpe.IntType)
895+
896+
/* transformDispatch takes the body of the method generated by
897+
* `genJSConstructorExport` and transform it recursively.
898+
*/
899+
def transformDispatch(tree: js.Tree): js.Tree = tree match {
900+
/* Dispatch to constructor with no arguments.
901+
* Contains trivial parameterless call to the constructor.
902+
*/
903+
case js.ApplyStatic(_, method, js.This() :: Nil)
904+
if ir.Definitions.isConstructorName(method.name) =>
905+
js.Assign(overloadRef,
906+
js.IntLiteral(jsConstructorBuilder.getOverrideNum(method.name)))
907+
908+
/* Dispatch to constructor with no arguments
909+
* Where js.Block's stats.init corresponds to the parameter casts and
910+
* js.Block's stats.last contains the call to the constructor.
911+
*/
912+
case js.Block(stats) =>
913+
val js.ApplyStatic(_, method, _) = stats.last
914+
915+
val num = jsConstructorBuilder.getOverrideNum(method.name)
916+
val overloadAssign = js.Assign(overloadRef, js.IntLiteral(num))
917+
918+
val refs = jsConstructorBuilder.getParamRefsFor(method.name)
919+
val paramCasts = stats.init.map(_.asInstanceOf[js.VarDef])
920+
val parameterAssigns = zipMap(refs, paramCasts) { (ref, paramCast) =>
921+
js.Assign(ref, paramCast.rhs)
922+
}
923+
924+
js.Block(overloadAssign :: parameterAssigns)
925+
926+
// Parameter count resolution
927+
case js.Match(selector, cases, default) =>
928+
val newCases = cases.map {
929+
case (literals, body) => (literals, transformDispatch(body))
930+
}
931+
val newDefault = transformDispatch(default)
932+
js.Match(selector, newCases, newDefault)(tree.tpe)
933+
934+
// Parameter type resolution
935+
case js.If(cond, thenp, elsep) =>
936+
js.If(cond, transformDispatch(thenp),
937+
transformDispatch(elsep))(tree.tpe)
938+
939+
// Throw(StringLiteral(No matching overload))
940+
case tree: js.Throw =>
941+
tree
942+
}
943+
944+
val newDispatchResolution = transformDispatch(dispatchResolution)
945+
val allParamDefsAsVars = jsConstructorBuilder.getAllParamDefsAsVars
946+
val overrideNumDef =
947+
js.VarDef(overloadIdent, jstpe.IntType, mutable = true, js.IntLiteral(0))
948+
949+
overrideNumDef :: allParamDefsAsVars ::: newDispatchResolution :: Nil
950+
}
951+
}
952+
953+
private def mkJSConstructorBuilder(ctors: List[js.MethodDef])(
954+
implicit pos: Position): JSConstructorBuilder = {
955+
def findCtorForwarderCall(tree: js.Tree): String = tree match {
956+
case js.ApplyStatic(_, method, js.This() :: _)
957+
if ir.Definitions.isConstructorName(method.name) =>
958+
method.name
959+
960+
case js.Block(stats) =>
961+
stats.collectFirst {
962+
case js.ApplyStatic(_, method, js.This() :: _)
963+
if ir.Definitions.isConstructorName(method.name) =>
964+
method.name
965+
}.get
966+
}
967+
968+
val (primaryCtor :: Nil, secondaryCtors) = ctors.partition {
969+
_.body match {
970+
case js.Block(stats) =>
971+
stats.exists(_.isInstanceOf[js.JSSuperConstructorCall])
972+
973+
case _: js.JSSuperConstructorCall => true
974+
case _ => false
975+
}
976+
}
977+
978+
val ctorToChildren = secondaryCtors.map { ctor =>
979+
findCtorForwarderCall(ctor.body) -> ctor
980+
}.groupBy(_._1).mapValues(_.map(_._2)).withDefaultValue(Nil)
981+
982+
var overrideNum = -1
983+
def mkConstructorTree(method: js.MethodDef): ConstructorTree = {
984+
val methodName = method.name.name
985+
val subCtrTrees = ctorToChildren(methodName).map(mkConstructorTree)
986+
overrideNum += 1
987+
new ConstructorTree(overrideNum, method, subCtrTrees)
672988
}
673989

674-
genMethod(primaryCtorTree).get
990+
new JSConstructorBuilder(mkConstructorTree(primaryCtor))
675991
}
676992

677993
// Generate a method -------------------------------------------------------
@@ -720,9 +1036,7 @@ abstract class GenJSCode extends plugins.PluginComponent
7201036
val isJSClassConstructor =
7211037
sym.isClassConstructor && isScalaJSDefinedJSClass(currentClassSym)
7221038

723-
val methodName: js.PropertyName =
724-
if (isJSClassConstructor) js.StringLiteral("constructor")
725-
else encodeMethodSym(sym)
1039+
val methodName: js.PropertyName = encodeMethodSym(sym)
7261040

7271041
def jsParams = for (param <- params) yield {
7281042
implicit val pos = param.pos
@@ -793,12 +1107,11 @@ abstract class GenJSCode extends plugins.PluginComponent
7931107
val methodDef = {
7941108
if (isJSClassConstructor) {
7951109
val body0 = genStat(rhs)
796-
val body1 = moveAllStatementsAfterSuperConstructorCall(body0)
797-
val (patchedParams, patchedBody) =
798-
patchFunBodyWithBoxes(sym, jsParams, body1)
1110+
val body1 =
1111+
if (!sym.isPrimaryConstructor) body0
1112+
else moveAllStatementsAfterSuperConstructorCall(body0)
7991113
js.MethodDef(static = false, methodName,
800-
patchedParams, jstpe.NoType, patchedBody)(
801-
optimizerHints, None)
1114+
jsParams, jstpe.NoType, body1)(optimizerHints, None)
8021115
} else if (sym.isClassConstructor) {
8031116
js.MethodDef(static = false, methodName,
8041117
jsParams, jstpe.NoType,

0 commit comments

Comments
 (0)
0