From 98b32b3f2cbf3797b523e0ac8803cc25a6bf539c Mon Sep 17 00:00:00 2001 From: Ethan Furman Date: Sat, 22 Jan 2022 17:37:33 -0800 Subject: [PATCH] ensure Flag subclasses have correct bitwise methods --- Lib/enum.py | 80 ++++++++++++++++++++++--------------------- Lib/test/test_enum.py | 7 ++++ 2 files changed, 48 insertions(+), 39 deletions(-) diff --git a/Lib/enum.py b/Lib/enum.py index b5104677312933..85245c95f9a9c7 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -618,6 +618,18 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k if name not in classdict: setattr(enum_class, name, getattr(first_enum, name)) # + # for Flag, add __or__, __and__, __xor__, and __invert__ + if Flag is not None and issubclass(enum_class, Flag): + for name in ( + '__or__', '__and__', '__xor__', + '__ror__', '__rand__', '__rxor__', + '__invert__' + ): + if name not in classdict: + enum_method = getattr(Flag, name) + setattr(enum_class, name, enum_method) + classdict[name] = enum_method + # # replace any other __new__ with our own (as long as Enum is not None, # anyway) -- again, this is to support pickle if Enum is not None: @@ -1466,44 +1478,10 @@ def __str__(self): def __bool__(self): return bool(self._value_) - def __or__(self, other): - if not isinstance(other, self.__class__): - return NotImplemented - return self.__class__(self._value_ | other._value_) - - def __and__(self, other): - if not isinstance(other, self.__class__): - return NotImplemented - return self.__class__(self._value_ & other._value_) - - def __xor__(self, other): - if not isinstance(other, self.__class__): - return NotImplemented - return self.__class__(self._value_ ^ other._value_) - - def __invert__(self): - if self._inverted_ is None: - if self._boundary_ is KEEP: - # use all bits - self._inverted_ = self.__class__(~self._value_) - else: - # calculate flags not in this member - self._inverted_ = self.__class__(self._flag_mask_ ^ self._value_) - if isinstance(self._inverted_, self.__class__): - self._inverted_._inverted_ = self - return self._inverted_ - - -class IntFlag(int, ReprEnum, Flag, boundary=EJECT): - """ - Support for integer-based Flags - """ - - def __or__(self, other): if isinstance(other, self.__class__): other = other._value_ - elif isinstance(other, int): + elif self._member_type_ is not object and isinstance(other, self._member_type_): other = other else: return NotImplemented @@ -1513,7 +1491,7 @@ def __or__(self, other): def __and__(self, other): if isinstance(other, self.__class__): other = other._value_ - elif isinstance(other, int): + elif self._member_type_ is not object and isinstance(other, self._member_type_): other = other else: return NotImplemented @@ -1523,17 +1501,34 @@ def __and__(self, other): def __xor__(self, other): if isinstance(other, self.__class__): other = other._value_ - elif isinstance(other, int): + elif self._member_type_ is not object and isinstance(other, self._member_type_): other = other else: return NotImplemented value = self._value_ return self.__class__(value ^ other) - __ror__ = __or__ + def __invert__(self): + if self._inverted_ is None: + if self._boundary_ is KEEP: + # use all bits + self._inverted_ = self.__class__(~self._value_) + else: + # calculate flags not in this member + self._inverted_ = self.__class__(self._flag_mask_ ^ self._value_) + if isinstance(self._inverted_, self.__class__): + self._inverted_._inverted_ = self + return self._inverted_ + __rand__ = __and__ + __ror__ = __or__ __rxor__ = __xor__ - __invert__ = Flag.__invert__ + + +class IntFlag(int, ReprEnum, Flag, boundary=EJECT): + """ + Support for integer-based Flags + """ def _high_bit(value): @@ -1662,6 +1657,13 @@ def convert_class(cls): body['_flag_mask_'] = None body['_all_bits_'] = None body['_inverted_'] = None + body['__or__'] = Flag.__or__ + body['__xor__'] = Flag.__xor__ + body['__and__'] = Flag.__and__ + body['__ror__'] = Flag.__ror__ + body['__rxor__'] = Flag.__rxor__ + body['__rand__'] = Flag.__rand__ + body['__invert__'] = Flag.__invert__ for name, obj in cls.__dict__.items(): if name in ('__dict__', '__weakref__'): continue diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index d7ce8add78715b..b8a7914355c530 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -2496,6 +2496,13 @@ def __new__(cls, val): self.assertEqual(Some.x.value, 1) self.assertEqual(Some.y.value, 2) + def test_custom_flag_bitwise(self): + class MyIntFlag(int, Flag): + ONE = 1 + TWO = 2 + FOUR = 4 + self.assertTrue(isinstance(MyIntFlag.ONE | MyIntFlag.TWO, MyIntFlag), MyIntFlag.ONE | MyIntFlag.TWO) + self.assertTrue(isinstance(MyIntFlag.ONE | 2, MyIntFlag)) class TestOrder(unittest.TestCase): "test usage of the `_order_` attribute"