8000 TST: Add `npt._GenericAlias` tests for (backported) Python 3.11 features · rjeb/numpy@2bb0968 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2bb0968

Browse files
Bas van BeekBvB93
Bas van Beek
authored andcommitted
TST: Add npt._GenericAlias tests for (backported) Python 3.11 features
1 parent 4682699 commit 2bb0968

File tree

1 file changed

+42
-3
lines changed

1 file changed

+42
-3
lines changed

numpy/typing/tests/test_generic_alias.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111
import numpy as np
1212
from numpy._typing._generic_alias import _GenericAlias
13+
from typing_extensions import Unpack
1314

1415
ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
1516
T1 = TypeVar("T1")
@@ -55,8 +56,8 @@ class TestGenericAlias:
5556
("__origin__", lambda n: n.__origin__),
5657
("__args__", lambda n: n.__args__),
5758
("__parameters__", lambda n: n.__parameters__),
58-
("__reduce__", lambda n: n.__reduce__()[1:]),
59-
("__reduce_ex__", lambda n: n.__reduce_ex__(1)[1:]),
59+
("__reduce__", lambda n: n.__reduce__()[1][:3]),
60+
("__reduce_ex__", lambda n: n.__reduce_ex__(1)[1][:3]),
6061
("__mro_entries__", lambda n: n.__mro_entries__([object])),
6162
("__hash__", lambda n: hash(n)),
6263
("__repr__", lambda n: repr(n)),
@@ -66,7 +67,6 @@ class TestGenericAlias:
6667
("__getitem__", lambda n: n[Union[T1, T2]][np.float32, np.float64]),
6768
("__eq__", lambda n: n == n),
6869
("__ne__", lambda n: n != np.ndarray),
69-
("__dir__", lambda n: dir(n)),
7070
("__call__", lambda n: n((1,), np.int64, BUFFER)),
7171
("__call__", lambda n: n(shape=(1,), dtype=np.int64, buffer=BUFFER)),
7272
("subclassing", lambda n: _get_subclass_mro(n)),
@@ -100,6 +100,45 @@ def test_copy(self, name: str, func: FuncType) -> None:
100100
value_ref = func(NDArray_ref)
101101
assert value == value_ref
102102

103+
def test_dir(self) -> None:
104+
value = dir(NDArray)
105+
if sys.version_info < (3, 9):
106+
return
107+
108+
# A number attributes only exist in `types.GenericAlias` in >= 3.11
109+
if sys.version_info < (3, 11, 0, "beta", 3):
110+
value.remove("__typing_unpacked_tuple_args__")
111+
if sys.version_info < (3, 11, 0, "beta", 1):
112+
value.remove("__unpacked__")
113+
assert value == dir(NDArray_ref)
114+
115+
@pytest.mark.parametrize("name,func,dev_version", [
116+
("__iter__", lambda n: len(list(n)), ("beta", 1)),
117+
("__iter__", lambda n: next(iter(n)), ("beta", 1)),
118+
("__unpacked__", lambda n: n.__unpacked__, ("beta", 1)),
119+
("Unpack", lambda n: Unpack[n], ("beta", 1)),
120+
121+
# The right operand should now have `__unpacked__ = True`,
122+
# and they are thus now longer equivalent
123+
("__ne__", lambda n: n != next(iter(n)), ("beta", 1)),
124+
125+
# >= beta3 stuff
126+
("__typing_unpacked_tuple_args__",
127+
lambda n: n.__typing_unpacked_tuple_args__, ("beta", 3)),
128+
])
129+
def test_py311_features(
130+
self,
131+
name: str,
132+
func: FuncType,
133+
dev_version: tuple[str, int],
134+
) -> None:
135+
"""Test Python 3.11 features."""
136+
value = func(NDArray)
137+
138+
if sys.version_info >= (3, 11, 0, *dev_version):
139+
value_ref = func(NDArray_ref)
140+
assert value == value_ref
141+
103142
def test_weakref(self) -> None:
104143
"""Test ``__weakref__``."""
105144
value = weakref.ref(NDArray)()

0 commit comments

Comments
 (0)
0