diff --git a/mypy/checker.py b/mypy/checker.py index 307afe8568d5..efadddb2a491 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3,6 +3,7 @@ import itertools import fnmatch from contextlib import contextmanager +import sys from typing import ( Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator @@ -311,12 +312,19 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: # for functions decorated with `@types.coroutine` or # `@asyncio.coroutine`. Its single parameter corresponds to tr. # + # PEP 525 adds a new type, the asynchronous generator, which was + # first released in Python 3.6. Async generators are `async def` + # functions that can also `yield` values. They can be parameterized + # with two types, ty and tc, because they cannot return a value. + # # There are several useful methods, each taking a type t and a # flag c indicating whether it's for a generator or coroutine: # # - is_generator_return_type(t, c) returns whether t is a Generator, # Iterator, Iterable (if not c), or Awaitable (if c), or # AwaitableGenerator (regardless of c). + # - is_async_generator_return_type(t) returns whether t is an + # AsyncGenerator. # - get_generator_yield_type(t, c) returns ty. # - get_generator_receive_type(t, c) returns tc. # - get_generator_return_type(t, c) returns tr. @@ -338,11 +346,24 @@ def is_generator_return_type(self, typ: Type, is_coroutine: bool) -> bool: return True return isinstance(typ, Instance) and typ.type.fullname() == 'typing.AwaitableGenerator' + def is_async_generator_return_type(self, typ: Type) -> bool: + """Is `typ` a valid type for an async generator? + + True if `typ` is a supertype of AsyncGenerator. + """ + try: + agt = self.named_generic_type('typing.AsyncGenerator', [AnyType(), AnyType()]) + except KeyError: + # we're running on a version of typing that doesn't have AsyncGenerator yet + return False + return is_subtype(agt, typ) + def get_generator_yield_type(self, return_type: Type, is_coroutine: bool) -> Type: """Given the declared return type of a generator (t), return the type it yields (ty).""" if isinstance(return_type, AnyType): return AnyType() - elif not self.is_generator_return_type(return_type, is_coroutine): + elif (not self.is_generator_return_type(return_type, is_coroutine) + and not self.is_async_generator_return_type(return_type)): # If the function doesn't have a proper Generator (or # Awaitable) return type, anything is permissible. return AnyType() @@ -353,7 +374,7 @@ def get_generator_yield_type(self, return_type: Type, is_coroutine: bool) -> Typ # Awaitable: ty is Any. return AnyType() elif return_type.args: - # AwaitableGenerator, Generator, Iterator, or Iterable; ty is args[0]. + # AwaitableGenerator, Generator, AsyncGenerator, Iterator, or Iterable; ty is args[0]. ret_type = return_type.args[0] # TODO not best fix, better have dedicated yield token if isinstance(ret_type, NoneTyp): @@ -373,7 +394,8 @@ def get_generator_receive_type(self, return_type: Type, is_coroutine: bool) -> T """Given a declared generator return type (t), return the type its yield receives (tc).""" if isinstance(return_type, AnyType): return AnyType() - elif not self.is_generator_return_type(return_type, is_coroutine): + elif (not self.is_generator_return_type(return_type, is_coroutine) + and not self.is_async_generator_return_type(return_type)): # If the function doesn't have a proper Generator (or # Awaitable) return type, anything is permissible. return AnyType() @@ -387,6 +409,8 @@ def get_generator_receive_type(self, return_type: Type, is_coroutine: bool) -> T and len(return_type.args) >= 3): # Generator: tc is args[1]. return return_type.args[1] + elif return_type.type.fullname() == 'typing.AsyncGenerator' and len(return_type.args) >= 2: + return return_type.args[1] else: # `return_type` is a supertype of Generator, so callers won't be able to send it # values. IOW, tc is None. @@ -537,8 +561,12 @@ def is_implicit_any(t: Type) -> bool: # Check that Generator functions have the appropriate return type. if defn.is_generator: - if not self.is_generator_return_type(typ.ret_type, defn.is_coroutine): - self.fail(messages.INVALID_RETURN_TYPE_FOR_GENERATOR, typ) + if defn.is_async_generator: + if not self.is_async_generator_return_type(typ.ret_type): + self.fail(messages.INVALID_RETURN_TYPE_FOR_ASYNC_GENERATOR, typ) + else: + if not self.is_generator_return_type(typ.ret_type, defn.is_coroutine): + self.fail(messages.INVALID_RETURN_TYPE_FOR_GENERATOR, typ) # Python 2 generators aren't allowed to return values. if (self.options.python_version[0] == 2 and @@ -1743,6 +1771,11 @@ def check_return_stmt(self, s: ReturnStmt) -> None: if s.expr: # Return with a value. typ = self.expr_checker.accept(s.expr, return_type) + + if defn.is_async_generator: + self.fail("'return' with value in async generator is not allowed", s) + return + # Returning a value of type Any is always fine. if isinstance(typ, AnyType): # (Unless you asked to be warned in that case, and the diff --git a/mypy/messages.py b/mypy/messages.py index 782df03f5bd6..88edba48a38c 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -34,6 +34,9 @@ INVALID_EXCEPTION_TYPE = 'Exception type must be derived from BaseException' INVALID_RETURN_TYPE_FOR_GENERATOR = \ 'The return type of a generator function should be "Generator" or one of its supertypes' +INVALID_RETURN_TYPE_FOR_ASYNC_GENERATOR = \ + 'The return type of an async generator function should be "AsyncGenerator" or one of its ' \ + 'supertypes' INVALID_GENERATOR_RETURN_ITEM_TYPE = \ 'The return type of a generator function must be None in its third type parameter in Python 2' YIELD_VALUE_EXPECTED = 'Yield value expected' diff --git a/mypy/nodes.py b/mypy/nodes.py index 1346581ef8d3..f9599c13cf4b 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -481,6 +481,7 @@ class FuncItem(FuncBase): is_overload = False is_generator = False # Contains a yield statement? is_coroutine = False # Defined using 'async def' syntax? + is_async_generator = False # Is an async def generator? is_awaitable_coroutine = False # Decorated with '@{typing,asyncio}.coroutine'? is_static = False # Uses @staticmethod? is_class = False # Uses @classmethod? @@ -488,8 +489,8 @@ class FuncItem(FuncBase): expanded = None # type: List[FuncItem] FLAGS = [ - 'is_overload', 'is_generator', 'is_coroutine', 'is_awaitable_coroutine', - 'is_static', 'is_class', + 'is_overload', 'is_generator', 'is_coroutine', 'is_async_generator', + 'is_awaitable_coroutine', 'is_static', 'is_class', ] def __init__(self, arguments: List[Argument], body: 'Block', diff --git a/mypy/semanal.py b/mypy/semanal.py index 0236bddf5270..e11b89574aaa 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -324,11 +324,15 @@ def visit_func_def(self, defn: FuncDef) -> None: self.errors.push_function(defn.name()) self.analyze_function(defn) if defn.is_coroutine and isinstance(defn.type, CallableType): - # A coroutine defined as `async def foo(...) -> T: ...` - # has external return type `Awaitable[T]`. - defn.type = defn.type.copy_modified( - ret_type = self.named_type_or_none('typing.Awaitable', - [defn.type.ret_type])) + if defn.is_async_generator: + # Async generator types are handled elsewhere + pass + else: + # A coroutine defined as `async def foo(...) -> T: ...` + # has external return type `Awaitable[T]`. + defn.type = defn.type.copy_modified( + ret_type = self.named_type_or_none('typing.Awaitable', + [defn.type.ret_type])) self.errors.pop_function() def prepare_method_signature(self, func: FuncDef) -> None: @@ -2842,7 +2846,11 @@ def visit_yield_expr(self, expr: YieldExpr) -> None: self.fail("'yield' outside function", expr, True, blocker=True) else: if self.function_stack[-1].is_coroutine: - self.fail("'yield' in async function", expr, True, blocker=True) + if self.options.python_version < (3, 6): + self.fail("'yield' in async function", expr, True, blocker=True) + else: + self.function_stack[-1].is_generator = True + self.function_stack[-1].is_async_generator = True else: self.function_stack[-1].is_generator = True if expr.expr: diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test index ac2ada81d7d7..ebae58ee6b1d 100644 --- a/test-data/unit/check-async-await.test +++ b/test-data/unit/check-async-await.test @@ -326,18 +326,15 @@ async def f() -> None: [builtins fixtures/async_await.pyi] [case testNoYieldInAsyncDef] +# flags: --python-version 3.5 async def f(): - yield None + yield None # E: 'yield' in async function async def g(): - yield + yield # E: 'yield' in async function async def h(): - x = yield + x = yield # E: 'yield' in async function [builtins fixtures/async_await.pyi] -[out] -main:3: error: 'yield' in async function -main:5: error: 'yield' in async function -main:7: error: 'yield' in async function [case testNoYieldFromInAsyncDef] @@ -410,6 +407,156 @@ def f() -> Generator[int, str, int]: [builtins fixtures/async_await.pyi] [out] +-- Async generators (PEP 525), some test cases adapted from the PEP text +-- --------------------------------------------------------------------- + +[case testAsyncGenerator] +# flags: --python-version 3.6 +from typing import AsyncGenerator, Generator + +async def f() -> int: + return 42 + +async def g() -> AsyncGenerator[int, None]: + value = await f() + reveal_type(value) # E: Revealed type is 'builtins.int*' + yield value + + yield 'not an int' # E: Incompatible types in yield (actual type "str", expected type "int") + # return without a value is fine + return +reveal_type(g) # E: Revealed type is 'def () -> typing.AsyncGenerator[builtins.int, void]' +reveal_type(g()) # E: Revealed type is 'typing.AsyncGenerator[builtins.int, void]' + +async def h() -> None: + async for item in g(): + reveal_type(item) # E: Revealed type is 'builtins.int*' + +async def wrong_return() -> Generator[int, None, None]: # E: The return type of an async generator function should be "AsyncGenerator" or one of its supertypes + yield 3 + +[builtins fixtures/dict.pyi] + +[case testAsyncGeneratorReturnIterator] +# flags: --python-version 3.6 +from typing import AsyncIterator + +async def gen() -> AsyncIterator[int]: + yield 3 + + yield 'not an int' # E: Incompatible types in yield (actual type "str", expected type "int") + +async def use_gen() -> None: + async for item in gen(): + reveal_type(item) # E: Revealed type is 'builtins.int*' + +[builtins fixtures/dict.pyi] + +[case testAsyncGeneratorManualIter] +# flags: --python-version 3.6 +from typing import AsyncGenerator + +async def genfunc() -> AsyncGenerator[int, None]: + yield 1 + yield 2 + +async def user() -> None: + gen = genfunc() + + reveal_type(gen.__aiter__()) # E: Revealed type is 'typing.AsyncGenerator[builtins.int*, void]' + + reveal_type(await gen.__anext__()) # E: Revealed type is 'builtins.int*' + +[builtins fixtures/dict.pyi] + +[case testAsyncGeneratorAsend] +# flags: --fast-parser --python-version 3.6 +from typing import AsyncGenerator + +async def f() -> None: + pass + +async def gen() -> AsyncGenerator[int, str]: + await f() + v = yield 42 + reveal_type(v) # E: Revealed type is 'builtins.str' + await f() + +async def h() -> None: + g = gen() + await g.asend(()) # E: Argument 1 to "asend" of "AsyncGenerator" has incompatible type "Tuple[]"; expected "str" + reveal_type(await g.asend('hello')) # E: Revealed type is 'builtins.int*' + +[builtins fixtures/dict.pyi] + +[case testAsyncGeneratorAthrow] +# flags: --fast-parser --python-version 3.6 +from typing import AsyncGenerator + +async def gen() -> AsyncGenerator[str, int]: + try: + yield 'hello' + except BaseException: + yield 'world' + +async def h() -> None: + g = gen() + v = await g.asend(1) + reveal_type(v) # E: Revealed type is 'builtins.str*' + reveal_type(await g.athrow(BaseException)) # E: Revealed type is 'builtins.str*' + +[builtins fixtures/dict.pyi] + +[case testAsyncGeneratorNoSyncIteration] +# flags: --fast-parser --python-version 3.6 +from typing import AsyncGenerator + +async def gen() -> AsyncGenerator[int, None]: + for i in (1, 2, 3): + yield i + +def h() -> None: + for i in gen(): + pass + +[builtins fixtures/dict.pyi] + +[out] +main:9: error: Iterable expected +main:9: error: AsyncGenerator[int, None] has no attribute "__iter__"; maybe "__aiter__"? + +[case testAsyncGeneratorNoYieldFrom] +# flags: --fast-parser --python-version 3.6 +from typing import AsyncGenerator + +async def f() -> AsyncGenerator[int, None]: + pass + +async def gen() -> AsyncGenerator[int, None]: + yield from f() # E: 'yield from' in async function + +[builtins fixtures/dict.pyi] + +[case testAsyncGeneratorNoReturnWithValue] +# flags: --fast-parser --python-version 3.6 +from typing import AsyncGenerator + +async def return_int() -> AsyncGenerator[int, None]: + yield 1 + return 42 # E: 'return' with value in async generator is not allowed + +async def return_none() -> AsyncGenerator[int, None]: + yield 1 + return None # E: 'return' with value in async generator is not allowed + +def f() -> None: + return + +async def return_f() -> AsyncGenerator[int, None]: + yield 1 + return f() # E: 'return' with value in async generator is not allowed + +[builtins fixtures/dict.pyi] -- The full matrix of coroutine compatibility -- ------------------------------------------ diff --git a/test-data/unit/fixtures/dict.pyi b/test-data/unit/fixtures/dict.pyi index d36145fa0654..dc89366c1133 100644 --- a/test-data/unit/fixtures/dict.pyi +++ b/test-data/unit/fixtures/dict.pyi @@ -38,3 +38,5 @@ class tuple: pass class function: pass class float: pass class bool: pass + +class BaseException: pass diff --git a/test-data/unit/lib-stub/typing.pyi b/test-data/unit/lib-stub/typing.pyi index 0377f19a1e5d..8d252dbddf2e 100644 --- a/test-data/unit/lib-stub/typing.pyi +++ b/test-data/unit/lib-stub/typing.pyi @@ -59,6 +59,22 @@ class Generator(Iterator[T], Generic[T, U, V]): @abstractmethod def __iter__(self) -> 'Generator[T, U, V]': pass +class AsyncGenerator(AsyncIterator[T], Generic[T, U]): + @abstractmethod + def __anext__(self) -> Awaitable[T]: pass + + @abstractmethod + def asend(self, value: U) -> Awaitable[T]: pass + + @abstractmethod + def athrow(self, typ: Any, val: Any=None, tb: Any=None) -> Awaitable[T]: pass + + @abstractmethod + def aclose(self) -> Awaitable[T]: pass + + @abstractmethod + def __aiter__(self) -> 'AsyncGenerator[T, U]': pass + class Awaitable(Generic[T]): @abstractmethod def __await__(self) -> Generator[Any, Any, T]: pass