8000 [3.13] gh-125710: [Enum] fix hashable<->nonhashable comparisons for m… · python/cpython@5bb0538 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5bb0538

Browse files
[3.13] gh-125710: [Enum] fix hashable<->nonhashable comparisons for member values (GH-125735) (GH-125851)
gh-125710: [Enum] fix hashable<->nonhashable comparisons for member values (GH-125735) (cherry picked from commit aaed91c) Co-authored-by: Ethan Furman <ethan@stoneleaf.us>
1 parent e52095a commit 5bb0538

File tree

3 files changed

+28
-6
lines changed

3 files changed

+28
-6
lines changed

Lib/enum.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,8 @@ def __set_name__(self, enum_class, member_name):
328328
# to the map, and by-value lookups for this value will be
329329
# linear.
330330
enum_class._value2member_map_.setdefault(value, enum_member)
331+
if value not in enum_class._hashable_values_:
332+
enum_class._hashable_values_.append(value)
331333
except TypeError:
332334
# keep track of the value in a list so containment checks are quick
333335
enum_class._unhashable_values_.append(value)
@@ -545,7 +547,8 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
545547
classdict['_member_names_'] = []
546548
classdict['_member_map_'] = {}
547549
classdict['_value2member_map_'] = {}
548-
classdict['_unhashable_values_'] = []
550+
classdict['_hashable_values_'] = [] # for comparing with non-hashable types
551+
classdict['_unhashable_values_'] = [] # e.g. frozenset() with set()
549552
classdict['_unhashable_values_map_'] = {}
550553
classdict['_member_type_'] = member_type
551554
# now set the __repr__ for the value
@@ -755,7 +758,10 @@ def __contains__(cls, value):
755758
try:
756759
return value in cls._value2member_map_
757760
except TypeError:
758-
return value in cls._unhashable_values_
761+
return (
762+
value in cls._unhashable_values_ # both structures are lists
763+
or value in cls._hashable_values_
764+
)
759765

760766
def __delattr__(cls, attr):
761767
# nicer error message when someone tries to delete an attribute
@@ -1165,8 +1171,11 @@ def __new__(cls, value):
11651171
pass
11661172
except TypeError:
11671173
# not there, now do long search -- O(n) behavior
1168-
for name, values in cls._unhashable_values_map_.items():
1169-
if value in values:
1174+
for name, unhashable_values in cls._unhashable_values_map_.items():
1175+
if value in unhashable_values:
1176+
return cls[name]
1177+
for name, member in cls._member_map_.items():
1178+
if value == member._value_:
11701179
return cls[name]
11711180
# still not found -- verify that members exist, in-case somebody got here mistakenly
11721181
# (such as via super when trying to override __new__)
@@ -1232,6 +1241,7 @@ def _add_value_alias_(self, value):
12321241
# to the map, and by-value lookups for this value will be
12331242
# linear.
12341243
cls._value2member_map_.setdefault(value, self)
1244+
cls._hashable_values_.append(value)
12351245
except TypeError:
12361246
# keep track of the value in a list so containment checks are quick
12371247
cls._unhashable_values_.append(value)
@@ -1762,6 +1772,7 @@ def convert_class(cls):
17621772
body['_member_names_'] = member_names = []
17631773
body['_member_map_'] = member_map = {}
17641774
body['_value2member_map_'] = value2member_map = {}
1775+
body['_hashable_values_'] = hashable_values = []
17651776
body['_unhashable_values_'] = unhashable_values = []
17661777
body['_unhashable_values_map_'] = {}
17671778
body['_member_type_'] = member_type = etype._member_type_
@@ -1825,7 +1836,7 @@ def convert_class(cls):
18251836
contained = value2member_map.get(member._value_)
18261837
except TypeError:
18271838
contained = None
1828-
if member._value_ in unhashable_values:
1839+
if member._value_ in unhashable_values or member.value in hashable_values:
18291840
for m in enum_class:
18301841
if m._value_ == member._value_:
18311842
contained = m
@@ -1845,6 +1856,7 @@ def convert_class(cls):
18451856
else:
18461857
enum_class._add_member_(name, member)
18471858
value2member_map[value] = member
1859+
hashable_values.append(value)
18481860
if _is_single_bit(value):
18491861
# not a multi-bit alias, record in _member_names_ and _flag_mask_
18501862
member_names.append(name)
@@ -1881,7 +1893,7 @@ def convert_class(cls):
18811893
contained = value2member_map.get(member._value_)
18821894
except TypeError:
18831895
contained = None
1884-
if member._value_ in unhashable_values:
1896+
if member._value_ in unhashable_values or member._value_ in hashable_values:
18851897
for m in enum_class:
18861898
if m._value_ == member._value_:
18871899
contained = m
@@ -1907,6 +1919,8 @@ def convert_class(cls):
19071919
# to the map, and by-value lookups for this value will be
19081920
# linear.
19091921
enum_class._value2member_map_.setdefault(value, member)
1922+
if value not in hashable_values:
1923+
hashable_values.append(value)
19101924
except TypeError:
19111925
# keep track of the value in a list so containment checks are quick
19121926
enum_class._unhashable_values_.append(value)

Lib/test/test_enum.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3474,6 +3474,13 @@ def test_empty_names(self):
34743474
self.assertRaisesRegex(TypeError, '.int. object is not iterable', Enum, 'bad_enum', names=0)
34753475
self.assertRaisesRegex(TypeError, '.int. object is not iterable', Enum, 'bad_enum', 0, type=int)
34763476

3477+
def test_nonhashable_matches_hashable(self): # issue 125710
3478+
class Directions(Enum):
3479+
DOWN_ONLY = frozenset({"sc"})
3480+
UP_ONLY = frozenset({"cs"})
3481+
UNRESTRICTED = frozenset({"sc", "cs"})
3482+
self.assertIs(Directions({"sc"}), Directions.DOWN_ONLY)
3483+
341A
34773484

34783485
class TestOrder(unittest.TestCase):
34793486
"test usage of the `_order_` attribute"
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[Enum] fix hashable<->nonhashable comparisons for member values

0 commit comments

Comments
 (0)
0