8000 Merge pull request #27653 from jorenham/typing/ndarray-array-api · walshb/numpy@70fde29 · GitHub
[go: up one dir, main page]

8000
Skip to content

Commit 70fde29

Browse files
authored
Merge pull request numpy#27653 from jorenham/typing/ndarray-array-api
TYP: Fix Array API method signatures
2 parents 7ed62d2 + 8d0a319 commit 70fde29

File tree

2 files changed

+22
-26
lines changed

2 files changed

+22
-26
lines changed

numpy/__init__.pyi

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import array as _array
77
import datetime as dt
88
import enum
99
from abc import abstractmethod
10-
from types import EllipsisType, TracebackType, MappingProxyType, GenericAlias
10+
from types import EllipsisType, ModuleType, TracebackType, MappingProxyType, GenericAlias
1111
from decimal import Decimal
1212
from fractions import Fraction
1313
from uuid import UUID
@@ -210,7 +210,7 @@ from typing import (
210210
# This is because the `typeshed` stubs for the standard library include
211211
# `typing_extensions` stubs:
212212
# https://github.com/python/typeshed/blob/main/stdlib/typing_extensions.pyi
213-
from typing_extensions import Generic, LiteralString, Protocol, Self, TypeVar, overload
213+
from typing_extensions import CapsuleType, Generic, LiteralString, Protocol, Self, TypeVar, overload
214214

215215
from numpy import (
216216
core,
@@ -763,7 +763,7 @@ class _SupportsWrite(Protocol[_AnyStr_contra]):
763763
def write(self, s: _AnyStr_contra, /) -> object: ...
764764

765765
__version__: LiteralString
766-
__array_api_version__: LiteralString
766+
__array_api_version__: Final = "2023.12"
767767
test: PytestTester
768768

769769

@@ -1431,7 +1431,7 @@ class _ArrayOrScalarCommon:
14311431
def __array_priority__(self) -> float: ...
14321432
@property
14331433
def __array_struct__(self) -> Any: ... # builtins.PyCapsule
1434-
def __array_namespace__(self, *, api_version: None | _ArrayAPIVersion = ...) -> Any: ...
1434+
def __array_namespace__(self, /, *, api_version: _ArrayAPIVersion | None = None) -> ModuleType: ...
14351435
def __setstate__(self, state: tuple[
14361436
SupportsIndex, # version
14371437
_ShapeLike, # Shape
@@ -1798,11 +1798,6 @@ _ArrayTD64_co: TypeAlias = NDArray[np.bool | integer[Any] | timedelta64]
17981798
# Introduce an alias for `dtype` to avoid naming conflicts.
17991799
_dtype: TypeAlias = dtype[_ScalarType]
18001800

1801-
if sys.version_info >= (3, 13):
1802-
from types import CapsuleType as _PyCapsule
1803-
else:
1804-
_PyCapsule: TypeAlias = Any
1805-
18061801
_ArrayAPIVersion: TypeAlias = L["2021.12", "2022.12", "2023.12"]
18071802

18081803
@type_check_only
@@ -3063,14 +3058,14 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
30633058

30643059
def __dlpack__(
30653060
self: NDArray[number[Any]],
3061+
/,
30663062
*,
3067-
stream: int | Any | None = ...,
3068-
max_version: tuple[int, int] | None = ...,
3069-
dl_device: tuple[int, L[0]] | None = ...,
3070-
copy: bool | None = ...,
3071-
) -> _PyCapsule: ...
3072-
3073-
def __dlpack_device__(self) -> tuple[int, L[0]]: ...
3063+
stream: int | Any | None = None,
3064+
max_version: tuple[int, int] | None = None,
3065+
dl_device: tuple[int, int] | None = None,
3066+
copy: builtins.bool | None = None,
3067+
) -> CapsuleType: ...
3068+
def __dlpack_device__(self, /) -> tuple[L[1], L[0]]: ...
30743069

30753070
def bitwise_count(
30763071
self,
@@ -4727,12 +4722,12 @@ class matrix(ndarray[_Shape2DType_co, _DType_co]):
47274722

47284723
@type_check_only
47294724
class _SupportsDLPack(Protocol[_T_contra]):
4730-
def __dlpack__(self, *, stream: None | _T_contra = ...) -> _PyCapsule: ...
4725+
def __dlpack__(self, /, *, stream: _T_contra | None = None) -> CapsuleType: ...
47314726

47324727
def from_dlpack(
4733-
obj: _SupportsDLPack[None],
4728+
x: _SupportsDLPack[None],
47344729
/,
47354730
*,
4736-
device: L["cpu"] | None = ...,
4737-
copy: bool | None = ...,
4738-
) -> NDArray[Any]: ...
4731+
device: L["cpu"] | None = None,
4732+
copy: builtins.bool | None = None,
4733+
) -> NDArray[number[Any] | np.bool]: ...

numpy/typing/tests/data/reveal/ndarray_misc.pyi

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ function-based counterpart in `../from_numeric.py`.
88

99
import operator
1010
import ctypes as ct
11+
from types import ModuleType
1112
from typing import Any, Literal
1213

1314
import numpy as np
1415
import numpy.typing as npt
1516

16-
from typing_extensions import assert_type
17+
from typing_extensions import CapsuleType, assert_type
1718

1819
class SubClass(npt.NDArray[np.object_]): ...
1920

@@ -30,8 +31,8 @@ AR_V: npt.NDArray[np.void]
3031

3132
ctypes_obj = AR_f8.ctypes
3233

33-
assert_type(AR_f8.__dlpack__(), Any)
34-
assert_type(AR_f8.__dlpack_device__(), tuple[int, Literal[0]])
34+
assert_type(AR_f8.__dlpack__(), CapsuleType)
35+
assert_type(AR_f8.__dlpack_device__(), tuple[Literal[1], Literal[0]])
3536

3637
assert_type(ctypes_obj.data, int)
3738
assert_type(ctypes_obj.shape, ct.Array[np.ctypeslib.c_intp])
@@ -225,5 +226,5 @@ assert_type(AR_u1.to_device("cpu"), npt.NDArray[np.uint8])
225226
assert_type(AR_c8.to_device("cpu"), npt.NDArray[np.complex64])
226227
assert_type(AR_m.to_device("cpu"), npt.NDArray[np.timedelta64])
227228

228-
assert_type(f8.__array_namespace__(), Any)
229-
assert_type(AR_f8.__array_namespace__(), Any)
229+
assert_type(f8.__array_namespace__(), ModuleType)
230+
assert_type(AR_f8.__array_namespace__(), ModuleType)

0 commit comments

Comments
 (0)
0