8000 Merge pull request #26873 from jorenham/typing-array_api_info · numpy/numpy@5479f04 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5479f04

Browse files
authored
Merge pull request #26873 from jorenham/typing-array_api_info
TYP: improved `numpy._array_api_info` typing
2 parents d583c54 + bd7e849 commit 5479f04

File tree

2 files changed

+260
-51
lines changed

2 files changed

+260
-51
lines changed

numpy/_array_api_info.pyi

Lines changed: 192 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,213 @@
1-
from typing import TypedDict, Optional, Union, Tuple, List
2-
from numpy._typing import DtypeLike
1+
import sys
2+
from typing import (
3+
TYPE_CHECKING,
4+
ClassVar,
5+
Literal,
6+
TypeAlias,
7+
TypedDict,
8+
TypeVar,
9+
final,
10+
overload,
11+
)
12+
13+
import numpy as np
14+
15+
if sys.version_info >= (3, 11):
16+
from typing import Never
17+
elif TYPE_CHECKING:
18+
from typing_extensions import Never
19+
else:
20+
# `NoReturn` and `Never` are equivalent (but not equal) for type-checkers,
21+
# but are used in different places by convention
22+
from typing import NoReturn as Never
23+
24+
_Device: TypeAlias = Literal["cpu"]
25+
_DeviceLike: TypeAlias = None | _Device
326

4-
Capabilities = TypedDict(
5-
"Capabilities",
27+
_Capabilities = TypedDict(
28+
"_Capabilities",
629
{
7-
"boolean indexing": bool,
8-
"data-dependent shapes": bool,
30+
"boolean indexing": Literal[True],
31+
"data-dependent shapes": Literal[True],
932
},
1033
)
1134

12-
DefaultDataTypes = TypedDict(
13-
"DefaultDataTypes",
35+
_DefaultDTypes = TypedDict(
36+
"_DefaultDTypes",
1437
{
15-
"real floating": DtypeLike,
16-
"complex floating": DtypeLike,
17-
"integral": DtypeLike,
18-
"indexing": DtypeLike,
38+
"real floating": np.dtype[np.float64],
39+
"complex floating": np.dtype[np.complex128],
40+
"integral": np.dtype[np.intp],
41+
"indexing": np.dtype[np.intp],
1942
},
2043
)
2144

22-
DataTypes = TypedDict(
23-
"DataTypes",
24-
{
25-
"bool": DtypeLike,
26-
"float32": DtypeLike,
27-
"float64": DtypeLike,
28-
"complex64": DtypeLike,
29-
"complex128": DtypeLike,
30-
"int8": DtypeLike,
31-
"int16": DtypeLike,
32-
"int32": DtypeLike,
33-
"int64": DtypeLike,
34-
"uint8": DtypeLike,
35-
"uint16": DtypeLike,
36-
"uint32": DtypeLike,
37-
"uint64": DtypeLike,
38-
},
39-
total=False,
45+
46+
_KindBool: TypeAlias = Literal["bool"]
47+
_KindInt: TypeAlias = Literal["signed integer"]
48+
_KindUInt: TypeAlias = Literal["unsigned integer"]
49+
_KindInteger: TypeAlias = Literal["integral"]
50+
_KindFloat: TypeAlias = Literal["real floating"]
51+
_KindComplex: TypeAlias = Literal["complex floating"]
52+
_KindNumber: TypeAlias = Literal["numeric"]
53+
_Kind: TypeAlias = (
54+
_KindBool
55+
| _KindInt
56+
| _KindUInt
57+
| _KindInteger
58+
| _KindFloat
59+
| _KindComplex
60+
| _KindNumber
4061
)
4162

42-
class __array_namespace_info__:
43-
__module__: str
4463

45-
def capabilities(self) -> Capabilities: ...
64+
_T1 = TypeVar("_T1")
65+
_T2 = TypeVar("_T2")
66+
_T3 = TypeVar("_T3")
67+
_Permute1: TypeAlias = _T1 | tuple[_T1]
68+
_Permute2: TypeAlias = tuple[_T1, _T2] | tuple[_T2, _T1]
69+
_Permute3: TypeAlias = (
70+
tuple[_T1, _T2, _T3] | tuple[_T1, _T3, _T2]
71+
| tuple[_T2, _T1, _T3] | tuple[_T2, _T3, _T1]
72+
| tuple[_T3, _T1, _T2] | tuple[_T3, _T2, _T1]
73+
)
74+
75+
class _DTypesBool(TypedDict):
76+
bool: np.dtype[np.bool]
77+
78+
class _DTypesInt(TypedDict):
79+
int8: np.dtype[np.int8]
80+
int16: np.dtype[np.int16]
81+
int32: np.dtype[np.int32]
82+
int64: np.dtype[np.int64]
83+
84+
class _DTypesUInt(TypedDict):
85+
uint8: np.dtype[np.uint8]
86+
uint16: np.dtype[np.uint16]
87+
uint32: np.dtype[np.uint32]
88+
uint64: np.dtype[np.uint64]
89+
90+
class _DTypesInteger(_DTypesInt, _DTypesUInt):
91+
...
4692

47-
def default_device(self) -> str: ...
93+
class _DTypesFloat(TypedDict):
94+
float32: np.dtype[np.float32]
95+
float64: np.dtype[np.float64]
4896

97+
class _DTypesComplex(TypedDict):
98+
complex64: np.dtype[np.complex64]
99+
complex128: np.dtype[np.complex128]
100+
101+
class _DTypesNumber(_DTypesInteger, _DTypesFloat, _DTypesComplex):
102+
...
103+
104+
class _DTypes(_DTypesBool, _DTypesNumber):
105+
...
106+
107+
class _DTypesUnion(TypedDict, total=False):
108+
bool: np.dtype[np.bool]
109+
int8: np.dtype[np.int8]
110+
int16: np.dtype[np.int16]
111+
int32: np.dtype[np.int32]
112+
int64: np.dtype[np.int64]
113+
uint8: np.dtype[np.uint8]
114+
uint16: np.dtype[np.uint16]
115+
uint32: np.dtype[np.uint32]
116+
uint64: np.dtype[np.uint64]
117+
float32: np.dtype[np.float32]
118+
float64: np.dtype[np.float64]
119+
complex64: np.dtype[np.complex64]
120+
complex128: np.dtype[np.complex128]
121+
122+
_EmptyDict: TypeAlias = dict[Never, Never]
123+
124+
125+
@final
126+
class __array_namespace_info__:
127+
__module__: ClassVar[Literal['numpy']]
128+
129+
def capabilities(self) -> _Capabilities: ...
130+
def default_device(self) -> _Device: ...
49131
def default_dtypes(
50132
self,
51133
*,
52-
device: Optional[str] = None,
53-
) -> DefaultDataTypes: ...
134+
device: _DeviceLike = ...,
135+
) -> _DefaultDTypes: ...
136+
def devices(self) -> list[_Device]: ...
54137

138+
@overload
55139
def dtypes(
56140
self,
57141
*,
58-
device: Optional[str] = None,
59-
kind: Optional[Union[str, Tuple[str, ...]]] = None,
60-
) -> DataTypes: ...
61-
62-
def devices(self) -> List[str]: ...
142+
device: _DeviceLike = ...,
143+
kind: None = ...,
144+
) -> _DTypes: ...
145+
@overload
146+
def dtypes(
147+
self,
148+
*,
149+
device: _DeviceLike = ...,
150+
kind: _Permute1[_KindBool],
151+
) -> _DTypesBool: ...
152+
@overload
153+
def dtypes(
154+
self,
155+
*,
156+
device: _DeviceLike = ...,
157+
kind: _Permute1[_KindInt],
158+
) -> _DTypesInt: ...
159+
@overload
160+
def dtypes(
161+
self,
162+
*,
163+
device: _DeviceLike = ...,
164+
kind: _Permute1[_KindUInt],
165+
) -> _DTypesUInt: ...
166+
@overload
167+
def dtypes(
168+
self,
169+
*,
170+
device: _DeviceLike = ...,
171+
kind: _Permute1[_KindFloat],
172+
) -> _DTypesFloat: ...
173+
@overload
174+
def dtypes(
175+
self,
176+
*,
177+
device: _DeviceLike = ...,
178+
kind: _Permute1[_KindComplex],
179+
) -> _DTypesComplex: ...
180+
@overload
181+
def dtypes(
182+
self,
183+
*,
184+
device: _DeviceLike = ...,
185+
kind: (
186+
_Permute1[_KindInteger]
187+
| _Permute2[_KindInt, _KindUInt]
188+
),
189+
) -> _DTypesInteger: ...
190+
@overload
191+
def dtypes(
192+
self,
193+
*,
194+
device: _DeviceLike = ...,
195+
kind: (
196+
_Permute1[_KindNumber]
197+
| _Permute3[_KindInteger, _KindFloat, _KindComplex]
198+
),
199+
) -> _DTypesNumber: ...
200+
@overload
201+
def dtypes(
202+
self,
203+
*,
204+
device: _DeviceLike = ...,
205+
kind: tuple[()],
206+
) -> _EmptyDict: ...
207+
@overload
208+
def dtypes(
209+
self,
210+
*,
211+
device: _DeviceLike = ...,
212+
kind: F438 tuple[_Kind, ...],
213+
) -> _DTypesUnion: ...
Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,76 @@
11
import sys
2-
from typing import List
2+
from typing import Literal
33

44
import numpy as np
55

66
if sys.version_info >= (3, 11):
7-
from typing import assert_type
7+
from typing import Never, assert_type
88
else:
9-
from typing_extensions import assert_type
9+
from typing_extensions import Never, assert_type
1010

11-
array_namespace_info = np.__array_namespace_info__()
11+
info = np.__array_namespace_info__()
1212

13-
assert_type(array_namespace_info.__module__, str)
14-
assert_type(array_namespace_info.capabilities(), np._array_api_info.Capabilities)
15-
assert_type(array_namespace_info.default_device(), str)
16-
assert_type(array_namespace_info.default_dtypes(), np._array_api_info.DefaultDataTypes)
17-
assert_type(array_namespace_info.dtypes(), np._array_api_info.DataTypes)
18-
assert_type(array_namespace_info.devices(), List[str])
13+
assert_type(info.__module__, Literal["numpy"])
14+
15+
assert_type(info.default_device(), Literal["cpu"])
16+
assert_type(info.devices()[0], Literal["cpu"])
17+
assert_type(info.devices()[-1], Literal["cpu"])
18+
19+
assert_type(info.capabilities()["boolean indexing"], Literal[True])
20+
assert_type(info.capabilities()["data-dependent shapes"], Literal[True])
21+
22+
assert_type(info.default_dtypes()["real floating"], np.dtype[np.float64])
23+
assert_type(info.default_dtypes()["complex floating"], np.dtype[np.complex128])
24+
assert_type(info.default_dtypes()["integral"], np.dtype[np.int_])
25+
assert_type(info.default_dtypes()["indexing"], np.dtype[np.intp])
26+
27+
assert_type(info.dtypes()["bool"], np.dtype[np.bool])
28+
assert_type(info.dtypes()["int8"], np.dtype[np.int8])
29+
assert_type(info.dtypes()["uint8"], np.dtype[np.uint8])
30+
assert_type(info.dtypes()["float32"], np.dtype[np.float32])
31+
assert_type(info.dtypes()["complex64"], np.dtype[np.complex64])
32+
33+
assert_type(info.dtypes(kind="bool")["bool"], np.dtype[np.bool])
34+
assert_type(info.dtypes(kind="signed integer")["int64"], np.dtype[np.int64])
35+
assert_type(info.dtypes(kind="unsigned integer")["uint64"], np.dtype[np.uint64])
36+
assert_type(info.dtypes(kind="integral")["int32"], np.dtype[np.int32])
37+
assert_type(info.dtypes(kind="integral")["uint32"], np.dtype[np.uint32])
38+
assert_type(info.dtypes(kind="real floating")["float64"], np.dtype[np.float64])
39+
assert_type(info.dtypes(kind="complex floating")["complex128"], np.dtype[np.complex128])
40+
assert_type(info.dtypes(kind="numeric")["int16"], np.dtype[np.int16])
41+
assert_type(info.dtypes(kind="numeric")["uint16"], np.dtype[np.uint16])
42+
assert_type(info.dtypes(kind="numeric")["float64"], np.dtype[np.float64])
43+
assert_type(info.dtypes(kind="numeric")["complex128"], np.dtype[np.complex128])
44+
45+
assert_type(info.dtypes(kind=()), dict[Never, Never])
46+
47+
assert_type(info.dtypes(kind=("bool",))["bool"], np.dtype[np.bool])
48+
assert_type(info.dtypes(kind=("signed integer",))["int64"], np.dtype[np.int64])
49+
assert_type(info.dtypes(kind=("integral",))["uint32"], np.dtype[np.uint32])
50+
assert_type(info.dtypes(kind=("complex floating",))["complex128"], np.dtype[np.complex128])
51+
assert_type(info.dtypes(kind=("numeric",))["float64"], np.dtype[np.float64])
52+
53+
assert_type(
54+
info.dtypes(kind=("signed integer", "unsigned integer"))["int8"],
55+
np.dtype[np.int8],
56+
)
57+
assert_type(
58+
info.dtypes(kind=("signed integer", "unsigned integer"))["uint8"],
59+
np.dtype[np.uint8],
60+
)
61+
assert_type(
62+
info.dtypes(kind=("integral", "real floating", "complex floating"))["int16"],
63+
np.dtype[np.int16],
64+
)
65+
assert_type(
66+
info.dtypes(kind=("integral", "real floating", "complex floating"))["uint16"],
67+
np.dtype[np.uint16],
68+
)
69+
assert_type(
70+
info.dtypes(kind=("integral", "real floating", "complex floating"))["float32"],
71+
np.dtype[np.float32],
72+
)
73+
assert_type(
74+
info.dtypes(kind=("integral", "real floating", "complex floating"))["complex64"],
75+
np.dtype[np.complex64],
76+
)

0 commit comments

Comments
 (0)
0