8000 MAINT: Adapt the `npt._GenericAlias` backport to Python 3.11 `types.G… · rjeb/numpy@4461ec4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4461ec4

Browse files
Bas van BeekBvB93
authored andcommitted
MAINT: Adapt the npt._GenericAlias backport to Python 3.11 types.GenericAlias changes
1 parent 7e15fd7 commit 4461ec4

File tree

1 file changed

+42
-7
lines changed

1 file changed

+42
-7
lines changed

numpy/_typing/_generic_alias.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _reconstruct_alias(alias: _T, parameters: Iterator[TypeVar]) -> _T:
6464
args.append(value)
6565

6666
cls = type(alias)
67-
return cls(alias.__origin__, tuple(args))
67+
return cls(alias.__origin__, tuple(args), alias.__unpacked__)
6868

6969

7070
class _GenericAlias:
@@ -80,7 +80,14 @@ class _GenericAlias:
8080
8181
"""
8282

83-
__slots__ = ("__weakref__", "_origin", "_args", "_parameters", "_hash")
83+
__slots__ = (
84+
"__weakref__",
85+
"_origin",
86+
"_args",
87+
"_parameters",
88+
"_hash",
89+
"_starred",
90+
)
8491

8592
@property
8693
def __origin__(self) -> type:
@@ -95,25 +102,38 @@ def __parameters__(self) -> tuple[TypeVar, ...]:
95102
"""Type variables in the ``GenericAlias``."""
96103
return super().__getattribute__("_parameters")
97104

105+
@property
106+
def __unpacked__(self) -> bool:
107+
return super().__getattribute__("_starred")
108+
109+
@property
110+
def __typing_unpacked_tuple_args__(self) -> tuple[object, ...] | None:
111+
# NOTE: This should return `__args__` if `__origin__` is a tuple,
112+
# which should never be the case with how `_GenericAlias` is used
113+
# within numpy
114+
return None
115+
98116
def __init__(
99117
self,
100118
origin: type,
101119
args: object | tuple[object, ...],
120+
starred: bool = False,
102121
) -> None:
103122
self._origin = origin
104123
self._args = args if isinstance(args, tuple) else (args,)
105124
self._parameters = tuple(_parse_parameters(self.__args__))
125+
self._starred = starred
106126

107127
@property
108128
def __call__(self) -> type[Any]:
109129
return self.__origin__
110130

111131
def __reduce__(self: _T) -> tuple[
112132
type[_T],
113-
tuple[type[Any], tuple[object, ...]],
133+
tuple[type[Any], tuple[object, ...], bool],
114134
]:
115135
cls = type(self)
116-
return cls, (self.__origin__, self.__args__)
136+
return cls, (self.__origin__, self.__args__, self.__unpacked__)
117137

118138
def __mro_entries__(self, bases: Iterable[object]) -> tuple[type[Any]]:
119139
return (self.__origin__,)
@@ -130,7 +150,11 @@ def __hash__(self) -> int:
130150
try:
131151
return super().__getattribute__("_hash")
132152
except AttributeError:
133-
self._hash: int = hash(self.__origin__) ^ hash(self.__args__)
153+
self._hash: int = (
154+
hash(self.__origin__) ^
155+
hash(self.__args__) ^
156+
hash(self.__unpacked__)
157+
)
134158
return super().__getattribute__("_hash")
135159

136160
def __instancecheck__(self, obj: object) -> NoReturn:
@@ -147,7 +171,8 @@ def __repr__(self) -> str:
147171
"""Return ``repr(self)``."""
148172
args = ", ".join(_to_str(i) for i in self.__args__)
149173
origin = _to_str(self.__origin__)
150-
return f"{origin}[{args}]"
174+
prefix = "*" if self.__unpacked__ else ""
175+
return f"{prefix}{origin}[{args}]"
151176

152177
def __getitem__(self: _T, key: object | tuple[object, ...]) -> _T:
153178
"""Return ``self[key]``."""
@@ -169,9 +194,17 @@ def __eq__(self, value: object) -> bool:
169194
return NotImplemented
170195
return (
171196
self.__origin__ == value.__origin__ and
172-
self.__args__ == value.__args__
197+
self.__args__ == value.__args__ and
198+
self.__unpacked__ == getattr(
199+
value, "__unpacked__", self.__unpacked__
200+
)
173201
)
174202

203+
def __iter__(self: _T) -> Generator[_T, None, None]:
204+
"""Return ``iter(self)``."""
205+
cls = type(self)
206+
yield cls(self.__origin__, self.__args__, True)
207+
175208
_ATTR_EXCEPTIONS: ClassVar[frozenset[str]] = frozenset({
176209
"__origin__",
177210
"__args__",
@@ -181,6 +214,8 @@ def __eq__(self, value: object) -> bool:
181214
"__reduce_ex__",
182215
"__copy__",
183216
"__deepcopy__",
217+
"__unpacked__",
218+
"__typing_unpacked_tuple_args__",
184219
})
185220

186221
def __getattribute__(self, name: str) -> Any:

0 commit comments

Comments
 (0)
0