8000 Wasm: Implement PriorityQueue without js.Array in Wasm backend · scala-js/scala-js@da5c453 · GitHub
[go: up one dir, main page]

Skip to content

Commit da5c453

Browse files
committed
Wasm: Implement PriorityQueue without js.Array in Wasm backend
Current js.Array-based PriorityQueue implementation on Wasm requires JS-interop for every operation, and JS-interop is very slow. We use a `linkTimeIf` to select a `scala.Array`-based implementation of internal data structure for PriorityQueue, based on wether it is on JS or WebAssembly.
1 parent 283f719 commit da5c453

File tree

1 file changed

+172
-49
lines changed

1 file changed

+172
-49
lines changed

javalib/src/main/scala/java/util/PriorityQueue.scala

Lines changed: 172 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,46 @@
1212

1313
package java.util
1414

15+
import java.lang.Utils.roundUpToPowerOfTwo
16+
1517
import scala.annotation.tailrec
1618

1719
import scala.scalajs.js
20+
import scala.scalajs.LinkingInfo
1821

1922
class PriorityQueue[E] private (
20-
private val comp: Comparator[_ >: E], internal: Boolean)
23+
private val comp: Comparator[_ >: E], internal: Boolean, initialCapacity: Int)
2124
extends AbstractQueue[E] with Serializable {
2225

2326
def this() =
24-
this(NaturalComparator, internal = true)
27+
this(NaturalComparator, internal = true, initialCapacity = 16)
2528

2629
def this(initialCapacity: Int) = {
27-
this()
28-
if (initialCapacity < 1)
29-
throw new IllegalArgumentException()
30+
this(
31+
NaturalComparator,
32+
internal = true,
33+
{
34+
if (initialCapacity < 1)
35+
throw new IllegalArgumentException
36+
initialCapacity
37+
}
38+
)
3039
}
3140

3241
def this(comparator: Comparator[_ >: E]) = {
33-
this(NaturalComparator.select(comparator), internal = true)
42+
this(NaturalComparator.select(comparator), internal = true, initialCapacity = 16)
3443
}
3544

3645
def this(initialCapacity: Int, comparator: Comparator[_ >: E]) = {
37-
this(comparator)
38-
if (initialCapacity < 1)
39-
throw new IllegalArgumentException()
46+
this(
47+
NaturalComparator.select(comparator),
48+
internal = true,
49+
{
50+
if (initialCapacity < 1)
51+
throw new IllegalArgumentException()
52+
initialCapacity
53+
}
54+
)
4055
}
4156

4257
def this(c: Collection[_ <: E]) = {
@@ -47,47 +62,74 @@ class PriorityQueue[E] private (
4762
NaturalComparator.select(c.comparator().asInstanceOf[Comparator[_ >: E]])
4863
case _ =>
4964
NaturalComparator
50-
}, internal = true)
65+
}, internal = true, roundUpToPowerOfTwo(c.size()))
5166
addAll(c)
5267
}
5368

5469
def this(c: PriorityQueue[_ <: E]) = {
55-
this(c.comp.asInstanceOf[Comparator[_ >: E]], internal = true)
70+
this(c.comp.asInstanceOf[Comparator[_ >: E]], internal = true, roundUpToPowerOfTwo(c.size()))
5671
addAll(c)
5772
}
5873

5974
def this(sortedSet: SortedSet[_ <: E]) = {
6075
this(NaturalComparator.select(
6176
sortedSet.comparator().asInstanceOf[Comparator[_ >: E]]),
62-
internal = true)
77+
internal = true,
78+
roundUpToPowerOfTwo(sortedSet.size()))
6379
addAll(sortedSet)
6480
}
6581

82+
/* Get the best available implementation of inner array for the given platform.
83+
*
84+
* Use Array[AnyRef] in WebAssembly to avoid JS-interop. In JS, use js.Array.
85+
* It is resizable by nature, so manual resizing is not needed.
86+
*
87+
* `linkTimeIf` is needed here to ensure the optimizer knows
88+
* there is only one implementation of `InnerArrayImpl`, and de-virtualize/inline
89+
* the function calls.
90+
*/
91+
92+
private val innerImpl: InnerArrayImpl =
93+
LinkingInfo.linkTimeIf[InnerArrayImpl](LinkingInfo.isWebAssembly) {
94+
InnerArrayImpl.JArrayImpl
95+
} {
96+
InnerArrayImpl.JSArrayImpl
97+
}
98+
99+
private var inner: innerImpl.Repr = innerImpl.make(initialCapacity)
100+
101+
// Wasm only
66102
// The index 0 is not used; the root is at index 1.
67103
// This is standard practice in binary heaps, to simplify arithmetics.
68-
private[this] val inner = js.Array[E](null.asInstanceOf[E])
104+
private var _size = 1
69105

70106
override def add(e: E): Boolean = {
71107
if (e == null)
72108
throw new NullPointerException()
73-
inner.push(e)
74-
fixUp(inner.length - 1)
109+
110+
if (LinkingInfo.isWebAssembly) {
111+
val minCapacity = innerImpl.length(inner) + 1
112+
if (innerImpl.capacity(inner) < minCapacity)
113+
inner = innerImpl.resized(inner, minCapacity)
114+
}
115+
innerImpl.push(inner, e)
116+
fixUp(innerImpl.length(inner) - 1)
75117
true
76118
}
77119

78120
def offer(e: E): Boolean = add(e)
79121

80122
def peek(): E =
81-
if (inner.length > 1) inner(1)
123+
if (innerImpl.length(inner) > 1) innerImpl.get(inner, 1)
82124
else null.asInstanceOf[E]
83125

84126
override def remove(o: Any): Boolean = {
85127
if (o == null) {
86128
false
87129
} else {
88-
val len = inner.length
130+
val len = innerImpl.length(inner)
89131
var i = 1
90-
while (i != len && !o.equals(inner(i))) {
132+
while (i != len && !o.equals(innerImpl.get(inner, i))) {
91133
i += 1
92134
}
93135

@@ -101,9 +143,9 @@ class PriorityQueue[E] private (
101143
}
102144

103145
private def removeExact(o: Any): Unit = {
104-
val len = inner.length
146+
val len = innerImpl.length(inner)
105147
var i = 1
106-
while (i != len && (o.asInstanceOf[AnyRef] ne inner(i).asInstanceOf[AnyRef])) {
148+
while (i != len && (o.asInstanceOf[AnyRef] ne innerImpl.get(inner, i).asInstanceOf[AnyRef])) {
107149
i += 1
108150
}
109151
if (i == len)
@@ -112,23 +154,25 @@ class PriorityQueue[E] private (
112154
}
113155

114156
private def removeAt(i: Int): Unit = {
115-
val newLength = inner.length - 1
157+
val newLength = innerImpl.length(inner) - 1
116158
if (i == newLength) {
117-
inner.length = newLength
159+
innerImpl.setLength(inner, newLength)
118160
} else {
119-
inner(i) = inner(newLength)
120-
inner.length = newLength
161+
innerImpl.set(inner, i, innerImpl.get(inner, newLength))
162+
innerImpl.setLength(inner, newLength)
121163
fixUpOrDown(i)
122164
}
165+
if (LinkingInfo.isWebAssembly)
166+
innerImpl.set(inner, innerImpl.length(inner), null.asInstanceOf[E]) // free reference for GC
123167
}
124168

125169
override def contains(o: Any): Boolean = {
126170
if (o == null) {
127171
false
128172
} else {
129-
val len = inner.length
173+
val len = innerImpl.length(inner)
130174
var i = 1
131-
while (i != len && !o.equals(inner(i))) {
175+
while (i != len && !o.equals(innerImpl.get(inner, i))) {
132176
i += 1
133177
}
134178
i != len
@@ -137,16 +181,20 @@ class PriorityQueue[E] private (
137181

138182
def iterator(): Iterator[E] = {
139183
new Iterator[E] {
140-
private[this] var inner: js.Array[E] = PriorityQueue.this.inner
184+
private[this] var inner: innerImpl.Repr = PriorityQueue.this.inner
185+
// Wasm only
186+
private[this] var innerIterSize: Int = innerImpl.length(PriorityQueue.this.inner)
141187
private[this] var nextIdx: Int = 1
142188
private[this] var last: E = _ // null
143189

144-
def hasNext(): Boolean = nextIdx < inner.length
190+
def hasNext(): Boolean =
191+
if (LinkingInfo.isWebAssembly) nextIdx < innerIterSize
192+
else nextIdx < innerImpl.length(inner)
145193

146194
def next(): E = {
147195
if (!hasNext())
148196
throw new NoSuchElementException("empty iterator")
149-
last = inner(nextIdx)
197+
last = innerImpl.get(inner, nextIdx)
150198
nextIdx += 1
151199
last
152200
}
@@ -173,7 +221,9 @@ class PriorityQueue[E] private (
173221
if (last == null)
174222
throw new IllegalStateException()
175223
if (inner eq PriorityQueue.this.inner) {
176-
inner = inner.jsSlice(nextIdx)
224+
if (LinkingInfo.isWebAssembly)
225+
innerIterSize = innerImpl.length(inner) - nextIdx
226+
inner = innerImpl.copyFrom(inner, nextIdx)
177227
nextIdx = 0
178228
}
179229
removeExact(last)
@@ -182,19 +232,21 @@ class PriorityQueue[E] private (
182232
}
183233
}
184234

185-
def size(): Int = inner.length - 1
235+
def size(): Int = innerImpl.length(inner) - 1
186236

187237
override def clear(): Unit =
188-
inner.length = 1
238+
innerImpl.clear(inner)
189239

190240
def poll(): E = {
191241
val inner = this.inner // local copy
192-
if (inner.length > 1) {
193-
val newSize = inner.length - 1
194-
val result = inner(1)
195-
inner(1) = inner(newSize)
196-
inner.length = newSize
242+
if (innerImpl.length(inner) > 1) {
243+
val newSize = innerImpl.length(inner) - 1
244+
val result = innerImpl.get(inner, 1)
245+
innerImpl.set(inner, 1, innerImpl.get(inner, newSize))
246+
innerImpl.setLength(inner, newSize)
197247
fixDown(1)
248+
if (LinkingInfo.isWebAssembly)
249+
innerImpl.set(inner, newSize, null.asInstanceOf[E]) // free reference for GC
198250
result
199251
} else {
200252
null.asInstanceOf[E]
@@ -212,7 +264,7 @@ class PriorityQueue[E] private (
212264
*/
213265
private[this] def fixUpOrDown(m: Int): Unit = {
214266
val inner = this.inner // local copy
215-
if (m > 1 && comp.compare(inner(m >> 1), inner(m)) > 0)
267+
if (m > 1 && comp.compare(innerImpl.get(inner, m >> 1), innerImpl.get(inner, m)) > 0)
216268
fixUp(m)
217269
else
218270
fixDown(m)
@@ -227,18 +279,18 @@ class PriorityQueue[E] private (
227279
/* At each step, even though `m` changes, the element moves with it, and
228280
* hence inner(m) is always the same initial `innerAtM`.
229281
*/
230-
val innerAtM = inner(m)
282+
val innerAtM = innerImpl.get(inner, m)
231283

232284
@inline @tailrec
233285
def loop(m: Int): Unit = {
234286
if (m > 1) {
235287
val parent = m >> 1
236-
val innerAtParent = inner(parent)
288+
val innerAtParent = innerImpl.get(inner, parent)
237289
if (comp.compare(innerAtParent, innerAtM) > 0) {
238-
inner(parent) = innerAtM
239-
inner(m) = innerAtParent
240-
loop(parent)
290+
innerImpl.set(inner, parent, innerAtM)
291+
innerImpl.set(inner, m, innerAtParent)
241292
}
293+
loop(parent)
242294
}
243295
}
244296

@@ -250,22 +302,22 @@ class PriorityQueue[E] private (
250302
*/
251303
private[this] def fixDown(m: Int): Unit = {
252304
val inner = this.inner // local copy
253-
val size = inner.length - 1
305+
val size = innerImpl.length(inner) - 1
254306

255307
/* At each step, even though `m` changes, the element moves with it, and
256308
* hence inner(m) is always the same initial `innerAtM`.
257309
*/
258-
val innerAtM = inner(m)
310+
val innerAtM = innerImpl.get(inner, m)
259311

260312
@inline @tailrec
261313
def loop(m: Int): Unit = {
262314
var j = 2 * m // left child of `m`
263315
if (j <= size) {
264-
var innerAtJ = inner(j)
316+
var innerAtJ = innerImpl.get(inner, j)
265317

266318
// if the left child is greater than the right child, switch to the right child
267319
if (j < size) {
268-
val innerAtJPlus1 = inner(j + 1)
320+
val innerAtJPlus1 = innerImpl.get(inner, j + 1)
269321
if (comp.compare(innerAtJ, innerAtJPlus1) > 0) {
270322
j += 1
271323
innerAtJ = innerAtJPlus1
@@ -274,13 +326,84 @@ class PriorityQueue[E] private (
274326

275327
// if the node `m` is greater than the selected child, swap and recurse
276328
if (comp.compare(innerAtM, innerAtJ) > 0) {
277-
inner(m) = innerAtJ
278-
inner(j) = innerAtM
329+
innerImpl.set(inner, m, innerAtJ)
330+
innerImpl.set(inner, j, innerAtM)
279331
loop(j)
280332
}
281333
}
282334
}
283335

284336
loop(m)
285337
}
338+
339+
private sealed abstract class InnerArrayImpl {
340+
type Repr <: AnyRef
341+
342+
def make(initialCapacity: Int): Repr
343+
def length(v: Repr): Int
344+
/** Set the length of innerArray.
345+
*
346+
* In WebAssembly, freeing the reference for GC is needed
347+
* when we shrink the inner array.
348+
*/
349+
def setLength(v: Repr, newLength: Int): Unit
350+
def get(v: Repr, index: Int): E
351+
def set(v: Repr, index: Int, e: E): Unit
352+
def push(v: Repr, e: E): Unit
353+
/** Wasm only. */
354+
def resized(v: Repr, minCapacity: Int): Repr
355+
/** Wasm only. */
356+
def capacity(v: Repr): Int
357+
def copyFrom(v: Repr, from: Int): Repr
358+
def clear(v: Repr): Unit
359+
}
360+
361+
private object InnerArrayImpl {
362+
object JSArrayImpl extends InnerArrayImpl {
363+
type Repr = js.Array[E]
364+
365+
// The index 0 is not used; the root is at index 1.
366+
// This is standard practice in binary heaps, to simplify arithmetics.
367+
@inline def make(_initialCapacity: Int): Repr = js.Array[E](null.asInstanceOf[E])
368+
@inline def length(v: Repr): Int = v.length
369+
@inline def setLength(v: Repr, newLength: Int): Unit =
370+
v.length = newLength
371+
@inline def get(v: Repr, index: Int): E = v(index)
372+
@inline def set(v: Repr, index: Int, e: E): Unit =
373+
v(index) = e
374+
@inline def push(v: Repr, e: E): Unit =
375+
v.push(e)
376+
@inline def resized(v: Repr, minCapacity: Int): Repr = v // no used
377+
@inline def capacity(v: Repr): Int = 0 // no used
378+
@inline def copyFrom(v: Repr, from: Int): Repr =
379+
v.jsSlice(from)
380+
@inline def clear(v: Repr): Unit =
381+
v.length = 1
382+
}
383+
384+
object JArrayImpl extends InnerArrayImpl {
385+
type Repr = Array[AnyRef]
386+
387+
@inline def make(initialCapacity: Int): Repr = new Array[AnyRef](initialCapacity)
388+
@inline def length(v: Repr): Int = _size
389+
@inline def setLength(v: Repr, newLength: Int): Unit =
390+
_size = newLength
391+
@inline def get(v: Repr, index: Int): E = v(index).asInstanceOf[E]
392+
@inline def set(v: Repr, index: Int, e: E): Unit =
393+
v(index) = e.asInstanceOf[AnyRef]
394+
@inline def push(v: Repr, e: E): Unit = {
395+
v(_size) = e.asInstanceOf[AnyRef]
396+
_size += 1
397+
}
398+
@inline def resized(v: Repr, minCapacity: Int): Repr =
399+
Arrays.copyOf(v, roundUpToPowerOfTwo(minCapacity))
400+
@inline def capacity(v: Repr): Int = v.length
401+
@inline def copyFrom(v: Repr, from: Int): Repr =
402+
Arrays.copyOfRange(v, from, _size)
403+
@inline def clear(v: Repr): Unit = {
404+
Arrays.fill(v, null)
405+
_size = 1
406+
}
407+
}
408+
}
286409
}

0 commit comments

Comments
 (0)
0