@@ -618,6 +618,18 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
618
618
if name not in classdict :
619
619
setattr (enum_class , name , getattr (first_enum , name ))
620
620
#
621
+ # for Flag, add __or__, __and__, __xor__, and __invert__
622
+ if Flag is not None and issubclass (enum_class , Flag ):
623
+ for name in (
624
+ '__or__' , '__and__' , '__xor__' ,
625
+ '__ror__' , '__rand__' , '__rxor__' ,
626
+ '__invert__'
627
+ ):
628
+ if name not in classdict :
629
+ enum_method = getattr (Flag , name )
630
+ setattr (enum_class , name , enum_method )
631
+ classdict [name ] = enum_method
632
+ #
621
633
# replace any other __new__ with our own (as long as Enum is not None,
622
634
# anyway) -- again, this is to support pickle
623
635
if Enum is not None :
@@ -1466,44 +1478,10 @@ def __str__(self):
1466
1478
def __bool__ (self ):
1467
1479
return bool (self ._value_ )
1468
1480
1469
- def __or__ (self , other ):
1470
- if not isinstance (other , self .__class__ ):
1471
- return NotImplemented
1472
- return self .__class__ (self ._value_ | other ._value_ )
1473
-
1474
- def __and__ (self , other ):
1475
- if not isinstance (other , self .__class__ ):
1476
- return NotImplemented
1477
- return self .__class__ (self ._value_ & other ._value_ )
1478
-
1479
- def __xor__ (self , other ):
1480
- if not isinstance (other , self .__class__ ):
1481
- return NotImplemented
1482
- return self .__class__ (self ._value_ ^ other ._value_ )
1483
-
1484
- def __invert__ (self ):
1485
- if self ._inverted_ is None :
1486
- if self ._boundary_ is KEEP :
1487
- # use all bits
1488
- self ._inverted_ = self .__class__ (~ self ._value_ )
1489
- else :
1490
- # calculate flags not in this member
1491
- self ._inverted_ = self .__class__ (self ._flag_mask_ ^ self ._value_ )
1492
- if isinstance (self ._inverted_ , self .__class__ ):
1493
- self ._inverted_ ._inverted_ = self
1494
- return self ._inverted_
1495
-
1496
-
1497
- class IntFlag (int , ReprEnum , Flag , boundary = EJECT ):
1498
- """
1499
- Support for integer-based Flags
1500
- """
1501
-
1502
-
1503
1481
def __or__ (self , other ):
1504
1482
if isinstance (other , self .__class__ ):
1505
1483
other = other ._value_
1506
- elif isinstance (other , int ):
1484
+ elif self . _member_type_ is not object and isinstance (other , self . _member_type_ ):
1507
1485
other = other
1508
1486
else :
1509
1487
return NotImplemented
@@ -1513,7 +1491,7 @@ def __or__(self, other):
1513
1491
def __and__ (self , other ):
1514
1492
if isinstance (other , self .__class__ ):
1515
1493
other = other ._value_
1516
- elif isinstance (other , int ):
1494
+ elif self . _member_type_ is not object and isinstance (other , self . _member_type_ ):
1517
1495
other = other
1518
1496
else :
1519
1497
return NotImplemented
@@ -1523,17 +1501,34 @@ def __and__(self, other):
1523
1501
def __xor__ (self , other ):
1524
1502
if isinstance (other , self .__class__ ):
1525
1503
other = other ._value_
1526
- elif isinstance (other , int ):
1504
+ elif self . _member_type_ is not object and isinstance (other , self . _member_type_ ):
1527
1505
other = other
1528
1506
else :
1529
1507
return NotImplemented
1530
1508
value = self ._value_
1531
1509
return self .__class__ (value ^ other )
1532
1510
1533
- __ror__ = __or__
1511
+ def __invert__ (self ):
1512
+ if self ._inverted_ is None :
1513
+ if self ._boundary_ is KEEP :
1514
+ # use all bits
1515
+ self ._inverted_ = self .__class__ (~ self ._value_ )
1516
+ else :
1517
+ # calculate flags not in this member
1518
+ self ._inverted_ = self .__class__ (self ._flag_mask_ ^ self ._value_ )
1519
+ if isinstance (self ._inverted_ , self .__class__ ):
1520
+ self ._inverted_ ._inverted_ = self
1521
+ return self ._inverted_
1522
+
1534
1523
__rand__ = __and__
1524
+ __ror__ = __or__
1535
1525
__rxor__ = __xor__
1536
- __invert__ = Flag .__invert__
1526
+
1527
+
1528
+ class IntFlag (int , ReprEnum , Flag , boundary = EJECT ):
1529
+ """
1530
+ Support for integer-based Flags
1531
+ """
1537
1532
1538
1533
1539
1534
def _high_bit (value ):
@@ -1662,6 +1657,13 @@ def convert_class(cls):
1662
1657
body ['_flag_mask_' ] = None
1663
1658
body ['_all_bits_' ] = None
1664
1659
body ['_inverted_' ] = None
1660
+ body ['__or__' ] = Flag .__or__
1661
+ body ['__xor__' ] = Flag .__xor__
1662
+ body ['__and__' ] = Flag .__and__
1663
+ body ['__ror__' ] = Flag .__ror__
1664
+ body ['__rxor__' ] = Flag .__rxor__
1665
+ body ['__rand__' ] = Flag .__rand__
1666
+ body ['__invert__' ] = Flag .__invert__
1665
1667
for name , obj in cls .__dict__ .items ():
1666
1668
if name in ('__dict__' , '__weakref__' ):
1667
1669
continue
0 commit comments