8000 chore: enable mypy to type check enums · toby-bro/python-betterproto@87f46fe · GitHub
[go: up one dir, main page]

Skip to content

Commit 87f46fe

Browse files
committed
chore: enable mypy to type check enums
1 parent 2cdc77a commit 87f46fe

File tree

1 file changed

+42
-22
lines changed

1 file changed

+42
-22
lines changed

src/betterproto/enum.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
Dict,
1313
Optional,
1414
Tuple,
15+
Type,
1516
)
1617

17-
1818
if TYPE_CHECKING:
1919
from collections.abc import (
2020
Generator,
@@ -31,17 +31,22 @@ def _is_descriptor(obj: object) -> bool:
3131
return (
3232
hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
3333
)
34+
if TYPE_CHECKING:
35+
BaseMetaType = EnumMeta
36+
BaseType = IntEnum
37+
else:
38+
BaseMetaType = type
39+
BaseType = int
3440

35-
36-
class EnumType(EnumMeta if TYPE_CHECKING else type):
41+
class EnumType(BaseMetaType):
3742
_value_map_: Mapping[int, Enum]
38-
_member_map_: Mapping[str, Enum]
43+
_member_map_: Mapping[str, Enum] # type: ignore[assignment]
3944

4045
def __new__(
4146
mcs, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any]
42-
) -> Self:
43-
value_map = {}
44-
member_map = {}
47+
) -> Type[Enum]:
48+
value_map: dict[str, Enum] = {}
49+
member_map: dict[str, Enum] = {}
4550

4651
new_mcs = type(
4752
f"{name}Type",
@@ -60,7 +65,7 @@ def __new__(
6065
if not _is_descriptor(value) and not name.startswith("__")
6166
}
6267

63-
cls = type.__new__(
68+
cls: Type[Enum] = type.__new__(
6469
new_mcs,
6570
name,
6671
bases,
@@ -72,7 +77,7 @@ def __new__(
7277
for name, value in members.items():
7378
member = value_map.get(value)
7479
if member is None:
75-
member = cls.__new__(cls, name=name, value=value) # type: ignore
80+
member = cls.__new__(cls, name=name, value=value)
7681
value_map[value] = member
7782
member_map[name] = member
7883
type.__setattr__(new_mcs, name, member)
@@ -81,56 +86,65 @@ def __new__(
8186

8287
if not TYPE_CHECKING:
8388

89+
@classmethod
8490
def __call__(cls, value: int) -> Enum:
8591
try:
8692
return cls._value_map_[value]
8793
except (KeyError, TypeError):
8894
raise ValueError(f"{value!r} is not a valid {cls.__name__}") from None
8995

96+
@classmethod
9097
def __iter__(cls) -> Generator[Enum, None, None]:
9198
yield from cls._member_map_.values()
9299

93-
if sys.version_info >= (3, 8): # 3.8 added __reversed__ to dict_values
100+
if sys.version_info >= (3, 8):
94101

102+
@classmethod
95103
def __reversed__(cls) -> Generator[Enum, None, None]:
96104
yield from reversed(cls._member_map_.values())
97105

98106
else:
99107

108+
@classmethod
100109
def __reversed__(cls) -> Generator[Enum, None, None]:
101110
yield from reversed(tuple(cls._member_map_.values()))
102111

112+
@classmethod
103113
def __getitem__(cls, key: str) -> Enum:
104114
return cls._member_map_[key]
105115

106-
@property
116+
@classmethod
107117
def __members__(cls) -> MappingProxyType[str, Enum]:
108118
return MappingProxyType(cls._member_map_)
109119

120+
@classmethod
110121
def __repr__(cls) -> str:
111122
return f"<enum {cls.__name__!r}>"
112123

113-
def __len__(cls) -> int:
124+
@classmethod
125+
def __len__(cls) -> int:
114126
return len(cls._member_map_)
115127

116-
def __setattr__(cls, name: str, value: Any) -> Never:
117-
raise AttributeError(f"{cls.__name__}: cannot reassign Enum members.")
128+
@classmethod
129+
def __setattr__(cls, name: str, value: Any) -> Never:
130+
raise AttributeError(f"{cls.__name__}: cannot reassign Enum classes.")
118131

132+
@classmethod
119133
def __delattr__(cls, name: str) -> Never:
120-
raise AttributeError(f"{cls.__name__}: cannot delete Enum members.")
134+
raise AttributeError(f"{cls.__name__}: cannot delete Enum classes.")
121135

136+
@classmethod
122137
def __contains__(cls, member: object) -> bool:
123-
return isinstance(member, cls) and member.name in cls._member_map_
124-
138+
return isinstance(member, cls) and isinstance(member, Enum) and member.name in cls._member_map_
125139

126-
class Enum(IntEnum if TYPE_CHECKING else int, metaclass=EnumType):
140+
class Enum(BaseType, metaclass=EnumType):
127141
"""
128142
The base class for protobuf enumerations, all generated enumerations will
129143
inherit from this. Emulates `enum.IntEnum`.
130144
"""
131145

132-
name: Optional[str]
133-
value: int
146+
name: str
147+
value: int # type: ignore[misc]
134148

135149
if not TYPE_CHECKING:
136150

@@ -178,7 +192,10 @@ def try_value(cls, value: int = 0) -> Self:
178192
``value`` isn't actually a member.
179193
"""
180194
try:
181-
return cls._value_map_[value]
195+
value = cls._value_map_[value]
196+
if not isinstance(value, type(cls)):
197+
raise TypeError(f'{value} should be of same type as {cls.__name__}')
198+
return value
182199
except (KeyError, TypeError):
183200
return cls.__new__(cls, name=None, value=value)
184201

@@ -197,6 +214,9 @@ def from_string(cls, name: str) -> Self:
197214
The member was not found in the Enum.
198215
"""
199216
try:
200-
return cls._member_map_[name]
217+
member = cls._member_map_[name]
218+
if not isinstance(member, type(cls)):
219+
raise TypeError(f'{member} should be of the same type as {cls.__name__}')
220+
return member
201221
except KeyError as e:
202222
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e

0 commit comments

Comments
 (0)
0