8000 Merge pull request #22388 from charris/backport-22357 · numpy/numpy@bd936cc · GitHub
[go: up one dir, main page]

Skip to content

Commit bd936cc

Browse files
authored
Merge pull request #22388 from charris/backport-22357
TYP,ENH: Mark ``numpy.typing`` protocols as runtime checkable
2 parents a35df7c + 3890441 commit bd936cc

File tree

5 files changed

+46
-3
lines changed

5 files changed

+46
-3
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
``numpy.typing`` protocols are now runtime checkable
2+
----------------------------------------------------
3+
4+
The protocols used in `~numpy.typing.ArrayLike` and `~numpy.typing.DTypeLike`
5+
are now properly marked as runtime checkable, making them easier to use for
6+
runtime type checkers.

numpy/_typing/_array_like.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# NOTE: Import `Sequence` from `typing` as we it is needed for a type-alias,
44
# not an annotation
55
from collections.abc import Collection, Callable
6-
from typing import Any, Sequence, Protocol, Union, TypeVar
6+
from typing import Any, Sequence, Protocol, Union, TypeVar, runtime_checkable
77
from numpy import (
88
ndarray,
99
dtype,
@@ -33,10 +33,12 @@
3333
# array.
3434
# Concrete implementations of the protocol are responsible for adding
3535
# any and all remaining overloads
36+
@runtime_checkable
3637
class _SupportsArray(Protocol[_DType_co]):
3738
def __array__(self) -> ndarray[Any, _DType_co]: ...
3839

3940

41+
@runtime_checkable
4042
class _SupportsArrayFunc(Protocol):
4143
"""A protocol class representing `~class.__array_function__`."""
4244
def __array_function__(
@@ -146,7 +148,7 @@ def __array_function__(
146148
# Used as the first overload, should only match NDArray[Any],
147149
# not any actual types.
148150
# https://github.com/numpy/numpy/pull/22193
149-
class _UnknownType:
151+
class _UnknownType:
150152
...
151153

152154

numpy/_typing/_dtype_like.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
TypeVar,
99
Protocol,
1010
TypedDict,
11+
runtime_checkable,
1112
)
1213

1314
import numpy as np
@@ -80,6 +81,7 @@ class _DTypeDict(_DTypeDictBase, total=False):
8081

8182

8283
# A protocol for anything with the dtype attribute
84+
@runtime_checkable
8385
class _SupportsDType(Protocol[_DType_co]):
8486
@property
8587
def dtype(self) -> _DType_co: ...

numpy/_typing/_nested_sequence.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
overload,
99
TypeVar,
1010
Protocol,
11+
runtime_checkable,
1112
)
1213

1314
__all__ = ["_NestedSequence"]
1415

1516
_T_co = TypeVar("_T_co", covariant=True)
1617

1718

19+
@runtime_checkable
1820
class _NestedSequence(Protocol[_T_co]):
1921
"""A protocol for representing nested sequences.
2022

numpy/typing/tests/test_runtime.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,19 @@
33
from __future__ import annotations
44

55
import sys
6-
from typing import get_type_hints, Union, NamedTuple, get_args, get_origin
6+
from typing import (
7+
get_type_hints,
8+
Union,
9+
NamedTuple,
10+
get_args,
11+
get_origin,
12+
Any,
13+
)
714

815
import pytest
916
import numpy as np
1017
import numpy.typing as npt
18+
import numpy._typing as _npt
1119

1220

1321
class TypeTup(NamedTuple):
@@ -80,3 +88,26 @@ def test_keys() -> None:
8088
keys = TYPES.keys()
8189
ref = set(npt.__all__)
8290
assert keys == ref
91+
92+
93+
PROTOCOLS: dict[str, tuple[type[Any], object]] = {
94+
"_SupportsDType": (_npt._SupportsDType, np.int64(1)),
95+
"_SupportsArray": (_npt._SupportsArray, np.arange(10)),
96+
"_SupportsArrayFunc": (_npt._SupportsArrayFunc, np.arange(10)),
97+
"_NestedSequence": (_npt._NestedSequence, [1]),
98+
}
99+
100+
101+
@pytest.mark.parametrize("cls,obj", PROTOCOLS.values(), ids=PROTOCOLS.keys())
102+
class TestRuntimeProtocol:
103+
def test_isinstance(self, cls: type[Any], obj: object) -> None:
104+
assert isinstance(obj, cls)
105+
assert not isinstance(None, cls)
106+
107+
def test_issubclass(self, cls: type[Any], obj: object) -> None:
108+
if cls is _npt._SupportsDType:
109+
pytest.xfail(
110+
"Protocols with non-method members don't support issubclass()"
111+
)
112+
assert issubclass(type(obj), cls)
113+
assert not issubclass(type(None), cls)

0 commit comments

Comments
 (0)
0