8000 Merge pull request #4917 from gzm0/more-cache · scala-js/scala-js@64e7725 · GitHub
[go: up one dir, main page]

Skip to content

Commit 64e7725

Browse files
authored
Merge pull request #4917 from gzm0/more-cache
Fuse emitting and printing of trees in the backend
2 parents 51a363e + 42efb2a commit 64e7725

File tree

10 files changed

+472
-295
lines changed

10 files changed

+472
-295
lines changed

linker/jvm/src/main/scala/org/scalajs/linker/backend/closure/ClosureLinkerBackend.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ final class ClosureLinkerBackend(config: LinkerBackendImpl.Config)
6060
.withTrackAllGlobalRefs(true)
6161
.withInternalModulePattern(m => OutputPatternsImpl.moduleName(config.outputPatterns, m.id))
6262

63-
new Emitter(emitterConfig)
63+
new Emitter(emitterConfig, ClosureLinkerBackend.PostTransformer)
6464
}
6565

6666
val symbolRequirements: SymbolRequirement = emitter.symbolRequirements
@@ -106,7 +106,8 @@ final class ClosureLinkerBackend(config: LinkerBackendImpl.Config)
106106
sjsModule <- moduleSet.modules.headOption
107107
} yield {
108108
val closureChunk = logger.time("Closure: Create trees)") {
109-
buildChunk(emitterResult.body(sjsModule.id))
109+
val (trees, _) = emitterResult.body(sjsModule.id)
110+
buildChunk(trees)
110111
}
111112

112113
logger.time("Closure: Compiler pass") {
@@ -295,4 +296,11 @@ private object ClosureLinkerBackend {
295296
Function.prototype.apply;
296297
var NaN = 0.0/0.0, Infinity = 1.0/0.0, undefined = void 0;
297298
"""
299+
300+
private object PostTransformer extends Emitter.PostTransformer[js.Tree] {
301+
// Do not apply ClosureAstTransformer eagerly:
302+
// The ASTs used by closure are highly mutable, so re-using them is non-trivial.
303+
// Since closure is slow anyways, we haven't built the optimization.
304+
def transformStats(trees: List[js.Tree], indent: Int): List[js.Tree] = trees
305+
}
298306
}

linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala

Lines changed: 56 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ import scala.concurrent._
1717
import java.nio.ByteBuffer
1818
import java.nio.charset.StandardCharsets
1919

20+
import java.util.concurrent.atomic.AtomicInteger
21+
2022
import org.scalajs.logging.Logger
2123

2224
import org.scalajs.linker.interface.{IRFile, OutputDirectory, Report}
@@ -36,12 +38,19 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config)
3638

3739
import BasicLinkerBackend._
3840

41+
private[this] var totalModules = 0
42+
private[this] val rewrittenModules = new AtomicInteger(0)
43+
3944
private[this] val emitter = {
4045
val emitterConfig = Emitter.Config(config.commonConfig.coreSpec)
4146
.withJSHeader(config.jsHeader)
4247
.withInternalModulePattern(m => OutputPatternsImpl.moduleName(config.outputPatterns, m.id))
4348

44-
new Emitter(emitterConfig)
49+
val postTransformer =
50+
if (config.sourceMap) PostTransformerWithSourceMap
51+
else PostTransformerWithoutSourceMap
52+
53+
new Emitter(emitterConfig, postTransformer)
4554
}
4655

4756
val symbolRequirements: SymbolRequirement = emitter.symbolRequirements
@@ -61,31 +70,35 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config)
6170
implicit ec: ExecutionContext): Future[Report] = {
6271
verifyModuleSet(moduleSet)
6372

73+
// Reset stats.
74+
75+
totalModules = moduleSet.modules.size
76+
rewrittenModules.set(0)
77+
6478
val emitterResult = logger.time("Emitter") {
6579
emitter.emit(moduleSet, logger)
6680
}
6781

6882
val skipContentCheck = !isFirstRun
6983
isFirstRun = false
7084

71-
printedModuleSetCache.startRun(moduleSet)
7285
val allChanged =
7386
printedModuleSetCache.updateGlobal(emitterResult.header, emitterResult.footer)
7487

7588
val writer = new OutputWriter(output, config, skipContentCheck) {
7689
protected def writeModuleWithoutSourceMap(moduleID: ModuleID, force: Boolean): Option[ByteBuffer] = {
7790
val cache = printedModuleSetCache.getModuleCache(moduleID)
78-
val changed = cache.update(emitterResult.body(moduleID))
91+
val (printedTrees, changed) = emitterResult.body(moduleID)
7992

8093
if (force || changed || allChanged) {
81-
printedModuleSetCache.incRewrittenModules()
94+
rewrittenModules.incrementAndGet()
8295

8396
val jsFileWriter = new ByteArrayWriter(sizeHintFor(cache.getPreviousFinalJSFileSize()))
8497

8598
jsFileWriter.write(printedModuleSetCache.headerBytes)
8699
jsFileWriter.writeASCIIString("'use strict';\n")
87100

88-
for (printedTree <- cache.printedTrees)
101+
for (printedTree <- printedTrees)
89102
jsFileWriter.write(printedTree.jsCode)
90103

91104
jsFileWriter.write(printedModuleSetCache.footerBytes)
@@ -99,10 +112,10 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config)
99112

100113
protected def writeModuleWithSourceMap(moduleID: ModuleID, force: Boolean): Option[(ByteBuffer, ByteBuffer)] = {
101114
val cache = printedModuleSetCache.getModuleCache(moduleID)
102-
val changed = cache.update(emitterResult.body(moduleID))
115+
val (printedTrees, changed) = emitterResult.body(moduleID)
103116

104117
if (force || changed || allChanged) {
105-
printedModuleSetCache.incRewrittenModules()
118+
rewrittenModules.incrementAndGet()
106119

107120
val jsFileWriter = new ByteArrayWriter(sizeHintFor(cache.getPreviousFinalJSFileSize()))
108121
val sourceMapWriter = new ByteArrayWriter(sizeHintFor(cache.getPreviousFinalSourceMapSize()))
@@ -120,7 +133,7 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config)
120133
jsFileWriter.writeASCIIString("'use strict';\n")
121134
smWriter.nextLine()
122135

123-
for (printedTree <- cache.printedTrees) {
136+
for (printedTree <- printedTrees) {
124137
jsFileWriter.write(printedTree.jsCode)
125138
smWriter.insertFragment(printedTree.sourceMapFragment)
126139
}
@@ -145,9 +158,15 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config)
145158
writer.write(moduleSet)
146159
}.andThen { case _ =>
147160
printedModuleSetCache.cleanAfterRun()
148-
printedModuleSetCache.logStats(logger)
161+
logStats(logger)
149162
}
150163
}
164+
165+
private def logStats(logger: Logger): Unit = {
166+
// Message extracted in BasicLinkerBackendTest
167+
logger.debug(
168+
s"BasicBackend: total modules: $totalModules; re-written: ${rewrittenModules.get()}")
169+
}
151170
}
152171

153172
private object BasicLinkerBackend {
@@ -161,20 +180,6 @@ private object BasicLinkerBackend {
161180

162181
private val modules = new java.util.concurrent.ConcurrentHashMap[ModuleID, PrintedModuleCache]
163182

164-
private var totalModules = 0
165-
private val rewrittenModules = new java.util.concurrent.atomic.AtomicInteger(0)
166-
167-
private var totalTopLevelTrees = 0
168-
private var recomputedTopLevelTrees = 0
169-
170-
def startRun(moduleSet: ModuleSet): Unit = {
171-
totalModules = moduleSet.modules.size
172-
rewrittenModules.set(0)
173-
174-
totalTopLevelTrees = 0
175-
recomputedTopLevelTrees = 0
176-
}
177-
178183
def updateGlobal(header: String, footer: String): Boolean = {
179184
if (header == lastHeader && footer == lastFooter) {
180185
false
@@ -193,61 +198,30 @@ private object BasicLinkerBackend {
193198
def headerNewLineCount: Int = _headerNewLineCountCache
194199

195200
def getModuleCache(moduleID: ModuleID): PrintedModuleCache = {
196-
val result = modules.computeIfAbsent(moduleID, { _ =>
197-
if (withSourceMaps) new PrintedModuleCacheWithSourceMaps
198-
else new PrintedModuleCache
199-
})
200-
201+
val result = modules.computeIfAbsent(moduleID, _ => new PrintedModuleCache)
201202
result.startRun()
202203
result
203204
}
204205

205-
def incRewrittenModules(): Unit =
206-
rewrittenModules.incrementAndGet()
207-
208206
def cleanAfterRun(): Unit = {
209207
val iter = modules.entrySet().iterator()
210208
while (iter.hasNext()) {
211209
val moduleCache = iter.next().getValue()
212-
if (moduleCache.cleanAfterRun()) {
213-
totalTopLevelTrees += moduleCache.getTotalTopLevelTrees
214-
recomputedTopLevelTrees += moduleCache.getRecomputedTopLevelTrees
215-
} else {
210+
if (!moduleCache.cleanAfterRun()) {
216211
iter.remove()
217212
}
218213
}
219214
}
220-
221-
def logStats(logger: Logger): Unit = {
222-
/* These messages are extracted in BasicLinkerBackendTest to assert that
223-
* we do not invalidate anything in a no-op second run.
224-
*/
225-
logger.debug(
226-
s"BasicBackend: total top-level trees: $totalTopLevelTrees; re-computed: $recomputedTopLevelTrees")
227-
logger.debug(
228-
s"BasicBackend: total modules: $totalModules; re-written: ${rewrittenModules.get()}")
229-
}
230-
}
231-
232-
private final class PrintedTree(val jsCode: Array[Byte], val sourceMapFragment: SourceMapWriter.Fragment) {
233-
var cachedUsed: Boolean = false
234215
}
235216

236217
private sealed class PrintedModuleCache {
237218
private var cacheUsed = false
238-
private var changed = false
239-
private var lastJSTrees: List[js.Tree] = Nil
240-
private var printedTreesCache: List[PrintedTree] = Nil
241-
private val cache = new java.util.IdentityHashMap[js.Tree, PrintedTree]
242219

243220
private var previousFinalJSFileSize: Int = 0
244221
private var previousFinalSourceMapSize: Int = 0
245222

246-
private var recomputedTopLevelTrees = 0
247-
248223
def startRun(): Unit = {
249224
cacheUsed = true
250-
recomputedTopLevelTrees = 0
251225
}
252226

253227
def getPreviousFinalJSFileSize(): Int = previousFinalJSFileSize
@@ -259,72 +233,42 @@ private object BasicLinkerBackend {
259233
previousFinalSourceMapSize = finalSourceMapSize
260234
}
261235

262-
def update(newJSTrees: List[js.Tree]): Boolean = {
263-
val changed = !newJSTrees.corresponds(lastJSTrees)(_ eq _)
264-
this.changed = changed
265-
if (changed) {
266-
printedTreesCache = newJSTrees.map(getOrComputePrintedTree(_))
267-
lastJSTrees = newJSTrees
268-
}
269-
changed
270-
}
271-
272-
private def getOrComputePrintedTree(tree: js.Tree): PrintedTree = {
273-
val result = cache.computeIfAbsent(tree, { (tree: js.Tree) =>
274-
recomputedTopLevelTrees += 1
275-
computePrintedTree(tree)
276-
})
277-
278-
result.cachedUsed = true
279-
result
280-
}
281-
282-
protected def computePrintedTree(tree: js.Tree): PrintedTree = {
283-
val jsCodeWriter = new ByteArrayWriter()
284-
val printer = new Printers.JSTreePrinter(jsCodeWriter)
285-
286-
printer.printStat(tree)
287-
288-
new PrintedTree(jsCodeWriter.toByteArray(), SourceMapWriter.Fragment.Empty)
236+
def cleanAfterRun(): Boolean = {
237+
val wasUsed = cacheUsed
238+
cacheUsed = false
239+
wasUsed
289240
}
241+
}
290242

291-
def printedTrees: List[PrintedTree] = printedTreesCache
243+
private object PostTransformerWithoutSourceMap extends Emitter.PostTransformer[js.PrintedTree] {
244+
def transformStats(trees: List[js.Tree], indent: Int): List[js.PrintedTree] = {
245+
if (trees.isEmpty) {
246+
Nil // Fast path
247+
} else {
248+
val jsCodeWriter = new ByteArrayWriter()
249+
val printer = new Printers.JSTreePrinter(jsCodeWriter, indent)
292250

293-
def cleanAfterRun(): Boolean = {
294-
if (cacheUsed) {
295-
cacheUsed = false
296-
297-
if (changed) {
298-
val iter = cache.entrySet().iterator()
299-
while (iter.hasNext()) {
300-
val printedTree = iter.next().getValue()
301-
if (printedTree.cachedUsed)
302-
printedTree.cachedUsed = false
303-
else
304-
iter.remove()
305-
}
306-
}
251+
trees.map(printer.printStat(_))
307252

308-
true
309-
} else {
310-
false
253+
js.PrintedTree(jsCodeWriter.toByteArray(), SourceMapWriter.Fragment.Empty) :: Nil
311254
}
312255
}
313-
314-
def getTotalTopLevelTrees: Int = lastJSTrees.size
315-
def getRecomputedTopLevelTrees: Int = recomputedTopLevelTrees
316256
}
317257

318-
private final class PrintedModuleCacheWithSourceMaps extends PrintedModuleCache {
319-
override protected def computePrintedTree(tree: js.Tree): PrintedTree = {
320-
val jsCodeWriter = new ByteArrayWriter()
321-
val smFragmentBuilder = new SourceMapWriter.FragmentBuilder()
322-
val printer = new Printers.JSTreePrinterWithSourceMap(jsCodeWriter, smFragmentBuilder)
258+
private object PostTransformerWithSourceMap extends Emitter.PostTransformer[js.PrintedTree] {
259+
def transformStats(trees: List[js.Tree], indent: Int): List[js.PrintedTree] = {
260+
if (trees.isEmpty) {
261+
Nil // Fast path
262+
} else {
263+
val jsCodeWriter = new ByteArrayWriter()
264+
val smFragmentBuilder = new SourceMapWriter.FragmentBuilder()
265+
val printer = new Printers.JSTreePrinterWithSourceMap(jsCodeWriter, smFragmentBuilder, indent)
323266

324-
printer.printStat(tree)
325-
smFragmentBuilder.complete()
267+
trees.map(printer.printStat(_))
268+
smFragmentBuilder.complete()
326269

327-
new PrintedTree(jsCodeWriter.toByteArray(), smFragmentBuilder.result())
270+
js.PrintedTree(jsCodeWriter.toByteArray(), smFragmentBuilder.result()) :: Nil
271+
}
328272
}
329273
}
330274
}

linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/ClassEmitter.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ private[emitter] final class ClassEmitter(sjsGen: SJSGen) {
4545

4646
def buildClass(className: ClassName, isJSClass: Boolean, jsClassCaptures: Option[List[ParamDef]],
4747
hasClassInitializer: Boolean,
48-
superClass: Option[ClassIdent], storeJSSuperClass: Option[js.Tree], useESClass: Boolean,
48+
superClass: Option[ClassIdent], storeJSSuperClass: List[js.Tree], useESClass: Boolean,
4949
members: List[js.Tree])(
5050
implicit moduleContext: ModuleContext,
5151
globalKnowledge: GlobalKnowledge, pos: Position): WithGlobals[List[js.Tree]] = {
@@ -75,7 +75,7 @@ private[emitter] final class ClassEmitter(sjsGen: SJSGen) {
7575
val createClassValueVar = genEmptyMutableLet(classValueIdent)
7676

7777
val entireClassDefWithGlobals = if (useESClass) {
78-
genJSSuperCtor(superClass, storeJSSuperClass.isDefined).map { jsSuperClass =>
78+
genJSSuperCtor(superClass, storeJSSuperClass.nonEmpty).map { jsSuperClass =>
7979
List(classValueVar := js.ClassDef(Some(classValueIdent), Some(jsSuperClass), members))
8080
}
8181
} else {
@@ -86,7 +86,7 @@ private[emitter] final class ClassEmitter(sjsGen: SJSGen) {
8686
entireClassDef <- entireClassDefWithGlobals
8787
createStaticFields <- genCreateStaticFieldsOfJSClass(className)
8888
} yield {
89-
storeJSSuperClass.toList ::: entireClassDef ::: createStaticFields
89+
storeJSSuperClass ::: entireClassDef ::: createStaticFields
9090
}
9191

9292
jsClassCaptures.fold {

linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/CoreJSLib.scala

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ import PolyfillableBuiltin._
3232

3333
private[emitter] object CoreJSLib {
3434

35-
def build(sjsGen: SJSGen, moduleContext: ModuleContext,
36-
globalKnowledge: GlobalKnowledge): WithGlobals[Lib] = {
37-
new CoreJSLibBuilder(sjsGen)(moduleContext, globalKnowledge).build()
35+
def build[E](sjsGen: SJSGen, postTransform: List[Tree] => E, moduleContext: ModuleContext,
36+
globalKnowledge: GlobalKnowledge): WithGlobals[Lib[E]] = {
37+
new CoreJSLibBuilder(sjsGen)(moduleContext, globalKnowledge).build(postTransform)
3838
}
3939

4040
/** A fully built CoreJSLib
@@ -52,10 +52,10 @@ private[emitter] object CoreJSLib {
5252
* @param initialization Things that depend on Scala.js generated classes.
5353
* These must have class definitions (but not static fields) available.
5454
*/
55-
final class Lib private[CoreJSLib] (
56-
val preObjectDefinitions: List[Tree],
57-
val postObjectDefinitions: List[Tree],
58-
val initialization: L BAD0 ist[Tree])
55+
final class Lib[E] private[CoreJSLib] (
56+
val preObjectDefinitions: E,
57+
val postObjectDefinitions: E,
58+
val initialization: E)
5959

6060
private class CoreJSLibBuilder(sjsGen: SJSGen)(
6161
implicit moduleContext: ModuleContext, globalKnowledge: GlobalKnowledge) {
@@ -115,9 +115,11 @@ private[emitter] object CoreJSLib {
115115
private val specializedArrayTypeRefs: List[NonArrayTypeRef] =
116116
ClassRef(ObjectClass) :: orderedPrimRefsWithoutVoid
117117

118-
def build(): WithGlobals[Lib] = {
119-
val lib = new Lib(buildPreObjectDefinitions(),
120-
buildPostObjectDefinitions(), buildInitializations())
118+
def build[E](postTransform: List[Tree] => E): WithGlobals[Lib[E]] = {
119+
val lib = new Lib(
120+
postTransform(buildPreObjectDefinitions()),
121+
postTransform(buildPostObjectDefinitions()),
122+
postTransform(buildInitializations()))
121123
WithGlobals(lib, trackedGlobalRefs)
122124
}
123125

0 commit comments

Comments
 (0)
0