8000 Also support modifying the return type of the ufunc · numpy/numpy-stubs@a1b405e · GitHub
[go: up one dir, main page]

Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

Commit a1b405e

Browse files
committed
Also support modifying the return type of the ufunc
1 parent 463739c commit a1b405e

File tree

4 files changed

+119
-101
lines changed

4 files changed

+119
-101
lines changed

numpy-stubs/__init__.pyi

Lines changed: 93 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,18 @@ from typing import (
2828
TypeVar,
2929
Union,
3030
)
31-
from typing_extensions import Literal
3231

3332
if sys.version_info[0] < 3:
3433
class SupportsBytes: ...
3534

3635
else:
3736
from typing import SupportsBytes
3837

38+
if sys.version_info >= (3, 8):
39+
from typing import Literal
40+
else:
41+
from typing_extensions import Literal
42+
3943
_Shape = Tuple[int, ...]
4044

4145
# Anything that can be coerced to a shape tuple
@@ -650,11 +654,11 @@ class ufunc(Generic[_Nin], Generic[_Nout]):
650654
# int, an int, and a callable, but there's no way to express
651655
# that.
652656
extobj: List[Union[int, Callable]] = ...,
653-
) -> Union[ndarray, generic]: ...
657+
) -> Union[ndarray, generic, Tuple[Union[ndarray, generic], ...]]: ...
654658
@property
655-
def nin(self) -> int: ...
659+
def nin(self) -> _Nin: ...
656660
@property
657-
def nout(self) -> int: ...
661+
def nout(self) -> _Nout: ...
658662
@property
659663
def nargs(self) -> int: ...
660664
@property
@@ -693,92 +697,92 @@ class ufunc(Generic[_Nin], Generic[_Nout]):
693697
@property
694698
def at(self) -> Any: ...
695699

696-
absolute: ufunc
697-
add: ufunc
698-
arccos: ufunc
699-
arccosh: ufunc
700-
arcsin: ufunc
701-
arcsinh: ufunc
702-
arctan2: ufunc
703-
arctan: ufunc
704-
arctanh: ufunc
705-
bitwise_and: ufunc
706-
bitwise_or: ufunc
707-
bitwise_xor: ufunc
708-
cbrt: ufunc
709-
ceil: ufunc
710-
conjugate: ufunc
711-
copysign: ufunc
712-
cos: ufunc
713-
cosh: ufunc
714-
deg2rad: ufunc
715-
degrees: ufunc
716-
divmod: ufunc
717-
equal: ufunc
718-
exp2: ufunc
719-
exp: ufunc
720-
expm1: ufunc
721-
fabs: ufunc
722-
float_power: ufunc
723-
floor: ufunc
724-
floor_divide: ufunc
725-
fmax: ufunc
726-
fmin: ufunc
727-
fmod: ufunc
728-
frexp: ufunc
729-
gcd: ufunc
730-
greater: ufunc
731-
greater_equal: ufunc
732-
heaviside: ufunc
733-
hypot: ufunc
734-
invert: ufunc
735-
isfinite: ufunc
736-
isinf: ufunc
737-
isnan: ufunc
738-
isnat: ufunc
739-
lcm: ufunc
740-
ldexp: ufunc
741-
left_shift: ufunc
742-
less: ufunc
743-
less_equal: ufunc
744-
log10: ufunc
745-
log1p: ufunc
746-
log2: ufunc
747-
log: ufunc
748-
logaddexp2: ufunc
749-
logaddexp: ufunc
750-
logical_and: ufunc
751-
logical_not: ufunc
752-
logical_or: ufunc
753-
logical_xor: ufunc
754-
matmul: ufunc
755-
maximum: ufunc
756-
minimum: ufunc
757-
modf: ufunc
758-
multiply: ufunc
759-
negative: ufunc
760-
nextafter: ufunc
761-
not_equal: ufunc
762-
positive: ufunc
763-
power: ufunc
764-
rad2deg: ufunc
765-
radians: ufunc
766-
reciprocal: ufunc
767-
remainder: ufunc
768-
right_shift: ufunc
769-
rint: ufunc
770-
sign: ufunc
771-
signbit: ufunc
700+
absolute: ufunc[Literal[1], Literal[1]]
701+
add: ufunc[Literal[2], Literal[1]]
702+
arccos: ufunc[Literal[1], Literal[1]]
703+
arccosh: ufunc[Literal[1], Literal[1]]
704+
arcsin: ufunc[Literal[1], Literal[1]]
705+
arcsinh: ufunc[Literal[1], Literal[1]]
706+
arctan2: ufunc[Literal[2], Literal[1]]
707+
arctan: ufunc[Literal[1], Literal[1]]
708+
arctanh: ufunc[Literal[1], Literal[1]]
709+
bitwise_and: ufunc[Literal[2], Literal[1]]
710+
bitwise_or: ufunc[Literal[2], Literal[1]]
711+
bitwise_xor: ufunc[Literal[2], Literal[1]]
712+
cbrt: ufunc[Literal[1], Literal[1]]
713+
ceil: ufunc[Literal[1], Literal[1]]
714+
conjugate: ufunc[Literal[1], Literal[1]]
715+
copysign: ufunc[Literal[2], Literal[1]]
716+
cos: ufunc[Literal[1], Literal[1]]
717+
cosh: ufunc[Literal[1], Literal[1]]
718+
deg2rad: ufunc[Literal[1], Literal[1]]
719+
degrees: ufunc[Literal[1], Literal[1]]
720+
divmod: ufunc[Literal[2], Literal[2]]
721+
equal: ufunc[Literal[2], Literal[1]]
722+
exp2: ufunc[Literal[1], Literal[1]]
723+
exp: ufunc[Literal[1], Literal[1]]
724+
expm1: ufunc[Literal[1], Literal[1]]
725+
fabs: ufunc[Literal[1], Literal[1]]
726+
float_power: ufunc[Literal[2], Literal[1]]
727+
floor: ufunc[Literal[1], Literal[1]]
728+
floor_divide: ufunc[Literal[2], Literal[1]]
729+
fmax: ufunc[Literal[2], Literal[1]]
730+
fmin: ufunc[Literal[2], Literal[1]]
731+
fmod: ufunc[Literal[2], Literal[1]]
732+
frexp: ufunc[Literal[1], Literal[2]]
733+
gcd: ufunc[Literal[2], Literal[1]]
734+
greater: ufunc[Literal[2], Literal[1]]
735+
greater_equal: ufunc[Literal[2], Literal[1]]
736+
heaviside: ufunc[Literal[2], Literal[1]]
737+
hypot: ufunc[Literal[2], Literal[1]]
738+
invert: ufunc[Literal[1], Literal[1]]
739+
isfinite: ufunc[Literal[1], Literal[1]]
740+
isinf: ufunc[Literal[1], Literal[1]]
741+
isnan: ufunc[Literal[1], Literal[1]]
742+
isnat: ufunc[Literal[1], Literal[1]]
743+
lcm: ufunc[Literal[2], Literal[1]]
744+
ldexp: ufunc[Literal[2], Literal[1]]
745+
left_shift: ufunc[Literal[2], Literal[1]]
746+
less: ufunc[Literal[2], Literal[1]]
747+
less_equal: ufunc[Literal[2], Literal[1]]
748+
log10: ufunc[Literal[1], Literal[1]]
749+
log1p: ufunc[Literal[1], Literal[1]]
750+
log2: ufunc[Literal[1], Literal[1]]
751+
log: ufunc[Literal[1], Literal[1]]
752+
logaddexp2: ufunc[Literal[2], Literal[1]]
753+
logaddexp: ufunc[Literal[2], Literal[1]]
754+
logical_and: ufunc[Literal[2], Literal[1]]
755+
logical_not: ufunc[Literal[1], Literal[1]]
756+
logical_or: ufunc[Literal[2], Literal[1]]
757+
logical_xor: ufunc[Literal[2], Literal[1]]
758+
matmul: ufunc[Literal[2], Literal[1]]
759+
maximum: ufunc[Literal[2], Literal[1]]
760+
minimum: ufunc[Literal[2], Literal[1]]
761+
modf: ufunc[Literal[1], Literal[2]]
762+
multiply: ufunc[Literal[2], Literal[1]]
763+
negative: ufunc[Literal[1], Literal[1]]
764+
nextafter: ufunc[Literal[2], Literal[1]]
765+
not_equal: ufunc[Literal[2], Literal[1]]
766+
positive: ufunc[Literal[1], Literal[1]]
767+
power: ufunc[Literal[2], Literal[1]]
768+
rad2deg: ufunc[Literal[1], Literal[1]]
769+
radians: ufunc[Literal[1], Literal[1]]
770+
reciprocal: ufunc[Literal[1], Literal[1]]
771+
remainder: ufunc[Literal[2], Literal[1]]
772+
right_shift: ufunc[Literal[2], Literal[1]]
773+
rint: ufunc[Literal[1], Literal[1]]
774+
sign: ufunc[Literal[1], Literal[1]]
775+
signbit: ufunc[Literal[1], Literal[1]]
772776
sin: ufunc[Literal[1], Literal[1]]
773-
sinh: ufunc
774-
spacing: ufunc
775-
sqrt: ufunc
776-
square: ufunc
777-
subtract: ufunc
778-
tan: ufunc
779-
tanh: ufunc
780-
true_divide: ufunc
781-
trunc: ufunc
777+
sinh: ufunc[Literal[1], Literal[1]]
778+
spacing: ufunc[Literal[1], Literal[1]]
779+
sqrt: ufunc[Literal[1], Literal[1]]
780+
square: ufunc[Literal[1], Literal[1]]
781+
subtract: ufunc[Literal[2], Literal[1]]
782+
tan: ufunc[Literal[1], Literal[1]]
783+
tanh: ufunc[Literal[1], Literal[1]]
784+
true_divide: ufunc[Literal[2], Literal[1]]
785+
trunc: ufunc[Literal[1], Literal[1]]
782786

783787
# TODO(shoyer): remove when the full numpy namespace is defined
784788
def __getattr__(name: str) -> Any: ...

numpy_ufuncs_plugin.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
11
from mypy.nodes import ARG_POS
22
from mypy.plugin import Plugin
3-
from mypy.types import CallableType
3+
import mypy.types
4+
from mypy.types import CallableType, LiteralType, TupleType, UnionType
45

56

67
def ufunc_call_hook(ctx):
7-
ufunc_name = ctx.context.callee.name
8-
9-
type_info = ctx.type.serialize()
10-
nin_arg, nout_arg = type_info['args']
11-
if nin_arg['.class'] != 'LiteralType':
8+
print(ctx.type.args)
9+
nin_arg, nout_arg = ctx.type.args
10+
if not isinstance(nin_arg, LiteralType):
11+
# Not a literal; we can't make the signature any more precise.
1212
return ctx.default_signature
13-
if nout_arg['.class'] != 'LiteralType':
13+
if not isinstance(nout_arg, LiteralType):
1414
return ctx.default_signature
15+
nin, nout = nin_arg.value, nout_arg.value
1516

16-
nin = nin_arg['value']
17-
nout = nout_arg['value']
18-
19-
# Strip off the *args and replace it with the correct number of
17+
# Strip off *args and replace it with the correct number of
2018
# positional arguments.
2119
arg_kinds = [ARG_POS] * nin + ctx.default_signature.arg_kinds[1:]
2220
arg_names = (
@@ -27,10 +25,18 @@ def ufunc_call_hook(ctx):
2725
[ctx.default_signature.arg_types[0]] * nin +
2826
ctx.default_signature.arg_types[1:]
2927
)
28+
ndarray_type, generic_type, _ = ctx.default_signature.ret_type.items
29+
scalar_or_ndarray = UnionType([ndarray_type, generic_type])
30+
if nout == 1:
31+
ret_type = scalar_or_ndarray
32+
else:
33+
ret_type = TupleType([scalar_or_ndarray] * nout)
34+
3035
return ctx.default_signature.copy_modified(
3136
arg_kinds=arg_kinds,
3237
arg_names=arg_names,
3338
arg_types=arg_types,
39+
ret_type=ret_type,
3440
)
3541

3642

scripts/find_ufuncs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ def main():
1010

1111
ufunc_stubs = []
1212
for ufunc in set(ufuncs):
13-
ufunc_stubs.append(f'{ufunc.__name__}: ufunc')
13+
ufunc_stubs.append(
14+
f'{ufunc.__name__}: ufunc[Literal[{ufunc.nin}], Literal[{ufunc.nout}]]'
15+
)
1416
ufunc_stubs.sort()
1517

1618
for stub in ufunc_stubs:

tests/reveal/ufuncs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import numpy as np
2+
3+
reveal_type(np.sin(1)) # E: Union[numpy.ndarray, numpy.generic]
4+
reveal_type(np.sin([1, 2, 3])) # E: Union[numpy.ndarray, numpy.generic]
5+
reveal_type(np.sin.nin) # E: Literal[1]
6+
reveal_type(np.sin.nout) # E: Literal[1]

0 commit comments

Comments
 (0)
0