8000 Add AsyncContextManager generic class (#438) · python/typing@74bc3ee · GitHub
[go: up one dir, main page]

Skip to content

Commit 74bc3ee

Browse files
ilevkivskyigvanrossum
authored andcommitted
Add AsyncContextManager generic class (#438)
Asynchronous context managers are defined by PEP 492, but there is no corresponding generic abstract class in typing. This PR adds it for Python 3.5+.
1 parent 47e0860 commit 74bc3ee

File tree

2 files changed

+71
-1
lines changed

2 files changed

+71
-1
lines changed

src/test_typing.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1552,6 +1552,12 @@ def __anext__(self) -> T_a:
15521552
return data
15531553
else:
15541554
raise StopAsyncIteration
1555+
1556+
class ACM:
1557+
async def __aenter__(self) -> int:
1558+
return 42
1559+
async def __aexit__(self, etype, eval, tb):
1560+
return None
15551561
"""
15561562

15571563
if ASYNCIO:
@@ -1562,12 +1568,13 @@ def __anext__(self) -> T_a:
15621568
else:
15631569
# fake names for the sake of static analysis
15641570
asyncio = None
1565-
AwaitableWrapper = AsyncIteratorWrapper = object
1571+
AwaitableWrapper = AsyncIteratorWrapper = ACM = object
15661572

15671573
PY36 = sys.version_info[:2] >= (3, 6)
15681574

15691575
PY36_TESTS = """
15701576
from test import ann_module, ann_module2, ann_module3
1577+
from typing import AsyncContextManager
15711578
15721579
class A:
15731580
y: float
@@ -1604,6 +1611,16 @@ def __str__(self):
16041611
return f'{self.x} -> {self.y}'
16051612
def __add__(self, other):
16061613
return 0
1614+
1615+
async def g_with(am: AsyncContextManager[int]):
1616+
x: int
1617+
async with am as x:
1618+
return x
1619+
1620+
try:
1621+
g_with(ACM()).send(None)
1622+
except StopIteration as e:
1623+
assert e.args[0] == 42
16071624
"""
16081625

16091626
if PY36:
@@ -2165,6 +2182,24 @@ def manager():
21652182
self.assertIsInstance(cm, typing.ContextManager)
21662183
self.assertNotIsInstance(42, typing.ContextManager)
21672184

2185+
@skipUnless(ASYNCIO, 'Python 3.5 required')
2186+
def test_async_contextmanager(self):
2187+
class NotACM:
2188+
pass
2189+
self.assertIsInstance(ACM(), typing.AsyncContextManager)
2190+
self.assertNotIsInstance(NotACM(), typing.AsyncContextManager)
2191+
@contextlib.contextmanager
2192+
def manager():
2193+
yield 42
2194+
2195+
cm = manager()
2196+
self.assertNotIsInstance(cm, typing.AsyncContextManager)
2197+
self.assertEqual(typing.AsyncContextManager[int].__args__, (int,))
2198+
with self.assertRaises(TypeError):
2199+
isinstance(42, typing.AsyncContextManager[int])
2200+
with self.assertRaises(TypeError):
2201+
typing.AsyncContextManager[int, str]
2202+
21682203

21692204
class TypeTests(BaseTestCase):
21702205

src/typing.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import collections.abc as collections_abc
1111
except ImportError:
1212
import collections as collections_abc # Fallback for PY3.2.
13+
if sys.version_info[:2] >= (3, 6):
14+
import _collections_abc # Needed for private function _check_methods # noqa
1315
try:
1416
from types import WrapperDescriptorType, MethodWrapperType, MethodDescriptorType
1517
except ImportError:
@@ -59,6 +61,7 @@
5961
# Coroutine,
6062
# Collection,
6163
# AsyncGenerator,
64+
# AsyncContextManager
6265

6366
< 8000 span class=pl-c># Structural checks, a.k.a. protocols.
6467
'Reversible',
@@ -1974,6 +1977,38 @@ def __subclasshook__(cls, C):
19741977
return NotImplemented
19751978

19761979

1980+
if hasattr(contextlib, 'AbstractAsyncContextManager'):
1981+
class AsyncContextManager(Generic[T_co],
1982+
extra=contextlib.AbstractAsyncContextManager):
1983+
__slots__ = ()
1984+
1985+
__all__.append('AsyncContextManager')
1986+
elif sys.version_info[:2] >= (3, 5):
1987+
exec("""
1988+
class AsyncContextManager(Generic[T_co]):
1989+
__slots__ = ()
1990+
1991+
async def __aenter__(self):
1992+
return self
1993+
1994+
@abc.abstractmethod
1995+
async def __aexit__(self, exc_type, exc_value, traceback):
1996+
return None
1997+
1998+
@classmethod
1999+
def __subclasshook__(cls, C):
2000+
if cls is AsyncContextManager:
2001+
if sys.version_info[:2] >= (3, 6):
2002+
return _collections_abc._check_methods(C, "__aenter__", "__aexit__")
2003+
if (any("__aenter__" in B.__dict__ for B in C.__mro__) and
2004+
any("__aexit__" in B.__dict__ for B in C.__mro__)):
2005+
return True
2006+
return NotImplemented
2007+
2008+
__all__.append('AsyncContextManager')
2009+
""")
2010+
2011+
19772012
class Dict(dict, MutableMapping[KT, VT], extra=dict):
19782013

19792014
__slots__ = ()

0 commit comments

Comments
 (0)
0