8000 Introduce IR ops for unsigned extension and comparisons. by sjrd · Pull Request #5186 · scala-js/scala-js · GitHub
[go: up one dir, main page]

Skip to content

Introduce IR ops for unsigned extension and comparisons. #5186

8000
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 89 additions & 38 deletions compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4526,50 +4526,78 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
if (opType == jstpe.AnyType) rsrc_in
else adaptPrimitive(rsrc_in, if (isShift) jstpe.IntType else opType)

def regular(op: js.BinaryOp.Code): js.Tree =
js.BinaryOp(op, lsrc, rsrc)

(opType: @unchecked) match {
case jstpe.IntType =>
val op = (code: @switch) match {
case ADD => Int_+
case SUB => Int_-
case MUL => Int_*
case DIV => Int_/
case MOD => Int_%
case OR => Int_|
case AND => Int_&
case XOR => Int_^
case LSL => Int_<<
case LSR => Int_>>>
case ASR => Int_>>
case EQ => Int_==
case NE => Int_!=
case LT => Int_<
case LE => Int_<=
case GT => Int_>
case GE => Int_>=
def comparison(signedOp: js.BinaryOp.Code, unsignedOp: js.BinaryOp.Code): js.Tree = {
(lsrc, rsrc) match {
case (IntFlipSign(flippedLhs), IntFlipSign(flippedRhs)) =>
js.BinaryOp(unsignedOp, flippedLhs, flippedRhs)
case (IntFlipSign(flippedLhs), js.IntLiteral(r)) =>
js.BinaryOp(unsignedOp, flippedLhs, js.IntLiteral(r ^ Int.MinValue)(rsrc.pos))
case (js.IntLiteral(l), IntFlipSign(flippedRhs)) =>
js.BinaryOp(unsignedOp, js.IntLiteral(l ^ Int.MinValue)(lsrc.pos), flippedRhs)
case _ =>
regular(signedOp)
}
} 10000

(code: @switch) match {
case ADD => regular(Int_+)
case SUB => regular(Int_-)
case MUL => regular(Int_*)
case DIV => regular(Int_/)
case MOD => regular(Int_%)
case OR => regular(Int_|)
case AND => regular(Int_&)
case XOR => regular(Int_^)
case LSL => regular(Int_<<)
case LSR => regular(Int_>>>)
case ASR => regular(Int_>>)
case EQ => regular(Int_==)
case NE => regular(Int_!=)

case LT => comparison(Int_<, Int_unsigned_<)
case LE => comparison(Int_<=, Int_unsigned_<=)
case GT => comparison(Int_>, Int_unsigned_>)
case GE => comparison(Int_>=, Int_unsigned_>=)
}
js.BinaryOp(op, lsrc, rsrc)

case jstpe.LongType =>
val op = (code: @switch) match {
case ADD => Long_+
case SUB => Long_-
case MUL => Long_*
case DIV => Long_/
case MOD => Long_%
case OR => Long_|
case XOR => Long_^
case AND => Long_&
case LSL => Long_<<
case LSR => Long_>>>
case ASR => Long_>>
case EQ => Long_==
case NE => Long_!=
case LT => Long_<
case LE => Long_<=
case GT => Long_>
case GE => Long_>=
def comparison(signedOp: js.BinaryOp.Code, unsignedOp: js.BinaryOp.Code): js.Tree = {
(lsrc, rsrc) match {
case (LongFlipSign(flippedLhs), LongFlipSign(flippedRhs)) =>
js.BinaryOp(unsignedOp, flippedLhs, flippedRhs)
case (LongFlipSign(flippedLhs), js.LongLiteral(r)) =>
js.BinaryOp(unsignedOp, flippedLhs, js.LongLiteral(r ^ Long.MinValue)(rsrc.pos))
case (js.LongLiteral(l), LongFlipSign(flippedRhs)) =>
js.BinaryOp(unsignedOp, js.LongLiteral(l ^ Long.MinValue)(lsrc.pos), flippedRhs)
case _ =>
regular(signedOp)
}
}

(code: @switch) match {
case ADD => regular(Long_+)
case SUB => regular(Long_-)
case MUL => regular(Long_*)
case DIV => regular(Long_/)
case MOD => regular(Long_%)
case OR => regular(Long_|)
case XOR => regular(Long_^)
case AND => regular(Long_&)
case LSL => regular(Long_<<)
case LSR => regular(Long_>>>)
case ASR => regular(Long_>>)
case EQ => regular(Long_==)
case NE => regular(Long_!=)
case LT => comparison(Long_<, Long_unsigned_<)
case LE => comparison(Long_<=, Long_unsigned_<=)
case GT => comparison(Long_>, Long_unsigned_>)
case GE => comparison(Long_>=, Long_unsigned_>=)
}
js.BinaryOp(op, lsrc, rsrc)

case jstpe.FloatType =>
def withFloats(op: Int): js.Tree =
Expand Down Expand Up @@ -7357,6 +7385,28 @@ private object GenJSCode {
}
}

private object IntFlipSign {
def unapply(tree: js.Tree): Option[js.Tree] = tree match {
case js.BinaryOp(js.BinaryOp.Int_^, lhs, js.IntLiteral(Int.MinValue)) =>
Some(lhs)
case js.BinaryOp(js.BinaryOp.Int_^, js.IntLiteral(Int.MinValue), rhs) =>
Some(rhs)
case _ =>
None
}
}

private object LongFlipSign {
def unapply(tree: js.Tree): Option[js.Tree] = tree match {
case js.BinaryOp(js.BinaryOp.Long_^, lhs, js.LongLiteral(Long.MinValue)) =>
Some(lhs)
case js.BinaryOp(js.BinaryOp.Long_^, js.LongLiteral(Long.MinValue), rhs) =>
Some(rhs)
case _ =>
None
}
}

private abstract class JavalibOpBody {
/** Generates the body of this special method, given references to the receiver and parameters. */
def generate(receiver: js.Tree, args: List[js.Tree])(implicit pos: ir.Position): js.Tree
Expand Down Expand Up @@ -7425,6 +7475,7 @@ private object GenJSCode {

val byClass: Map[ClassName, Map[MethodName, JavalibOpBody]] = Map(
jswkn.BoxedIntegerClass.withSuffix("$") -> Map(
m("toUnsignedLong", List(I), J) -> ArgUnaryOp(unop.UnsignedIntToLong),
m("divideUnsigned", List(I, I), I) -> ArgBinaryOp(binop.Int_unsigned_/),
m("remainderUnsigned", List(I, I), I) -> ArgBinaryOp(binop.Int_unsigned_%),
m("numberOfLeadingZeros", List(I), I) -> ArgUnaryOp(unop.Int_clz)
Expand Down
1E0A
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,72 @@ class OptimizationTest extends JSASTTest {
case js.LoadModule(`testName`) =>
}
}

@Test
def unsignedComparisonsInt: Unit = {
import js.BinaryOp._

val comparisons = List(
(Int_unsigned_<, "<"),
(Int_unsigned_<=, "<="),
(Int_unsigned_>, ">"),
(Int_unsigned_>=, ">=")
)

for ((op, codeOp) <- comparisons) {
s"""
class Test {
private final val SignBit = Int.MinValue

def unsignedComparisonsInt(x: Int, y: Int): Unit = {
(x ^ 0x80000000) $codeOp (y ^ 0x80000000)
(SignBit ^ x) $codeOp (y ^ SignBit)
(SignBit ^ x) $codeOp 0x80000010
0x00000020 $codeOp (y ^ SignBit)
}
}
""".hasExactly(4, "unsigned comparisons") {
case js.BinaryOp(`op`, _, _) =>
}.hasNot("any Int_^") {
case js.BinaryOp(Int_^, _, _) =>
}.hasNot("any signed comparison") {
case js.BinaryOp(Int_< | Int_<= | Int_> | Int_>=, _, _) =>
}
}
}

@Test
def unsignedComparisonsLong: Unit = {
import js.BinaryOp._

val comparisons = List(
(Long_unsigned_<, "<"),
(Long_unsigned_<=, "<="),
(Long_unsigned_>, ">"),
(Long_unsigned_>=, ">=")
)

for ((op, codeOp) <- comparisons) {
s"""
class Test {
private final val SignBit = Long.MinValue

def unsignedComparisonsInt(x: Long, y: Long): Unit = {
(x ^ 0x8000000000000000L) $codeOp (y ^ 0x8000000000000000L)
(SignBit ^ x) $codeOp (y ^ SignBit)
(SignBit ^ x) $codeOp 0x8000000000000010L
0x0000000000000020L $codeOp (y ^ SignBit)
}
}
""".hasExactly(4, "unsigned comparisons") {
case js.BinaryOp(`op`, _, _) =>
}.hasNot("any Long_^") {
case js.BinaryOp(Long_^, _, _) =>
}.hasNot("any signed comparison") {
case js.BinaryOp(Long_< | Long_<= | Long_> | Long_>=, _, _) =>
}
}
}
}

object OptimizationTest {
Expand Down
12 changes: 12 additions & 0 deletions ir/shared/src/main/scala/org/scalajs/ir/Printers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ object Printers {

case Int_clz => p("<clz>(", ")")
case Long_clz => p("<clz>(", ")")

case UnsignedIntToLong => p("<toLongUnsigned>(", ")")
}

case BinaryOp(BinaryOp.Int_-, IntLiteral(0), rhs) =>
Expand Down Expand Up @@ -584,6 +586,16 @@ object Printers {
case Int_unsigned_% => "unsigned_%[int]"
case Long_unsigned_/ => "unsigned_/[long]"
case Long_unsigned_% => "unsigned_%[long]"

case Int_unsigned_< => "unsigned_<[int]"
case Int_unsigned_<= => "unsigned_<=[int]"
case Int_unsigned_> => "unsigned_>[int]"
case Int_unsigned_>= => "unsigned_>=[int]"

case Long_unsigned_< => "unsigned_<[long]"
case Long_unsigned_<= => "unsigned_<=[long]"
case Long_unsigned_> => "unsigned_>[long]"
case Long_unsigned_>= => "unsigned_>=[long]"
})
print(' ')
print(rhs)
Expand Down
18 changes: 16 additions & 2 deletions ir/shared/src/main/scala/org/scalajs/ir/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ object Trees {
// Other nodes introduced in 1.20
final val Int_clz = 38
final val Long_clz = 39
final val UnsignedIntToLong = 40

def isClassOp(op: Code): Boolean =
op >= Class_name && op <= Class_superClass
Expand All @@ -545,7 +546,7 @@ object Trees {
String_length | Array_length | IdentityHashCode | Float_toBits |
Int_clz | Long_clz =>
IntType
case IntToLong | DoubleToLong | Double_toBits =>
case IntToLong | DoubleToLong | Double_toBits | UnsignedIntToLong =>
LongType
case DoubleToFloat | LongToFloat | Float_fromBits =>
FloatType
Expand Down Expand Up @@ -685,11 +686,22 @@ object Trees {
final val Class_newArray = 62

// New in 1.20

final val Int_unsigned_/ = 63
final val Int_unsigned_% = 64
final val Long_unsigned_/ = 65
final val Long_unsigned_% = 66

final val Int_unsigned_< = 67
final val Int_unsigned_<= = 68
final val Int_unsigned_> = 69
final val Int_unsigned_>= = 70

final val Long_unsigned_< = 71
final val Long_unsigned_<= = 72
final val Long_unsigned_> = 73
final val Long_unsigned_>= = 74

def isClassOp(op: Code): Boolean =
op >= Class_isInstance && op <= Class_newArray

Expand All @@ -699,7 +711,9 @@ object Trees {
Int_== | Int_!= | Int_< | Int_<= | Int_> | Int_>= |
Long_== | Long_!= | Long_< | Long_<= | Long_> | Long_>= |
Double_== | Double_!= | Double_< | Double_<= | Double_> | Double_>= |
Class_isInstance | Class_isAssignableFrom =>
Class_isInstance | Class_isAssignableFrom |
Int_unsigned_< | Int_unsigned_<= | Int_unsigned_> | Int_unsigned_>= |
Long_unsigned_< | Long_unsigned_<= | Long_unsigned_> | Long_unsigned_>= =>
BooleanType
case String_+ =>
StringType
Expand Down
20 changes: 20 additions & 0 deletions ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,8 @@ class PrintersTest {

assertPrintEquals("<clz>(x)", UnaryOp(Int_clz, ref("x", IntType)))
assertPrintEquals("<clz>(x)", UnaryOp(Long_clz, ref("x", LongType)))

assertPrintEquals("<toLongUnsigned>(x)", UnaryOp(UnsignedIntToLong, ref("x", IntType)))
}

@Test def printPseudoUnaryOp(): Unit = {
Expand Down Expand Up @@ -683,6 +685,24 @@ class PrintersTest {
BinaryOp(Long_unsigned_/, ref("x", LongType), ref("y", LongType)))
assertPrintEquals("(x unsigned_%[long] y)",
BinaryOp(Long_unsigned_%, ref("x", LongType), ref("y", LongType)))

assertPrintEquals("(x unsigned_<[int] y)",
BinaryOp(Int_unsigned_<, ref("x", IntType), ref("y", IntType)))
assertPrintEquals("(x unsigned_<=[int] y)",
BinaryOp(Int_unsigned_<=, ref("x", IntType), ref("y", IntType)))
assertPrintEquals("(x unsigned_>[int] y)",
BinaryOp(Int_unsigned_>, ref("x", IntType), ref("y", IntType)))
assertPrintEquals("(x unsigned_>=[int] y)",
BinaryOp(Int_unsigned_>=, ref("x", IntType), ref("y", IntType)))

assertPrintEquals("(x unsigned_<[long] y)",
BinaryOp(Long_unsigned_<, ref("x", LongType), ref("y", LongType)))
assertPrintEquals("(x unsigned_<=[long] y)",
BinaryOp(Long_unsigned_<=, ref("x", LongType), ref("y", LongType)))
assertPrintEquals("(x unsigned_>[long] y)",
BinaryOp(Long_unsigned_>, ref("x", LongType), ref("y", LongType)))
assertPrintEquals("(x unsigned_>=[long] y)",
BinaryOp(Long_unsigned_>=, ref("x", LongType), ref("y", LongType)))
}

@Test def printNewArray(): Unit = {
Expand Down
14 changes: 8 additions & 6 deletions javalib/src/main/scala/java/lang/Integer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -186,18 +186,20 @@ object Integer {
parse(s, base)
}

@inline def compare(x: scala.Int, y: scala.Int): scala.Int =
if (x == y) 0 else if (x < y) -1 else 1
@inline def compare(x: scala.Int, y: scala.Int): scala.Int = {
if (x == y) 0
else if (x < y) -1
else 1
}

@inline def compareUnsigned(x: scala.Int, y: scala.Int): scala.Int = {
import Utils.toUint
if (x == y) 0
else if (toUint(x) > toUint(y)) 1
else -1
else if ((x ^ Int.MinValue) < (y ^ Int.MinValue)) -1
else 1
}

@inline def toUnsignedLong(x: Int): scala.Long =
x.toLong & 0xffffffffL
throw new Error("stub") // body replaced by the compiler back-end

// Wasm intrinsic
def bitCount(i: scala.Int): scala.Int = {
Expand Down
11 changes: 7 additions & 4 deletions javalib/src/main/scala/java/lang/Long.scala
Original file line number Diff line number Diff line change
Expand Up @@ -337,16 +337,19 @@ object Long {
@inline def hashCode(value: scala.Long): Int =
value.toInt ^ (value >>> 32).toInt

// Intrinsic
// RuntimeLong intrinsic
@inline def compare(x: scala.Long, y: scala.Long): scala.Int = {
if (x == y) 0
else if (x < y) -1
else 1
}

// TODO Intrinsic?
@inline def compareUnsigned(x: scala.Long, y: scala.Long): scala.Int =
compare(x ^ SignBit, y ^ SignBit)
// TODO RuntimeLong intrinsic?
@inline def compareUnsigned(x: scala.Long, y: scala.Long): scala.Int = {
if (x == y) 0
else if ((x ^ scala.Long.MinValue) < (y ^ scala.Long.MinValue)) -1
else 1
}

@inline def divideUnsigned(dividend: scala.Long, divisor: scala.Long): scala.Long =
throw new Error("stub") // body replaced by the compiler back-end
Expand Down
Loading
0