8000 Merge pull request #11023 from lrytz/t13033b · scala/scala@4345ced · GitHub
[go: up one dir, main page]

Skip to content

Commit 4345ced

Browse files
authored
Merge pull request #11023 from lrytz/t13033b
Mix in the `productPrefix` hash statically in case class `hashCode`
2 parents 802e74d + ab634a0 commit 4345ced

File tree

9 files changed

+175
-30
lines changed

9 files changed

+175
-30
lines changed

src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ package typechecker
1515

1616
import scala.collection.mutable
1717
import scala.collection.mutable.ListBuffer
18+
import scala.runtime.Statics
1819
import scala.tools.nsc.Reporting.WarningCategory
1920
import symtab.Flags._
2021

@@ -90,11 +91,13 @@ trait SyntheticMethods extends ast.TreeDSL {
9091
else templ
9192
}
9293

94+
def Lit(c: Any) = LIT.typed(c)
95+
9396
def accessors = clazz.caseFieldAccessors
9497
val arity = accessors.size
9598

9699
def forwardToRuntime(method: Symbol): Tree =
97-
forwardMethod(method, getMember(ScalaRunTimeModule, (method.name prepend "_")))(mkThis :: _)
100+
forwardMethod(method, getMember(ScalaRunTimeModule, method.name.prepend("_")))(mkThis :: _)
98101

99102
def callStaticsMethodName(name: TermName)(args: Tree*): Tree = {
100103
val method = RuntimeStaticsModule.info.member(name)
@@ -285,8 +288,8 @@ trait SyntheticMethods extends ast.TreeDSL {
285288

286289
def hashcodeImplementation(sym: Symbol): Tree = {
287290
sym.tpe.finalResultType.typeSymbol match {
288-
case UnitClass | NullClass => Literal(Constant(0))
289-
case BooleanClass => If(Ident(sym), Literal(Constant(1231)), Literal(Constant(1237)))
291+
case UnitClass | NullClass => Lit(0)
292+
case BooleanClass => If(Ident(sym), Lit(1231), Lit(1237))
290293
case IntClass => Ident(sym)
291294
case ShortClass | ByteClass | CharClass => Select(Ident(sym), nme.toInt)
292295
case LongClass => callStaticsMethodName(nme.longHash)(Ident(sym))
@@ -299,29 +302,51 @@ trait SyntheticMethods extends ast.TreeDSL {
299302
def specializedHashcode = {
300303
createMethod(nme.hashCode_, Nil, IntTpe) { m =>
301304
val accumulator = m.newVariable(newTermName("acc"), m.pos, SYNTHETIC) setInfo IntTpe
302-
val valdef = ValDef(accumulator, Literal(Constant(0xcafebabe)))
305+
val valdef = ValDef(accumulator, Lit(0xcafebabe))
303306
val mixPrefix =
304307
Assign(
305308
Ident(accumulator),
306-
callStaticsMethod("mix")(Ident(accumulator),
307-
Apply(gen.mkAttributedSelect(gen.mkAttributedSelect(mkThis, Product_productPrefix), Object_hashCode), Nil)))
309+
callStaticsMethod("mix")(Ident(accumulator), Lit(clazz.name.decode.hashCode)))
308310
val mixes = accessors map (acc =>
309311
Assign(
310312
Ident(accumulator),
311313
callStaticsMethod("mix")(Ident(accumulator), hashcodeImplementation(acc))
312314
)
313315
)
314-
val finish = callStaticsMethod("finalizeHash")(Ident(accumulator), Literal(Constant(arity)))
316+
val finish = callStaticsMethod("finalizeHash")(Ident(accumulator), Lit(arity))
315317

316318
Block(valdef :: mixPrefix :: mixes, finish)
317319
}
318320
}
319-
def chooseHashcode = {
321+
322+
def productHashCode: Tree = {
323+
// case `hashCode` used to call `ScalaRunTime._hashCode`, but that implementation mixes in the result
324+
// of `productPrefix`, which causes scala/bug#13033.
325+
// Because case hashCode has two possible implementations (`specializedHashcode` and `productHashCode`) we
326+
// need to fix it twice.
327+
// 1. `specializedHashcode` above was changed to mix in the case class name statically.
328+
// 2. we can achieve the same thing here by calling `MurmurHash3Module.productHash` with a `seed` that mixes
329+
// in the case class name already. This is backwards and forwards compatible:
330+
// - the new generated code works with old and new standard libraries
331+
// - the `MurmurHash3Module.productHash` implementation returns the same result as before when called by
332+
// previously compiled case classes
333+
// Alternatively, we could decide to always generate the full implementation (like `specializedHashcode`)
334+
// at the cost of bytecode size.
335+
createMethod(nme.hashCode_, Nil, IntTpe) { _ =>
336+
if (arity == 0) Lit(clazz.name.decode.hashCode)
337+
else gen.mkMethodCall(MurmurHash3Module, TermName("productHash"), List(
338+
mkThis,
339+
Lit(Statics.mix(0xcafebabe, clazz.name.decode.hashCode)),
340+
Lit(true)
341+
))
342+
}
343+
}
344+
345+
def chooseHashcode =
320346
if (accessors exists (x => isPrimitiveValueType(x.tpe.finalResultType)))
321347
specializedHashcode
322348
else
323-
forwardToRuntime(Object_hashCode)
324-
}
349+
productHashCode
325350

326351
def valueClassMethods = List(
327352
Any_hashCode -> (() => hashCodeDerivedValueClassMethod),

src/library/scala/runtime/ScalaRunTime.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,16 @@ object ScalaRunTime {
152152
// More background at ticket #2318.
153153
def ensureAccessible(m: JMethod): JMethod = scala.reflect.ensureAccessible(m)
154154

155+
// This is called by the synthetic case class `toString` method.
156+
// It originally had a `CaseClass` parameter type which was changed to `Product`.
155157
def _toString(x: Product): String =
156158
x.productIterator.mkString(x.productPrefix + "(", ",", ")")
157159

158-
def _hashCode(x: Product): Int = scala.util.hashing.MurmurHash3.productHash(x)
160+
// This method is called by case classes compiled by older Scala 2.13 / Scala 3 versions, so it needs to stay.
161+
// In newer versions, the synthetic case class `hashCode` has either the calculation inlined or calls
162+
// `MurmurHash3.productHash`.
163+
// There used to be an `_equals` method as well which was removed in 5e7e81ab2a.
164+
def _hashCode(x: Product): Int = scala.util.hashing.MurmurHash3.caseClassHash(x)
159165

160166
/** A helper for case classes. */
161167
def typedProductIterator[T](x: Product): Iterator[T] = {

src/library/scala/util/hashing/MurmurHash3.scala

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,16 @@ private[hashing] class MurmurHash3 {
6060
finalizeHash(h, 2)
6161
}
6262

63-
/** Compute the hash of a product */
63+
// @deprecated("use `caseClassHash` instead", "2.13.17")
64+
// The deprecation is commented because this method is called by the synthetic case class hashCode.
65+
// In this case, the `seed` already has the case class name mixed in and `ignorePrefix` is set to true.
66+
// Case classes compiled before 2.13.17 call this method with `productSeed` and `ignorePrefix = false`.
67+
// See `productHashCode` in `SyntheticMethods` for details.
6468
final def productHash(x: Product, seed: Int, ignorePrefix: Boolean = false): Int = {
6569
val arr = x.productArity
66-
// Case objects have the hashCode inlined directly into the
67-
// synthetic hashCode method, but this method should still give
68-
// a correct result if passed a case object.
69-
if (arr == 0) {
70-
x.productPrefix.hashCode
71-
} else {
70+
if (arr == 0)
71+
if (!ignorePrefix) x.productPrefix.hashCode else seed
72+
else {
7273
var h = seed
7374
if (!ignorePrefix) h = mix(h, x.productPrefix.hashCode)
7475
var i = 0
@@ -80,6 +81,24 @@ private[hashing] class MurmurHash3 {
8081
}
8182
}
8283

84+
/** See the [[MurmurHash3.caseClassHash(x:Product,caseClassName:String)]] overload */
85+
final def caseClassHash(x: Product, seed: Int, caseClassName: String): Int = {
86+
val arr = x.productArity
87+
val aye = (if (caseClassName != null) caseClassName else x.productPrefix).hashCode
88+
if (arr == 0) aye
89+
else {
90+
var h = seed
91+
h = mix(h, aye)
92+
var i = 0
93+
while (i < arr) {
94+
h = mix(h, x.productElement(i).##)
95+
i += 1
96+
}
97+
finalizeHash(h, arr)
98+
}
99+
}
100+
101+
83102
/** Compute the hash of a string */
84103
final def stringHash(str: String, seed: Int): Int = {
85104
var h = seed
@@ -337,14 +356,46 @@ object MurmurHash3 extends MurmurHash3 {
337356
final val mapSeed = "Map".hashCode
338357
final val setSeed = "Set".hashCode
339358

340-
def arrayHash[@specialized T](a: Array[T]): Int = arrayHash(a, arraySeed)
341-
def bytesHash(data: Array[Byte]): Int = bytesHash(data, arraySeed)
342-
def orderedHash(xs: IterableOnce[Any]): Int = orderedHash(xs, symmetricSeed)
343-
def productHash(x: Product): Int = productHash(x, productSeed)
344-
def stringHash(x: String): Int = stringHash(x, stringSeed)
345-
def unorderedHash(xs: IterableOnce[Any]): Int = unorderedHash(xs, traversableSeed)
359+
def arrayHash[@specialized T](a: Array[T]): Int = arrayHash(a, arraySeed)
360+
def bytesHash(data: Array[Byte]): Int = bytesHash(data, arraySeed)
361+
def orderedHash(xs: IterableOnce[Any]): Int = orderedHash(xs, symmetricSeed)
362+
def stringHash(x: String): Int = stringHash(x, stringSeed)
363+
def unorderedHash(xs: IterableOnce[Any]): Int = unorderedHash(xs, traversableSeed)
346364
def rangeHash(start: Int, step: Int, last: Int): Int = rangeHash(start, step, last, seqSeed)
347365

366+
@deprecated("use `caseClassHash` instead", "2.13.17")
367+
def productHash(x: Product): Int = caseClassHash(x, productSeed, null)
368+
369+
/**
370+
* Compute the `hashCode` of a case class instance. This method returns the same value as `x.hashCode`
371+
* if `x` is an instance of a case class with the default, synthetic `hashCode`.
372+
*
373+
* This method can be used to implement case classes with a cached `hashCode`:
374+
* {{{
375+
* case class C(data: Data) {
376+
* override lazy val hashCode: Int = MurmurHash3.caseClassHash(this)
377+
* }
378+
* }}}
379+
*
380+
* '''NOTE''': For case classes (or subclasses) that override `productPrefix`, the `caseClassName` parameter
381+
* needs to be specified in order to obtain the same result as the synthetic `hashCode`. Otherwise, the value
382+
* is not in sync with the case class `equals` method (scala/bug#13033).
383+
*
384+
* {{{
385+
* scala> case class C(x: Int) { override def productPrefix = "Y" }
386+
*
387+
* scala> C(1).hashCode
388+
* val res0: Int = -668012062
389+
*
390+
* scala> MurmurHash3.caseClassHash(C(1))
391+
* val res1: Int = 1015658380
392+
*
393+
* scala> MurmurHash3.caseClassHash(C(1), "C")
394+
* val res2: Int = -668012062
395+
* }}}
396+
*/
397+
def caseClassHash(x: Product, caseClassName: String = null): Int = caseClassHash(x, productSeed, caseClassName)
398+
348399
private[scala] def arraySeqHash[@specialized T](a: Array[T]): Int = arrayHash(a, seqSeed)
349400
private[scala] def tuple2Hash(x: Any, y: Any): Int = tuple2Hash(x.##, y.##, productSeed)
350401

@@ -397,8 +448,13 @@ object MurmurHash3 extends MurmurHash3 {
397448
def hash(xs: IterableOnce[Any]) = orderedHash(xs)
398449
}
399450

451+
@deprecated("use `caseClassHashing` instead", "2.13.17")
400452
def productHashing = new Hashing[Product] {
401-
def hash(x: Product) = productHash(x)
453+
def hash(x: Product) = caseClassHash(x)
454+
}
455+
456+
def caseClassHashing = new Hashing[Product] {
457+
def hash(x: Product) = caseClassHash(x)
402458
}
403459

404460
def stringHashing = new Hashing[String] {

src/reflect/scala/reflect/internal/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ trait Definitions extends api.StandardDefinitions {
401401
lazy val SpecializableModule = requiredModule[Specializable]
402402

403403
lazy val ScalaRunTimeModule = requiredModule[scala.runtime.ScalaRunTime.type]
404+
lazy val MurmurHash3Module = requiredModule[scala.util.hashing.MurmurHash3.type]
404405
lazy val SymbolModule = requiredModule[scala.Symbol.type]
405406
def Symbol_apply = getMemberMethod(SymbolModule, nme.apply)
406407

src/reflect/scala/reflect/runtime/JavaUniverseForce.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ trait JavaUniverseForce { self: runtime.JavaUniverse =>
299299
definitions.PredefModule
300300
definitions.SpecializableModule
301301
definitions.ScalaRunTimeModule
302+
definitions.MurmurHash3Module
302303
definitions.SymbolModule
303304
definitions.ScalaNumberClass
304305
definitions.DelayedInitClass

test/files/run/caseClassHash.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ object Test {
1111

1212
println("## method 1: " + foo1.##)
1313
println("## method 2: " + foo2.##)
14-
println(" Murmur 1: " + scala.util.hashing.MurmurHash3.productHash(foo1))
15-
println(" Murmur 2: " + scala.util.hashing.MurmurHash3.productHash(foo2))
14+
println(" Murmur 1: " + scala.util.hashing.MurmurHash3.caseClassHash(foo1))
15+
println(" Murmur 2: " + scala.util.hashing.MurmurHash3.caseClassHash(foo2))
1616
}
1717
}
1818

test/files/run/idempotency-case-classes.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ C(2,3)
2929
};
3030
override <synthetic> def hashCode(): Int = {
3131
<synthetic> var acc: Int = -889275714;
32-
acc = scala.runtime.Statics.mix(acc, C.this.productPrefix.hashCode());
32+
acc = scala.runtime.Statics.mix(acc, 67);
3333
acc = scala.runtime.Statics.mix(acc, x);
3434
acc = scala.runtime.Statics.mix(acc, y);
3535
scala.runtime.Statics.finalizeHash(acc, 2)

test/files/run/macroPlugins-namerHooks.check

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ enterStat(super.<init>())
1818
enterSym(<synthetic> def copy$default$1 = x)
1919
enterSym(<synthetic> def copy$default$2 = y)
2020
enterSym(<synthetic> var acc: Int = -889275714)
21-
enterSym(acc = scala.runtime.Statics.mix(acc, C.this.productPrefix.hashCode()))
21+
enterSym(acc = scala.runtime.Statics.mix(acc, 67))
2222
enterSym(acc = scala.runtime.Statics.mix(acc, x))
2323
enterSym(acc = scala.runtime.Statics.mix(acc, y))
2424
enterStat(<synthetic> var acc: Int = -889275714)
25-
enterStat(acc = scala.runtime.Statics.mix(acc, C.this.productPrefix.hashCode()))
25+
enterStat(acc = scala.runtime.Statics.mix(acc, 67))
2626
enterStat(acc = scala.runtime.Statics.mix(acc, x))
2727
enterStat(acc = scala.runtime.Statics.mix(acc, y))
2828
enterSym(<synthetic> val C$1: C = x$1.asInstanceOf[C])

test/files/run/t13033.scala

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
//> using options -deprecation
2+
3+
import scala.util.hashing.MurmurHash3.caseClassHash
4+
5+
case class C1(a: Int)
6+
class C2(a: Int) extends C1(a) { override def productPrefix = "C2" }
7+
class C3(a: Int) extends C1(a) { override def productPrefix = "C3" }
8+
case class C4(a: Int) { override def productPrefix = "Sea4" }
9+
case class C5()
10+
case object C6
11+
case object C6b { override def productPrefix = "Sea6b" }
12+
case class C7(s: String) // hashCode forwards to ScalaRunTime._hashCode if there are no primitives
13+
class C8(s: String) extends C7(s) { override def productPrefix = "C8" }
14+
15+
case class VCC(x: Int) extends AnyVal
16+
17+
object Test extends App {
18+
val c1 = C1(1)
19+
val c2 = new C2(1)
20+
val c3 = new C3(1)
21+
assert(c1 == c2)
22+
assert(c2 == c1)
23+
assert(c2 == c3)
24+
assert(c1.hashCode == c2.hashCode)
25+
assert(c2.hashCode == c3.hashCode)
26+
27+
assert(c1.hashCode == caseClassHash(c1))
28+
// `caseClassHash` mixes in the `productPrefix.hashCode`, while `hashCode` mixes in the case class name statically
29+
assert(c2.hashCode != caseClassHash(c2))
30+
assert(c2.hashCode == caseClassHash(c2, c1.productPrefix))
31+
32+
val c4 = C4(1)
33+
assert(c4.hashCode != caseClassHash(c4))
34+
assert(c4.hashCode == caseClassHash(c4, "C4"))
35+
36+
assert((1, 2).hashCode == caseClassHash(1 -> 2))
37+
assert(("", "").hashCode == caseClassHash("" -> ""))
38+
39+
assert(C5().hashCode == caseClassHash(C5()))
40+
assert(C6.hashCode == caseClassHash(C6))
41+
assert(C6b.hashCode == caseClassHash(C6b, "C6b"))
42+
43+
val c7 = C7("hi")
44+
val c8 = new C8("hi")
45+
assert(c7.hashCode == caseClassHash(c7))
46+
assert(c7 == c8)
47+
assert(c7.hashCode == c8.hashCode)
48+
assert(c8.hashCode != caseClassHash(c8))
49+
assert(c8.hashCode == caseClassHash(c8, "C7"))
50+
51+
52+
// should be true -- scala/bug#13034
53+
assert(!VCC(1).canEqual(VCC(1)))
54+
// also due to scala/bug#13034
55+
assert(VCC(1).canEqual(1))
56+
}

0 commit comments

Comments
 (0)
0