|
| 1 | +"""A module with runtime tests for `numpy.typing.NestedSequence`.""" |
| 2 | + |
| 3 | +import sys |
| 4 | +from typing import Callable, Any |
| 5 | +from collections.abc import Sequence |
| 6 | + |
| 7 | +import pytest |
| 8 | +import numpy as np |
| 9 | +from numpy.typing import NestedSequence |
| 10 | +from numpy.typing._nested_sequence import _ProtocolMixin |
| 11 | + |
| 12 | +if sys.version_info >= (3, 8): |
| 13 | + from typing import Protocol |
| 14 | + HAVE_PROTOCOL = True |
| 15 | +else: |
| 16 | + try: |
| 17 | + from typing_extensions import Protocol |
| 18 | + except ImportError: |
| 19 | + HAVE_PROTOCOL = False |
| 20 | + else: |
| 21 | + HAVE_PROTOCOL = True |
| 22 | + |
| 23 | +if HAVE_PROTOCOL: |
| 24 | + class _SubClass(NestedSequence[int]): |
| 25 | + def __init__(self, seq): |
| 26 | + self._seq = seq |
| 27 | + |
| 28 | + def __getitem__(self, s): |
| 29 | + return self._seq[s] |
| 30 | + |
| 31 | + def __len__(self): |
| 32 | + return len(self._seq) |
| 33 | + |
| 34 | + SEQ = _SubClass([0, 0, 1]) |
| 35 | +else: |
| 36 | + SEQ = NotImplemented |
| 37 | + |
| 38 | + |
| 39 | +class TestNestedSequence: |
| 40 | + """Runtime tests for `numpy.typing.NestedSequence`.""" |
| 41 | + |
| 42 | + @pytest.mark.parametrize( |
| 43 | + "name,func", |
| 44 | + [ |
| 45 | + ("__instancecheck__", lambda: isinstance(1, _ProtocolMixin)), |
| 46 | + ("__subclasscheck__", lambda: issubclass(int, _ProtocolMixin)), |
| 47 | + ("__init__", lambda: _ProtocolMixin()), |
| 48 | + ("__init_subclass__", lambda: type("SubClass", (_ProtocolMixin,), {})), |
| 49 | + ] |
| 50 | + ) |
| 51 | + def test_raises(self, name: str, func: Callable[[], Any]) -> None: |
| 52 | + """Test that the `_ProtocolMixin` methods successfully raise.""" |
| 53 | + with pytest.raises(RuntimeError): |
| 54 | + func() |
| 55 | + |
| 56 | + @pytest.mark.parametrize( |
| 57 | + "name,ref,func", |
| 58 | + [ |
| 59 | + ("__contains__", True, lambda: 0 in SEQ), |
| 60 | + ("__getitem__", 0, lambda: SEQ[0]), |
| 61 | + ("__getitem__", [0, 0, 1], lambda: SEQ[:]), |
| 62 | + ("__iter__", 0, lambda: next(iter(SEQ))), |
| 63 | + ("__len__", 3, lambda: len(SEQ)), |
| 64 | + ("__reversed__", 1, lambda: next(reversed(SEQ))), |
| 65 | + ("count", 2, lambda: SEQ.count(0)), |
| 66 | + ("index", 0, lambda: SEQ.index(0)), |
| 67 | + ("index", 1, lambda: SEQ.index(0, start=1)), |
| 68 | + ("__instancecheck__", True, lambda: isinstance([1], NestedSequence)), |
| 69 | + ("__instancecheck__", False, lambda: isinstance(1, NestedSequence)), |
| 70 | + ("__subclasscheck__", True, lambda: issubclass(Sequence, NestedSequence)), |
| 71 | + ("__subclasscheck__", False, lambda: issubclass(int, NestedSequence)), |
| 72 | + ("__class_getitem__", True, lambda: bool(NestedSequence[int])), |
| 73 | + ("__abstractmethods__", Sequence.__abstractmethods__,
6215
span> |
| 74 | + lambda: NestedSequence.__abstractmethods__), |
| 75 | + ] |
| 76 | + ) |
| 77 | + @pytest.mark.skipif(not HAVE_PROTOCOL, reason="requires the `Protocol` class") |
| 78 | + def test_method(self, name: str, ref: Any, func: Callable[[], Any]) -> None: |
| 79 | + """Test that the ``NestedSequence`` methods return the intended values.""" |
| 80 | + value = func() |
| 81 | + assert value == ref |
0 commit comments