8000 Merge pull request #21660 from charris/backport-21605 · numpy/numpy@a10a73e · GitHub
[go: up one dir, main page]

Skip to content

Commit a10a73e

Browse files
authored
Merge pull request #21660 from charris/backport-21605
MAINT: Adapt the npt._GenericAlias backport to Python 3.11 types.GenericAlias changes
2 parents ad186e3 + 11338ef commit a10a73e

File tree

5 files changed

+87
-11
lines changed

5 files changed

+87
-11
lines changed

.github/workflows/build_test.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,12 @@ jobs:
250250
# use x86_64 cross-compiler to speed up the build
251251
sudo apt update
252252
sudo apt install -y gcc-arm-linux-gnueabihf g++-arm-linux-gnueabihf
253+
254+
# Keep the `test_requirements.txt` dependency-subset synced
253255
docker run --name the_container --interactive -v /:/host arm32v7/ubuntu:focal /bin/bash -c "
254256
apt update &&
255257
apt install -y git python3 python3-dev python3-pip &&
256-
pip3 install cython==0.29.30 setuptools\<49.2.0 hypothesis==6.23.3 pytest==6.2.5 &&
258+
pip3 install cython==0.29.30 setuptools\<49.2.0 hypothesis==6.23.3 pytest==6.2.5 'typing_extensions>=4.2.0' &&
257259
ln -s /host/lib64 /lib64 &&
258260
ln -s /host/lib/x86_64-linux-gnu /lib/x86_64-linux-gnu &&
259261
ln -s /host/usr/arm-linux-gnueabihf /usr/arm-linux-gnueabihf &&

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dependencies:
2020
- hypothesis
2121
# For type annotations
2222
- mypy=0.950
23+
- typing_extensions>=4.2.0
2324
# For building docs
2425
- sphinx=4.5.0
2526
- sphinx-panels

numpy/_typing/_generic_alias.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _reconstruct_alias(alias: _T, parameters: Iterator[TypeVar]) -> _T:
6464
args.append(value)
6565

6666
cls = type(alias)
67-
return cls(alias.__origin__, tuple(args))
67+
return cls(alias.__origin__, tuple(args), alias.__unpacked__)
6868

6969

7070
class _GenericAlias:
@@ -80,7 +80,14 @@ class _GenericAlias:
8080
8181
"""
8282

83-
__slots__ = ("__weakref__", "_origin", "_args", "_parameters", "_hash")
83+
__slots__ = (
84+
"__weakref__",
85+
"_origin",
86+
"_args",
87+
"_parameters",
88+
"_hash",
89+
"_starred",
90+
)
8491

8592
@property
8693
def __origin__(self) -> type:
@@ -95,25 +102,38 @@ def __parameters__(self) -> tuple[TypeVar, ...]:
95102
"""Type variables in the ``GenericAlias``."""
96103
return super().__getattribute__("_parameters")
97104

105+
@property
106+
def __unpacked__(self) -> bool:
107+
return super().__getattribute__("_starred")
108+
109+
@property
110+
def __typing_unpacked_tuple_args__(self) -> tuple[object, ...] | None:
111+
# NOTE: This should return `__args__` if `__origin__` is a tuple,
112+
# which should never be the case with how `_GenericAlias` is used
113+
# within numpy
114+
return None
115+
98116
def __init__(
99117
self,
100118
origin: type,
101119
args: object | tuple[object, ...],
120+
starred: bool = False,
102121
) -> None:
103122
self._origin = origin
104123
self._args = args if isinstance(args, tuple) else (args,)
105124
self._parameters = tuple(_parse_parameters(self.__args__))
125+
self._starred = starred
106126

107127
@property
108128
def __call__(self) -> type[Any]:
109129
return self.__origin__
110130

111131
def __reduce__(self: _T) -> tuple[
112132
type[_T],
113-
tuple[type[Any], tuple[object, ...]],
133+
tuple[type[Any], tuple[object, ...], bool],
114134
]:
115135
cls = type(self)
116-
return cls, (self.__origin__, self.__args__)
136+
return cls, (self.__origin__, self.__args__, self.__unpacked__)
117137

118138
def __mro_entries__(self, bases: Iterable[object]) -> tuple[type[Any]]:
119139
return (self.__origin__,)
@@ -130,7 +150,11 @@ def __hash__(self) -> int:
130150
try:
131151
return super().__getattribute__("_hash")
132152
except AttributeError:
133-
self._hash: int = hash(self.__origin__) ^ hash(self.__args__)
153+
self._hash: int = (
154+
hash(self.__origin__) ^
155+
hash(self.__args__) ^
156+
hash(self.__unpacked__)
157+
)
134158
return super().__getattribute__("_hash")
135159

136160
def __instancecheck__(self, obj: object) -> NoReturn:
@@ -147,7 +171,8 @@ def __repr__(self) -> str:
147171
"""Return ``repr(self)``."""
148172
args = ", ".join(_to_str(i) for i in self.__args__)
149173
origin = _to_str(self.__origin__)
150-
return f"{origin}[{args}]"
174+
prefix = "*" if self.__unpacked__ else ""
175+
return f"{prefix}{origin}[{args}]"
151176

152177
def __getitem__(self: _T, key: object | tuple[object, ...]) -> _T:
153178
"""Return ``self[key]``."""
@@ -169,9 +194,17 @@ def __eq__(self, value: object) -> bool:
169194
return NotImplemented
170195
return (
171196
self.__origin__ == value.__origin__ and
172-
self.__args__ == value.__args__
197+
self.__args__ == value.__args__ and
198+
self.__unpacked__ == getattr(
199+
value, "__unpacked__", self.__unpacked__
200+
)
173201
)
174202

203+
def __iter__(self: _T) -> Generator[_T, None, None]:
204+
"""Return ``iter(self)``."""
205+
cls = type(self)
206+
yield cls(self.__origin__, self.__args__, True)
207+
175208
_ATTR_EXCEPTIONS: ClassVar[frozenset[str]] = frozenset({
176209
"__origin__",
177210
"__args__",
@@ -181,6 +214,8 @@ def __eq__(self, value: object) -> bool:
181214
"__reduce_ex__",
182215
"__copy__",
183216
"__deepcopy__",
217+
"__unpacked__",
218+
"__typing_unpacked_tuple_args__",
184219
})
185220

186221
def __getattribute__(self, name: str) -> Any:

numpy/typing/tests/test_generic_alias.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111
import numpy as np
1212
from numpy._typing._generic_alias import _GenericAlias
13+
from typing_extensions import Unpack
1314

1415
ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
1516
T1 = TypeVar("T1")
@@ -55,8 +56,6 @@ class TestGenericAlias:
5556
("__origin__", lambda n: n.__origin__),
5657
("__args__", lambda n: n.__args__),
5758
("__parameters__", lambda n: n.__parameters__),
58-
("__reduce__", lambda n: n.__reduce__()[1:]),
59-
("__reduce_ex__", lambda n: n.__reduce_ex__(1)[1:]),
6059
("__mro_entries__", lambda n: n.__mro_entries__([object])),
6160
("__hash__", lambda n: hash(n)),
6261
("__repr__", lambda n: repr(n)),
@@ -66,7 +65,6 @@ class TestGenericAlias:
6665
("__getitem__", lambda n: n[Union[T1, T2]][np.float32, np.float64]),
6766
("__eq__", lambda n: n == n),
6867
("__ne__", lambda n: n != np.ndarray),
69-
("__dir__", lambda n: dir(n)),
7068
("__call__", lambda n: n((1,), np.int64, BUFFER)),
7169
("__call__", lambda n: n(shape=(1,), dtype=np.int64, buffer=BUFFER)),
7270
("subclassing", lambda n: _get_subclass_mro(n)),
@@ -100,6 +98,45 @@ def test_copy(self, name: str, func: FuncType) -> None:
10098
value_ref = func(NDArray_ref)
10199
assert value == value_ref
102100

101+
def test_dir(self) -> None:
102+
value = dir(NDArray)
103+
if sys.version_info < (3, 9):
104+
return
105+
106+
# A number attributes only exist in `types.GenericAlias` in >= 3.11
107+
if sys.version_info < (3, 11, 0, "beta", 3):
108+
value.remove("__typing_unpacked_tuple_args__")
109+
if sys.version_info < (3, 11, 0, "beta", 1):
110+
value.remove("__unpacked__")
111+
assert value == dir(NDArray_ref)
112+
113+
@pytest.mark.parametrize("name,func,dev_version", [
114+
("__iter__", lambda n: len(list(n)), ("beta", 1)),
115+
("__iter__", lambda n: next(iter(n)), ("beta", 1)),
116+
("__unpacked__", lambda n: n.__unpacked__, ("beta", 1)),
117+
("Unpack", lambda n: Unpack[n], ("beta", 1)),
118+
119+
# The right operand should now have `__unpacked__ = True`,
120+
# and they are thus now longer equivalent
121+
("__ne__", lambda n: n != next(iter(n)), ("beta", 1)),
122+
123+
# >= beta3 stuff
124+
("__typing_unpacked_tuple_args__",
125+
lambda n: CA68 n.__typing_unpacked_tuple_args__, ("beta", 3)),
126+
])
127+
def test_py311_features(
128+
self,
129+
name: str,
130+
func: FuncType,
131+
dev_version: tuple[str, int],
132+
) -> None:
133+
"""Test Python 3.11 features."""
134+
value = func(NDArray)
135+
136+
if sys.version_info >= (3, 11, 0, *dev_version):
137+
value_ref = func(NDArray_ref)
138+
assert value == value_ref
139+
103140
def test_weakref(self) -> None:
104141
"""Test ``__weakref__``."""
105142
value = weakref.ref(NDArray)()

test_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ cffi; python_version < '3.10'
1111
# - Mypy relies on C API features not present in PyPy
1212
# NOTE: Keep mypy in sync with environment.yml
1313
mypy==0.950; platform_python_implementation != "PyPy"
14+
typing_extensions>=4.2.0

0 commit comments

Comments
 (0)
0