8000 Introduce IR ops for unsigned extension and comparisons. · scala-js/scala-js@a366057 · GitHub
[go: up one dir, main page]

Skip to content

Commit a366057

Browse files
committed
Introduce IR ops for unsigned extension and comparisons.
This completes the set of `UnaryOp`s and `BinaryOp`s to directly manipulate the unsigned representation of integers. Unlike other operations, such as `Int_unsigned_/`, the unsigned extension and comparisons have efficient (and convenient) implementations in user land. It is common for regular code to directly use the efficient implementation (e.g., `x.toLong & 0xffffffffL`) instead of the dedicated library method (`Integer.toUnsignedLong`). If we only replaced the body of the library methods with IR nodes, we would miss improvements in all the other code. Therefore, in this case, we instead recognize the user-space patterns in the optimizer, and replace them with the unsigned IR operations through folding. Moreover, for unsigned comparisons, we also recognize the patterns in the compiler backend. The purpose here is mostly to make sure that all these opcodes end up in the serialized IR, so that we effectively test them along the entire pipeline. When targeting JavaScript, the new IR nodes do not actually make any difference. For `int` operations, the Emitter sort of "undoes" the folding of the optimizer to implement them. That said, it could choose an alternative implementation based on `>>> 0`, which we should investigate in the future. For `Long`s, the subexpressions of the patterns are expanded into the `RuntimeLong` operations before folding gets a chance to recognize them (when they have not been transformed by the compiler backend). That's fine, because internal folding of the underlying `int` operations will do the best possible thing anyway. The size increase is only due to the additional always-reachable methods in `RuntimeLong`. Those can be removed by standard JS minifiers. When targeting Wasm, this allows the emitter to produce the dedicated Wasm opcodes, which are more likely to be efficient. To be fair, we could have achieved the same result by recognizing the patterns in the Wasm emitter instead. The deeper reason to add those IR operations is for completeness. They were the last operations from a standard set that were missing in the IR.
1 parent 052d861 commit a366057

File tree

15 files changed

+478
-119
lines changed

15 files changed

+478
-119
lines changed

compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala

Lines changed: 89 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4526,50 +4526,78 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
45264526
if (opType == jstpe.AnyType) rsrc_in
45274527
else adaptPrimitive(rsrc_in, if (isShift) jstpe.IntType else opType)
45284528

4529+
def regular(op: js.BinaryOp.Code): js.Tree =
4530+
js.BinaryOp(op, lsrc, rsrc)
4531+
45294532
(opType: @unchecked) match {
45304533
case jstpe.IntType =>
4531-
val op = (code: @switch) match {
4532-
case ADD => Int_+
4533-
case SUB => Int_-
4534-
case MUL => Int_*
4535-
case DIV => Int_/
4536-
case MOD => Int_%
4537-
case OR => Int_|
4538-
case AND => Int_&
4539-
case XOR => Int_^
4540-
case LSL => Int_<<
4541-
case LSR => Int_>>>
4542-
case ASR => Int_>>
4543-
case EQ => Int_==
4544-
case NE => Int_!=
4545-
case LT => Int_<
4546-
case LE => Int_<=
4547-
case GT => Int_>
4548-
case GE => Int_>=
4534+
def comparison(signedOp: js.BinaryOp.Code, unsignedOp: js.BinaryOp.Code): js.Tree = {
4535+
(lsrc, rsrc) match {
4536+
case (IntFlipSign(flippedLhs), IntFlipSign(flippedRhs)) =>
4537+
js.BinaryOp(unsignedOp, flippedLhs, flippedRhs)
4538+
case (IntFlipSign(flippedLhs), js.IntLiteral(r)) =>
4539+
js.BinaryOp(unsignedOp, flippedLhs, js.IntLiteral(r ^ Int.MinValue)(rsrc.pos))
4540+
case (js.IntLiteral(l), IntFlipSign(flippedRhs)) =>
4541+
js.BinaryOp(unsignedOp, js.IntLiteral(l ^ Int.MinValue)(lsrc.pos), flippedRhs)
4542+
case _ =>
4543+
regular(signedOp)
4544+
}
4545+
}
4546+
4547+
(code: @switch) match {
4548+
case ADD => regular(Int_+)
4549+
case SUB => regular(Int_-)
4550+
case MUL => regular(Int_*)
4551+
case DIV => regular(Int_/)
4552+
case MOD => regular(Int_%)
4553+
case OR => regular(Int_|)
4554+
case AND => regular(Int_&)
4555+
case XOR => regular(Int_^)
4556+
case LSL => regular(Int_<<)
4557+
case LSR => regular(Int_>>>)
4558+
case ASR => regular(Int_>>)
4559+
case EQ => regular(Int_==)
4560+
case NE => regular(Int_!=)
4561+
4562+
case LT => comparison(Int_<, Int_unsigned_<)
4563+
case LE => comparison(Int_<=, Int_unsigned_<=)
4564+
case GT => comparison(Int_>, Int_unsigned_>)
4565+
case GE => comparison(Int_>=, Int_unsigned_>=)
45494566
}
4550-
js.BinaryOp(op, lsrc, rsrc)
45514567

45524568
case jstpe.LongType =>
4553-
val op = (code: @switch) match {
4554-
case ADD => Long_+
4555-
case SUB => Long_-
4556-
case MUL => Long_*
4557-
case DIV => Long_/
4558-
case MOD => Long_%
4559-
case OR => Long_|
4560-
case XOR => Long_^
4561-
case AND => Long_&
4562-
case LSL => Long_<<
4563-
case LSR => Long_>>>
4564-
case ASR => Long_>>
4565-
case EQ => Long_==
4566-
case NE => Long_!=
4567-
case LT => Long_<
4568-
case LE => Long_<=
4569-
case GT => Long_>
4570-
case GE => Long_>=
4569+
def comparison(signedOp: js.BinaryOp.Code, unsignedOp: js.BinaryOp.Code): js.Tree = {
4570+
(lsrc, rsrc) match {
4571+
case (LongFlipSign(flippedLhs), LongFlipSign(flippedRhs)) =>
4572+
js.BinaryOp(unsignedOp, flippedLhs, flippedRhs)
4573+
case (LongFlipSign(flippedLhs), js.LongLiteral(r)) =>
4574+
js.BinaryOp(unsignedOp, flippedLhs, js.LongLiteral(r ^ Long.MinValue)(rsrc.pos))
4575+
case (js.LongLiteral(l), LongFlipSign(flippedRhs)) =>
4576+
js.BinaryOp(unsignedOp, js.LongLiteral(l ^ Long.MinValue)(lsrc.pos), flippedRhs)
4577+
case _ =>
4578+
regular(signedOp)
4579+
}
4580+
}
4581+
4582+
(code: @switch) match {
4583+
case ADD => regular(Long_+)
4584+
case SUB => regular(Long_-)
4585+
case MUL => regular(Long_*)
4586+
case DIV => regular(Long_/)
4587+
case MOD => regular(Long_%)
4588+
case OR => regular(Long_|)
4589+
case XOR => regular(Long_^)
4590+
case AND => regular(Long_&)
4591+
case LSL => regular(Long_<<)
4592+
case LSR => regular(Long_>>>)
4593+
case ASR => regular(Long_>>)
4594+
case EQ => regular(Long_==)
4595+
case NE => regular(Long_!=)
4596+
case LT => comparison(Long_<, Long_unsigned_<)
4597+
case LE => comparison(Long_<=, Long_unsigned_<=)
4598+
case GT => comparison(Long_>, Long_unsigned_>)
4599+
case GE => comparison(Long_>=, Long_unsigned_>=)
45714600
}
4572-
js.BinaryOp(op, lsrc, rsrc)
45734601

45744602
case jstpe.FloatType =>
45754603
def withFloats(op: Int): js.Tree =
@@ -7357,6 +7385,28 @@ private object GenJSCode {
73577385
}
73587386
}
73597387

7388+
private object IntFlipSign {
7389+
def unapply(tree: js.Tree): Option[js.Tree] = tree match {
7390+
case js.BinaryOp(js.BinaryOp.Int_^, lhs, js.IntLiteral(Int.MinValue)) =>
7391+
Some(lhs)
7392+
case js.BinaryOp(js.BinaryOp.Int_^, js.IntLiteral(Int.MinValue), rhs) =>
7393+
Some(rhs)
7394+
case _ =>
7395+
None
7396+
}
7397+
}
7398+
7399+
private object LongFlipSign {
7400+
def unapply(tree: js.Tree): Option[js.Tree] = tree match {
7401+
case js.BinaryOp(js.BinaryOp.Long_^, lhs, js.LongLiteral(Long.MinValue)) =>
7402+
Some(lhs)
7403+
case js.BinaryOp(js.BinaryOp.Long_^, js.LongLiteral(Long.MinValue), rhs) =>
7404+
Some(rhs)
7405+
case _ =>
7406+
None
7407+
}
7408+
}
7409+
73607410
private abstract class JavalibOpBody {
73617411
/** Generates the body of this special method, given references to the receiver and parameters. */
73627412
def generate(receiver: js.Tree, args: List[js.Tree])(implicit pos: ir.Position): js.Tree
@@ -7425,6 +7475,7 @@ private object GenJSCode {
74257475

74267476
val byClass: Map[ClassName, Map[MethodName, JavalibOpBody]] = Map(
74277477
jswkn.BoxedIntegerClass.withSuffix("$") -> Map(
7478+
m("toUnsignedLong", List(I), J) -> ArgUnaryOp(unop.UnsignedIntToLong),
74287479
m("divideUnsigned", List(I, I), I) -> ArgBinaryOp(binop.Int_unsigned_/),
74297480
m("remainderUnsigned", List(I, I), I) -> ArgBinaryOp(binop.Int_unsigned_%),
74307481
m("numberOfLeadingZeros", List(I), I) -> ArgUnaryOp(unop.Int_clz)

compiler/src/test/scala/org/scalajs/nscplugin/test/OptimizationTest.scala

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,72 @@ class OptimizationTest extends JSASTTest {
582582
case js.LoadModule(`testName`) =>
583583
}
584584
}
585+
586+
@Test
587+
def unsignedComparisonsInt: Unit = {
588+
import js.BinaryOp._
589+
590+
val comparisons = List(
591+
(Int_unsigned_<, "<"),
592+
(Int_unsigned_<=, "<="),
593+
(Int_unsigned_>, ">"),
594+
(Int_unsigned_>=, ">=")
595+
)
596+
597+
for ((op, codeOp) <- comparisons) {
598+
s"""
599+
class Test {
600+
private final val SignBit = Int.MinValue
601+
602+
def unsignedComparisonsInt(x: Int, y: Int): Unit = {
603+
(x ^ 0x80000000) $codeOp (y ^ 0x80000000)
604+
(SignBit ^ x) $codeOp (y ^ SignBit)
605+
(SignBit ^ x) $codeOp 0x80000010
606+
0x00000020 $codeOp (y ^ SignBit)
607+
}
608+
}
609+
""".hasExactly(4, "unsigned comparisons") {
610+
case js.BinaryOp(`op`, _, _) =>
611+
}.hasNot("any Int_^") {
612+
case js.BinaryOp(Int_^, _, _) =>
613+
}.hasNot("any signed comparison") {
614+
case js.BinaryOp(Int_< | Int_<= | Int_> | Int_>=, _, _) =>
615+
}
616+
}
617+
}
618+
619+
@Test
620+
def unsignedComparisonsLong: Unit = {
621+
import js.BinaryOp._
622+
623+
val comparisons = List(
624+
(Long_unsigned_<, "<"),
625+
(Long_unsigned_<=, "<="),
626+
(Long_unsigned_>, ">"),
627+
(Long_unsigned_>=, ">=")
628+
)
629+
630+
for ((op, codeOp) <- comparisons) {
631+
s"""
632+
class Test {
633+
private final val SignBit = Long.MinValue
634+
635+
def unsignedComparisonsInt(x: Long, y: Long): Unit = {
636+
(x ^ 0x8000000000000000L) $codeOp (y ^ 0x8000000000000000L)
637+
(SignBit ^ x) $codeOp (y ^ SignBit)
638+
(SignBit ^ x) $codeOp 0x8000000000000010L
639+
0x0000000000000020L $codeOp (y ^ SignBit)
640+
}
641+
}
642+
""".hasExactly(4, "unsigned comparisons") {
643+
case js.BinaryOp(`op`, _, _) =>
644+
}.hasNot("any Long_^") {
645+
case js.BinaryOp(Long_^, _, _) =>
646+
}.hasNot("any signed comparison") {
647+
case js.BinaryOp(Long_< | Long_<= | Long_> | Long_>=, _, _) =>
648+
}
649+
}
650+
}
585651
}
586652

587653
object OptimizationTest {

ir/shared/src/main/scala/org/scalajs/ir/Printers.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,8 @@ object Printers {
453453

454454
case Int_clz => p("<clz>(", ")")
455455
case Long_clz => p("<clz>(", ")")
456+
457+
case UnsignedIntToLong => p("<toLongUnsigned>(", ")")
456458
}
457459

458460
case BinaryOp(BinaryOp.Int_-, IntLiteral(0), rhs) =>
@@ -584,6 +586,16 @@ object Printers {
584586
case Int_unsigned_% => "unsigned_%[int]"
585587
case Long_unsigned_/ => "unsigned_/[long]"
586588
case Long_unsigned_% => "unsigned_%[long]"
589+
590+
case Int_unsigned_< => "unsigned_<[int]"
591+
case Int_unsigned_<= => "unsigned_<=[int]"
592+
case Int_unsigned_> => "unsigned_>[int]"
593+
case Int_unsigned_>= => "unsigned_>=[int]"
594+
595+
case Long_unsigned_< => "unsigned_<[long]"
596+
case Long_unsigned_<= => "unsigned_<=[long]"
597+
case Long_unsigned_> => "unsigned_>[long]"
598+
case Long_unsigned_>= => "unsigned_>=[long]"
587599
})
588600
print(' ')
589601
print(rhs)

ir/shared/src/main/scala/org/scalajs/ir/Trees.scala

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,7 @@ object Trees {
520520
// Other nodes introduced in 1.20
521521
final val Int_clz = 38
522522
final val Long_clz = 39
523+
final val UnsignedIntToLong = 40
523524

524525
def isClassOp(op: Code): Boolean =
525526
op >= Class_name && op <= Class_superClass
@@ -545,7 +546,7 @@ object Trees {
545546
String_length | Array_length | IdentityHashCode | Float_toBits |
546547
Int_clz | Long_clz =>
547548
IntType
548-
case IntToLong | DoubleToLong | Double_toBits =>
549+
case IntToLong | DoubleToLong | Double_toBits | UnsignedIntToLong =>
549550
LongType
550551
case DoubleToFloat | LongToFloat | Float_fromBits =>
551552
FloatType
@@ -685,11 +686,22 @@ object Trees {
685686
final val Class_newArray = 62
686687

687688
// New in 1.20
689+
688690
final val Int_unsigned_/ = 63
689691
final val Int_unsigned_% = 64
690692
final val Long_unsigned_/ = 65
691693
final val Long_unsigned_% = 66
692694

695+
final val Int_unsigned_< = 67
696+
final val Int_unsigned_<= = 68
697+
final val Int_unsigned_> = 69
698+
final val Int_unsigned_>= = 70
699+
700+
final val Long_unsigned_< = 71
701+
final val Long_unsigned_<= = 72
702+
final val Long_unsigned_> = 73
703+
final val Long_unsigned_>= = 74
704+
693705
def isClassOp(op: Code): Boolean =
694706
op >= Class_isInstance && op <= Class_newArray
695707

@@ -699,7 +711,9 @@ object Trees {
699711
Int_== | Int_!= | Int_< | Int_<= | Int_> | Int_>= |
700712
Long_== | Long_!= | Long_< | Long_<= | Long_> | Long_>= |
701713
Double_== | Double_!= | Double_< | Double_<= | Double_> | Double_>= |
702-
Class_isInstance | Class_isAssignableFrom =>
714+
Class_isInstance | Class_isAssignableFrom |
715+
Int_unsigned_< | Int_unsigned_<= | Int_unsigned_> | Int_unsigned_>= |
716+
Long_unsigned_< | Long_unsigned_<= | Long_unsigned_> | Long_unsigned_>= =>
703717
BooleanType
704718
case String_+ =>
705719
StringType

ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,8 @@ class PrintersTest {
522522

523523
assertPrintEquals("<clz>(x)", UnaryOp(Int_clz, ref("x", IntType)))
524524
assertPrintEquals("<clz>(x)", UnaryOp(Long_clz, ref("x", LongType)))
525+
526+
assertPrintEquals("<toLongUnsigned>(x)", UnaryOp(UnsignedIntToLong, ref("x", IntType)))
525527
}
526528

527529
@Test def printPseudoUnaryOp(): Unit = {
@@ -683,6 +685,24 @@ class PrintersTest {
683685
BinaryOp(Long_unsigned_/, ref("x", LongType), ref("y", LongType)))
684686
assertPrintEquals("(x unsigned_%[long] y)",
685687
BinaryOp(Long_unsigned_%, ref("x", LongType), ref("y", LongType)))
688+
689+
assertPrintEquals("(x unsigned_<[int] y)",
690+
BinaryOp(Int_unsigned_<, ref("x", IntType), ref("y", IntType)))
691+
assertPrintEquals("(x unsigned_<=[int] y)",
692+
BinaryOp(Int_unsigned_<=, ref("x", IntType), ref("y", IntType)))
693+
assertPrintEquals("(x unsigned_>[int] y)",
694+
BinaryOp(Int_unsigned_>, ref("x", IntType), ref("y", IntType)))
695+
assertPrintEquals("(x unsigned_>=[int] y)",
696+
BinaryOp(Int_unsigned_>=, ref("x", IntType), ref("y", IntType)))
697+
698+
assertPrintEquals("(x unsigned_<[long] y)",
699+
BinaryOp(Long_unsigned_<, ref("x", LongType), ref("y", LongType)))
700+
assertPrintEquals("(x unsigned_<=[long] y)",
701+
BinaryOp(Long_unsigned_<=, ref("x", LongType), ref("y", LongType)))
702+
assertPrintEquals("(x unsigned_>[long] y)",
703+
BinaryOp(Long_unsigned_>, ref("x", LongType), ref("y", LongType)))
704+
assertPrintEquals("(x unsigned_>=[long] y)",
705+
BinaryOp(Long_unsigned_>=, ref("x", LongType), ref("y", LongType)))
686706
}
687707

688708
@Test def printNewArray(): Unit = {

javalib/src/main/scala/java/lang/Integer.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -186,18 +186,20 @@ object Integer {
186186
parse(s, base)
187187
}
188188

189-
@inline def compare(x: scala.Int, y: scala.Int): scala.Int =
190-
if (x == y) 0 else if (x < y) -1 else 1
189+
@inline def compare(x: scala.Int, y: scala.Int): scala.Int = {
190+
if (x == y) 0
191+
else if (x < y) -1
192+
else 1
193+
}
191194

192195
@inline def compareUnsigned(x: scala.Int, y: scala.Int): scala.Int = {
193-
import Utils.toUint
194196
if (x == y) 0
195-
else if (toUint(x) > toUint(y)) 1
196-
else -1
197+
else if ((x ^ Int.MinValue) < (y ^ Int.MinValue)) -1
198+
else 1
197199
}
198200

199201
@inline def toUnsignedLong(x: Int): scala.Long =
200-
x.toLong & 0xffffffffL
202+
throw new Error("stub") // body replaced by the compiler back-end
201203

202204
// Wasm intrinsic
203205
def bitCount(i: scala.Int): scala.Int = {

javalib/src/main/scala/java/lang/Long.scala

Lines changed: 7 additions & 4 deletions< C7B1 /span>
Original file line numberDiff line numberDiff line change
@@ -337,16 +337,19 @@ object Long {
337337
@inline def hashCode(value: scala.Long): Int =
338338
value.toInt ^ (value >>> 32).toInt
339339

340-
// Intrinsic
340+
// RuntimeLong intrinsic
341341
@inline def compare(x: scala.Long, y: scala.Long): scala.Int = {
342342
if (x == y) 0
343343
else if (x < y) -1
344344
else 1
345345
}
346346

347-
// TODO Intrinsic?
348-
@inline def compareUnsigned(x: scala.Long, y: scala.Long): scala.Int =
349-
compare(x ^ SignBit, y ^ SignBit)
347+
// TODO RuntimeLong intrinsic?
348+
@inline def compareUnsigned(x: scala.Long, y: scala.Long): scala.Int = {
349+
if (x == y) 0
350+
else if ((x ^ scala.Long.MinValue) < (y ^ scala.Long.MinValue)) -1
351+
else 1
352+
}
350353

351354
@inline def divideUnsigned(dividend: scala.Long, divisor: scala.Long): scala.Long =
352355
throw new Error("stub") // body replaced by the compiler back-end

0 commit comments

Comments
 (0)
0