8000 Make coroutine function return type more specific (#5052) · python/mypy@6519eb6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6519eb6

Browse files
AndreLouisCaronilevkivskyi
authored andcommitted
Make coroutine function return type more specific (#5052)
* Change the return type of coroutine functions to coroutine object Previously, the return type was `Awaitable`, which is correct, but not specific enough for some use cases. For example, if you have a function parameter that should be a coroutine object (but not a `Future`, `Task` or other awaitable), then `mypy` is unable to detect incorrect invocation 10000 s of the function. This change is (deliberately) imcomplete as is. It seems like this breaks quite a few tests in `mypy`. The first symptom is that coroutine objects are now (incorrectly) detected as generators.. * Prevent coroutine functions from being classified as generators This change removes the undesired "return type of a generator function should be `Generator` or one of its subtypes for coroutine function definitions. However, it introduces a new error where the return type of coroutine functions is expected to be `Coroutine[T, Any, Any]` instead of the desired `T`. It looks like we were hijacking a generator-specific path for checking the return type of coroutine functinos. I added an explicit path for coroutine functions, allowing our test to pass. However, lots of tests now fail. A few of them were simply places that were incidentally relying on coroutine functions to have type `Awaitable`. I fix them. The remaining failures all seem to be about coroutine functions with return type `None` without an explicit return statement. Seems like this is also something for which we were relying on implicit classification as generators. * Allow implicit return for coroutine functions that return `None` Most of the tests are fixed, but two tests still fail. One about not detecting invalid `yield from` on `AwaitableGenerator`. The other about types being erased in call to `asyncio.gather()`. * Fix return type for coroutine functions decorated with @coroutine * Fix detection of await expression on direct coroutine function call Changing the return type of coroutine functions to `Coroutine` introduced a regression in the expression checks. * Fix regression after change of coroutine function return type * Fix position of return type in `Coroutine` This fixes the type inference logic that was causing the last failing test to fail. Build should now be green :-) * Fix issues raised in code review Fixes #3569. Fixes #4460. Special thanks to @ilevkivskyi and @gvanrossum for their time and their infinite patience with all my questions :-)
1 parent e42d600 commit 6519eb6

File tree

7 files changed

+38
-13
lines changed

7 files changed

+38
-13
lines changed

mypy/checker.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,13 @@ def get_generator_receive_type(self, return_type: Type, is_coroutine: bool) -> T
578578
# values. IOW, tc is None.
579579
return NoneTyp()
580580

581+
def get_coroutine_return_type(self, return_type: Type) -> Type:
582+
if isinstance(return_type, AnyType):
583+
return AnyType(TypeOfAny.from_another_any, source_any=return_type)
584+
assert isinstance(return_type, Instance), "Should only be called on coroutine functions."
585+
# Note: return type is the 3rd type parameter of Coroutine.
586+
return return_type.args[2]
587+
581588
def get_generator_return_type(self, return_type: Type, is_coroutine: bool) -> Type:
582589
"""Given the declared return type of a generator (t), return the type it returns (tr)."""
583590
if isinstance(return_type, AnyType):
@@ -756,7 +763,10 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])
756763
c = defn.is_coroutine
757764
ty = self.get_generator_yield_type(t, c)
758765
tc = self.get_generator_receive_type(t, c)
759-
tr = self.get_generator_return_type(t, c)
766+
if c:
767+
tr = self.get_coroutine_return_type(t)
768+
else:
769+
tr = self.get_generator_return_type(t, c)
760770
ret_type = self.named_generic_type('typing.AwaitableGenerator',
761771
[ty, tc, tr, t])
762772
typ = typ.copy_modified(ret_type=ret_type)
@@ -841,6 +851,8 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])
841851
is_named_instance(self.return_types[-1], 'typing.AwaitableGenerator')):
842852
return_type = self.get_generator_return_type(self.return_types[-1],
843853
defn.is_coroutine)
854+
elif defn.is_coroutine:
855+
return_type = self.get_coroutine_return_type(self.return_types[-1])
844856
else:
845857
return_type = self.return_types[-1]
846858

@@ -878,7 +890,7 @@ def is_unannotated_any(t: Type) -> bool:
878890
if is_unannotated_any(ret_type):
879891
self.fail(messages.RETURN_TYPE_EXPECTED, fdef)
880892
elif (fdef.is_coroutine and isinstance(ret_type, Instance) and
881-
is_unannotated_any(ret_type.args[0])):
893+
is_unannotated_any(self.get_coroutine_return_type(ret_type))):
882894
self.fail(messages.RETURN_TYPE_EXPECTED, fdef)
883895
if any(is_unannotated_any(t) for t in fdef.type.arg_types):
884896
self.fail(messages.ARGUMENT_TYPE_EXPECTED, fdef)
@@ -2211,6 +2223,8 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
22112223
if defn.is_generator:
22122224
return_type = self.get_generator_return_type(self.return_types[-1],
22132225
defn.is_coroutine)
2226+
elif defn.is_coroutine:
2227+
return_type = self.get_coroutine_return_type(self.return_types[-1])
22142228
else:
22152229
return_type = self.return_types[-1]
22162230

mypy/checkexpr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2688,7 +2688,7 @@ def is_async_def(t: Type) -> bool:
26882688
and t.type.fullname() == 'typing.AwaitableGenerator'
26892689
and len(t.args) >= 4):
26902690
t = t.args[3]
2691-
return isinstance(t, Instance) and t.type.fullname() == 'typing.Awaitable'
2691+
return isinstance(t, Instance) and t.type.fullname() == 'typing.Coroutine'
26922692

26932693

26942694
def map_actuals_to_formals(caller_kinds: List[int],

mypy/fastparse.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,8 +398,7 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef],
398398
self.as_required_block(n.body, n.lineno),
399399
func_type)
400400
if is_coroutine:
401-
# A coroutine is also a generator, mostly for internal reasons.
402-
func_def.is_generator = func_def.is_coroutine = True
401+
func_def.is_coroutine = True
403402
if func_type is not None:
404403
func_type.definition = func_def
405404
func_type.line = n.lineno

mypy/semanal.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,9 +431,11 @@ def _visit_func_def(self, defn: FuncDef) -> None:
431431
pass
432432
else:
433433
# A coroutine defined as `async def foo(...) -> T: ...`
434-
# has external return type `Awaitable[T]`.
435-
ret_type = self.named_type_or_none('typing.Awaitable', [defn.type.ret_type])
436-
assert ret_type is not None, "Internal error: typing.Awaitable not found"
434+
# has external return type `Coroutine[Any, Any, T]`.
435+
any_type = AnyType(TypeOfAny.special_form)
436+
ret_type = self.named_type_or_none('typing.Coroutine',
437+
[any_type, any_type, defn.type.ret_type])
438+
assert ret_type is not None, "Internal error: typing.Coroutine not found"
437439
defn.type = defn.type.copy_modified(ret_type=ret_type)
438440

439441
def prepare_method_signature(self, func: FuncDef, info: TypeInfo) -> None:

test-data/unit/check-async-await.test

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ async def f() -> int:
1212

1313
async def f() -> int:
1414
return 0
15-
reveal_type(f()) # E: Revealed type is 'typing.Awaitable[builtins.int]'
15+
reveal_type(f()) # E: Revealed type is 'typing.Coroutine[Any, Any, builtins.int]'
1616
[builtins fixtures/async_await.pyi]
1717
[typing fixtures/typing-full.pyi]
1818

@@ -378,7 +378,7 @@ def g() -> Generator[Any, None, str]:
378378
[builtins fixtures/async_await.pyi]
379379
[typing fixtures/typing-full.pyi]
380380
[out]
381-
main:6: error: "yield from" can't be applied to "Awaitable[str]"
381+
main:6: error: "yield from" can't be applied to "Coroutine[Any, Any, str]"
382382

383383
[case testAwaitableSubclass]
384384

@@ -630,9 +630,9 @@ def plain_host_generator() -> Generator[str, None, None]:
630630
yield 'a'
631631
x = 0
632632
x = yield from plain_generator()
633-
x = yield from plain_coroutine() # E: "yield from" can't be applied to "Awaitable[int]"
633+
x = yield from plain_coroutine() # E: "yield from" can't be applied to "Coroutine[Any, Any, int]"
634634
x = yield from decorated_generator()
635-
x = yield from decorated_coroutine() # E: "yield from" can't be applied to "AwaitableGenerator[Any, Any, int, Awaitable[int]]"
635+
x = yield from decorated_coroutine() # E: "yield from" can't be applied to "AwaitableGenerator[Any, Any, int, Coroutine[Any, Any, int]]"
636636
x = yield from other_iterator()
637637
x = yield from other_coroutine() # E: "yield from" can't be applied to "Aw"
638638

test-data/unit/check-class-namedtuple.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ class XRepr(NamedTuple):
503503
return 0
504504

505505
reveal_type(XMeth(1).double()) # E: Revealed type is 'builtins.int'
506-
reveal_type(XMeth(1).asyncdouble()) # E: Revealed type is 'typing.Awaitable[builtins.int]'
506+
reveal_type(XMeth(1).asyncdouble()) # E: Revealed type is 'typing.Coroutine[Any, Any, builtins.int]'
507507
reveal_type(XMeth(42).x) # E: Revealed type is 'builtins.int'
508508
reveal_type(XRepr(42).__str__()) # E: Revealed type is 'builtins.str'
509509
reveal_type(XRepr(1, 2).__add__(XRepr(3))) # E: Revealed type is 'builtins.int'

test-data/unit/fixtures/typing-full.pyi

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,16 @@ class Awaitable(Protocol[T]):
101101
class AwaitableGenerator(Generator[T, U, V], Awaitable[V], Generic[T, U, V, S]):
102102
pass
103103

104+
class Coroutine(Awaitable[V], Generic[T, U, V]):
105+
@abstractmethod
106+
def send(self, value: U) -> T: pass
107+
108+
@abstractmethod
109+
def throw(self, typ: Any, val: Any=None, tb: Any=None) -> None: pass
110+
111+
@abstractmethod
112+
def close(self) -> None: pass
113+
104114
@runtime
105115
class AsyncIterable(Protocol[T]):
106116
@abstractmethod

0 commit comments

Comments
 (0)
0