8000 gh-125710: [Enum] fix hashable<->nonhashable comparisons for member v… · python/cpython@aaed91c · GitHub
[go: up one dir, main page]

Skip to content

Commit aaed91c

Browse files
authored
gh-125710: [Enum] fix hashable<->nonhashable comparisons for member values (GH-125735)
1 parent 079875e commit aaed91c

File tree

3 files changed

+28
-6
lines changed

3 files changed

+28
-6
lines changed

Lib/enum.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,8 @@ def __set_name__(self, enum_class, member_name):
327327
# to the map, and by-value lookups for this value will be
328328
# linear.
329329
enum_class._value2member_map_.setdefault(value, enum_member)
330+
if value not in enum_class._hashable_values_:
331+
enum_class._hashable_values_.append(value)
330332
except TypeError:
331333
# keep track of the value in a list so containment checks are quick
332334
enum_class._unhashable_values_.append(value)
@@ -538,7 +540,8 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
538540
classdict['_member_names_'] = []
539541
classdict['_member_map_'] = {}
540542
classdict['_value2member_map_'] = {}
541-
classdict['_unhashable_values_'] = []
543+
classdict['_hashable_values_'] = [] # for comparing with non-hashable types
544+
classdict['_unhashable_values_'] = [] # e.g. frozenset() with set()
542545
classdict['_unhashable_values_map_'] = {}
543546
classdict['_member_type_'] = member_type
544547
# now set the __repr__ for the value
@@ -748,7 +751,10 @@ def __contains__(cls, value):
748751
try:
749752
return value in cls._value2member_map_
750753
except TypeError:
751-
return value in cls._unhashable_values_
754+
return (
755+
value in cls._unhashable_values_ # both structures are lists
756+
or value in cls._hashable_values_
757+
)
752758

753759
def __delattr__(cls, attr):
754760
# nicer error message when someone tries to delete an attribute
@@ -1166,8 +1172,11 @@ def __new__(cls, value):
11661172
pass
11671173
except TypeError:
11681174
# not there, now do long search -- O(n) behavior
1169-
for name, values in cls._unhashable_values_map_.items():
1170-
if value in values:
1175+
for name, unhashable_values in cls._unhashable_values_map_.items():
1176+
if value in unhashable_values:
1177+
return cls[name]
1178+
for name, member in cls._member_map_.items():
1179+
if value == member._value_:
11711180
return cls[name]
11721181
# still not found -- verify that members exist, in-case somebody got here mistakenly
11731182
# (such as via super when trying to override __new__)
@@ -1233,6 +1242,7 @@ def _add_value_alias_(self, value):
12331242
# to the map, and by-value lookups for this value will be
12341243
# linear.
12351244
cls._value2member_map_.setdefault(value, self)
1245+
cls._hashable_values_.append(value)
12361246
except TypeError:
12371247
# keep track of the value in a list so containment checks are quick
12381248
cls._unhashable_values_.append(value)
@@ -1763,6 +1773,7 @@ def convert_class(cls):
17631773
body['_member_names_'] = member_names = []
17641774
body['_member_map_'] = member_map = {}
17651775
body['_value2member_map_'] = value2member_map = {}
1776+
body['_hashable_values_'] = hashable_values = []
17661777
body['_unhashable_values_'] = unhashable_values = []
17671778
body['_unhashable_values_map_'] = {}
17681779
body['_member_type_'] = member_type = etype._member_type_
@@ -1826,7 +1837,7 @@ def convert_class(cls):
18261837
contained = value2member_map.get(member._value_)
18271838
except TypeError:
18281839
contained = None
1829-
if member._value_ in unhashable_values:
1840+
if member._value_ in unhashable_values or member.value in hashable_values:
18301841
for m in enum_class:
18311842
if m._value_ == member._value_:
18321843
contained = m
@@ -1846,6 +1857,7 @@ def convert_class(cls):
18461857
else:
18471858
enum_class._add_member_(name, member)
18481859
value2member_map[value] = member
1860+
hashable_values.append(value)
18491861
if _is_single_bit(value):
18501862
# not a multi-bit alias, record in _member_names_ and _flag_mask_
18511863
member_names.append(name)
@@ -1882,7 +1894,7 @@ def convert_class(cls):
18821894
contained = value2member_map.get(member._value_)
18831895
except TypeError:
18841896
contained = None
1885-
if member._value_ in unhashable_values:
1897+
if member._value_ in unhashable_values or member._value_ in hashable_values:
18861898
for m in enum_class:
18871899
if m._value_ == member._value_:
18881900
contained = m
@@ -1908,6 +1920,8 @@ def convert_class(cls):
19081920
# to the map, and by-value lookups for this value will be
19091921
# linear.
19101922
enum_class._value2member_map_.setdefault(value, member)
1923+
if value not in hashable_values:
1924+
hashable_values.append(value)
19111925
except TypeError:
19121926
# keep track of the value in a list so containment checks are quick
19131927
enum_class._unhashable_values_.append(value)

Lib/test/test_enum.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3460,6 +3460,13 @@ def test_empty_names(self):
34603460
self.assertRaisesRegex(TypeError, '.int. object is not iterable', Enum, 'bad_enum', names=0)
34613461
self.assertRaisesRegex(TypeError, '.int. object is not iterable', Enum, 'bad_enum', 0, type=int)
34623462

3463+
def test_nonhashable_matches_hashable(self): # issue 125710
3464+
class Directions(Enum):
3465+
DOWN_ONLY = frozenset({"sc"})
3466+
UP_ONLY = frozenset({"cs"})
3467+
UNRESTRICTED = frozenset({"sc", "cs"})
3468+
self.assertIs(Directions({"sc"}), Directions.DOWN_ONLY)
3469+
34633470

34643471
class TestOrder(unittest.TestCase):
34653472
"test usage of the `_order_` attribute"
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[Enum] fix hashable<->nonhashable comparisons for member values

0 commit comments

Comments
 (0)
0