8000 gh-94601: [Enum] fix inheritance for __str__ and friends (GH-94942) · python/cpython@e747562 · GitHub
[go: up one dir, main page]

Skip to content

Commit e747562

Browse files
gh-94601: [Enum] fix inheritance for __str__ and friends (GH-94942)
(cherry picked from commit c961d14) Co-authored-by: Ethan Furman <ethan@stoneleaf.us>
1 parent 8d0249e commit e747562

File tree

2 files changed

+42
-10
lines changed

2 files changed

+42
-10
lines changed

Lib/enum.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,10 @@ def __set_name__(self, enum_class, member_name):
247247
if not enum_class._use_args_:
248248
enum_member = enum_class._new_member_(enum_class)
249249
if not hasattr(enum_member, '_value_'):
250-
enum_member._value_ = value
250+
try:
251+
enum_member._value_ = enum_class._member_type_(*args)
252+
except Exception as exc:
253+
enum_member._value_ = value
251254
else:
252255
enum_member = enum_class._new_member_(enum_class, *args)
253256
if not hasattr(enum_member, '_value_'):
@@ -562,7 +565,13 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
562565
classdict['__str__'] = enum_class.__str__
563566
for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'):
564567
if name not in classdict:
565-
setattr(enum_class, name, getattr(first_enum, name))
568+
# check for mixin overrides before replacing
569+
enum_method = getattr(first_enum, name)
570+
found_method = getattr(enum_class, name)
571+
object_method = getattr(object, name)
572+
data_type_method = getattr(member_type, name)
573+
if found_method in (data_type_method, object_method):
574+
setattr(enum_class, name, enum_method)
566575
#
567576
# for Flag, add __or__, __and__, __xor__, and __invert__
568577
if Flag is not None and issubclass(enum_class, Flag):
@@ -950,16 +959,18 @@ def _find_data_repr_(mcls, class_name, bases):
950959
@classmethod
951960
def _find_data_type_(mcls, class_name, bases):
952961
data_types = set()
962+
base_chain = set()
953963
for chain in bases:
954964
candidate = None
955965
for base in chain.__mro__:
966+
base_chain.add(base)
956967
if base is object:
957968
continue
958969
elif issubclass(base, Enum):
959970
if base._member_type_ is not object:
960971
data_types.add(base._member_type_)
961972
break
962-
elif '__new__' in base.__dict__:
973+
elif '__new__' in base.__dict__ or '__init__' in base.__dict__:
963974
if issubclass(base, Enum):
964975
continue
965976
data_types.add(candidate or base)
@@ -1671,7 +1682,13 @@ def convert_class(cls):
16711682
enum_class = type(cls_name, (etype, ), body, boundary=boundary, _simple=True)
16721683
for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'):
16731684
if name not in body:
1674-
setattr(enum_class, name, getattr(etype, name))
1685+
# check for mixin overrides before replacing
1686+
enum_method = getattr(etype, name)
1687+
found_method = getattr(enum_class, name)
1688+
object_method = getattr(object, name)
1689+
data_type_method = getattr(member_type, name)
1690+
if found_method in (data_type_method, object_met 8000 hod):
1691+
setattr(enum_class, name, enum_method)
16751692
gnv_last_values = []
16761693
if issubclass(enum_class, Flag):
16771694
# Flag / IntFlag
@@ -2002,7 +2019,6 @@ def _old_convert_(etype, name, module, filter, source=None, *, boundary=None):
20022019
members.sort(key=lambda t: t[0])
20032020
cls = etype(name, members, module=module, boundary=boundary or KEEP)
20042021
cls.__reduce_ex__ = _reduce_ex_by_global_name
2005-
cls.__repr__ = global_enum_repr
20062022
return cls
20072023

20082024
_stdlib_enums = IntEnum, StrEnum, IntFlag

Lib/test/test_enum.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2658,23 +2658,39 @@ def test_repr_with_dataclass(self):
26582658
@dataclass
26592659
class Foo:
26602660
__qualname__ = 'Foo'
2661-
a: int = 0
2661+
a: int
26622662
class Entries(Foo, Enum):
2663-
ENTRY1 = Foo(1)
2663+
ENTRY1 = 1
2664+
self.assertTrue(isinstance(Entries.ENTRY1, Foo))
2665+
self.assertTrue(Entries._member_type_ is Foo, Entries._member_type_)
2666+
self.assertTrue(Entries.ENTRY1.value == Foo(1), Entries.ENTRY1.value)
26642667
self.assertEqual(repr(Entries.ENTRY1), '<Entries.ENTRY1: Foo(a=1)>')
26652668

2666-
def test_repr_with_non_data_type_mixin(self):
2669+
def test_repr_with_init_data_type_mixin(self):
26672670
# non-data_type is a mixin that doesn't define __new__
26682671
class Foo:
26692672
def __init__(self, a):
26702673
self.a = a
26712674
def __repr__(self):
26722675
return f'Foo(a={self.a!r})'
26732676
class Entries(Foo, Enum):
2674-
ENTRY1 = Foo(1)
2675-
2677+
ENTRY1 = 1
2678+
#
26762679
self.assertEqual(repr(Entries.ENTRY1), '<Entries.ENTRY1: Foo(a=1)>')
26772680

2681+
def test_repr_and_str_with_non_data_type_mixin(self):
2682+
# non-data_type is a mixin that doesn't define __new__
2683+
class Foo:
2684+
def __repr__(self):
2685+
return 'Foo'
2686+
def __str__(self):
2687+
return 'ooF'
2688+
class Entries(Foo, Enum):
2689+
ENTRY1 = 1
2690+
#
2691+
self.assertEqual(repr(Entries.ENTRY1), 'Foo')
2692+
self.assertEqual(str(Entries.ENTRY1), 'ooF')
2693+
26782694
def test_value_backup_assign(self):
26792695
# check that enum will add missing values when custom __new__ does not
26802696
class Some(Enum):

0 commit comments

Comments
 (0)
0