8000 API: Add `npt.NDArray`, a runtime-subscriptable alias for `np.ndarray` by BvB93 · Pull Request #18935 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

API: Add npt.NDArray, a runtime-subscriptable alias for np.ndarray #18935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 17, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
TST: Add tests for GenericAlias
  • Loading branch information
Bas van Beek committed May 12, 2021
commit 610edf23678a81a9a81d71d238a8f8c1eedcc78a
113 changes: 113 additions & 0 deletions numpy/typing/tests/test_generic_alias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import annotations

import sys
import types
import pickle
import weakref
from typing import TypeVar, Any, Callable, Tuple, Type, Union

import pytest
import numpy as np
from numpy.typing._generic_alias import _GenericAlias

ScalarType = TypeVar("ScalarType", bound=np.generic)
DType = _GenericAlias(np.dtype, (ScalarType,))
NDArray = _GenericAlias(np.ndarray, (Any, DType))

if sys.version_info >= (3, 9):
DType_ref = types.GenericAlias(np.dtype, (ScalarType,))
NDArray_ref = types.GenericAlias(np.ndarray, (Any, DType_ref))
FuncType = Callable[[Union[_GenericAlias, types.GenericAlias]], Any]
else:
DType_ref = NotImplemented
NDArray_ref = NotImplemented
FuncType = Callable[[_GenericAlias], Any]

GETATTR_NAMES = sorted(set(dir(np.ndarray)) - _GenericAlias._ATTR_EXCEPTIONS)


def _get_subclass_mro(base: type) -> Tuple[type, ...]:
class Subclass(base): # type: ignore[misc,valid-type]
pass
return Subclass.__mro__[1:]


class TestGenericAlias:
"""Tests for `numpy.typing._generic_alias._GenericAlias`."""

@pytest.mark.parametrize(
"name,func",
[
("__init__", lambda n: n),
("__origin__", lambda n: n.__origin__),
("__args__", lambda n: n.__args__),
("__parameters__", lambda n: n.__parameters__),
("__reduce__", lambda n: n.__reduce__()[1:]),
("__reduce_ex__", lambda n: n.__reduce_ex__(1)[1:]),
("__mro_entries__", lambda n: n.__mro_entries 8000 __([object])),
("__hash__", lambda n: hash(n)),
("__repr__", lambda n: repr(n)),
("__getitem__", lambda n: n[np.float64]),
("__getitem__", lambda n: n[ScalarType][np.float64]),
("__getitem__", lambda n: n[Union[np.int64, ScalarType]][np.float64]),
("__eq__", lambda n: n == n),
("__ne__", lambda n: n != np.ndarray),
("__dir__", lambda n: dir(n)),
("__call__", lambda n: n((5,), np.int64)),
("__call__", lambda n: n(shape=(5,), dtype=np.int64)),
("subclassing", lambda n: _get_subclass_mro(n)),
("pickle", lambda n: n == pickle.loads(pickle.dumps(n))),
("__weakref__", lambda n: n == weakref.ref(n)()),
]
)
def test_pass(self, name: str, func: FuncType) -> None:
"""Compare `types.GenericAlias` with its numpy-based backport.

Checker whether ``func`` runs as intended and that both `GenericAlias`
and `_GenericAlias` return the same result.

"""
value = func(NDArray)

if sys.version_info >= (3, 9):
value_ref = func(NDArray_ref)
assert value == value_ref

@pytest.mark.parametrize("name", GETATTR_NAMES)
def test_getattr(self, name: str) -> None:
"""Test that `getattr` wraps around the underlying type (``__origin__``)."""
value = getattr(NDArray, name)
value_ref1 = getattr(np.ndarray, name)

if sys.version_info >= (3, 9):
value_ref2 = getattr(NDArray_ref, name)
assert value == value_ref1 == value_ref2
else:
assert value == value_ref1

@pytest.mark.parametrize(
"name,exc_type,func",
[
("__getitem__", TypeError, lambda n: n[()]),
("__getitem__", TypeError, lambda n: n[Any, Any]),
("__getitem__", TypeError, lambda n: n[Any][Any]),
("__instancecheck__", TypeError, lambda n: isinstance(np.array(1), n)),
("__subclasscheck__", TypeError, lambda n: issubclass(np.ndarray, n)),
("__setattr__", AttributeError, lambda n: setattr(n, "__origin__", int)),
("__setattr__", AttributeError, lambda n: setattr(n, "test", int)),
("__getattribute__", AttributeError, lambda n: getattr(n, "test")),
]
)
def test_raise(
self,
name: str,
exc_type: Type[BaseException],
func: FuncType,
) -> None:
"""Test operations that are supposed to raise."""
with pytest.raises(exc_type):
func(NDArray)

if sys.version_info >= (3, 9):
with pytest.raises(exc_type):
func(NDArray_ref)
0