8000 Add `js.async { ... js.await(p) ... }` blocks. · sjrd/scala-js@ad8f830 · GitHub
[go: up one dir, main page]

Skip to content

Commit ad8f830

Browse files
committed
Add js.async { ... js.await(p) ... } blocks.
We introduce a new pair of primitive methods, `js.async` and `js.await`. They correspond to JavaScript `async` functions and `await` expressions. `js.await(p)` awaits a `Promise`, but it must be directly scoped within a `js.async { ... }` block. --- At the IR level, `js.await(p)` directly translates to a dedicated IR node `JSAwait(arg)`. `js.async` blocks don't have a direct representation. Instead, the IR models `async function`s and `async =>`functions, as `Closure`s with an additionnal `async` flag. This corresponds to the JavaScript model for `async/await`. A `js.async { body }` block therefore corresponds to an immediately-applied `async Closure`, which in JavaScript would be written as `(async () => body)()`. --- On the JavaScript platform, async `Closure`s and `JSAwait` are directly compiled as their JavaScript equivalent. On WebAssembly, we leverage the JavaScript Promise Integration feature (JSPI). We turn async `Closure`s into WebAssembly.promising` functions. `js.await(p)` is compiled as a call to a unique `jsAwait` helper, which is declared as an "identity" `new WebAssembly.Suspending((x) => x)`. The static scoping rule of `js.await` guarantees that the suspending call is always performed in a valid context (one that will not throw a `WebAssembly.SuspendError`).
1 parent 70a4164 commit ad8f830

File tree

38 files changed

+589
-55
lines changed

38 files changed

+589
-55
lines changed

Jenkinsfile

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -365,59 +365,75 @@ def Tasks = [
365365
npm install &&
366366
sbtretry ++$scala \
367367
'set Global/enableWasmEverywhere := true' \
368+
'set scalaJSLinkerConfig in helloworld.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \
368369
helloworld$v/run &&
369370
sbtretry ++$scala \
370371
'set Global/enableWasmEverywhere := true' \
372+
'set scalaJSLinkerConfig in helloworld.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \
371373
'set scalaJSStage in Global := FullOptStage' \
372374
'set scalaJSLinkerConfig in helloworld.v$v ~= (_.withPrettyPrint(true))' \
373375
helloworld$v/run &&
374376
sbtretry ++$scala \
375377
'set Global/enableWasmEverywhere := true' \
378+
'set scalaJSLinkerConfig in reversi.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \
376379
reversi$v/fastLinkJS \
377380
reversi$v/fullLinkJS &&
378381
sbtretry ++$scala \
379382
'set Global/enableWasmEverywhere := true' \
380-
jUnitTestOutputsJVM$v/test jUnitTestOutputsJS$v/test testBridge$v/test \
381-
'set scalaJSStage in Global := FullOptStage' jUnitTestOutputsJS$v/test testBridge$v/test &&
383+
'set scalaJSLinkerConfig in jUnitTestOutputsJS.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \
384+
'set scalaJSLinkerConfig in testBridge.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \
385+
jUnitTestOutputsJS$v/test testBridge$v/test \
386+
'set scalaJSStage in Global := FullOptStage' \
387+
jUnitTestOutputsJS$v/test testBridge$v/test &&
382388
sbtretry ++$scala \
383389
'set Global/enableWasmEverywhere := true' \
390+
'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \
384391
$testSuite$v/test &&
385392
sbtretry ++$scala \
386393
'set Global/enableWasmEverywhere := true' \
394+
'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \
387395
'set scalaJSStage in Global := FullOptStage' \
388396
$testSuite$v/test &&
389397
sbtretry ++$scala \
390398
'set Global/enableWasmEverywhere := true' \
399+
'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \
391400
'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withOptimizer(false))' \
392401
$testSuite$v/test &&
393402
sbtretry ++$scala \
394403
'set Global/enableWasmEverywhere := true' \
404+
'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \
395405
'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withOptimizer(false))' \
396406
'set scalaJSStage in Global := FullOptStage' \
397407
$testSuite$v/test &&
398408
sbtretry ++$scala \
399409
'set Global/enableWasmEverywhere := true' \
410+
'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \
400411
'set scalaJSLinkerConfig in $testSuite.v$v ~= makeCompliant' \
401412
$testSuite$v/test &&
402413
sbtretry ++$scala \
403414
'set Global/enableWasmEverywhere := true' \
415+
'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \
404416
'set scalaJSLinkerConfig in $testSuite.v$v ~= makeCompliant' \
405417
'set scalaJSStage in Global := FullOptStage' \
406418
$testSuite$v/test &&
407419
sbtretry ++$scala \
408420
'set Global/enableWasmEverywhere := true' \
421+
'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \
409422
'set scalaJSLinkerConfig in $testSuite.v$v ~= makeCompliant' \
410423
'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withOptimizer(false))' \
411424
$testSuite$v/test &&
412425
sbtretry ++$scala \
413426
'set Global/enableWasmEverywhere := true' \
427+
'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \
414428
testingExample$v/testHtml &&
415429
sbtretry ++$scala \
416430
'set Global/enableWasmEverywhere := true' \
431+
'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \
417432
'set scalaJSStage in Global := FullOptStage' \
418433
testingExample$v/testHtml &&
419434
sbtretry ++$scala \
420435
'set Global/enableWasmEverywhere := true' \
436+
'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \
421437
irJS$v/fastLinkJS
422438
''',
423439

@@ -551,6 +567,8 @@ def allESVersions = [
551567
"ES2020",
552568
"ES2021" // We do not use anything specifically from ES2021, but always test the latest to avoid #4675
553569
]
570+
def defaultESVersion = "ES2015"
571+
def latestESVersion = "ES2021"
554572

555573
// The 'quick' matrix
556574
def quickMatrix = []
@@ -562,11 +580,12 @@ mainScalaVersions.each { scalaVersion ->
562580
quickMatrix.add([task: "test-suite-default-esversion", scala: scalaVersion, java: mainJavaVersion, testMinify: "false", testSuite: "testSuite"])
563581
quickMatrix.add([task: "test-suite-default-esversion", scala: scalaVersion, java: mainJavaVersion, testMinify: "true", testSuite: "testSuite"])
564582
quickMatrix.add([task: "test-suite-custom-esversion", scala: scalaVersion, java: mainJavaVersion, esVersion: "ES5_1", testSuite: "testSuite"])
565-
quickMatrix.add([task: "test-suite-webassembly", scala: scalaVersion, java: mainJavaVersion, testMinify: "false", testSuite: "testSuite"])
566-
quickMatrix.add([task: "test-suite-webassembly", scala: scalaVersion, java: mainJavaVersion, testMinify: "false", testSuite: "testSuiteEx"])
583+
quickMatrix.add([task: "test-suite-webassembly", scala: scalaVersion, java: mainJavaVersion, esVersion: defaultESVersion, testMinify: "false", testSuite: "testSuite"])
584+
quickMatrix.add([task: "test-suite-webassembly", scala: scalaVersion, java: mainJavaVersion, esVersion: latestESVersion, testMinify: "false", testSuite: "testSuite"])
585+
quickMatrix.add([task: "test-suite-webassembly", scala: scalaVersion, java: mainJavaVersion, esVersion: defaultESVersion, testMinify: "false", testSuite: "testSuiteEx"])
567586
quickMatrix.add([task: "test-suite-default-esversion", scala: scalaVersion, java: mainJavaVersion, testMinify: "false", testSuite: "scalaTestSuite"])
568587
quickMatrix.add([task: "test-suite-custom-esversion", scala: scalaVersion, java: mainJavaVersion, esVersion: "ES5_1", testSuite: "scalaTestSuite"])
569-
quickMatrix.add([task: "test-suite-webassembly", scala: scalaVersion, java: mainJavaVersion, testMinify: "false", testSuite: "scalaTestSuite"])
588+
quickMatrix.add([task: "test-suite-webassembly", scala: scalaVersion, java: mainJavaVersion, esVersion: defaultESVersion, testMinify: "false", testSuite: "scalaTestSuite"])
570589
quickMatrix.add([task: "bootstrap", scala: scalaVersion, java: mainJavaVersion])
571590
quickMatrix.add([task: "partest-fastopt", scala: scalaVersion, java: mainJavaVersion, partestopts: ""])
572591
quickMatrix.add([task: "partest-fastopt", scala: scalaVersion, java: mainJavaVersion, partestopts: "--wasm"])
@@ -591,7 +610,7 @@ otherScalaVersions.each { scalaVersion ->
591610
mainScalaVersions.each { scalaVersion ->
592611
otherJavaVersions.each { javaVersion ->
593612
quickMatrix.add([task: "test-suite-default-esversion", scala: scalaVersion, java: javaVersion, testMinify: "false", testSuite: "testSuite"])
594-
quickMatrix.add([task: "test-suite-webassembly", scala: scalaVersion, java: mainJavaVersion, testMinify: "false", testSuite: "testSuite"])
613+
quickMatrix.add([task: "test-suite-webassembly", scala: scalaVersion, java: mainJavaVersion, esVersion: defaultESVersion, testMinify: "false", testSuite: "testSuite"])
595614
}
596615
fullMatrix.add([task: "partest-noopt", scala: scalaVersion, java: mainJavaVersion, partestopts: ""])
597616
fullMatrix.add([task: "partest-noopt", scala: scalaVersion, java: mainJavaVersion, partestopts: "--wasm"])

compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
167167
private val fieldsMutatedInCurrentClass = new ScopedVar[mutable.Set[Name]]
168168
private val generatedSAMWrapperCount = new ScopedVar[VarBox[Int]]
169169
private val delambdafyTargetDefDefs = new ScopedVar[mutable.Map[Symbol, DefDef]]
170+
private val methodsAllowingJSAwait = new ScopedVar[mutable.Set[Symbol]]
170171

171172
def currentThisTypeNullable: jstpe.Type =
172173
encodeClassType(currentClassSym)
@@ -241,6 +242,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
241242
fieldsMutatedInCurrentClass := mutable.Set.empty,
242243
generatedSAMWrapperCount := new VarBox(0),
243244
delambdafyTargetDefDefs := mutable.Map.empty,
245+
methodsAllowingJSAwait := mutable.Set.empty,
244246
currentMethodSym := null,
245247
thisLocalVarName := null,
246248
enclosingLabelDefInfos := null,
@@ -469,7 +471,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
469471
currentClassSym := sym,
470472
fieldsMutatedInCurrentClass := mutable.Set.empty,
471473
generatedSAMWrapperCount := new VarBox(0),
472-
delambdafyTargetDefDefs := mutable.Map.empty
474+
delambdafyTargetDefDefs := mutable.Map.empty,
475+
methodsAllowingJSAwait := mutable.Set.empty
473476
) {
474477
val tree = if (isJSType(sym)) {
475478
if (!sym.isTraitOrInterface && isNonNativeJSClass(sym) &&
@@ -5333,6 +5336,35 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
53335336
// js.import.meta
53345337
js.JSImportMeta()
53355338

5339+
case JS_ASYNC =>
5340+
// js.async(arg)
5341+
assert(args.size == 1,
5342+
s"Expected exactly 1 argument for JS primitive $code but got " +
5343+
s"${args.size} at $pos")
5344+
val Block(stats, fun @ Function(_, Apply(target, _))) = args.head
5345+
methodsAllowingJSAwait += target.symbol
5346+
val genStats = stats.map(genStat(_))
5347+
val asyncExpr = genAnonFunction(fun) match {
5348+
case js.NewLambda(_, closure: js.Closure)
5349+
if closure.params.isEmpty && closure.resultType == jstpe.AnyType =>
5350+
val newFlags = closure.flags.withTyped(false).withAsync(true)
5351+
js.JSFunctionApply(closure.copy(flags = newFlags), Nil)
5352+
case other =>
5353+
abort(s"Unexpected tree generated for the Function0 argument to js.async at $pos: $other")
5354+
}
5355+
js.Block(genStats, asyncExpr)
5356+
5357+
case JS_AWAIT =>
5358+
// js.await(arg)
5359+
if (!methodsAllowingJSAwait.contains(currentMethodSym)) {
5360+
reporter.error(pos,
5361+
"Illegal use of js.await().\n" +
5362+
"It can only be used inside a js.async {...} block, without any lambda,\n" +
5363+
"by-name argument or nested method in-between.")
5364+
}
5365+
val arg = genArgs1
5366+
js.JSAwait(arg)
5367+
53365368
case DYNAMIC_IMPORT =>
53375369
assert(args.size == 1,
53385370
s"Expected exactly 1 argument for JS primitive $code but got " +

compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ trait JSDefinitions {
4747
lazy val JSPackage_native = getMemberMethod(ScalaJSJSPackageModule, newTermName("native"))
4848
lazy val JSPackage_undefined = getMemberMethod(ScalaJSJSPackageModule, newTermName("undefined"))
4949
lazy val JSPackage_dynamicImport = getMemberMethod(ScalaJSJSPackageModule, newTermName("dynamicImport"))
50+
lazy val JSPackage_async = getMemberMethod(ScalaJSJSPackageModule, newTermName("async"))
51+
lazy val JSPackage_await = getMemberMethod(ScalaJSJSPackageModule, newTermName("await"))
5052

5153
lazy val JSNativeAnnotation = getRequiredClass("scala.scalajs.js.native")
5254

compiler/src/main/scala/org/scalajs/nscplugin/JSPrimitives.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ abstract class JSPrimitives {
5151
final val JS_IMPORT = JS_NEW_TARGET + 1 // js.import.apply(specifier)
5252
final val JS_IMPORT_META = JS_IMPORT + 1 // js.import.meta
5353

54-
final val CONSTRUCTOROF = JS_IMPORT_META + 1 // runtime.constructorOf(clazz)
54+
final val JS_ASYNC = JS_IMPORT_META + 1 // js.async
55+
final val JS_AWAIT = JS_ASYNC + 1 // js.await
56+
10000
57+
final val CONSTRUCTOROF = JS_AWAIT + 1 // runtime.constructorOf(clazz)
5558
final val CREATE_INNER_JS_CLASS = CONSTRUCTOROF + 1 // runtime.createInnerJSClass
5659
final val CREATE_LOCAL_JS_CLASS = CREATE_INNER_JS_CLASS + 1 // runtime.createLocalJSClass
5760
final val WITH_CONTEXTUAL_JS_CLASS_VALUE = CREATE_LOCAL_JS_CLASS + 1 // runtime.withContextualJSClassValue
@@ -96,6 +99,8 @@ abstract class JSPrimitives {
9699

97100
addPrimitive(JSPackage_typeOf, TYPEOF)
98101
addPrimitive(JSPackage_native, JS_NATIVE)
102+
addPrimitive(JSPackage_async, JS_ASYNC)
103+
addPrimitive(JSPackage_await, JS_AWAIT)
99104

100105
addPrimitive(BoxedUnit_UNIT, UNITVAL)
101106

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Scala.js (https://www.scala-js.org/)
3+
*
4+
* Copyright EPFL.
5+
*
6+
* Licensed under Apache License 2.0
7+
* (https://www.apache.org/licenses/LICENSE-2.0).
8+
*
9+
* See the NOTICE file distributed with this work for
10+
* additional information regarding copyright ownership.
11+
*/
12+
13+
package org.scalajs.nscplugin.test
14+
15+
import org.scalajs.nscplugin.test.util._
16+
import org.junit.Test
17+
18+
// scalastyle:off line.size.limit
19+
20+
class JSAsyncAwaitTest extends DirectTest with TestHelpers {
21+
22+
override def preamble: String =
23+
"""import scala.scalajs.js
24+
"""
25+
26+
@Test
27+
def orphanAwait(): Unit = {
28+
"""
29+
class A {
30+
def foo(x: js.Promise[Int]): Int =
31+
js.await(x)
32+
}
33+
""" hasErrors
34+
"""
35+
|newSource1.scala:5: error: Illegal use of js.await().
36+
|It can only be used inside a js.async {...} block, without any lambda,
37+
|by-name argument or nested method in-between.
38+
| js.await(x)
39+
| ^
40+
"""
41+
42+
"""
43+
class A {
44+
def foo(x: js.Promise[Int]): js.Promise[Int] = js.async {
45+
val f: () => Int = () => js.await(x)
46+
f()
47+
}
48+
}
49+
""" hasErrors
50+
"""
51+
|newSource1.scala:5: error: Illegal use of js.await().
52+
|It can only be used inside a js.async {...} block, without any lambda,
53+
|by-name argument or nested method in-between.
54+
| val f: () => Int = () => js.await(x)
55+
| ^
56+
"""
57+
58+
"""
59+
class A {
60+
def foo(x: js.Promise[Int]): js.Promise[Int] = js.async {
61+
def f(): Int = js.await(x)
62+
f()
63+
}
64+
}
65+
""" hasErrors
66+
"""
67+
|newSource1.scala:5: error: Illegal use of js.await().
68+
|It can only be used inside a js.async {...} block, without any lambda,
69+
|by-name argument or nested method in-between.
70+
| def f(): Int = js.await(x)
71+
| ^
72+
"""
73+
}
74+
}

ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,10 @@ object Hashers {
242242
mixTree(default)
243243
mixType(tree.tpe)
244244

245+
case JSAwait(arg) =>
246+
mixTag(TagJSAwait)
247+
mixTree(arg)
248+
245249
case Debugger() =>
246250
mixTag(TagDebugger)
247251

ir/shared/src/main/scala/org/scalajs/ir/Printers.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,11 @@ object Printers {
291291
undent()
292292
undent(); println(); print('}')
293293

294+
case JSAwait(arg) =>
295+
print("await(")
296+
print(arg)
297+
print(")")
298+
294299
case Debugger() =>
295300
print("debugger")
296301

@@ -896,12 +901,16 @@ object Printers {
896901
print(name)
897902

898903
case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) =>
904+
print("(")
905+
if (flags.async)
906+
print("async ")
899907
if (flags.typed)
900-
print("(typed-lambda<")
908+
print("typed-lambda")
901909
else if (flags 97AE .arrow)
902-
print("(arrow-lambda<")
910+
print("arrow-lambda")
903911
else
904-
print("(lambda<")
912+
print("lambda")
913+
print("<")
905914
var first = true
906915
for ((param, value) <- captureParams.zip(captureValues)) {
907916
if (first)

ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,10 @@ object Serializers {
326326
writeTree(default)
327327
writeType(tree.tpe)
328328

329+
case JSAwait(arg) =>
330+
writeTagAndPos(TagJSAwait)
331+
writeTree(arg)
332+
329333
case Debugger() =>
330334
writeTagAndPos(TagDebugger)
331335

@@ -1216,6 +1220,10 @@ object Serializers {
12161220
Match(readTree(), List.fill(readInt()) {
12171221
(readTrees().map(_.asInstanceOf[MatchableLiteral]), readTree())
12181222
}, readTree())(readType())
1223+
1224+
case TagJSAwait =>
1225+
JSAwait(readTree())
1226+
12191227
case TagDebugger => Debugger()
12201228

12211229
case TagNew => New(readClassName(), readMethodIdent(), readTrees())

ir/shared/src/main/scala/org/scalajs/ir/Tags.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ private[ir] object Tags {
133133
// New in 1.19
134134
final val TagApplyTypedClosure = TagLinkTimeProperty + 1
135135
final val TagNewLambda = TagApplyTypedClosure + 1
136+
final val TagJSAwait = TagNewLambda + 1
136137

137138
// Tags for member defs
138139

ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ object Transformers {
7777
Match(transform(selector), cases.map(c => (c._1, transform(c._2))),
7878
transform(default))(tree.tpe)
7979

80+
case JSAwait(arg) =>
81+
JSAwait(transform(arg))
82+
8083
// Scala expressions
8184

8285
case New(className, ctor, args) =>

ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ object Traversers {
6969
cases foreach (c => (c._1 map traverse, traverse(c._2)))
7070
traverse(default)
7171

72+
case JSAwait(arg) =>
73+
traverse(arg)
74+
7275
// Scala expressions
7376

7477
case New(_, _, args) =>

0 commit comments

Comments
 (0)
0