8000 TST: Added tests for `NestedSequence` · numpy/numpy@dc33633 · GitHub
[go: up one dir, main page]

Skip to content
10000

Commit dc33633

Browse files
author
Bas van Beek
committed
TST: Added tests for NestedSequence
1 parent a2e2a2f commit dc33633

File tree

4 files changed

+132
-0
lines changed

4 files changed

+132
-0
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from typing import Sequence, Tuple, List
2+
import numpy.typing as npt
3+
4+
a: Sequence[float]
5+
b: List[complex]
6+
c: Tuple[str, ...]
7+
d: int
8+
e: str
9+
10+
def func(a: npt.NestedSequence[int]) -> None:
11+
...
12+
13+
reveal_type(func(a)) # E: incompatible type
14+
reveal_type(func(b)) # E: incompatible type
15+
reveal_type(func(c)) # E: incompatible type
16+
reveal_type(func(d)) # E: incompatible type
17+
reveal_type(func(e)) # E: incompatible type

numpy/typing/tests/data/reveal/nbit_base_example.py renamed to numpy/typing/tests/data/reveal/examples.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,12 @@ def add(a: np.floating[T], b: np.integer[T]) -> np.floating[T]:
1616
reveal_type(add(f4, i8)) # E: {float64}
1717
reveal_type(add(f8, i4)) # E: {float64}
1818
reveal_type(add(f4, i4)) # E: {float32}
19+
20+
21+
def get_dtype(seq: npt.NestedSequence[int]) -> np.dtype[np.int_]:
22+
return np.asarray(seq).dtype
23+
24+
25+
reveal_type(get_dtype([1])) # E: numpy.dtype[{int_}]
26+
reveal_type(get_dtype([[1]])) # E: numpy.dtype[{int_}]
27+
reveal_type(get_dtype([[[1]]])) # E: numpy.dtype[{int_}]
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from typing import Sequence, Tuple, List, Any
2+
import numpy.typing as npt
3+
4+
a: Sequence[int]
5+
b: Sequence[Sequence[int]]
6+
c: Sequence[Sequence[Sequence[int]]]
7+
d: Sequence[Sequence[Sequence[Sequence[int]]]]
8+
e: Sequence[bool]
9+
f: Tuple[int, ...]
10+
g: List[int]
11+
h: Sequence[Any]
12+
13+
def func(a: npt.NestedSequence[int]) -> None:
14+
...
15+
16+
reveal_type(func(a)) # E: None
17+
reveal_type(func(b)) # E: None
18+
reveal_type(func(c)) # E: None
19+
reveal_type(func(d)) # E: None
20+
reveal_type(func(e)) # E: None
21+
reveal_type(func(f)) # E: None
22+
reveal_type(func(g)) # E: None
23+
reveal_type(func(h)) # E: None
24+
25+
reveal_type(isinstance(1, npt.NestedSequence)) # E: bool
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""A module with runtime tests for `numpy.typing.NestedSequence`."""
2+
3+
import sys
4+
from typing import Callable, Any
5+
from collections.abc import Sequence
6+
7+
import pytest
8+
import numpy as np
9+
from numpy.typing import NestedSequence
10+
from numpy.typing._nested_sequence import _ProtocolMixin
11+
12+
if sys.version_info >= (3, 8):
13+
from typing import Protocol
14+
HAVE_PROTOCOL = True
15+
else:
16+
try:
17+
from typing_extensions import Protocol
18+
except ImportError:
19+
HAVE_PROTOCOL = False
20+
else:
21+
HAVE_PROTOCOL = True
22+
23+
if HAVE_PROTOCOL:
24+
class _SubClass(NestedSequence[int]):
25+
def __init__(self, seq):
26+
self._seq = seq
27+
28+
def __getitem__(self, s):
29+
return self._seq[s]
30+
31+
def __len__(self):
32+
return len(self._seq)
33+
34+
SEQ = _SubClass([0, 0, 1])
35+
else:
36+
SEQ = NotImplemented
37+
38+
39+
class TestNestedSequence:
40+
"""Runtime tests for `numpy.typing.NestedSequence`."""
41+
42+
@pytest.mark.parametrize(
43+
"name,func",
44+
[
45+
("__instancecheck__", lambda: isinstance(1, _ProtocolMixin)),
46+
("__subclasscheck__", lambda: issubclass(int, _ProtocolMixin)),
47+
("__init__", lambda: _ProtocolMixin()),
48+
("__init_subclass__", lambda: type("SubClass", (_ProtocolMixin,), {})),
49+
]
50+
)
51+
def test_raises(self, name: str, func: Callable[[], Any]) -> None:
52+
"""Test that the `_ProtocolMixin` methods successfully raise."""
53+
with pytest.raises(RuntimeError):
54+
func()
55+
56+
@pytest.mark.parametrize(
57+
"name,ref,func",
58+
[
59+
("__contains__", True, lambda: 0 in SEQ),
60+
("__getitem__", 0, lambda: SEQ[0]),
61+
("__getitem__", [0, 0, 1], lambda: SEQ[:]),
62+
("__iter__", 0, lambda: next(iter(SEQ))),
63+
("__len__", 3, lambda: len(SEQ)),
64+
("__reversed__", 1, lambda: next(reversed(SEQ))),
65+
("count", 2, lambda: SEQ.count(0)),
66+
("index", 0, lambda: SEQ.index(0)),
67+
("index", 1, lambda: SEQ.index(0, start=1)),
68+
("__instancecheck__", True, lambda: isinstance([1], NestedSequence)),
69+
("__instancecheck__", False, lambda: isinstance(1, NestedSequence)),
70+
("__subclasscheck__", True, lambda: issubclass(Sequence, NestedSequence)),
71+
("__subclasscheck__", False, lambda: issubclass(int, NestedSequence)),
72+
("__class_getitem__", True, lambda: bool(NestedSequence[int])),
73+
("__abstractmethods__", Sequence.__abstractmethods__,
74+
lambda: NestedSequence.__abstractmethods__),
75+
]
76+
)
77+
@pytest.mark.skipif(not HAVE_PROTOCOL, reason="requires the `Protocol` class")
78+
def test_method(self, name: str, ref: Any, func: Callable[[], Any]) -> None:
79+
"""Test that the ``NestedSequence`` methods return the intended values."""
80+
value = func()
81+
assert value == ref

0 commit comments

Comments
 (0)
0