10000 ENH: get more specific about _ArrayLike, make it public · numpy/numpy-stubs@e30b16a · GitHub
[go: up one dir, main page]

Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

Commit e30b16a

Browse files
committed
ENH: get more specific about _ArrayLike, make it public
Closes #37. Add tests to check various examples. Note that supporting __array__ also requires making _DtypeLike public too, so this does that as well.
1 parent 2011153 commit e30b16a

File tree

5 files changed

+103
-55
lines changed

5 files changed

+103
-55
lines changed

numpy-stubs/__init__.pyi

Lines changed: 58 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ else:
3636
from typing import SupportsBytes
3737

3838
if sys.version_info >= (3, 8):
39-
from typing import Literal
39+
from typing import Literal, Protocol
4040
else:
41-
from typing_extensions import Literal
41+
from typing_extensions import Literal, Protocol
4242

4343
# TODO: remove when the full numpy namespace is defined
4444
def __getattr__(name: str) -> Any: ...
@@ -52,7 +52,7 @@ _DtypeLikeNested = Any # TODO: wait for support for recursive types
5252

5353
# Anything that can be coerced into numpy.dtype.
5454
# Reference: https://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html
55-
_DtypeLike = Union[
55+
DtypeLike = Union[
5656
dtype,
5757
# default data type (float64)
5858
None,
@@ -92,13 +92,17 @@ _DtypeLike = Union[
9292

9393
_NdArraySubClass = TypeVar("_NdArraySubClass", bound=ndarray)
9494

95-
_ArrayLike = TypeVar("_ArrayLike")
95+
class _SupportsArray(Protocol):
96+
@overload
97+
def __array__(self, __dtype: DtypeLike = ...) -> ndarray: ...
98+
@overload
99+
def __array__(self, dtype: Optional[DtypeLike] = ...) -> ndarray: ...
100+
101+
ArrayLike = Union[bool, int, float, complex, _SupportsArray, Sequence]
96102

97103
class dtype:
98104
names: Optional[Tuple[str, ...]]
99-
def __init__(
100-
self, obj: _DtypeLike, align: bool = ..., copy: bool = ...
101-
) -> None: ...
105+
def __init__(self, obj: DtypeLike, align: bool = ..., copy: bool = ...) -> None: ...
102106
@property
103107
def alignment(self) -> int: ...
104108
@property
@@ -217,6 +221,7 @@ class _ArrayOrScalarCommon(
217221
def shape(self) -> _Shape: ...
218222
@property
219223
def strides(self) -> _Shape: ...
224+
def __array__(self, __dtype: Optional[DtypeLike] = ...) -> ndarray: ...
220225
def __int__(self) -> int: ...
221226
def __float__(self) -> float: ...
222227
def __complex__(self) -> complex: ...
@@ -299,7 +304,7 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
299304
def __new__(
300305
cls,
301306
shape: Sequence[int],
302-
dtype: Union[_DtypeLike, str] = ...,
307+
dtype: Union[DtypeLike, str] = ...,
303308
buffer: _BufferType = ...,
304309
offset: int = ...,
305310
strides: _ShapeLike = ...,
@@ -338,7 +343,7 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
338343
def dumps(self) -> bytes: ...
339344
def astype(
340345
self,
341-
dtype: _DtypeLike,
346+
dtype: DtypeLike,
342347
order: str = ...,
343348
casting: str = ...,
344349
subok: bool = ...,
@@ -349,14 +354,14 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
349354
@overload
350355
def view(self, dtype: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
351356
@overload
352-
def view(self, dtype: _DtypeLike = ...) -> ndarray: ...
357+
def view(self, dtype: DtypeLike = ...) -> ndarray: ...
353358
@overload
354359
def view(
355-
self, dtype: _DtypeLike, type: Type[_NdArraySubClass]
360+
self, dtype: DtypeLike, type: Type[_NdArraySubClass]
356361
) -> _NdArraySubClass: ...
357362
@overload
358363
def view(self, *, type: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
359-
def getfield(self, dtype: Union[_DtypeLike, str], offset: int = ...) -> ndarray: ...
364+
def getfield(self, dtype: Union[DtypeLike, str], offset: int = ...) -> ndarray: ...
360365
def setflags(
361366
self, write: bool = ..., align: bool = ..., uic: bool = ...
362367
) -> None: ...
@@ -484,26 +489,26 @@ class str_(character): ...
484489

485490
def array(
486491
object: object,
487-
dtype: _DtypeLike = ...,
492+
dtype: DtypeLike = ...,
488493
copy: bool = ...,
489494
subok: bool = ...,
490495
ndmin: int = ...,
491496
) -> ndarray: ...
492497
def zeros(
493-
shape: _ShapeLike, dtype: _DtypeLike = ..., order: Optional[str] = ...
498+
shape: _ShapeLike, dtype: DtypeLike = ..., order: Optional[str] = ...
494499
) -> ndarray: ...
495500
def ones(
496-
shape: _ShapeLike, dtype: _DtypeLike = ..., order: Optional[str] = ...
501+
shape: _ShapeLike, dtype: DtypeLike = ..., order: Optional[str] = ...
497502
) -> ndarray: ...
498503
def zeros_like(
499-
a: _ArrayLike,
504+
a: ArrayLike,
500505
dtype: Optional[dtype] = ...,
501506
order: str = ...,
502507
subok: bool = ...,
503508
shape: Optional[Union[int, Sequence[int]]] = ...,
504509
) -> ndarray: ...
505510
def ones_like(
506-
a: _ArrayLike,
511+
a: ArrayLike,
507512
dtype: Optional[dtype] = ...,
508513
order: str = ...,
509514
subok: bool = ...,
@@ -513,43 +518,43 @@ def full(
513518
shape: _ShapeLike, fill_value: Any, dtype: Optional[dtype] = ..., order: str = ...
514519
) -> ndarray: ...
515520
def full_like(
516-
a: _ArrayLike,
521+
a: ArrayLike,
517522
fill_value: Any,
518523
dtype: Optional[dtype] = ...,
519524
order: str = ...,
520525
subok: bool = ...,
521526
shape: Optional[_ShapeLike] = ...,
522527
) -> ndarray: ...
523528
def count_nonzero(
524-
a: _ArrayLike, axis: Optional[Union[int, Tuple[int], Tuple[int, int]]] = ...
529+
a: ArrayLike, axis: Optional[Union[int, Tuple[int], Tuple[int, int]]] = ...
525530
) -> Union[int, ndarray]: ...
526531
def isfortran(a: ndarray) -> bool: ...
527-
def argwhere(a: _ArrayLike) -> ndarray: ...
528-
def flatnonzero(a: _ArrayLike) -> ndarray: ...
529-
def correlate(a: _ArrayLike, v: _ArrayLike, mode: str = ...) -> ndarray: ...
530-
def convolve(a: _ArrayLike, v: _ArrayLike, mode: str = ...) -> ndarray: ...
531-
def outer(a: _ArrayLike, b: _ArrayLike, out: ndarray = ...) -> ndarray: ...
532+
def argwhere(a: ArrayLike) -> ndarray: ...
533+
def flatnonzero(a: ArrayLike) -> ndarray: ...
534+
def correlate(a: ArrayLike, v: ArrayLike, mode: str = ...) -> ndarray: ...
535+
def convolve(a: ArrayLike, v: ArrayLike, mode: str = ...) -> ndarray: ...
536+
def outer(a: ArrayLike, b: ArrayLike, out: ndarray = ...) -> ndarray: ...
532537
def tensordot(
533-
a: _ArrayLike,
534-
b: _ArrayLike,
538+
a: ArrayLike,
539+
b: ArrayLike,
535540
axes: Union[
536541
int, Tuple[int, int], Tuple[Tuple[int, int], ...], Tuple[List[int, int], ...]
537542
] = ...,
538543
) -> ndarray: ...
539544
def roll(
540-
a: _ArrayLike,
545+
a: ArrayLike,
541546
shift: Union[int, Tuple[int, ...]],
542547
axis: Optional[Union[int, Tuple[int, ...]]] = ...,
543548
) -> ndarray: ...
544-
def rollaxis(a: _ArrayLike, axis: int, start: int = ...) -> ndarray: ...
549+
def rollaxis(a: ArrayLike, axis: int, start: int = ...) -> ndarray: ...
545550
def moveaxis(
546551
a: ndarray,
547552
source: Union[int, Sequence[int]],
548553
destination: Union[int, Sequence[int]],
549554
) -> ndarray: ...
550555
def cross(
551-
a: _ArrayLike,
552-
b: _ArrayLike,
556+
a: ArrayLike,
557+
b: ArrayLike,
553558
axisa: int = ...,
554559
axisb: int = ...,
555560
axisc: int = ...,
@@ -564,21 +569,21 @@ def binary_repr(num: int, width: Optional[int] = ...) -> str: ...
564569
def base_repr(number: int, base: int = ..., padding: int = ...) -> str: ...
565570
def identity(n: int, dtype: Optional[dtype] = ...) -> ndarray: ...
566571
def allclose(
567-
a: _ArrayLike,
568-
b: _ArrayLike,
572+
a: ArrayLike,
573+
b: ArrayLike,
569574
rtol: float = ...,
570575
atol: float = ...,
571576
equal_nan: bool = ...,
572577
) -> bool: ...
573578
def isclose(
574-
a: _ArrayLike,
575-
b: _ArrayLike,
579+
a: ArrayLike,
580+
b: ArrayLike,
576581
rtol: float = ...,
577582
atol: float = ...,
578583
equal_nan: bool = ...,
579584
) -> Union[bool_, ndarray]: ...
580-
def array_equal(a1: _ArrayLike, a2: _ArrayLike) -> bool: ...
581-
def array_equiv(a1: _ArrayLike, a2: _ArrayLike) -> bool: ...
585+
def array_equal(a1: ArrayLike, a2: ArrayLike) -> bool: ...
586+
def array_equiv(a1: ArrayLike, a2: ArrayLike) -> bool: ...
582587

583588
#
584589
# Constants
@@ -632,7 +637,7 @@ class ufunc:
632637
def __name__(self) -> str: ...
633638
def __call__(
634639
self,
635-
*args: _ArrayLike,
640+
*args: ArrayLike,
636641
out: Optional[Union[ndarray, Tuple[ndarray, ...]]] = ...,
637642
where: Optional[ndarray] = ...,
638643
# The list should be a list of tuples of ints, but since we
@@ -647,7 +652,7 @@ class ufunc:
647652
casting: str = ...,
648653
# TODO: make this precise when we can use Literal.
649654
order: Optional[str] = ...,
650-
dtype: Optional[_DtypeLike] = ...,
655+
dtype: Optional[DtypeLike] = ...,
651656
subok: bool = ...,
652657
signature: Union[str, Tuple[str]] = ...,
653658
# In reality this should be a length of list 3 containing an
@@ -845,74 +850,74 @@ def take(
845850
) -> _ScalarNumpy: ...
846851
@overload
847852
def take(
848-
a: _ArrayLike,
853+
a: ArrayLike,
849854
indices: int,
850855
axis: Optional[int] = ...,
851856
out: Optional[ndarray] = ...,
852857
mode: _Mode = ...,
853858
) -> _ScalarNumpy: ...
854859
@overload
855860
def take(
856-
a: _ArrayLike,
861+
a: ArrayLike,
857862
indices: _ArrayLikeInt,
858863
axis: Optional[int] = ...,
859864
out: Optional[ndarray] = ...,
860865
mode: _Mode = ...,
861866
) -> Union[_ScalarNumpy, ndarray]: ...
862-
def reshape(a: _ArrayLike, newshape: _ShapeLike, order: _Order = ...) -> ndarray: ...
867+
def reshape(a: ArrayLike, newshape: _ShapeLike, order: _Order = ...) -> ndarray: ...
863868
@overload
864869
def choose(
865870
a: _ScalarGeneric,
866-
choices: Union[Sequence[_ArrayLike], ndarray],
871+
choices: Union[Sequence[ArrayLike], ndarray],
867872
out: Optional[ndarray] = ...,
868873
mode: _Mode = ...,
869874
) -> _ScalarGeneric: ...
870875
@overload
871876
def choose(
872877
a: _Scalar,
873-
choices: Union[Sequence[_ArrayLike], ndarray],
878+
choices: Union[Sequence[ArrayLike], ndarray],
874879
out: Optional[ndarray] = ...,
875880
mode: _Mode = ...,
876881
) -> _ScalarNumpy: ...
877882
@overload
878883
def choose(
879-
a: _ArrayLike,
880-
choices: Union[Sequence[_ArrayLike], ndarray],
884+
a: ArrayLike,
885+
choices: Union[Sequence[ArrayLike], ndarray],
881886
out: Optional[ndarray] = ...,
882887
mode: _Mode = ...,
883888
) -> ndarray: ...
884889
def repeat(
885-
a: _ArrayLike, repeats: _ArrayLikeInt, axis: Optional[int] = ...
890+
a: ArrayLike, repeats: _ArrayLikeInt, axis: Optional[int] = ...
886891
) -> ndarray: ...
887-
def put(a: ndarray, ind: _ArrayLikeInt, v: _ArrayLike, mode: _Mode = ...) -> None: ...
892+
def put(a: ndarray, ind: _ArrayLikeInt, v: ArrayLike, mode: _Mode = ...) -> None: ...
888893
def swapaxes(
889-
a: Union[Sequence[_ArrayLike], ndarray], axis1: int, axis2: int
894+
a: Union[Sequence[ArrayLike], 10000 ndarray], axis1: int, axis2: int
890895
) -> ndarray: ...
891896
def transpose(
892-
a: _ArrayLike, axes: Union[None, Sequence[int], ndarray] = ...
897+
a: ArrayLike, axes: Union[None, Sequence[int], ndarray] = ...
893898
) -> ndarray: ...
894899
def partition(
895-
a: _ArrayLike,
900+
a: ArrayLike,
896901
kth: _ArrayLikeInt,
897902
axis: Optional[int] = ...,
898903
kind: _PartitionKind = ...,
899904
order: Union[None, str, Sequence[str]] = ...,
900905
) -> ndarray: ...
901906
def argpartition(
902-
a: _ArrayLike,
907+
a: ArrayLike,
903908
kth: _ArrayLikeInt,
904909
axis: Optional[int] = ...,
905910
kind: _PartitionKind = ...,
906911
order: Union[None, str, Sequence[str]] = ...,
907912
) -> ndarray: ...
908913
def sort(
909-
a: Union[Sequence[_ArrayLike], ndarray],
914+
a: Union[Sequence[ArrayLike], ndarray],
910915
axis: Optional[int] = ...,
911916
kind: Optional[_SortKind] = ...,
912917
order: Union[None, str, Sequence[str]] = ...,
913918
) -> ndarray: ...
914919
def argsort(
915-
a: Union[Sequence[_ArrayLike], ndarray],
920+
a: Union[Sequence[ArrayLike], ndarray],
916921
axis: Optional[int] = ...,
917922
kind: Optional[_SortKind] = ...,
918923
order: Union[None, str, Sequence[str]] = ...,

scripts/array_protocol.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from typing_extensions import Protocol
2+
3+
import numpy as np

tests/fail/array_like.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import numpy as np
2+
3+
4+
class A:
5+
pass
6+
7+
8+
x1: "np.ArrayLike" = (i for i in range(10)) # E: Incompatible types in assignment
9+
x2: "np.ArrayLike" = A() # E: Incompatible types in assignment
10+
x3: "np.ArrayLike" = {1: "foo", 2: "bar"} # E: Incompatible types in assignment

tests/pass/array_like.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from typing import Any, List, Optional
2+
3+
import numpy as np
4+
5+
x1: "np.ArrayLike" = True
6+
x2: "np.ArrayLike" = 5
7+
x3: "np.ArrayLike" = 1.0
8+
x4: "np.ArrayLike" = 1 + 1j
9+
x5: "np.ArrayLike" = np.int8(1)
10+
x6: "np.ArrayLike" = np.float64(1)
11+
x7: "np.ArrayLike" = np.complex128(1)
12+
x8: "np.ArrayLike" = np.array([1, 2, 3])
13+
x9: "np.ArrayLike" = [1, 2, 3]
14+
x10: "np.ArrayLike" = (1, 2, 3)
15+
x11: "np.ArrayLike" = "foo"
16+
17+
18+
class B:
19+
def __array__(self, dtype: Optional["np.DtypeLike"] = None) -> np.ndarray:
20+
return np.array([1, 2, 3])
21+
22+
23+
x12: "np.ArrayLike" = B()
24+
x13: "np._SupportsArray" = np.int64(1)
25+
x14: "np._SupportsArray" = np.array(1)
26+
27+
# Escape hatch for when you mean to make something like an object
28+
# array.
29+
object_array_scalar: Any = (i for i in range(10))
30+
np.array(object_array_scalar)

tests/reveal/fromnumeric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323
np.take(B, 0) # E: Union[numpy.generic, datetime.datetime, datetime.timedelta]
2424
)
2525
reveal_type(
26-
np.take( # E: Union[numpy.generic, datetime.datetime, datetime.timedelta, numpy.ndarray]
26+
np.take( # E: Union[Union[numpy.generic, datetime.datetime, datetime.timedelta], numpy.ndarray]
2727
A, [0]
2828
)
2929
)
3030
reveal_type(
31-
np.take( # E: Union[numpy.generic, datetime.datetime, datetime.timedelta, numpy.ndarray]
31+
np.take( # E: Union[Union[numpy.generic, datetime.datetime, datetime.timedelta], numpy.ndarray]
3232
B, [0]
3333
)
3434
)

0 commit comments

Comments
 (0)
0