8000 bpo-46477: [Enum] ensure Flag subclasses have correct bitwise methods… · python/cpython@353e3b2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 353e3b2

Browse files
authored
bpo-46477: [Enum] ensure Flag subclasses have correct bitwise methods (GH-30816)
1 parent 976dec9 commit 353e3b2

File tree

2 files changed

+48
-39
lines changed

2 files changed

+48
-39
lines changed

Lib/enum.py

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,18 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
618618
if name not in classdict:
619619
setattr(enum_class, name, getattr(first_enum, name))
620620
#
621+
# for Flag, add __or__, __and__, __xor__, and __invert__
622+
if Flag is not None and issubclass(enum_class, Flag):
623+
for name in (
624+
'__or__', '__and__', '__xor__',
625+
'__ror__', '__rand__', '__rxor__',
626+
'__invert__'
627+
):
628+
if name not in classdict:
629+
enum_method = getattr(Flag, name)
630+
setattr(enum_class, name, enum_method)
631+
classdict[name] = enum_method
632+
#
621633
# replace any other __new__ with our own (as long as Enum is not None,
622634
# anyway) -- again, this is to support pickle
623635
if Enum is not None:
@@ -1466,44 +1478,10 @@ def __str__(self):
14661478
def __bool__(self):
14671479
return bool(self._value_)
14681480

1469-
def __or__(self, other):
1470-
if not isinstance(other, self.__class__):
1471-
return NotImplemented
1472-
return self.__class__(self._value_ | other._value_)
1473-
1474-
def __and__(self, other):
1475-
if not isinstance(other, self.__class__):
1476-
return NotImplemented
1477-
return self.__class__(self._value_ & other._value_)
1478-
1479-
def __xor__(self, other):
1480-
if not isinstance(other, self.__class__):
1481-
return NotImplemented
1482-
return self.__class__(self._value_ ^ other._value_)
1483-
1484-
def __invert__(self):
1485-
if self._inverted_ is None:
1486-
if self._boundary_ is KEEP:
1487-
# use all bits
1488-
self._inverted_ = self.__class__(~self._value_)
1489-
else:
1490-
# calculate flags not in this member
1491-
self._inverted_ = self.__class__(self._flag_mask_ ^ self._value_)
1492-
if isinstance(self._inverted_, self.__class__):
1493-
self._inverted_._inverted_ = self
1494-
return self._inverted_
1495-
1496-
1497-
class IntFlag(int, ReprEnum, Flag, boundary=EJECT):
1498-
"""
1499-
Support for integer-based Flags
1500-
"""
1501-
1502-
15031481
def __or__(self, other):
15041482
if isinstance(other, self.__class__):
15051483
other = other._value_
1506-
elif isinstance(other, int):
1484+
elif self._member_type_ is not object and isinstance(other, self._member_type_):
15071485
other = other
15081486
else:
15091487
return NotImplemented
@@ -1513,7 +1491,7 @@ def __or__(self, other):
15131491
def __and__(self, other):
15141492
if isinstance(other, self.__class__):
15151493
other = other._value_
1516-
elif isinstance(other, int):
1494+
elif self._member_type_ is not object and isinstance(other, self._member_type_):
15171495
other = other
15181496
else:
15191497
return NotImplemented
@@ -1523,17 +1501,34 @@ def __and__(self, other):
15231501
def __xor__(self, other):
15241502
if isinstance(other, self.__class__):
15251503
other = other._value_
1526-
elif isinstance(other, int):
1504+
elif self._member_type_ is not object and isinstance(other, self._member_type_):
15271505
other = other
15281506
else:
15291507
return NotImplemented
15301508
value = self._value_
15311509
return self.__class__(value ^ other)
15321510

1533-
__ror__ = __or__
1511+
def __invert__(self):
1512+
if self._inverted_ is None:
1513+
if self._boundary_ is KEEP:
1514+
# use all bits
1515+
self._inverted_ = self.__class__(~self._value_)
1516+
else:
1517+
# calculate flags not in this member
1518+
self._inverted_ = self.__class__(self._flag_mask_ ^ self._value_)
1519+
if isinstance(self._inverted_, self.__class__):
1520+
self._inverted_._inverted_ = self
1521+
return self._inverted_
1522+
15341523
__rand__ = __and__
1524+
__ror__ = __or__
15351525
__rxor__ = __xor__
1536-
__invert__ = Flag.__invert__
1526+
1527+
1528+
class IntFlag(int, ReprEnum, Flag, boundary=EJECT):
1529+
"""
1530+
Support for integer-based Flags
1531+
"""
15371532

15381533

15391534
def _high_bit(value):
@@ -1662,6 +1657,13 @@ def convert_class(cls):
16621657
body['_flag_mask_'] = None
16631658
body['_all_bits_'] = None
16641659
body['_inverted_'] = None
1660+
body['__or__'] = Flag.__or__
1661+
body['__xor__'] = Flag.__xor__
1662+
body['__and__'] = Flag.__and__
1663+
body['__ror__'] = Flag.__ror__
1664+
body['__rxor__'] = Flag.__rxor__
1665+
body['__rand__'] = Flag.__rand__
1666+
body['__invert__'] = Flag.__invert__
16651667
for name, obj in cls.__dict__.items():
16661668
if name in ('__dict__', '__weakref__'):
16671669
continue

Lib/test/test_enum.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2496,6 +2496,13 @@ def __new__(cls, val):
24962496
self.assertEqual(Some.x.value, 1)
24972497
self.assertEqual(Some.y.value, 2)
24982498

2499+
def test_custom_flag_bitwise(self):
2500+
class MyIntFlag(int, Flag):
2501+
ONE = 1
2502+
TWO = 2
2503+
FOUR = 4
2504+
self.assertTrue(isinstance(MyIntFlag.ONE | MyIntFlag.TWO, MyIntFlag), MyIntFlag.ONE | MyIntFlag.TWO)
2505+
self.assertTrue(isinstance(MyIntFlag.ONE | 2, MyIntFlag))
24992506

25002507
class TestOrder(unittest.TestCase):
25012508
"test usage of the `_order_` attribute"

0 commit comments

Comments
 (0)
0