8000 ENH: get more specific about _ArrayLike, make it public · numpy/numpy-stubs@f39c6ef · GitHub
[go: up one dir, main page]

Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

Commit f39c6ef

Browse files
committed
ENH: get more specific about _ArrayLike, make it public
Closes #37. Add tests to check various examples. Note that supporting __array__ also requires making _DtypeLike public too, so this does that as well.
1 parent caef625 commit f39c6ef

File tree

5 files changed

+177
-96
lines changed

5 files changed

+177
-96
lines changed

numpy-stubs/__init__.pyi

Lines changed: 51 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ import sys
33
import datetime as dt
44

55
from numpy.core._internal import _ctypes
6+
from numpy.typing import ArrayLike, DtypeLike
7+
68
from typing import (
79
Any,
810
ByteString,
@@ -36,9 +38,9 @@ else:
3638
from typing import SupportsBytes
3739

3840
if sys.version_info >= (3, 8):
39-
from typing import Literal
41+
from typing import Literal, Protocol
4042
else:
41-
from typing_extensions import Literal
43+
from typing_extensions import Literal, Protocol
4244

4345
# TODO: remove when the full numpy namespace is defined
4446
def __getattr__(name: str) -> Any: ...
@@ -48,57 +50,11 @@ _Shape = Tuple[int, ...]
4850
# Anything that can be coerced to a shape tuple
4951
_ShapeLike = Union[int, Sequence[int]]
5052

51-
_DtypeLikeNested = Any # TODO: wait for support for recursive types
52-
53-
# Anything that can be coerced into numpy.dtype.
54-
# Reference: https://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html
55-
_DtypeLike = Union[
56-
dtype,
57-
# default data type (float64)
58-
None,
59-
# array-scalar types and generic types
60-
type, # TODO: enumerate these when we add type hints for numpy scalars
61-
# TODO: add a protocol for anything with a dtype attribute
62-
# character codes, type strings or comma-separated fields, e.g., 'float64'
63-
str,
64-
# (flexible_dtype, itemsize)
65-
Tuple[_DtypeLikeNested, int],
66-
# (fixed_dtype, shape)
67-
Tuple[_DtypeLikeNested, _ShapeLike],
68-
# [(field_name, field_dtype, field_shape), ...]
69-
#
70-
# The type here is quite broad because NumPy accepts quite a wide
71-
# range of inputs inside the list; see the tests for some
72-
# examples.
73-
List[Any],
74-
# {'names': ..., 'formats': ..., 'offsets': ..., 'titles': ...,
75-
# 'itemsize': ...}
76-
# TODO: use TypedDict when/if it's officially supported
77-
Dict[
78-
str,
79-
Union[
80-
Sequence[str], # names
81-
Sequence[_DtypeLikeNested], # formats
82-
Sequence[int], # offsets
83-
Sequence[Union[bytes, Text, None]], # titles
84-
int, # itemsize
85-
],
86-
],
87-
# {'field1': ..., 'field2': ..., ...}
88-
Dict[str, Tuple[_DtypeLikeNested, int]],
89-
# (base_dtype, new_dtype)
90-
Tuple[_DtypeLikeNested, _DtypeLikeNested],
91-
]
92-
9353
_NdArraySubClass = TypeVar("_NdArraySubClass", bound=ndarray)
9454

95-
_ArrayLike = TypeVar("_ArrayLike")
96-
9755
class dtype:
9856
names: Optional[Tuple[str, ...]]
99-
def __init__(
100-
self, obj: _DtypeLike, align: bool = ..., copy: bool = ...
101-
) -> None: ...
57+
def __init__(self, obj: DtypeLike, align: bool = ..., copy: bool = ...) -> None: ...
10258
@property
10359
def alignment(self) -> int: ...
10460
@property
@@ -217,6 +173,7 @@ class _ArrayOrScalarCommon(
217173
def shape(self) -> _Shape: ...
218174
@property
219175
def strides(self) -> _Shape: ...
176+
def __array__(self, __dtype: DtypeLike = ...) -> ndarray: ...
220177
def __int__(self) -> int: ...
221178
def __float__(self) -> float: ...
222179
def __complex__(self) -> complex: ...
@@ -299,7 +256,7 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
299256
def __new__(
300257
cls,
301258
shape: Sequence[int],
302-
dtype: Union[_DtypeLike, str] = ...,
259+
dtype: DtypeLike = ...,
303260
buffer: _BufferType = ...,
304261
offset: int = ...,
305262
strides: _ShapeLike = ...,
@@ -338,7 +295,7 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
338295
def dumps(self) -> bytes: ...
339296
def astype(
340297
self,
341-
dtype: _DtypeLike,
298+
dtype: DtypeLike,
342299
order: str = ...,
343300
casting: str = ...,
344301
subok: bool = ...,
@@ -349,14 +306,14 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
349306
@overload
350307
def view(self, dtype: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
351308
@overload
352-
def view(self, dtype: _DtypeLike = ...) -> ndarray: ...
309+
def view(self, dtype: DtypeLike = ...) -> ndarray: ...
353310
@overload
354311
def view(
355-
self, dtype: _DtypeLike, type: Type[_NdArraySubClass]
312+
self, dtype: DtypeLike, type: Type[_NdArraySubClass]
356313
) -> _NdArraySubClass: ...
357314
@overload
358315
def view(self, *, type: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
359-
def getfield(self, dtype: Union[_DtypeLike, str], offset: int = ...) -> ndarray: ...
316+
def getfield(self, dtype: DtypeLike, offset: int = ...) -> ndarray: ...
360317
def setflags(
361318
self, write: bool = ..., align: bool = ..., uic: bool = ...
362319
) -> None: ...
@@ -501,26 +458,26 @@ class str_(character): ...
501458

502459
def array(
503460
object: object,
504-
dtype: _DtypeLike = ...,
461+
dtype: DtypeLike = ...,
505462
copy: bool = ...,
506463
subok: bool = ...,
507464
ndmin: int = ...,
508465
) -> ndarray: ...
509466
def zeros(
510-
shape: _ShapeLike, dtype: _DtypeLike = ..., order: Optional[str] = ...
467+
shape: _ShapeLike, dtype: DtypeLike = ..., order: Optional[str] = ...
511468
) -> ndarray: ...
512469
def ones(
513-
shape: _ShapeLike, dtype: _DtypeLike = ..., order: Optional[str] = ...
470+
shape: _ShapeLike, dtype: DtypeLike = ..., order: Optional[str] = ...
514471
) -> ndarray: ...
515472
def zeros_like(
516-
a: _ArrayLike,
473+
a: ArrayLike,
517474
dtype: Optional[dtype] = ...,
518475
order: str = ...,
519476
subok: bool = ...,
520477
shape: Optional[Union[int, Sequence[int]]] = ...,
521478
) -> ndarray: ...
522479
def ones_like(
523-
a: _ArrayLike,
480+
a: ArrayLike,
524481
dtype: Optional[dtype] = ...,
525482
order: str = ...,
526483
subok: bool = ...,
@@ -530,43 +487,43 @@ def full(
530487
shape: _ShapeLike, fill_value: Any, dtype: Optional[dtype] = ..., order: str = ...
531488
) -> ndarray: ...
532489
def full_like(
533-
a: _ArrayLike,
490+
a: ArrayLike,
534491
fill_value: Any,
535492
dtype: Optional[dtype] = ...,
536493
order: str = ...,
537494
subok: bool = ...,
538495
shape: Optional[_ShapeLike] = ...,
539496
) -> ndarray: ...
540497
def count_nonzero(
541-
a: _ArrayLike, axis: Optional[Union[int, Tuple[int], Tuple[int, int]]] = ...
498+
a: ArrayLike, axis: Optional[Union[int, Tuple[int], Tuple[int, int]]] = ...
542499
) -> Union[int, ndarray]: ...
543500
def isfortran(a: ndarray) -> bool: ...
544-
def argwhere(a: _ArrayLike) -> ndarray: ...
545-
def flatnonzero(a: _ArrayLike) -> ndarray: ...
546-
def correlate(a: _ArrayLike, v: _ArrayLike, mode: str = ...) -> ndarray: ...
547-
def convolve(a: _ArrayLike, v: _ArrayLike, mode: str = ...) -> ndarray: ...
548-
def outer(a: _ArrayLike, b: _ArrayLike, out: ndarray = ...) -> ndarray: ...
501+
def argwhere(a: ArrayLike) -> ndarray: ...
502+
def flatnonzero(a: ArrayLike) -> ndarray: ...
503+
def correlate(a: ArrayLike, v: ArrayLike, mode: str = ...) -> ndarray: ...
504+
def convolve(a: ArrayLike, v: ArrayLike, mode: str = ...) -> ndarray: ...
505+
def outer(a: ArrayLike, b: ArrayLike, out: ndarray = ...) -> ndarray: ...
549506
def tensordot(
550-
a: _ArrayLike,
551-
b: _ArrayLike,
507+
a: ArrayLike,
508+
b: ArrayLike,
552509
axes: Union[
553510
int, Tuple[int, int], Tuple[Tuple[int, int], ...], Tuple[List[int, int], ...]
554511
] = ...,
555512
) -> ndarray: ...
556513
def roll(
557-
a: _ArrayLike,
514+
a: ArrayLike,
558515
shift: Union[int, Tuple[int, ...]],
559516
axis: Optional[Union[int, Tuple[int, ...]]] = ...,
560517
) -> ndarray: ...
561-
def rollaxis(a: _ArrayLike, axis: int, start: int = ...) -> ndarray: ...
518+
def rollaxis(a: ArrayLike, axis: int, start: int = ...) -> ndarray: ...
562519
def moveaxis(
563520
a: ndarray,
564521
source: Union[int, Sequence[int]],
565522
destination: Union[int, Sequence[int]],
566523
) -> ndarray: ...
567524
def cross(
568-
a: _ArrayLike,
569-
b: _ArrayLike,
525+
a: ArrayLike,
526+
b: ArrayLike,
570527
axisa: int = ...,
571528
axisb: int = ...,
572529
axisc: int = ...,
@@ -581,21 +538,21 @@ def binary_repr(num: int, width: Optional[int] = ...) -> str: ...
581538
def base_repr(number: int, base: int = ..., padding: int = ...) -> str: ...
582539
def identity(n: int, dtype: Optional[dtype] = ...) -> ndarray: ...
583540
def allclose(
584-
a: _ArrayLike,
585-
b: _ArrayLike,
541+
a: ArrayLike,
542+
b: ArrayLike,
586543
rtol: float = ...,
587544
atol: float = ...,
588545
equal_nan: bool = ...,
589546
) -> bool: ...
590547
def isclose(
591-
a: _ArrayLike,
592-
b: _ArrayLike,
548+
a: ArrayLike,
549+
b: ArrayLike,
593550
rtol: float = ...,
594551
atol: float = ...,
595552
equal_nan: bool = ...,
596553
) -> Union[bool_, ndarray]: ...
597-
def array_equal(a1: _ArrayLike, a2: _ArrayLike) -> bool: ...
598-
def array_equiv(a1: _ArrayLike, a2: _ArrayLike) -> bool: ...
554+
def array_equal(a1: ArrayLike, a2: ArrayLike) -> bool: ...
555+
def array_equiv(a1: ArrayLike, a2: ArrayLike) -> bool: ...
599556

600557
#
601558
# Constants
@@ -649,7 +606,7 @@ class ufunc:
649606
def __name__(self) -> str: ...
650607
def __call__(
651608
self,
652-
*args: _ArrayLike,
609+
*args: ArrayLike,
653610
out: Optional[Union[ndarray, Tuple[ndarray, ...]]] = ...,
654611
where: Optional[ndarray] = ...,
655612
# The list should be a list of tuples of ints, but since we
@@ -664,7 +621,7 @@ class ufunc:
664621
casting: str = ...,
665622
# TODO: make this precise when we can use Literal.
666623
order: Optional[str] = ...,
667-
dtype: Optional[_DtypeLike] = ...,
624+
dtype: DtypeLike = ...,
668625
subok: bool = ...,
669626
signature: Union[str, Tuple[str]] = ...,
670627
# In reality this should be a length of list 3 containing an
@@ -876,56 +833,56 @@ def take(
876833
) -> _ScalarNumpy: ...
877834
@overload
878835
def take(
879-
a: _ArrayLike,
836+
a: ArrayLike,
880837
indices: int,
881838
axis: Optional[int] = ...,
882839
out: Optional[ndarray] = ...,
883840
mode: _Mode = ...,
884841
) -> _ScalarNumpy: ...
885842
@overload
886843
def take(
887-
a: _ArrayLike,
844+
a: ArrayLike,
888845
indices: _ArrayLikeIntOrBool,
889846
axis: Optional[int] = ...,
890847
out: Optional[ndarray] = ...,
891848
mode: _Mode = ...,
892849
) -> Union[_ScalarNumpy, ndarray]: ...
893-
def reshape(a: _ArrayLike, newshape: _ShapeLike, order: _Order = ...) -> ndarray: ...
850+
def reshape(a: ArrayLike, newshape: _ShapeLike, order: _Order = ...) -> ndarray: ...
894851
@overload
895852
def choose(
896853
a: _ScalarIntOrBool,
897-
choices: Union[Sequence[_ArrayLike], ndarray],
854+
choices: Union[Sequence[ArrayLike], ndarray],
898855
out: Optional[ndarray] = ...,
899856
mode: _Mode = ...,
900857
) -> _ScalarIntOrBool: ...
901858
@overload
902859
def choose(
903860
a: _IntOrBool,
904-
choices: Union[Sequence[_ArrayLike], ndarray],
861+
choices: Union[Sequence[ArrayLike], ndarray],
905862
out: Optional[ndarray] = ...,
906863
mode: _Mode = ...,
907864
) -> Union[integer, bool_]: ...
908865
@overload
909866
def choose(
910867
a: _ArrayLikeIntOrBool,
911-
choices: Union[Sequence[_ArrayLike], ndarray],
868+
choices: Union[Sequence[ArrayLike], ndarray],
912869
out: Optional[ndarray] = ...,
913870
mode: _Mode = ...,
914871
) -> ndarray: ...
915872
def repeat(
916-
a: _ArrayLike, repeats: _ArrayLikeIntOrBool, axis: Optional[int] = ...
873+
a: ArrayLike, repeats: _ArrayLikeIntOrBool, axis: Optional[int] = ...
917874
) -> ndarray: ...
918875
def put(
919-
a: ndarray, ind: _ArrayLikeIntOrBool, v: _ArrayLike, mode: _Mode = ...
876+
a: ndarray, ind: _ArrayLikeIntOrBool, v: ArrayLike, mode: _Mode = ...
920877
) -> None: ...
921878
def swapaxes(
922-
a: Union[Sequence[_ArrayLike], ndarray], axis1: int, axis2: int
879+
a: Union[Sequence[ArrayLike], ndarray], axis1: int, axis2: int
923880
) -> ndarray: ...
924881
def transpose(
925-
a: _ArrayLike, axes: Union[None, Sequence[int], ndarray] = ...
882+
a: ArrayLike, axes: Union[None, Sequence[int], ndarray] = ...
926883
) -> ndarray: ...
927884
def partition(
928-
a: _ArrayLike,
885+
a: ArrayLike,
929886
kth: _ArrayLikeIntOrBool,
930887
axis: Optional[int] = ...,
931888
kind: _PartitionKind = ...,
@@ -949,20 +906,20 @@ def argpartition(
949906
) -> ndarray: ...
950907
@overload
951908
def argpartition(
952-
a: _ArrayLike,
909+
a: ArrayLike,
953910
kth: _ArrayLikeIntOrBool,
954911
axis: Optional[int] = ...,
955912
kind: _PartitionKind = ...,
956913
order: Union[None, str, Sequence[str]] = ...,
957914
) -> ndarray: ...
958915
def sort(
959-
a: Union[Sequence[_ArrayLike], ndarray],
916+
a: Union[Sequence[ArrayLike], ndarray],
960917
axis: Optional[int] = ...,
961918
kind: Optional[_SortKind] = ...,
962919
order: Union[None, str, Sequence[str]] = ...,
963920
) -> ndarray: ...
964921
def argsort(
965-
a: Union[Sequence[_ArrayLike], ndarray],
922+
a: Union[Sequence[ArrayLike], ndarray],
966923
axis: Optional[int] = ...,
967924
kind: Optional[_SortKind] = ...,
968925
order: Union[None, str, Sequence[str]] = ...,

0 commit comments

Comments
 (0)
0