8000 bpo-46477: [Enum] ensure Flag subclasses have correct bitwise methods by ethanfurman · Pull Request #30816 · python/cpython · GitHub
[go: up one dir, main page]

Skip to content

bpo-46477: [Enum] ensure Flag subclasses have correct bitwise methods #30816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 41 additions & 39 deletions Lib/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,18 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
if name not in classdict:
setattr(enum_class, name, getattr(first_enum, name))
#
# for Flag, add __or__, __and__, __xor__, and __invert__
if Flag is not None and issubclass(enum_class, Flag):
for name in (
'__or__', '__and__', '__xor__',
'__ror__', '__rand__', '__rxor__',
'__invert__'
):
if name not in classdict:
enum_method = getattr(Flag, name)
setattr(enum_class, name, enum_method)
classdict[name] = enum_method
#
# replace any other __new__ with our own (as long as Enum is not None,
# anyway) -- again, this is to support pickle
if Enum is not None:
Expand Down Expand Up @@ -1466,44 +1478,10 @@ def __str__(self):
def __bool__(self):
return bool(self._value_)

def __or__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return self.__class__(self._value_ | other._value_)

def __and__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return self.__class__(self._value_ & other._value_)

def __xor__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return self.__class__(self._value_ ^ other._value_)

def __invert__(self):
if self._inverted_ is None:
if self._boundary_ is KEEP:
# use all bits
self._inverted_ = self.__class__(~self._value_)
else:
# calculate flags not in this member
self._inverted_ = self.__class__(self._flag_mask_ ^ self._value_)
if isinstance(self._inverted_, self.__class__):
self._inverted_._inverted_ = self
return self._inverted_


class IntFlag(int, ReprEnum, Flag, boundary=EJECT):
"""
Support for integer-based Flags
"""


def __or__(self, other):
if isinstance(other, self.__class__):
other = other._value_
elif isinstance(other, int):
elif self._member_type_ is not object and isinstance(other, self._member_type_):
other = other
else:
return NotImplemented
Expand All @@ -1513,7 +1491,7 @@ def __or__(self, other):
def __and__(self, other):
if isinstance(other, self.__class__):
other = other._value_
elif isinstance(other, int):
elif self._member_type_ is not object and isinstance(other, self._member_type_):
other = other
else:
return NotImplemented
Expand All @@ -1523,17 +1501,34 @@ def __and__(self, other):
def __xor__(self, other):
if isinstance(other, self.__class__):
other = other._value_
elif isinstance(other, int):
elif self._member_type_ is not object and isinstance(other, self._member_type_):
other = other
else:
return NotImplemented
value = self._value_
return self.__class__(value ^ other)

__ror__ = __or__
def __invert__(self):
if self._inverted_ is None:
if self._boundary_ is KEEP:
# use all bits
self._inverted_ = self.__class__(~self._value_)
else:
# calculate flags not in this member
self._inverted_ = self.__class__(self._flag_mask_ ^ self._value_)
if isinstance(self._inverted_, self.__class__):
self._inverted_._inverted_ = self
return self._inverted_

__rand__ = __and__
__ror__ = __or__
__rxor__ = __xor__
__invert__ = Flag.__invert__


class IntFlag(int, ReprEnum, Flag, boundary=EJECT):
"""
Support for integer-based Flags
"""


def _high_bit(value):
Expand Down Expand Up @@ -1662,6 +1657,13 @@ def convert_class(cls):
body['_flag_mask_'] = None
body['_all_bits_'] = None
body['_inverted_'] = None
body['__or__'] = Flag.__or__
body['__xor__'] = Flag.__xor__
body['__and__'] = Flag.__and__
body['__ror__'] = Flag.__ror__
body['__rxor__'] = Flag.__rxor__
body['__rand__'] = Flag.__rand__
body['__invert__'] = Flag.__invert__
for name, obj in cls.__dict__.items():
if name in ('__dict__', '__weakref__'):
continue
Expand Down
97AC
7 changes: 7 additions & 0 deletions Lib/test/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2496,6 +2496,13 @@ def __new__(cls, val):
self.assertEqual(Some.x.value, 1)
self.assertEqual(Some.y.value, 2)

def test_custom_flag_bitwise(self):
class MyIntFlag(int, Flag):
ONE = 1
TWO = 2
FOUR = 4
self.assertTrue(isinstance(MyIntFlag.ONE | MyIntFlag.TWO, MyIntFlag), MyIntFlag.ONE | MyIntFlag.TWO)
self.assertTrue(isinstance(MyIntFlag.ONE | 2, MyIntFlag))

class TestOrder(unittest.TestCase):
"test usage of the `_order_` attribute"
Expand Down
0