10000 Fuse emitting and printing of trees in the backend by gzm0 · Pull Request #4917 · scala-js/scala-js · GitHub
[go: up one dir, main page]

Skip to content

Fuse emitting and printing of trees in the backend #4917

8000 New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Track in Emitter whether a module changed in an incremental run
In the next commit, we want to avoid caching entire classes because of
the memory cost. However, the BasicLinkerBackend relies on the
identity of the generated trees to detect changes: Since that identity
will change if we stop caching them, we need to provide an explicit
"changed" signal.
  • Loading branch information
gzm0 committed Jan 29, 2024
commit 5c56042c11adc6df0f830c71d35b053864bc0b66
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ final class ClosureLinkerBackend(config: LinkerBackendImpl.Config)
sjsModule <- moduleSet.modules.headOption
} yield {
val closureChunk = logger.time("Closure: Create trees)") {
buildChunk(emitterResult.body(sjsModule.id))
val (trees, _) = emitterResult.body(sjsModule.id)
buildChunk(trees)
}

logger.time("Closure: Compiler pass") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config)
val writer = new OutputWriter(output, config, skipContentCheck) {
protected def writeModuleWithoutSourceMap(moduleID: ModuleID, force: Boolean): Option[ByteBuffer] = {
val cache = printedModuleSetCache.getModuleCache(moduleID)
val printedTrees = emitterResult.body(moduleID)

val changed = cache.update(printedTrees)
val (printedTrees, changed) = emitterResult.body(moduleID)

if (force || changed || allChanged) {
rewrittenModules.incrementAndGet()
Expand All @@ -114,9 +112,7 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config)

protected def writeModuleWithSourceMap(moduleID: ModuleID, force: Boolean): Option[(ByteBuffer, ByteBuffer)] = {
val cache = printedModuleSetCache.getModuleCache(moduleID)
val printedTrees = emitterResult.body(moduleID)

val changed = cache.update(printedTrees)
val (printedTrees, changed) = emitterResult.body(moduleID)

if (force || changed || allChanged) {
rewrittenModules.incrementAndGet()
Expand Down Expand Up @@ -220,8 +216,6 @@ private object BasicLinkerBackend {

private sealed class PrintedModuleCache {
private var cacheUsed = false
private var changed = false
private var lastPrintedTrees: List[js.PrintedTree] = Nil

private var previousFinalJSFileSize: Int = 0
private var previousFinalSourceMapSize: Int = 0
Expand All @@ -239,15 +233,6 @@ private object BasicLinkerBackend {
previousFinalSourceMapSize = finalSourceMapSize
}

def update(newPrintedTrees: List[js.PrintedTree]): Boolean = {
val changed = !newPrintedTrees.corresponds(lastPrintedTrees)(_ eq _)
this.changed = changed
if (changed) {
lastPrintedTrees = newPrintedTrees
}
changed
}

def cleanAfterRun(): Boolean = {
val wasUsed = cacheUsed
cacheUsed = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ final class Emitter[E >: Null <: js.Tree](
}

private def emitInternal(moduleSet: ModuleSet,
logger: Logger): WithGlobals[Map[ModuleID, List[E]]] = {
logger: Logger): WithGlobals[Map[ModuleID, (List[E], Boolean)]] = {
// Reset caching stats.
statsClassesReused = 0
statsClassesInvalidated = 0
Expand Down Expand Up @@ -169,7 +169,7 @@ final class Emitter[E >: Null <: js.Tree](
*/
@tailrec
private def emitAvoidGlobalClash(moduleSet: ModuleSet,
logger: Logger, secondAttempt: Boolean): WithGlobals[Map[ModuleID, List[E]]] = {
logger: Logger, secondAttempt: Boolean): WithGlobals[Map[ModuleID, (List[E], Boolean)]] = {
val result = emitOnce(moduleSet, logger)

val mentionedDangerousGlobalRefs =
Expand All @@ -194,7 +194,7 @@ final class Emitter[E >: Null <: js.Tree](
}

private def emitOnce(moduleSet: ModuleSet,
logger: Logger): WithGlobals[Map[ModuleID, List[E]]] = {
logger: Logger): WithGlobals[Map[ModuleID, (List[E], Boolean)]] = {
// Genreate classes first so we can measure time separately.
val generatedClasses = logger.time("Emitter: Generate Classes") {
moduleSet.modules.map { module =>
Expand All @@ -212,18 +212,26 @@ final class Emitter[E >: Null <: js.Tree](

val moduleTrees = logger.time("Emitter: Write trees") {
moduleSet.modules.map { module =>
var changed = false
def extractChangedAndWithGlobals[T](x: (WithGlobals[T], Boolean)): T = {
changed ||= x._2
extractWithGlobals(x._1)
}

val moduleContext = ModuleContext.fromModule(module)
val moduleCache = state.moduleCaches.getOrElseUpdate(module.id, new ModuleCache)

val moduleClasses = generatedClasses(module.id)

val moduleImports = extractWithGlobals {
changed ||= moduleClasses.exists(_.changed)

val moduleImports = extractChangedAndWithGlobals {
moduleCache.getOrComputeImports(module.externalDependencies, module.internalDependencies) {
genModuleImports(module).map(postTransform(_, 0))
}
}

val topLevelExports = extractWithGlobals {
val topLevelExports = extractChangedAndWithGlobals {
/* We cache top level exports all together, rather than individually,
* since typically there are few.
*/
Expand All @@ -233,7 +241,7 @@ final class Emitter[E >: Null <: js.Tree](
}
}

val moduleInitializers = extractWithGlobals {
val moduleInitializers = extractChangedAndWithGlobals {
val initializers = module.initializers.toList
moduleCache.getOrComputeInitializers(initializers) {
WithGlobals.list(initializers.map { initializer =>
Expand Down Expand Up @@ -324,7 +332,7 @@ final class Emitter[E >: Null <: js.Tree](
trackedGlobalRefs = unionPreserveEmpty(trackedGlobalRefs, genClass.trackedGlobalRefs)
}

module.id -> allTrees
module.id -> (allTrees, changed)
}
}

Expand Down Expand Up @@ -382,8 +390,14 @@ final class Emitter[E >: Null <: js.Tree](
val classCache = classCaches.getOrElseUpdate(
new ClassID(linkedClass.ancestors, moduleContext), new ClassCache)

var changed = false
def extractChanged[T](x: (T, Boolean)): T = {
changed ||= x._2
x._1
}

val classTreeCache =
classCache.getCache(linkedClass.version)
extractChanged(classCache.getCache(linkedClass.version))

val kind = linkedClass.kind

Expand All @@ -396,6 +410,9 @@ final class Emitter[E >: Null <: js.Tree](
withGlobals.value
}

def extractWithGlobalsAndChanged[T](x: (WithGlobals[T], Boolean)): T =
extractWithGlobals(extractChanged(x))

// Main part

val main = List.newBuilder[E]
Expand Down Expand Up @@ -426,7 +443,7 @@ final class Emitter[E >: Null <: js.Tree](
val methodCache =
classCache.getStaticLikeMethodCache(namespace, methodDef.methodName)

main ++= extractWithGlobals(methodCache.getOrElseUpdate(methodDef.version, {
main ++= extractWithGlobalsAndChanged(methodCache.getOrElseUpdate(methodDef.version, {
classEmitter.genStaticLikeMethod(className, methodDef)(moduleContext, methodCache)
.map(postTransform(_, 0))
}))
Expand Down Expand Up @@ -486,7 +503,7 @@ final class Emitter[E >: Null <: js.Tree](
}

// JS constructor
val ctorWithGlobals = {
val ctorWithGlobals = extractChanged {
/* The constructor depends both on the class version, and the version
* of the inlineable init, if there is one.
*
Expand Down Expand Up @@ -571,13 +588,13 @@ final class Emitter[E >: Null <: js.Tree](
classCache.getMemberMethodCache(method.methodName)

val version = Version.combine(isJSClassVersion, method.version)
methodCache.getOrElseUpdate(version,
extractChanged(methodCache.getOrElseUpdate(version,
classEmitter.genMemberMethod(
className, // invalidated by overall class cache
isJSClass, // invalidated by isJSClassVersion
useESClass, // invalidated by isJSClassVersion
method // invalidated by method.version
)(moduleContext, methodCache).map(postTransform(_, memberIndent)))
)(moduleContext, methodCache).map(postTransform(_, memberIndent))))
}

// Exported Members
Expand All @@ -586,13 +603,13 @@ final class Emitter[E >: Null <: js.Tree](
} yield {
val memberCache = classCache.getExportedMemberCache(idx)
val version = Version.combine(isJSClassVersion, member.version)
memberCache.getOrElseUpdate(version,
extractChanged(memberCache.getOrElseUpdate(version,
classEmitter.genExportedMember(
className, // invalidated by overall class cache
isJSClass, // invalidated by isJSClassVersion
useESClass, // invalidated by isJSClassVersion
member // invalidated by version
)(moduleContext, memberCache).map(postTransform(_, memberIndent)))
)(moduleContext, memberCache).map(postTransform(_, memberIndent))))
}

val hasClassInitializer: Boolean = {
Expand All @@ -602,7 +619,7 @@ final class Emitter[E >: Null <: js.Tree](
}
}

val fullClass = {
val fullClass = extractChanged {
val fullClassCache = classCache.getFullClassCache()

fullClassCache.getOrElseUpdate(linkedClass.version, ctorWithGlobals,
Expand Down Expand Up @@ -714,7 +731,8 @@ final class Emitter[E >: Null <: js.Tree](
main.result(),
staticFields,
staticInitialization,
trackedGlobalRefs
trackedGlobalRefs,
changed
)
}

Expand Down Expand Up @@ -751,28 +769,33 @@ final class Emitter[E >: Null <: js.Tree](
}

def getOrComputeImports(externalDependencies: Set[String], internalDependencies: Set[ModuleID])(
compute: => WithGlobals[List[E]]): WithGlobals[List[E]] = {
compute: => WithGlobals[List[E]]): (WithGlobals[List[E]], Boolean) = {

_cacheUsed = true

if (externalDependencies != _lastExternalDependencies || internalDependencies != _lastInternalDependencies) {
_importsCache = compute
_lastExternalDependencies = externalDependencies
_lastInternalDependencies = internalDependencies
(_importsCache, true)
} else {
(_importsCache, false)
}
_importsCache

}

def getOrComputeTopLevelExports(topLevelExports: List[LinkedTopLevelExport])(
compute: => WithGlobals[List[E]]): WithGlobals[List[E]] = {
compute: => WithGlobals[List[E]]): (WithGlobals[List[E]], Boolean) = {

_cacheUsed = true

if (!sameTopLevelExports(topLevelExports, _lastTopLevelExports)) {
_topLevelExportsCache = compute
_lastTopLevelExports = topLevelExports
(_topLevelExportsCache, true)
} else {
(_topLevelExportsCache, false)
}
_topLevelExportsCache
}

private def sameTopLevelExports(tles1: List[LinkedTopLevelExport], tles2: List[LinkedTopLevelExport]): Boolean = {
Expand Down Expand Up @@ -803,15 +826,17 @@ final class Emitter[E >: Null <: js.Tree](
}

def getOrComputeInitializers(initializers: List[ModuleInitializer.Initializer])(
compute: => WithGlobals[List[E]]): WithGlobals[List[E]] = {
compute: => WithGlobals[List[E]]): (WithGlobals[List[E]], Boolean) = {

_cacheUsed = true

if (initializers != _lastInitializers) {
_initializersCache = compute
_lastInitializers = initializers
(_initializersCache, true)
} else {
(_initializersCache, false)
}
_initializersCache
}

def cleanAfterRun(): Boolean = {
Expand Down Expand Up @@ -856,17 +881,18 @@ final class Emitter[E >: Null <: js.Tree](
_fullClassCache.foreach(_.startRun())
}

def getCache(version: Version): DesugaredClassCache[List[E]] = {
def getCache(version: Version): (DesugaredClassCache[List[E]], Boolean) = {
_cacheUsed = true
if (_cache == null || !_lastVersion.sameVersion(version)) {
invalidate()
statsClassesInvalidated += 1
_lastVersion = version
_cache = new DesugaredClassCache[List[E]]
(_cache, true)
} else {
statsClassesReused += 1
(_cache, false)
}
_cacheUsed = true
_cache
}

def getMemberMethodCache(
Expand Down Expand Up @@ -932,17 +958,18 @@ final class Emitter[E >: Null <: js.Tree](
def startRun(): Unit = _cacheUsed = false

def getOrElseUpdate(version: Version,
v: => WithGlobals[T]): WithGlobals[T] = {
v: => WithGlobals[T]): (WithGlobals[T], Boolean) = {
_cacheUsed = true
if (_tree == null || !_lastVersion.sameVersion(version)) {
invalidate()
statsMethodsInvalidated += 1
_tree = v
_lastVersion = version
(_tree, true)
} else {
statsMethodsReused += 1
(_tree, false)
}
_cacheUsed = true
_tree
}

def cleanAfterRun(): Boolean = {
Expand Down Expand Up @@ -974,7 +1001,7 @@ final class Emitter[E >: Null <: js.Tree](

def getOrElseUpdate(version: Version, ctor: WithGlobals[List[E]],
memberMethods: List[WithGlobals[List[E]]], exportedMembers: List[WithGlobals[List[E]]],
compute: => WithGlobals[List[E]]): WithGlobals[List[E]] = {
compute: => WithGlobals[List[E]]): (WithGlobals[List[E]], Boolean) = {

@tailrec
def allSame[A <: AnyRef](xs: List[A], ys: List[A]): Boolean = {
Expand All @@ -984,6 +1011,8 @@ final class Emitter[E >: Null <: js.Tree](
}
}

_cacheUsed = true

if (_tree == null || !version.sameVersion(_lastVersion) || (_lastCtor ne ctor) ||
!allSame(_lastMemberMethods, memberMethods) ||
!allSame(_lastExportedMembers, exportedMembers)) {
Expand All @@ -993,10 +1022,10 @@ final class Emitter[E >: Null <: js.Tree](
_lastCtor = ctor
_lastMemberMethods = memberMethods
_lastExportedMembers = exportedMembers
(_tree, true)
} else {
(_tree, false)
}

_cacheUsed = true
_tree
}

def cleanAfterRun(): Boolean = {
Expand Down Expand Up @@ -1030,7 +1059,7 @@ object Emitter {
/** Result of an emitter run. */
final class Result[E] private[Emitter](
val header: String,
val body: Map[ModuleID, List[E]],
val body: Map[ModuleID, (List[E], Boolean)],
val footer: String,
val topLevelVarDecls: List[String],
val globalRefs: Set[String]
Expand Down Expand Up @@ -1121,7 +1150,8 @@ object Emitter {
val main: List[E],
val staticFields: List[E],
val staticInitialization: List[E],
val trackedGlobalRefs: Set[String]
val trackedGlobalRefs: Set[String],
val changed: Boolean
)

private final class OneTimeCache[A >: Null] {
Expand Down
0