8000 gh-103365: [Enum] STRICT boundary corrections (GH-103494) · python/cpython@2194071 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2194071

Browse files
authored
gh-103365: [Enum] STRICT boundary corrections (GH-103494)
STRICT boundary: - fix bitwise operations - make default for Flag
1 parent efb8a25 commit 2194071

File tree

4 files changed

+82
-38
lines changed

4 files changed

+82
-38
lines changed

Doc/library/enum.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,8 @@ Data Types
696696

697697
.. attribute:: STRICT
698698

699-
Out-of-range values cause a :exc:`ValueError` to be raised::
699+
Out-of-range values cause a :exc:`ValueError` to be raised. This is the
700+
default for :class:`Flag`::
700701

701702
>>> from enum import Flag, STRICT, auto
702703
>>> class StrictFlag(Flag, boundary=STRICT):
@@ -714,7 +715,7 @@ Data Types
714715
.. attribute:: CONFORM
715716

716717
Out-of-range values have invalid values removed, leaving a valid *Flag*
717-
value. This is the default for :class:`Flag`::
718+
value::
718719

719720
>>> from enum import Flag, CONFORM, auto
720721
>>> class ConformFlag(Flag, boundary=CONFORM):

Lib/enum.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,13 @@ def __set_name__(self, enum_class, member_name):
275275
enum_member.__objclass__ = enum_class
276276
enum_member.__init__(*args)
277277
enum_member._sort_order_ = len(enum_class._member_names_)
278+
279+
if Flag is not None and issubclass(enum_class, Flag):
280+
enum_class._flag_mask_ |= value
281+
if _is_single_bit(value):
282+
enum_class._singles_mask_ |= value
283+
enum_class._all_bits_ = 2 ** ((enum_class._flag_mask_).bit_length()) - 1
284+
278285
# If another member with the same value was already defined, the
279286
# new member becomes an alias to the existing one.
280287
try:
@@ -532,12 +539,8 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
532539
classdict['_use_args_'] = use_args
533540
#
534541
# convert future enum members into temporary _proto_members
535-
# and record integer values in case this will be a Flag
536-
flag_mask = 0
537542
for name in member_names:
538543
value = classdict[name]
539-
if isinstance(value, int):
540-
flag_mask |= value
541544
6D40 classdict[name] = _proto_member(value)
542545
#
543546
# house-keeping structures
@@ -554,8 +557,9 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
554557
boundary
555558
or getattr(first_enum, '_boundary_', None)
556559
)
557-
classdict['_flag_mask_'] = flag_mask
558-
classdict['_all_bits_'] = 2 ** ((flag_mask).bit_length()) - 1
560+
classdict['_flag_mask_'] = 0
561+
classdict['_singles_mask_'] = 0
562+
classdict['_all_bits_'] = 0
559563
classdict['_inverted_'] = None
560564
try:
561565
exc = None
@@ -644,21 +648,10 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
644648
):
645649
delattr(enum_class, '_boundary_')
646650
delattr(enum_class, '_flag_mask_')
651+
delattr(enum_class, '_singles_mask_')
647652
delattr(enum_class, '_all_bits_')
648653
delattr(enum_class, '_inverted_')
649654
elif Flag is not None and issubclass(enum_class, Flag):
650-
# ensure _all_bits_ is correct and there are no missing flags
651-
single_bit_total = 0
652-
multi_bit_total = 0
653-
for flag in enum_class._member_map_.values():
654-
flag_value = flag._value_
655-
if _is_single_bit(flag_value):
656-
single_bit_total |= flag_value
657-
else:
658-
# multi-bit flags are considered aliases
659-
multi_bit_total |= flag_value
660-
enum_class._flag_mask_ = single_bit_total
661-
#
662655
# set correct __iter__
663656
member_list = [m._value_ for m in enum_class]
664657
if member_list != sorted(member_list):
@@ -1303,8 +1296,8 @@ def _reduce_ex_by_global_name(self, proto):
13031296
class FlagBoundary(StrEnum):
13041297
"""
13051298
control how out of range values are handled
1306-
"strict" -> error is raised
1307-
"conform" -> extra bits are discarded [default for Flag]
1299+
"strict" -> error is raised [default for Flag]
1300+
"conform" -> extra bits are discarded
13081301
"eject" -> lose flag status
13091302
"keep" -> keep flag status and all bits [default for IntFlag]
13101303
"""
@@ -1315,7 +1308,7 @@ class FlagBoundary(StrEnum):
13151308
STRICT, CONFORM, EJECT, KEEP = FlagBoundary
13161309

13171310

1318-
class Flag(Enum, boundary=CONFORM):
1311+
class Flag(Enum, boundary=STRICT):
13191312
"""
13201313
Support for flags
13211314
"""
@@ -1394,6 +1387,7 @@ def _missing_(cls, value):
13941387
# - value must not include any skipped flags (e.g. if bit 2 is not
13951388
# defined, then 0d10 is invalid)
13961389
flag_mask = cls._flag_mask_
1390+
singles_mask = cls._singles_mask_
13971391
all_bits = cls._all_bits_
13981392
neg_value = None
13991393
if (
@@ -1425,7 +1419,8 @@ def _missing_(cls, value):
14251419
value = all_bits + 1 + value
14261420
# get members and unknown
14271421
unknown = value & ~flag_mask
1428-
member_value = value & flag_mask
1422+
aliases = value & ~singles_mask
1423+
member_value = value & singles_mask
14291424
if unknown and cls._boundary_ is not KEEP:
14301425
raise ValueError(
14311426
'%s(%r) --> unknown values %r [%s]'
@@ -1439,11 +1434,25 @@ def _missing_(cls, value):
14391434
pseudo_member = cls._member_type_.__new__(cls, value)
14401435
if not hasattr(pseudo_member, '_value_'):
14411436
pseudo_member._value_ = value
1442-
if member_value:
1443-
pseudo_member._name_ = '|'.join([
1444-
m._name_ for m in cls._iter_member_(member_value)
1445-
])
1446-
if unknown:
1437+
if member_value or aliases:
1438+
members = []
1439+
combined_value = 0
1440+
for m in cls._iter_member_(member_value):
1441+
members.append(m)
1442+
combined_value |= m._value_
1443+
if aliases:
1444+
value = member_value | aliases
1445+
for n, pm in cls._member_map_.items():
1446+
if pm not in members and pm._value_ and pm._value_ & value == pm._value_:
1447+
members.append(pm)
1448+
combined_value |= pm._value_
1449+
unknown = value ^ combined_value
1450+
pseudo_member._name_ = '|'.join([m._name_ for m in members])
1451+
if not combined_value:
1452+
pseudo_member._name_ = None
1453+
elif unknown and cls._boundary_ is STRICT:
1454+
raise ValueError('%r: no members with value %r' % (cls, unknown))
1455+
elif unknown:
14471456
pseudo_member._name_ += '|%s' % cls._numeric_repr_(unknown)
14481457
else:
14491458
pseudo_member._name_ = None
@@ -1675,6 +1684,7 @@ def convert_class(cls):
16751684
body['_boundary_'] = boundary or etype._boundary_
16761685
body['_flag_mask_'] = None
16771686
body['_all_bits_'] = None
1687+
body['_singles_mask_'] = None
16781688
body['_inverted_'] = None
16791689
body['__or__'] = Flag.__or__
16801690
body['__xor__'] = Flag.__xor__
@@ -1750,7 +1760,8 @@ def convert_class(cls):
17501760
else:
17511761
multi_bits |= value
17521762
gnv_last_values.append(value)
1753-
enum_class._flag_mask_ = single_bits
1763+
enum_class._flag_mask_ = single_bits | multi_bits
1764+
enum_class._singles_mask_ = single_bits
17541765
enum_class._all_bits_ = 2 ** ((single_bits|multi_bits).bit_length()) - 1
17551766
# set correct __iter__
17561767
member_list = [m._value_ for m in enum_class]

Lib/test/test_enum.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2873,6 +2873,8 @@ def __new__(cls, c):
28732873
#
28742874
a = ord('a')
28752875
#
2876+
self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343)
2877+
self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900672)
28762878
self.assertEqual(FlagFromChar.a, 158456325028528675187087900672)
28772879
self.assertEqual(FlagFromChar.a|1, 158456325028528675187087900673)
28782880
#
@@ -2887,6 +2889,8 @@ def __new__(cls, c):
28872889
a = ord('a')
28882890
z = 1
28892891
#
2892+
self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343)
2893+
self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900674)
28902894
self.assertEqual(FlagFromChar.a.value, 158456325028528675187087900672)
28912895
self.assertEqual((FlagFromChar.a|FlagFromChar.z).value, 158456325028528675187087900674)
28922896
#
@@ -2900,6 +2904,8 @@ def __new__(cls, c):
29002904
#
29012905
a = ord('a')
29022906
#
2907+
self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343)
2908+
self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900672)
29032909
self.assertEqual(FlagFromChar.a, 158456325028528675187087900672)
29042910
self.assertEqual(FlagFromChar.a|1, 158456325028528675187087900673)
29052911

@@ -3077,18 +3083,18 @@ def test_bool(self):
30773083
self.assertEqual(bool(f.value), bool(f))
30783084

30793085
def test_boundary(self):
3080-
self.assertIs(enum.Flag._boundary_, CONFORM)
3081-
class Iron(Flag, boundary=STRICT):
3086+
self.assertIs(enum.Flag._boundary_, STRICT)
3087+
class Iron(Flag, boundary=CONFORM):
30823088
ONE = 1
30833089
TWO = 2
30843090
EIGHT = 8
3085-
self.assertIs(Iron._boundary_, STRICT)
3091+
self.assertIs(Iron._boundary_, CONFORM)
30863092
#
3087-
class Water(Flag, boundary=CONFORM):
3093+
class Water(Flag, boundary=STRICT):
30883094
ONE = 1
30893095
TWO = 2
30903096
EIGHT = 8
3091-
self.assertIs(Water._boundary_, CONFORM)
3097+
self.assertIs(Water._boundary_, STRICT)
30923098
#
30933099
class Space(Flag, boundary=EJECT):
30943100
ONE = 1
@@ -3101,17 +3107,42 @@ class Bizarre(Flag, boundary=KEEP):
31013107
c = 4
31023108
d = 6
31033109
#
3104-
self.assertRaisesRegex(ValueError, 'invalid value 7', Iron, 7)
3110+
self.assertRaisesRegex(ValueError, 'invalid value 7', Water, 7)
31053111
#
3106-
self.assertIs(Water(7), Water.ONE|Water.TWO)
3107-
self.assertIs(Water(~9), Water.TWO)
3112+
self.assertIs(Iron(7), Iron.ONE|Iron.TWO)
3113+
self.assertIs(Iron(~9), Iron.TWO)
31083114
#
31093115
self.assertEqual(Space(7), 7)
31103116
self.assertTrue(type(Space(7)) is int)
31113117
#
31123118
self.assertEqual(list(Bizarre), [Bizarre.c])
31133119
self.assertIs(Bizarre(3), Bizarre.b)
31143120
self.assertIs(Bizarre(6), Bizarre.d)
3121+
#
3122+
class SkipFlag(enum.Flag):
3123+
A = 1
3124+
B = 2
3125+
C = 4 | B
3126+
#
3127+
self.assertTrue(SkipFlag.C in (SkipFlag.A|SkipFlag.C))
3128+
self.assertRaisesRegex(ValueError, 'SkipFlag.. invalid value 42', SkipFlag, 42)
3129+
#
3130+
class SkipIntFlag(enum.IntFlag):
3131+
A = 1
3132+
B = 2
3133+
C = 4 | B
3134+
#
3135+
self.assertTrue(SkipIntFlag.C in (SkipIntFlag.A|SkipIntFlag.C))
3136+
self.assertEqual(SkipIntFlag(42).value, 42)
3137+
#
3138+
class MethodHint(Flag):
3139+
HiddenText = 0x10
3140+
DigitsOnly = 0x01
3141+
LettersOnly = 0x02
8D94
3142+
OnlyMask = 0x0f
3143+
#
3144+
self.assertEqual(str(MethodHint.HiddenText|MethodHint.OnlyMask), 'MethodHint.HiddenText|DigitsOnly|LettersOnly|OnlyMask')
3145+
31153146

31163147
def test_iter(self):
31173148
Color = self.Color
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Set default Flag boundary to ``STRICT`` and fix bitwise operations.

0 commit comments

Comments
 (0)
0