diff --git a/Lib/asyncio/__init__.py b/Lib/asyncio/__init__.py index ff69378ba9..03165a425e 100644 --- a/Lib/asyncio/__init__.py +++ b/Lib/asyncio/__init__.py @@ -1,22 +1,14 @@ """The asyncio package, tracking PEP 3156.""" # flake8: noqa -import sys - -import selectors -# XXX RustPython TODO: _overlapped -if sys.platform == 'win32' and False: - # Similar thing for _overlapped. - try: - from . import _overlapped - except ImportError: - import _overlapped # Will also be exported. +import sys # This relies on each of the submodules having an __all__ variable. from .base_events import * from .coroutines import * from .events import * +from .exceptions import * from .futures import * from .locks import * from .protocols import * @@ -25,11 +17,15 @@ from .streams import * from .subprocess import * from .tasks import * +from .taskgroups import * +from .timeouts import * +from .threads import * from .transports import * __all__ = (base_events.__all__ + coroutines.__all__ + events.__all__ + + exceptions.__all__ + futures.__all__ + locks.__all__ + protocols.__all__ + @@ -38,6 +34,9 @@ streams.__all__ + subprocess.__all__ + tasks.__all__ + + taskgroups.__all__ + + threads.__all__ + + timeouts.__all__ + transports.__all__) if sys.platform == 'win32': # pragma: no cover diff --git a/Lib/asyncio/__main__.py b/Lib/asyncio/__main__.py new file mode 100644 index 0000000000..18bb87a5bc --- /dev/null +++ b/Lib/asyncio/__main__.py @@ -0,0 +1,125 @@ +import ast +import asyncio +import code +import concurrent.futures +import inspect +import sys +import threading +import types +import warnings + +from . import futures + + +class AsyncIOInteractiveConsole(code.InteractiveConsole): + + def __init__(self, locals, loop): + super().__init__(locals) + self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT + + self.loop = loop + + def runcode(self, code): + future = concurrent.futures.Future() + + def callback(): + global repl_future + global repl_future_interrupted + + repl_future = None + repl_future_interrupted = False + + func = types.FunctionType(code, self.locals) + try: + coro = func() + except SystemExit: + raise + except KeyboardInterrupt as ex: + repl_future_interrupted = True + future.set_exception(ex) + return + except BaseException as ex: + future.set_exception(ex) + return + + if not inspect.iscoroutine(coro): + future.set_result(coro) + return + + try: + repl_future = self.loop.create_task(coro) + futures._chain_future(repl_future, future) + except BaseException as exc: + future.set_exception(exc) + + loop.call_soon_threadsafe(callback) + + try: + return future.result() + except SystemExit: + raise + except BaseException: + if repl_future_interrupted: + self.write("\nKeyboardInterrupt\n") + else: + self.showtraceback() + + +class REPLThread(threading.Thread): + + def run(self): + try: + banner = ( + f'asyncio REPL {sys.version} on {sys.platform}\n' + f'Use "await" directly instead of "asyncio.run()".\n' + f'Type "help", "copyright", "credits" or "license" ' + f'for more information.\n' + f'{getattr(sys, "ps1", ">>> ")}import asyncio' + ) + + console.interact( + banner=banner, + exitmsg='exiting asyncio REPL...') + finally: + warnings.filterwarnings( + 'ignore', + message=r'^coroutine .* was never awaited$', + category=RuntimeWarning) + + loop.call_soon_threadsafe(loop.stop) + + +if __name__ == '__main__': + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + repl_locals = {'asyncio': asyncio} + for key in {'__name__', '__package__', + '__loader__', '__spec__', + '__builtins__', '__file__'}: + repl_locals[key] = locals()[key] + + console = AsyncIOInteractiveConsole(repl_locals, loop) + + repl_future = None + repl_future_interrupted = False + + try: + import readline # NoQA + except ImportError: + pass + + repl_thread = REPLThread() + repl_thread.daemon = True + repl_thread.start() + + while True: + try: + loop.run_forever() + except KeyboardInterrupt: + if repl_future and not repl_future.done(): + repl_future.cancel() + repl_future_interrupted = True + continue + else: + break diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 2df379933c..29eff0499c 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -14,13 +14,15 @@ """ import collections +import collections.abc import concurrent.futures +import errno +import functools import heapq -import inspect import itertools -import logging import os import socket +import stat import subprocess import threading import time @@ -29,16 +31,27 @@ import warnings import weakref -from . import compat +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import constants from . import coroutines from . import events +from . import exceptions from . import futures +from . import protocols +from . import sslproto +from . import staggered from . import tasks -from .coroutines import coroutine +from . import timeouts +from . import transports +from . import trsock from .log import logger -__all__ = ['BaseEventLoop'] +__all__ = 'BaseEventLoop','Server', # Minimum number of _scheduled timer handles before cleanup of @@ -49,10 +62,11 @@ # before cleanup of cancelled handles is performed. _MIN_CANCELLED_TIMER_HANDLES_FRACTION = 0.5 -# Exceptions which must not call the exception handler in fatal error -# methods (_fatal_error()) -_FATAL_ERROR_IGNORE = (BrokenPipeError, - ConnectionResetError, ConnectionAbortedError) + +_HAS_IPv6 = hasattr(socket, 'AF_INET6') + +# Maximum timeout passed to select to avoid OS limitations +MAXIMUM_SELECT_TIMEOUT = 24 * 3600 def _format_handle(handle): @@ -84,21 +98,7 @@ def _set_reuseport(sock): 'SO_REUSEPORT defined but not implemented.') -def _is_stream_socket(sock): - # Linux's socket.type is a bitmask that can include extra info - # about socket, therefore we can't do simple - # `sock_type == socket.SOCK_STREAM`. - return (sock.type & socket.SOCK_STREAM) == socket.SOCK_STREAM - - -def _is_dgram_socket(sock): - # Linux's socket.type is a bitmask that can include extra info - # about socket, therefore we can't do simple - # `sock_type == socket.SOCK_DGRAM`. - return (sock.type & socket.SOCK_DGRAM) == socket.SOCK_DGRAM - - -def _ipaddr_info(host, port, family, type, proto): +def _ipaddr_info(host, port, family, type, proto, flowinfo=0, scopeid=0): # Try to skip getaddrinfo if "host" is already an IP. Users might have # handled name resolution in their own code and pass in resolved IPs. if not hasattr(socket, 'inet_pton'): @@ -109,11 +109,6 @@ def _ipaddr_info(host, port, family, type, proto): return None if type == socket.SOCK_STREAM: - # Linux only: - # getaddrinfo() can raise when socket.type is a bit mask. - # So if socket.type is a bit mask of SOCK_STREAM, and say - # SOCK_NONBLOCK, we simply return None, which will trigger - # a call to getaddrinfo() letting it process this request. proto = socket.IPPROTO_TCP elif type == socket.SOCK_DGRAM: proto = socket.IPPROTO_UDP @@ -135,7 +130,7 @@ def _ipaddr_info(host, port, family, type, proto): if family == socket.AF_UNSPEC: afs = [socket.AF_INET] - if hasattr(socket, 'AF_INET6'): + if _HAS_IPv6: afs.append(socket.AF_INET6) else: afs = [family] @@ -151,7 +146,10 @@ def _ipaddr_info(host, port, family, type, proto): try: socket.inet_pton(af, host) # The host has already been resolved. - return af, type, proto, '', (host, port) + if _HAS_IPv6 and af == socket.AF_INET6: + return af, type, proto, '', (host, port, flowinfo, scopeid) + else: + return af, type, proto, '', (host, port) except OSError: pass @@ -159,75 +157,253 @@ def _ipaddr_info(host, port, family, type, proto): return None -def _ensure_resolved(address, *, family=0, type=socket.SOCK_STREAM, proto=0, - flags=0, loop): - host, port = address[:2] - info = _ipaddr_info(host, port, family, type, proto) - if info is not None: - # "host" is already a resolved IP. - fut = loop.create_future() - fut.set_result([info]) - return fut - else: - return loop.getaddrinfo(host, port, family=family, type=type, - proto=proto, flags=flags) +def _interleave_addrinfos(addrinfos, first_address_family_count=1): + """Interleave list of addrinfo tuples by family.""" + # Group addresses by family + addrinfos_by_family = collections.OrderedDict() + for addr in addrinfos: + family = addr[0] + if family not in addrinfos_by_family: + addrinfos_by_family[family] = [] + addrinfos_by_family[family].append(addr) + addrinfos_lists = list(addrinfos_by_family.values()) + + reordered = [] + if first_address_family_count > 1: + reordered.extend(addrinfos_lists[0][:first_address_family_count - 1]) + del addrinfos_lists[0][:first_address_family_count - 1] + reordered.extend( + a for a in itertools.chain.from_iterable( + itertools.zip_longest(*addrinfos_lists) + ) if a is not None) + return reordered def _run_until_complete_cb(fut): - exc = fut._exception - if (isinstance(exc, BaseException) - and not isinstance(exc, Exception)): - # Issue #22429: run_forever() already finished, no need to - # stop it. - return - fut._loop.stop() + if not fut.cancelled(): + exc = fut.exception() + if isinstance(exc, (SystemExit, KeyboardInterrupt)): + # Issue #22429: run_forever() already finished, no need to + # stop it. + return + futures._get_loop(fut).stop() + + +if hasattr(socket, 'TCP_NODELAY'): + def _set_nodelay(sock): + if (sock.family in {socket.AF_INET, socket.AF_INET6} and + sock.type == socket.SOCK_STREAM and + sock.proto == socket.IPPROTO_TCP): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) +else: + def _set_nodelay(sock): + pass + + +def _check_ssl_socket(sock): + if ssl is not None and isinstance(sock, ssl.SSLSocket): + raise TypeError("Socket cannot be of type SSLSocket") + + +class _SendfileFallbackProtocol(protocols.Protocol): + def __init__(self, transp): + if not isinstance(transp, transports._FlowControlMixin): + raise TypeError("transport should be _FlowControlMixin instance") + self._transport = transp + self._proto = transp.get_protocol() + self._should_resume_reading = transp.is_reading() + self._should_resume_writing = transp._protocol_paused + transp.pause_reading() + transp.set_protocol(self) + if self._should_resume_writing: + self._write_ready_fut = self._transport._loop.create_future() + else: + self._write_ready_fut = None + + async def drain(self): + if self._transport.is_closing(): + raise ConnectionError("Connection closed by peer") + fut = self._write_ready_fut + if fut is None: + return + await fut + + def connection_made(self, transport): + raise RuntimeError("Invalid state: " + "connection should have been established already.") + + def connection_lost(self, exc): + if self._write_ready_fut is not None: + # Never happens if peer disconnects after sending the whole content + # Thus disconnection is always an exception from user perspective + if exc is None: + self._write_ready_fut.set_exception( + ConnectionError("Connection is closed by peer")) + else: + self._write_ready_fut.set_exception(exc) + self._proto.connection_lost(exc) + + def pause_writing(self): + if self._write_ready_fut is not None: + return + self._write_ready_fut = self._transport._loop.create_future() + + def resume_writing(self): + if self._write_ready_fut is None: + return + self._write_ready_fut.set_result(False) + self._write_ready_fut = None + + def data_received(self, data): + raise RuntimeError("Invalid state: reading should be paused") + + def eof_received(self): + raise RuntimeError("Invalid state: reading should be paused") + + async def restore(self): + self._transport.set_protocol(self._proto) + if self._should_resume_reading: + self._transport.resume_reading() + if self._write_ready_fut is not None: + # Cancel the future. + # Basically it has no effect because protocol is switched back, + # no code should wait for it anymore. + self._write_ready_fut.cancel() + if self._should_resume_writing: + self._proto.resume_writing() class Server(events.AbstractServer): - def __init__(self, loop, sockets): + def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog, + ssl_handshake_timeout, ssl_shutdown_timeout=None): self._loop = loop - self.sockets = sockets + self._sockets = sockets self._active_count = 0 self._waiters = [] + self._protocol_factory = protocol_factory + self._backlog = backlog + self._ssl_context = ssl_context + self._ssl_handshake_timeout = ssl_handshake_timeout + self._ssl_shutdown_timeout = ssl_shutdown_timeout + self._serving = False + self._serving_forever_fut = None def __repr__(self): - return '<%s sockets=%r>' % (self.__class__.__name__, self.sockets) + return f'<{self.__class__.__name__} sockets={self.sockets!r}>' def _attach(self): - assert self.sockets is not None + assert self._sockets is not None self._active_count += 1 def _detach(self): assert self._active_count > 0 self._active_count -= 1 - if self._active_count == 0 and self.sockets is None: + if self._active_count == 0 and self._sockets is None: self._wakeup() + def _wakeup(self): + waiters = self._waiters + self._waiters = None + for waiter in waiters: + if not waiter.done(): + waiter.set_result(None) + + def _start_serving(self): + if self._serving: + return + self._serving = True + for sock in self._sockets: + sock.listen(self._backlog) + self._loop._start_serving( + self._protocol_factory, sock, self._ssl_context, + self, self._backlog, self._ssl_handshake_timeout, + self._ssl_shutdown_timeout) + + def get_loop(self): + return self._loop + + def is_serving(self): + return self._serving + + @property + def sockets(self): + if self._sockets is None: + return () + return tuple(trsock.TransportSocket(s) for s in self._sockets) + def close(self): - sockets = self.sockets + sockets = self._sockets if sockets is None: return - self.sockets = None + self._sockets = None + for sock in sockets: self._loop._stop_serving(sock) + + self._serving = False + + if (self._serving_forever_fut is not None and + not self._serving_forever_fut.done()): + self._serving_forever_fut.cancel() + self._serving_forever_fut = None + if self._active_count == 0: self._wakeup() - def _wakeup(self): - waiters = self._waiters - self._waiters = None - for waiter in waiters: - if not waiter.done(): - waiter.set_result(waiter) + async def start_serving(self): + self._start_serving() + # Skip one loop iteration so that all 'loop.add_reader' + # go through. + await tasks.sleep(0) + + async def serve_forever(self): + if self._serving_forever_fut is not None: + raise RuntimeError( + f'server {self!r} is already being awaited on serve_forever()') + if self._sockets is None: + raise RuntimeError(f'server {self!r} is closed') + + self._start_serving() + self._serving_forever_fut = self._loop.create_future() + + try: + await self._serving_forever_fut + except exceptions.CancelledError: + try: + self.close() + await self.wait_closed() + finally: + raise + finally: + self._serving_forever_fut = None - @coroutine - def wait_closed(self): - if self.sockets is None or self._waiters is None: + async def wait_closed(self): + """Wait until server is closed and all connections are dropped. + + - If the server is not closed, wait. + - If it is closed, but there are still active connections, wait. + + Anyone waiting here will be unblocked once both conditions + (server is closed and all connections have been dropped) + have become true, in either order. + + Historical note: In 3.11 and before, this was broken, returning + immediately if the server was already closed, even if there + were still active connections. An attempted fix in 3.12.0 was + still broken, returning immediately if the server was still + open and there were no active connections. Hopefully in 3.12.1 + we have it right. + """ + # Waiters are unblocked by self._wakeup(), which is called + # from two places: self.close() and self._detach(), but only + # when both conditions have become true. To signal that this + # has happened, self._wakeup() sets self._waiters to None. + if self._waiters is None: return waiter = self._loop.create_future() self._waiters.append(waiter) - yield from waiter + await waiter class BaseEventLoop(events.AbstractEventLoop): @@ -243,49 +419,54 @@ def __init__(self): # Identifier of the thread running the event loop, or None if the # event loop is not running self._thread_id = None - self._clock_resolution = 1e-06 #time.get_clock_info('monotonic').resolution + self._clock_resolution = time.get_clock_info('monotonic').resolution self._exception_handler = None - self.set_debug((not sys.flags.ignore_environment - and bool(os.environ.get('PYTHONASYNCIODEBUG')))) + self.set_debug(coroutines._is_debug_mode()) # In debug mode, if the execution of a callback or a step of a task # exceed this duration in seconds, the slow callback/task is logged. self.slow_callback_duration = 0.1 self._current_handle = None self._task_factory = None - self._coroutine_wrapper_set = False - - if hasattr(sys, 'get_asyncgen_hooks'): - # Python >= 3.6 - # A weak set of all asynchronous generators that are - # being iterated by the loop. - self._asyncgens = weakref.WeakSet() - else: - self._asyncgens = None + self._coroutine_origin_tracking_enabled = False + self._coroutine_origin_tracking_saved_depth = None + # A weak set of all asynchronous generators that are + # being iterated by the loop. + self._asyncgens = weakref.WeakSet() # Set to True when `loop.shutdown_asyncgens` is called. self._asyncgens_shutdown_called = False + # Set to True when `loop.shutdown_default_executor` is called. + self._executor_shutdown_called = False def __repr__(self): - return ('<%s running=%s closed=%s debug=%s>' - % (self.__class__.__name__, self.is_running(), - self.is_closed(), self.get_debug())) + return ( + f'<{self.__class__.__name__} running={self.is_running()} ' + f'closed={self.is_closed()} debug={self.get_debug()}>' + ) def create_future(self): """Create a Future object attached to the loop.""" return futures.Future(loop=self) - def create_task(self, coro): + def create_task(self, coro, *, name=None, context=None): """Schedule a coroutine object. Return a task object. """ self._check_closed() if self._task_factory is None: - task = tasks.Task(coro, loop=self) + task = tasks.Task(coro, loop=self, name=name, context=context) if task._source_traceback: del task._source_traceback[-1] else: - task = self._task_factory(self, coro) + if context is None: + # Use legacy API if context is not needed + task = self._task_factory(self, coro) + else: + task = self._task_factory(self, coro, context=context) + + tasks._set_task_name(task, name) + return task def set_task_factory(self, factory): @@ -311,9 +492,13 @@ def _make_socket_transport(self, sock, protocol, waiter=None, *, """Create socket transport.""" raise NotImplementedError - def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, - *, server_side=False, server_hostname=None, - extra=None, server=None): + def _make_ssl_transport( + self, rawsock, protocol, sslcontext, waiter=None, + *, server_side=False, server_hostname=None, + extra=None, server=None, + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, + call_connection_made=True): """Create SSL transport.""" raise NotImplementedError @@ -332,10 +517,9 @@ def _make_write_pipe_transport(self, pipe, protocol, waiter=None, """Create write pipe transport.""" raise NotImplementedError - @coroutine - def _make_subprocess_transport(self, protocol, args, shell, - stdin, stdout, stderr, bufsize, - extra=None, **kwargs): + async def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): """Create subprocess transport.""" raise NotImplementedError @@ -356,29 +540,29 @@ def _check_closed(self): if self._closed: raise RuntimeError('Event loop is closed') + def _check_default_executor(self): + if self._executor_shutdown_called: + raise RuntimeError('Executor shutdown has been called') + def _asyncgen_finalizer_hook(self, agen): self._asyncgens.discard(agen) if not self.is_closed(): - self.create_task(agen.aclose()) - # Wake up the loop if the finalizer was called from - # a different thread. - self._write_to_self() + self.call_soon_threadsafe(self.create_task, agen.aclose()) def _asyncgen_firstiter_hook(self, agen): if self._asyncgens_shutdown_called: warnings.warn( - "asynchronous generator {!r} was scheduled after " - "loop.shutdown_asyncgens() call".format(agen), + f"asynchronous generator {agen!r} was scheduled after " + f"loop.shutdown_asyncgens() call", ResourceWarning, source=self) self._asyncgens.add(agen) - @coroutine - def shutdown_asyncgens(self): + async def shutdown_asyncgens(self): """Shutdown all active asynchronous generators.""" self._asyncgens_shutdown_called = True - if self._asyncgens is None or not len(self._asyncgens): + if not len(self._asyncgens): # If Python version is <3.6 or we don't have any asynchronous # generators alive. return @@ -386,36 +570,72 @@ def shutdown_asyncgens(self): closing_agens = list(self._asyncgens) self._asyncgens.clear() - shutdown_coro = tasks.gather( + results = await tasks.gather( *[ag.aclose() for ag in closing_agens], - return_exceptions=True, - loop=self) + return_exceptions=True) - results = yield from shutdown_coro for result, agen in zip(results, closing_agens): if isinstance(result, Exception): self.call_exception_handler({ - 'message': 'an error occurred during closing of ' - 'asynchronous generator {!r}'.format(agen), + 'message': f'an error occurred during closing of ' + f'asynchronous generator {agen!r}', 'exception': result, 'asyncgen': agen }) - def run_forever(self): - """Run until stop() is called.""" - self._check_closed() + async def shutdown_default_executor(self, timeout=None): + """Schedule the shutdown of the default executor. + + The timeout parameter specifies the amount of time the executor will + be given to finish joining. The default value is None, which means + that the executor will be given an unlimited amount of time. + """ + self._executor_shutdown_called = True + if self._default_executor is None: + return + future = self.create_future() + thread = threading.Thread(target=self._do_shutdown, args=(future,)) + thread.start() + try: + async with timeouts.timeout(timeout): + await future + except TimeoutError: + warnings.warn("The executor did not finishing joining " + f"its threads within {timeout} seconds.", + RuntimeWarning, stacklevel=2) + self._default_executor.shutdown(wait=False) + else: + thread.join() + + def _do_shutdown(self, future): + try: + self._default_executor.shutdown(wait=True) + if not self.is_closed(): + self.call_soon_threadsafe(futures._set_result_unless_cancelled, + future, None) + except Exception as ex: + if not self.is_closed() and not future.cancelled(): + self.call_soon_threadsafe(future.set_exception, ex) + + def _check_running(self): if self.is_running(): raise RuntimeError('This event loop is already running') if events._get_running_loop() is not None: raise RuntimeError( 'Cannot run the event loop while another loop is running') - self._set_coroutine_wrapper(self._debug) - self._thread_id = threading.get_ident() - if self._asyncgens is not None: - old_agen_hooks = sys.get_asyncgen_hooks() + + def run_forever(self): + """Run until stop() is called.""" + self._check_closed() + self._check_running() + self._set_coroutine_origin_tracking(self._debug) + + old_agen_hooks = sys.get_asyncgen_hooks() + try: + self._thread_id = threading.get_ident() sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook, finalizer=self._asyncgen_finalizer_hook) - try: + events._set_running_loop(self) while True: self._run_once() @@ -425,9 +645,8 @@ def run_forever(self): self._stopping = False self._thread_id = None events._set_running_loop(None) - self._set_coroutine_wrapper(False) - if self._asyncgens is not None: - sys.set_asyncgen_hooks(*old_agen_hooks) + self._set_coroutine_origin_tracking(False) + sys.set_asyncgen_hooks(*old_agen_hooks) def run_until_complete(self, future): """Run until the Future is done. @@ -441,6 +660,7 @@ def run_until_complete(self, future): Return the Future's result, or raise its exception. """ self._check_closed() + self._check_running() new_task = not futures.isfuture(future) future = tasks.ensure_future(future, loop=self) @@ -459,7 +679,8 @@ def run_until_complete(self, future): # local task. future.exception() raise - future.remove_done_callback(_run_until_complete_cb) + finally: + future.remove_done_callback(_run_until_complete_cb) if not future.done(): raise RuntimeError('Event loop stopped before Future completed.') @@ -490,6 +711,7 @@ def close(self): self._closed = True self._ready.clear() self._scheduled.clear() + self._executor_shutdown_called = True executor = self._default_executor if executor is not None: self._default_executor = None @@ -499,16 +721,11 @@ def is_closed(self): """Returns True if the event loop was closed.""" return self._closed - # On Python 3.3 and older, objects with a destructor part of a reference - # cycle are never destroyed. It's not more the case on Python 3.4 thanks - # to the PEP 442. - if compat.PY34: - def __del__(self): - if not self.is_closed(): - warnings.warn("unclosed event loop %r" % self, ResourceWarning, - source=self) - if not self.is_running(): - self.close() + def __del__(self, _warn=warnings.warn): + if not self.is_closed(): + _warn(f"unclosed event loop {self!r}", ResourceWarning, source=self) + if not self.is_running(): + self.close() def is_running(self): """Returns True if the event loop is running.""" @@ -523,7 +740,7 @@ def time(self): """ return time.monotonic() - def call_later(self, delay, callback, *args): + def call_later(self, delay, callback, *args, context=None): """Arrange for a callback to be called at a given time. Return a Handle: an opaque object with a cancel() method that @@ -533,34 +750,39 @@ def call_later(self, delay, callback, *args): always relative to the current time. Each callback will be called exactly once. If two callbacks - are scheduled for exactly the same time, it undefined which + are scheduled for exactly the same time, it is undefined which will be called first. Any positional arguments after the callback will be passed to the callback when it is called. """ - timer = self.call_at(self.time() + delay, callback, *args) + if delay is None: + raise TypeError('delay must not be None') + timer = self.call_at(self.time() + delay, callback, *args, + context=context) if timer._source_traceback: del timer._source_traceback[-1] return timer - def call_at(self, when, callback, *args): + def call_at(self, when, callback, *args, context=None): """Like call_later(), but uses an absolute time. Absolute time corresponds to the event loop's time() method. """ + if when is None: + raise TypeError("when cannot be None") self._check_closed() if self._debug: self._check_thread() self._check_callback(callback, 'call_at') - timer = events.TimerHandle(when, callback, args, self) + timer = events.TimerHandle(when, callback, args, self, context) if timer._source_traceback: del timer._source_traceback[-1] heapq.heappush(self._scheduled, timer) timer._scheduled = True return timer - def call_soon(self, callback, *args): + def call_soon(self, callback, *args, context=None): """Arrange for a callback to be called as soon as possible. This operates as a FIFO queue: callbacks are called in the @@ -574,7 +796,7 @@ def call_soon(self, callback, *args): if self._debug: self._check_thread() self._check_callback(callback, 'call_soon') - handle = self._call_soon(callback, args) + handle = self._call_soon(callback, args, context) if handle._source_traceback: del handle._source_traceback[-1] return handle @@ -583,15 +805,14 @@ def _check_callback(self, callback, method): if (coroutines.iscoroutine(callback) or coroutines.iscoroutinefunction(callback)): raise TypeError( - "coroutines cannot be used with {}()".format(method)) + f"coroutines cannot be used with {method}()") if not callable(callback): raise TypeError( - 'a callable object was expected by {}(), got {!r}'.format( - method, callback)) + f'a callable object was expected by {method}(), ' + f'got {callback!r}') - - def _call_soon(self, callback, args): - handle = events.Handle(callback, args, self) + def _call_soon(self, callback, args, context): + handle = events.Handle(callback, args, self, context) if handle._source_traceback: del handle._source_traceback[-1] self._ready.append(handle) @@ -614,12 +835,12 @@ def _check_thread(self): "Non-thread-safe operation invoked on an event loop other " "than the current one") - def call_soon_threadsafe(self, callback, *args): + def call_soon_threadsafe(self, callback, *args, context=None): """Like call_soon(), but thread-safe.""" self._check_closed() if self._debug: self._check_callback(callback, 'call_soon_threadsafe') - handle = self._call_soon(callback, args) + handle = self._call_soon(callback, args, context) if handle._source_traceback: del handle._source_traceback[-1] self._write_to_self() @@ -631,24 +852,31 @@ def run_in_executor(self, executor, func, *args): self._check_callback(func, 'run_in_executor') if executor is None: executor = self._default_executor + # Only check when the default executor is being used + self._check_default_executor() if executor is None: - executor = concurrent.futures.ThreadPoolExecutor() + executor = concurrent.futures.ThreadPoolExecutor( + thread_name_prefix='asyncio' + ) self._default_executor = executor - return futures.wrap_future(executor.submit(func, *args), loop=self) + return futures.wrap_future( + executor.submit(func, *args), loop=self) def set_default_executor(self, executor): + if not isinstance(executor, concurrent.futures.ThreadPoolExecutor): + raise TypeError('executor must be ThreadPoolExecutor instance') self._default_executor = executor def _getaddrinfo_debug(self, host, port, family, type, proto, flags): - msg = ["%s:%r" % (host, port)] + msg = [f"{host}:{port!r}"] if family: - msg.append('family=%r' % family) + msg.append(f'family={family!r}') if type: - msg.append('type=%r' % type) + msg.append(f'type={type!r}') if proto: - msg.append('proto=%r' % proto) + msg.append(f'proto={proto!r}') if flags: - msg.append('flags=%r' % flags) + msg.append(f'flags={flags!r}') msg = ', '.join(msg) logger.debug('Get address info %s', msg) @@ -656,33 +884,152 @@ def _getaddrinfo_debug(self, host, port, family, type, proto, flags): addrinfo = socket.getaddrinfo(host, port, family, type, proto, flags) dt = self.time() - t0 - msg = ('Getting address info %s took %.3f ms: %r' - % (msg, dt * 1e3, addrinfo)) + msg = f'Getting address info {msg} took {dt * 1e3:.3f}ms: {addrinfo!r}' if dt >= self.slow_callback_duration: logger.info(msg) else: logger.debug(msg) return addrinfo - def getaddrinfo(self, host, port, *, - family=0, type=0, proto=0, flags=0): + async def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): if self._debug: - return self.run_in_executor(None, self._getaddrinfo_debug, - host, port, family, type, proto, flags) + getaddr_func = self._getaddrinfo_debug else: - return self.run_in_executor(None, socket.getaddrinfo, - host, port, family, type, proto, flags) + getaddr_func = socket.getaddrinfo + + return await self.run_in_executor( + None, getaddr_func, host, port, family, type, proto, flags) - def getnameinfo(self, sockaddr, flags=0): - return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + async def getnameinfo(self, sockaddr, flags=0): + return await self.run_in_executor( + None, socket.getnameinfo, sockaddr, flags) - @coroutine - def create_connection(self, protocol_factory, host=None, port=None, *, - ssl=None, family=0, proto=0, flags=0, sock=None, - local_addr=None, server_hostname=None): + async def sock_sendfile(self, sock, file, offset=0, count=None, + *, fallback=True): + if self._debug and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + _check_ssl_socket(sock) + self._check_sendfile_params(sock, file, offset, count) + try: + return await self._sock_sendfile_native(sock, file, + offset, count) + except exceptions.SendfileNotAvailableError as exc: + if not fallback: + raise + return await self._sock_sendfile_fallback(sock, file, + offset, count) + + async def _sock_sendfile_native(self, sock, file, offset, count): + # NB: sendfile syscall is not supported for SSL sockets and + # non-mmap files even if sendfile is supported by OS + raise exceptions.SendfileNotAvailableError( + f"syscall sendfile is not available for socket {sock!r} " + f"and file {file!r} combination") + + async def _sock_sendfile_fallback(self, sock, file, offset, count): + if offset: + file.seek(offset) + blocksize = ( + min(count, constants.SENDFILE_FALLBACK_READBUFFER_SIZE) + if count else constants.SENDFILE_FALLBACK_READBUFFER_SIZE + ) + buf = bytearray(blocksize) + total_sent = 0 + try: + while True: + if count: + blocksize = min(count - total_sent, blocksize) + if blocksize <= 0: + break + view = memoryview(buf)[:blocksize] + read = await self.run_in_executor(None, file.readinto, view) + if not read: + break # EOF + await self.sock_sendall(sock, view[:read]) + total_sent += read + return total_sent + finally: + if total_sent > 0 and hasattr(file, 'seek'): + file.seek(offset + total_sent) + + def _check_sendfile_params(self, sock, file, offset, count): + if 'b' not in getattr(file, 'mode', 'b'): + raise ValueError("file should be opened in binary mode") + if not sock.type == socket.SOCK_STREAM: + raise ValueError("only SOCK_STREAM type sockets are supported") + if count is not None: + if not isinstance(count, int): + raise TypeError( + "count must be a positive integer (got {!r})".format(count)) + if count <= 0: + raise ValueError( + "count must be a positive integer (got {!r})".format(count)) + if not isinstance(offset, int): + raise TypeError( + "offset must be a non-negative integer (got {!r})".format( + offset)) + if offset < 0: + raise ValueError( + "offset must be a non-negative integer (got {!r})".format( + offset)) + + async def _connect_sock(self, exceptions, addr_info, local_addr_infos=None): + """Create, bind and connect one socket.""" + my_exceptions = [] + exceptions.append(my_exceptions) + family, type_, proto, _, address = addr_info + sock = None + try: + sock = socket.socket(family=family, type=type_, proto=proto) + sock.setblocking(False) + if local_addr_infos is not None: + for lfamily, _, _, _, laddr in local_addr_infos: + # skip local addresses of different family + if lfamily != family: + continue + try: + sock.bind(laddr) + break + except OSError as exc: + msg = ( + f'error while attempting to bind on ' + f'address {laddr!r}: ' + f'{exc.strerror.lower()}' + ) + exc = OSError(exc.errno, msg) + my_exceptions.append(exc) + else: # all bind attempts failed + if my_exceptions: + raise my_exceptions.pop() + else: + raise OSError(f"no matching local address with {family=} found") + await self.sock_connect(sock, address) + return sock + except OSError as exc: + my_exceptions.append(exc) + if sock is not None: + sock.close() + raise + except: + if sock is not None: + sock.close() + raise + finally: + exceptions = my_exceptions = None + + async def create_connection( + self, protocol_factory, host=None, port=None, + *, ssl=None, family=0, + proto=0, flags=0, sock=None, + local_addr=None, server_hostname=None, + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, + happy_eyeballs_delay=None, interleave=None, + all_errors=False): """Connect to a TCP server. - Create a streaming transport connection to a given Internet host and + Create a streaming transport connection to a given internet host and port: socket family AF_INET or socket.AF_INET6 depending on host (or family if specified), socket type SOCK_STREAM. protocol_factory must be a callable returning a protocol instance. @@ -710,85 +1057,86 @@ def create_connection(self, protocol_factory, host=None, port=None, *, 'when using ssl without a host') server_hostname = host + if ssl_handshake_timeout is not None and not ssl: + raise ValueError( + 'ssl_handshake_timeout is only meaningful with ssl') + + if ssl_shutdown_timeout is not None and not ssl: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') + + if sock is not None: + _check_ssl_socket(sock) + + if happy_eyeballs_delay is not None and interleave is None: + # If using happy eyeballs, default to interleave addresses by family + interleave = 1 + if host is not None or port is not None: if sock is not None: raise ValueError( 'host/port and sock can not be specified at the same time') - f1 = _ensure_resolved((host, port), family=family, - type=socket.SOCK_STREAM, proto=proto, - flags=flags, loop=self) - fs = [f1] - if local_addr is not None: - f2 = _ensure_resolved(local_addr, family=family, - type=socket.SOCK_STREAM, proto=proto, - flags=flags, loop=self) - fs.append(f2) - else: - f2 = None - - yield from tasks.wait(fs, loop=self) - - infos = f1.result() + infos = await self._ensure_resolved( + (host, port), family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags, loop=self) if not infos: raise OSError('getaddrinfo() returned empty list') - if f2 is not None: - laddr_infos = f2.result() + + if local_addr is not None: + laddr_infos = await self._ensure_resolved( + local_addr, family=family, + type=socket.SOCK_STREAM, proto=proto, + flags=flags, loop=self) if not laddr_infos: raise OSError('getaddrinfo() returned empty list') + else: + laddr_infos = None + + if interleave: + infos = _interleave_addrinfos(infos, interleave) exceptions = [] - for family, type, proto, cname, address in infos: + if happy_eyeballs_delay is None: + # not using happy eyeballs + for addrinfo in infos: + try: + sock = await self._connect_sock( + exceptions, addrinfo, laddr_infos) + break + except OSError: + continue + else: # using happy eyeballs + sock, _, _ = await staggered.staggered_race( + (functools.partial(self._connect_sock, + exceptions, addrinfo, laddr_infos) + for addrinfo in infos), + happy_eyeballs_delay, loop=self) + + if sock is None: + exceptions = [exc for sub in exceptions for exc in sub] try: - sock = socket.socket(family=family, type=type, proto=proto) - sock.setblocking(False) - if f2 is not None: - for _, _, _, _, laddr in laddr_infos: - try: - sock.bind(laddr) - break - except OSError as exc: - exc = OSError( - exc.errno, 'error while ' - 'attempting to bind on address ' - '{!r}: {}'.format( - laddr, exc.strerror.lower())) - exceptions.append(exc) - else: - sock.close() - sock = None - continue - if self._debug: - logger.debug("connect %r to %r", sock, address) - yield from self.sock_connect(sock, address) - except OSError as exc: - if sock is not None: - sock.close() - exceptions.append(exc) - except: - if sock is not None: - sock.close() - raise - else: - break - else: - if len(exceptions) == 1: - raise exceptions[0] - else: - # If they all have the same str(), raise one. - model = str(exceptions[0]) - if all(str(exc) == model for exc in exceptions): + if all_errors: + raise ExceptionGroup("create_connection failed", exceptions) + if len(exceptions) == 1: raise exceptions[0] - # Raise a combined exception so the user can see all - # the various error messages. - raise OSError('Multiple exceptions: {}'.format( - ', '.join(str(exc) for exc in exceptions))) + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise OSError('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + finally: + exceptions = None else: if sock is None: raise ValueError( 'host and port was not specified and no sock specified') - if not _is_stream_socket(sock): + if sock.type != socket.SOCK_STREAM: # We allow AF_INET, AF_INET6, AF_UNIX as long as they # are SOCK_STREAM. # We support passing AF_UNIX sockets even though we have @@ -796,10 +1144,12 @@ def create_connection(self, protocol_factory, host=None, port=None, *, # Disallowing AF_UNIX in this method, breaks backwards # compatibility. raise ValueError( - 'A Stream Socket was expected, got {!r}'.format(sock)) + f'A Stream Socket was expected, got {sock!r}') - transport, protocol = yield from self._create_connection_transport( - sock, protocol_factory, ssl, server_hostname) + transport, protocol = await self._create_connection_transport( + sock, protocol_factory, ssl, server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) if self._debug: # Get the socket from the transport because SSL transport closes # the old socket and creates a new SSL socket @@ -808,9 +1158,11 @@ def create_connection(self, protocol_factory, host=None, port=None, *, sock, host, port, transport, protocol) return transport, protocol - @coroutine - def _create_connection_transport(self, sock, protocol_factory, ssl, - server_hostname, server_side=False): + async def _create_connection_transport( + self, sock, protocol_factory, ssl, + server_hostname, server_side=False, + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): sock.setblocking(False) @@ -820,42 +1172,166 @@ def _create_connection_transport(self, sock, protocol_factory, ssl, sslcontext = None if isinstance(ssl, bool) else ssl transport = self._make_ssl_transport( sock, protocol, sslcontext, waiter, - server_side=server_side, server_hostname=server_hostname) + server_side=server_side, server_hostname=server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) else: transport = self._make_socket_transport(sock, protocol, waiter) try: - yield from waiter + await waiter except: transport.close() raise return transport, protocol - @coroutine - def create_datagram_endpoint(self, protocol_factory, - local_addr=None, remote_addr=None, *, - family=0, proto=0, flags=0, - reuse_address=None, reuse_port=None, - allow_broadcast=None, sock=None): + async def sendfile(self, transport, file, offset=0, count=None, + *, fallback=True): + """Send a file to transport. + + Return the total number of bytes which were sent. + + The method uses high-performance os.sendfile if available. + + file must be a regular file object opened in binary mode. + + offset tells from where to start reading the file. If specified, + count is the total number of bytes to transmit as opposed to + sending the file until EOF is reached. File position is updated on + return or also in case of error in which case file.tell() + can be used to figure out the number of bytes + which were sent. + + fallback set to True makes asyncio to manually read and send + the file when the platform does not support the sendfile syscall + (e.g. Windows or SSL socket on Unix). + + Raise SendfileNotAvailableError if the system does not support + sendfile syscall and fallback is False. + """ + if transport.is_closing(): + raise RuntimeError("Transport is closing") + mode = getattr(transport, '_sendfile_compatible', + constants._SendfileMode.UNSUPPORTED) + if mode is constants._SendfileMode.UNSUPPORTED: + raise RuntimeError( + f"sendfile is not supported for transport {transport!r}") + if mode is constants._SendfileMode.TRY_NATIVE: + try: + return await self._sendfile_native(transport, file, + offset, count) + except exceptions.SendfileNotAvailableError as exc: + if not fallback: + raise + + if not fallback: + raise RuntimeError( + f"fallback is disabled and native sendfile is not " + f"supported for transport {transport!r}") + + return await self._sendfile_fallback(transport, file, + offset, count) + + async def _sendfile_native(self, transp, file, offset, count): + raise exceptions.SendfileNotAvailableError( + "sendfile syscall is not supported") + + async def _sendfile_fallback(self, transp, file, offset, count): + if offset: + file.seek(offset) + blocksize = min(count, 16384) if count else 16384 + buf = bytearray(blocksize) + total_sent = 0 + proto = _SendfileFallbackProtocol(transp) + try: + while True: + if count: + blocksize = min(count - total_sent, blocksize) + if blocksize <= 0: + return total_sent + view = memoryview(buf)[:blocksize] + read = await self.run_in_executor(None, file.readinto, view) + if not read: + return total_sent # EOF + await proto.drain() + transp.write(view[:read]) + total_sent += read + finally: + if total_sent > 0 and hasattr(file, 'seek'): + file.seek(offset + total_sent) + await proto.restore() + + async def start_tls(self, transport, protocol, sslcontext, *, + server_side=False, + server_hostname=None, + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): + """Upgrade transport to TLS. + + Return a new transport that *protocol* should start using + immediately. + """ + if ssl is None: + raise RuntimeError('Python ssl module is not available') + + if not isinstance(sslcontext, ssl.SSLContext): + raise TypeError( + f'sslcontext is expected to be an instance of ssl.SSLContext, ' + f'got {sslcontext!r}') + + if not getattr(transport, '_start_tls_compatible', False): + raise TypeError( + f'transport {transport!r} is not supported by start_tls()') + + waiter = self.create_future() + ssl_protocol = sslproto.SSLProtocol( + self, protocol, sslcontext, waiter, + server_side, server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout, + call_connection_made=False) + + # Pause early so that "ssl_protocol.data_received()" doesn't + # have a chance to get called before "ssl_protocol.connection_made()". + transport.pause_reading() + + transport.set_protocol(ssl_protocol) + conmade_cb = self.call_soon(ssl_protocol.connection_made, transport) + resume_cb = self.call_soon(transport.resume_reading) + + try: + await waiter + except BaseException: + transport.close() + conmade_cb.cancel() + resume_cb.cancel() + raise + + return ssl_protocol._app_transport + + async def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0, + reuse_port=None, + allow_broadcast=None, sock=None): """Create datagram connection.""" if sock is not None: - if not _is_dgram_socket(sock): + if sock.type == socket.SOCK_STREAM: raise ValueError( - 'A UDP Socket was expected, got {!r}'.format(sock)) + f'A datagram socket was expected, got {sock!r}') if (local_addr or remote_addr or family or proto or flags or - reuse_address or reuse_port or allow_broadcast): + reuse_port or allow_broadcast): # show the problematic kwargs in exception msg opts = dict(local_addr=local_addr, remote_addr=remote_addr, family=family, proto=proto, flags=flags, - reuse_address=reuse_address, reuse_port=reuse_port, + reuse_port=reuse_port, allow_broadcast=allow_broadcast) - problems = ', '.join( - '{}={}'.format(k, v) for k, v in opts.items() if v) + problems = ', '.join(f'{k}={v}' for k, v in opts.items() if v) raise ValueError( - 'socket modifier keyword arguments can not be used ' - 'when sock is specified. ({})'.format(problems)) + f'socket modifier keyword arguments can not be used ' + f'when sock is specified. ({problems})') sock.setblocking(False) r_addr = None else: @@ -863,15 +1339,34 @@ def create_datagram_endpoint(self, protocol_factory, if family == 0: raise ValueError('unexpected address family') addr_pairs_info = (((family, proto), (None, None)),) + elif hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX: + for addr in (local_addr, remote_addr): + if addr is not None and not isinstance(addr, str): + raise TypeError('string is expected') + + if local_addr and local_addr[0] not in (0, '\x00'): + try: + if stat.S_ISSOCK(os.stat(local_addr).st_mode): + os.remove(local_addr) + except FileNotFoundError: + pass + except OSError as err: + # Directory may have permissions only to create socket. + logger.error('Unable to check or remove stale UNIX ' + 'socket %r: %r', + local_addr, err) + + addr_pairs_info = (((family, proto), + (local_addr, remote_addr)), ) else: # join address by (family, protocol) - addr_infos = collections.OrderedDict() + addr_infos = {} # Using order preserving dict for idx, addr in ((0, local_addr), (1, remote_addr)): if addr is not None: - assert isinstance(addr, tuple) and len(addr) == 2, ( - '2-tuple is expected') + if not (isinstance(addr, tuple) and len(addr) == 2): + raise TypeError('2-tuple is expected') - infos = yield from _ensure_resolved( + infos = await self._ensure_resolved( addr, family=family, type=socket.SOCK_DGRAM, proto=proto, flags=flags, loop=self) if not infos: @@ -894,9 +1389,6 @@ def create_datagram_endpoint(self, protocol_factory, exceptions = [] - if reuse_address is None: - reuse_address = os.name == 'posix' and sys.platform != 'cygwin' - for ((family, proto), (local_address, remote_address)) in addr_pairs_info: sock = None @@ -904,9 +1396,6 @@ def create_datagram_endpoint(self, protocol_factory, try: sock = socket.socket( family=family, type=socket.SOCK_DGRAM, proto=proto) - if reuse_address: - sock.setsockopt( - socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if reuse_port: _set_reuseport(sock) if allow_broadcast: @@ -917,7 +1406,8 @@ def create_datagram_endpoint(self, protocol_factory, if local_addr: sock.bind(local_address) if remote_addr: - yield from self.sock_connect(sock, remote_address) + if not allow_broadcast: + await self.sock_connect(sock, remote_address) r_addr = remote_address except OSError as exc: if sock is not None: @@ -947,36 +1437,50 @@ def create_datagram_endpoint(self, protocol_factory, remote_addr, transport, protocol) try: - yield from waiter + await waiter except: transport.close() raise return transport, protocol - @coroutine - def _create_server_getaddrinfo(self, host, port, family, flags): - infos = yield from _ensure_resolved((host, port), family=family, + async def _ensure_resolved(self, address, *, + family=0, type=socket.SOCK_STREAM, + proto=0, flags=0, loop): + host, port = address[:2] + info = _ipaddr_info(host, port, family, type, proto, *address[2:]) + if info is not None: + # "host" is already a resolved IP. + return [info] + else: + return await loop.getaddrinfo(host, port, family=family, type=type, + proto=proto, flags=flags) + + async def _create_server_getaddrinfo(self, host, port, family, flags): + infos = await self._ensure_resolved((host, port), family=family, type=socket.SOCK_STREAM, flags=flags, loop=self) if not infos: - raise OSError('getaddrinfo({!r}) returned empty list'.format(host)) + raise OSError(f'getaddrinfo({host!r}) returned empty list') return infos - @coroutine - def create_server(self, protocol_factory, host=None, port=None, - *, - family=socket.AF_UNSPEC, - flags=socket.AI_PASSIVE, - sock=None, - backlog=100, - ssl=None, - reuse_address=None, - reuse_port=None): + async def create_server( + self, protocol_factory, host=None, port=None, + *, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, + sock=None, + backlog=100, + ssl=None, + reuse_address=None, + reuse_port=None, + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, + start_serving=True): """Create a TCP server. - The host parameter can be a string, in that case the TCP server is bound - to host and port. + The host parameter can be a string, in that case the TCP server is + bound to host and port. The host parameter can also be a sequence of strings and in that case the TCP server is bound to all hosts of the sequence. If a host @@ -990,19 +1494,30 @@ def create_server(self, protocol_factory, host=None, port=None, """ if isinstance(ssl, bool): raise TypeError('ssl argument must be an SSLContext or None') + + if ssl_handshake_timeout is not None and ssl is None: + raise ValueError( + 'ssl_handshake_timeout is only meaningful with ssl') + + if ssl_shutdown_timeout is not None and ssl is None: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') + + if sock is not None: + _check_ssl_socket(sock) + if host is not None or port is not None: if sock is not None: raise ValueError( 'host/port and sock can not be specified at the same time') - AF_INET6 = getattr(socket, 'AF_INET6', 0) if reuse_address is None: - reuse_address = os.name == 'posix' and sys.platform != 'cygwin' + reuse_address = os.name == "posix" and sys.platform != "cygwin" sockets = [] if host == '': hosts = [None] elif (isinstance(host, str) or - not isinstance(host, collections.Iterable)): + not isinstance(host, collections.abc.Iterable)): hosts = [host] else: hosts = host @@ -1010,7 +1525,7 @@ def create_server(self, protocol_factory, host=None, port=None, fs = [self._create_server_getaddrinfo(host, port, family=family, flags=flags) for host in hosts] - infos = yield from tasks.gather(*fs, loop=self) + infos = await tasks.gather(*fs) infos = set(itertools.chain.from_iterable(infos)) completed = False @@ -1035,16 +1550,31 @@ def create_server(self, protocol_factory, host=None, port=None, # Disable IPv4/IPv6 dual stack support (enabled by # default on Linux) which makes a single socket # listen on both address families. - if af == AF_INET6 and hasattr(socket, 'IPPROTO_IPV6'): + if (_HAS_IPv6 and + af == socket.AF_INET6 and + hasattr(socket, 'IPPROTO_IPV6')): sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, True) try: sock.bind(sa) except OSError as err: - raise OSError(err.errno, 'error while attempting ' - 'to bind on address %r: %s' - % (sa, err.strerror.lower())) + msg = ('error while attempting ' + 'to bind on address %r: %s' + % (sa, err.strerror.lower())) + if err.errno == errno.EADDRNOTAVAIL: + # Assume the family is not enabled (bpo-30945) + sockets.pop() + sock.close() + if self._debug: + logger.warning(msg) + continue + raise OSError(err.errno, msg) from None + + if not sockets: + raise OSError('could not bind on any address out of %r' + % ([info[4] for info in infos],)) + completed = True finally: if not completed: @@ -1053,36 +1583,49 @@ def create_server(self, protocol_factory, host=None, port=None, else: if sock is None: raise ValueError('Neither host/port nor sock were specified') - if not _is_stream_socket(sock): - raise ValueError( - 'A Stream Socket was expected, got {!r}'.format(sock)) + if sock.type != socket.SOCK_STREAM: + raise ValueError(f'A Stream Socket was expected, got {sock!r}') sockets = [sock] - server = Server(self, sockets) for sock in sockets: - sock.listen(backlog) sock.setblocking(False) - self._start_serving(protocol_factory, sock, ssl, server, backlog) + + server = Server(self, sockets, protocol_factory, + ssl, backlog, ssl_handshake_timeout, + ssl_shutdown_timeout) + if start_serving: + server._start_serving() + # Skip one loop iteration so that all 'loop.add_reader' + # go through. + await tasks.sleep(0) + if self._debug: logger.info("%r is serving", server) return server - @coroutine - def connect_accepted_socket(self, protocol_factory, sock, *, ssl=None): - """Handle an accepted connection. + async def connect_accepted_socket( + self, protocol_factory, sock, + *, ssl=None, + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): + if sock.type != socket.SOCK_STREAM: + raise ValueError(f'A Stream Socket was expected, got {sock!r}') - This is used by servers that accept connections outside of - asyncio but that use asyncio to handle connections. + if ssl_handshake_timeout is not None and not ssl: + raise ValueError( + 'ssl_handshake_timeout is only meaningful with ssl') - This method is a coroutine. When completed, the coroutine - returns a (transport, protocol) pair. - """ - if not _is_stream_socket(sock): + if ssl_shutdown_timeout is not None and not ssl: raise ValueError( - 'A Stream Socket was expected, got {!r}'.format(sock)) + 'ssl_shutdown_timeout is only meaningful with ssl') - transport, protocol = yield from self._create_connection_transport( - sock, protocol_factory, ssl, '', server_side=True) + if sock is not None: + _check_ssl_socket(sock) + + transport, protocol = await self._create_connection_transport( + sock, protocol_factory, ssl, '', server_side=True, + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) if self._debug: # Get the socket from the transport because SSL transport closes # the old socket and creates a new SSL socket @@ -1090,14 +1633,13 @@ def connect_accepted_socket(self, protocol_factory, sock, *, ssl=None): logger.debug("%r handled: (%r, %r)", sock, transport, protocol) return transport, protocol - @coroutine - def connect_read_pipe(self, protocol_factory, pipe): + async def connect_read_pipe(self, protocol_factory, pipe): protocol = protocol_factory() waiter = self.create_future() transport = self._make_read_pipe_transport(pipe, protocol, waiter) try: - yield from waiter + await waiter except: transport.close() raise @@ -1107,14 +1649,13 @@ def connect_read_pipe(self, protocol_factory, pipe): pipe.fileno(), transport, protocol) return transport, protocol - @coroutine - def connect_write_pipe(self, protocol_factory, pipe): + async def connect_write_pipe(self, protocol_factory, pipe): protocol = protocol_factory() waiter = self.create_future() transport = self._make_write_pipe_transport(pipe, protocol, waiter) try: - yield from waiter + await waiter except: transport.close() raise @@ -1127,21 +1668,24 @@ def connect_write_pipe(self, protocol_factory, pipe): def _log_subprocess(self, msg, stdin, stdout, stderr): info = [msg] if stdin is not None: - info.append('stdin=%s' % _format_pipe(stdin)) + info.append(f'stdin={_format_pipe(stdin)}') if stdout is not None and stderr == subprocess.STDOUT: - info.append('stdout=stderr=%s' % _format_pipe(stdout)) + info.append(f'stdout=stderr={_format_pipe(stdout)}') else: if stdout is not None: - info.append('stdout=%s' % _format_pipe(stdout)) + info.append(f'stdout={_format_pipe(stdout)}') if stderr is not None: - info.append('stderr=%s' % _format_pipe(stderr)) + info.append(f'stderr={_format_pipe(stderr)}') logger.debug(' '.join(info)) - @coroutine - def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, - stdout=subprocess.PIPE, stderr=subprocess.PIPE, - universal_newlines=False, shell=True, bufsize=0, - **kwargs): + async def subprocess_shell(self, protocol_factory, cmd, *, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=False, + shell=True, bufsize=0, + encoding=None, errors=None, text=None, + **kwargs): if not isinstance(cmd, (bytes, str)): raise ValueError("cmd must be a string") if universal_newlines: @@ -1150,45 +1694,57 @@ def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, raise ValueError("shell must be True") if bufsize != 0: raise ValueError("bufsize must be 0") + if text: + raise ValueError("text must be False") + if encoding is not None: + raise ValueError("encoding must be None") + if errors is not None: + raise ValueError("errors must be None") + protocol = protocol_factory() + debug_log = None if self._debug: # don't log parameters: they may contain sensitive information # (password) and may be too long debug_log = 'run shell command %r' % cmd self._log_subprocess(debug_log, stdin, stdout, stderr) - transport = yield from self._make_subprocess_transport( + transport = await self._make_subprocess_transport( protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs) - if self._debug: + if self._debug and debug_log is not None: logger.info('%s: %r', debug_log, transport) return transport, protocol - @coroutine - def subprocess_exec(self, protocol_factory, program, *args, - stdin=subprocess.PIPE, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, universal_newlines=False, - shell=False, bufsize=0, **kwargs): + async def subprocess_exec(self, protocol_factory, program, *args, + stdin=subprocess.PIPE, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, universal_newlines=False, + shell=False, bufsize=0, + encoding=None, errors=None, text=None, + **kwargs): if universal_newlines: raise ValueError("universal_newlines must be False") if shell: raise ValueError("shell must be False") if bufsize != 0: raise ValueError("bufsize must be 0") + if text: + raise ValueError("text must be False") + if encoding is not None: + raise ValueError("encoding must be None") + if errors is not None: + raise ValueError("errors must be None") + popen_args = (program,) + args - for arg in popen_args: - if not isinstance(arg, (str, bytes)): - raise TypeError("program arguments must be " - "a bytes or text string, not %s" - % type(arg).__name__) protocol = protocol_factory() + debug_log = None if self._debug: # don't log parameters: they may contain sensitive information # (password) and may be too long - debug_log = 'execute program %r' % program + debug_log = f'execute program {program!r}' self._log_subprocess(debug_log, stdin, stdout, stderr) - transport = yield from self._make_subprocess_transport( + transport = await self._make_subprocess_transport( protocol, popen_args, False, stdin, stdout, stderr, bufsize, **kwargs) - if self._debug: + if self._debug and debug_log is not None: logger.info('%s: %r', debug_log, transport) return transport, protocol @@ -1210,8 +1766,8 @@ def set_exception_handler(self, handler): documentation for details about context). """ if handler is not None and not callable(handler): - raise TypeError('A callable object or None is expected, ' - 'got {!r}'.format(handler)) + raise TypeError(f'A callable object or None is expected, ' + f'got {handler!r}') self._exception_handler = handler def default_exception_handler(self, context): @@ -1221,6 +1777,11 @@ def default_exception_handler(self, context): handler is set, and can be called by a custom exception handler that wants to defer to the default behavior. + This default handler logs the error message and other + context-dependent information. In debug mode, a truncated + stack trace is also appended showing where the given object + (e.g. a handle or future or task) was created, if any. + The context parameter has the same meaning as in `call_exception_handler()`. """ @@ -1234,10 +1795,11 @@ def default_exception_handler(self, context): else: exc_info = False - if ('source_traceback' not in context - and self._current_handle is not None - and self._current_handle._source_traceback): - context['handle_traceback'] = self._current_handle._source_traceback + if ('source_traceback' not in context and + self._current_handle is not None and + self._current_handle._source_traceback): + context['handle_traceback'] = \ + self._current_handle._source_traceback log_lines = [message] for key in sorted(context): @@ -1254,7 +1816,7 @@ def default_exception_handler(self, context): value += tb.rstrip() else: value = repr(value) - log_lines.append('{}: {}'.format(key, value)) + log_lines.append(f'{key}: {value}') logger.error('\n'.join(log_lines), exc_info=exc_info) @@ -1266,6 +1828,7 @@ def call_exception_handler(self, context): - 'message': Error message; - 'exception' (optional): Exception object; - 'future' (optional): Future instance; + - 'task' (optional): Task instance; - 'handle' (optional): Handle instance; - 'protocol' (optional): Protocol instance; - 'transport' (optional): Transport instance; @@ -1282,7 +1845,9 @@ def call_exception_handler(self, context): if self._exception_handler is None: try: self.default_exception_handler(context) - except Exception: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException: # Second protection layer for unexpected errors # in the default implementation, as well as for subclassed # event loops with overloaded "default_exception_handler". @@ -1290,8 +1855,25 @@ def call_exception_handler(self, context): exc_info=True) else: try: - self._exception_handler(self, context) - except Exception as exc: + ctx = None + thing = context.get("task") + if thing is None: + # Even though Futures don't have a context, + # Task is a subclass of Future, + # and sometimes the 'future' key holds a Task. + thing = context.get("future") + if thing is None: + # Handles also have a context. + thing = context.get("handle") + if thing is not None and hasattr(thing, "get_context"): + ctx = thing.get_context() + if ctx is not None and hasattr(ctx, "run"): + ctx.run(self._exception_handler, self, context) + else: + self._exception_handler(self, context) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: # Exception in the user set custom exception handler. try: # Let's try default handler. @@ -1300,7 +1882,9 @@ def call_exception_handler(self, context): 'exception': exc, 'context': context, }) - except Exception: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException: # Guard 'default_exception_handler' in case it is # overloaded. logger.error('Exception in default exception handler ' @@ -1309,12 +1893,9 @@ def call_exception_handler(self, context): exc_info=True) def _add_callback(self, handle): - """Add a Handle to _scheduled (TimerHandle) or _ready.""" - assert isinstance(handle, events.Handle), 'A Handle is required here' - if handle._cancelled: - return - assert not isinstance(handle, events.TimerHandle) - self._ready.append(handle) + """Add a Handle to _ready.""" + if not handle._cancelled: + self._ready.append(handle) def _add_callback_signalsafe(self, handle): """Like _add_callback() but called from a signal handler.""" @@ -1363,31 +1944,12 @@ def _run_once(self): elif self._scheduled: # Compute the desired timeout. when = self._scheduled[0]._when - timeout = max(0, when - self.time()) - - if self._debug and timeout != 0: - t0 = self.time() - event_list = self._selector.select(timeout) - dt = self.time() - t0 - if dt >= 1.0: - level = logging.INFO - else: - level = logging.DEBUG - nevent = len(event_list) - if timeout is None: - logger.log(level, 'poll took %.3f ms: %s events', - dt * 1e3, nevent) - elif nevent: - logger.log(level, - 'poll %.3f ms took %.3f ms: %s events', - timeout * 1e3, dt * 1e3, nevent) - elif dt >= 1.0: - logger.log(level, - 'poll %.3f ms took %.3f ms: timeout', - timeout * 1e3, dt * 1e3) - else: - event_list = self._selector.select(timeout) + timeout = min(max(0, when - self.time()), MAXIMUM_SELECT_TIMEOUT) + + event_list = self._selector.select(timeout) self._process_events(event_list) + # Needed to break cycles when an exception occurs. + event_list = None # Handle 'later' callbacks that are ready. end_time = self.time() + self._clock_resolution @@ -1425,38 +1987,20 @@ def _run_once(self): handle._run() handle = None # Needed to break cycles when an exception occurs. - def _set_coroutine_wrapper(self, enabled): - try: - set_wrapper = sys.set_coroutine_wrapper - get_wrapper = sys.get_coroutine_wrapper - except AttributeError: + def _set_coroutine_origin_tracking(self, enabled): + if bool(enabled) == bool(self._coroutine_origin_tracking_enabled): return - enabled = bool(enabled) - if self._coroutine_wrapper_set == enabled: - return - - wrapper = coroutines.debug_wrapper - current_wrapper = get_wrapper() - if enabled: - if current_wrapper not in (None, wrapper): - warnings.warn( - "loop.set_debug(True): cannot set debug coroutine " - "wrapper; another wrapper is already set %r" % - current_wrapper, RuntimeWarning) - else: - set_wrapper(wrapper) - self._coroutine_wrapper_set = True + self._coroutine_origin_tracking_saved_depth = ( + sys.get_coroutine_origin_tracking_depth()) + sys.set_coroutine_origin_tracking_depth( + constants.DEBUG_STACK_DEPTH) else: - if current_wrapper not in (None, wrapper): - warnings.warn( - "loop.set_debug(False): cannot unset debug coroutine " - "wrapper; another wrapper was set %r" % - current_wrapper, RuntimeWarning) - else: - set_wrapper(None) - self._coroutine_wrapper_set = False + sys.set_coroutine_origin_tracking_depth( + self._coroutine_origin_tracking_saved_depth) + + self._coroutine_origin_tracking_enabled = enabled def get_debug(self): return self._debug @@ -1465,4 +2009,4 @@ def set_debug(self, enabled): self._debug = enabled if self.is_running(): - self._set_coroutine_wrapper(enabled) + self.call_soon_threadsafe(self._set_coroutine_origin_tracking, enabled) diff --git a/Lib/asyncio/base_futures.py b/Lib/asyncio/base_futures.py index 01259a062e..7987963bd9 100644 --- a/Lib/asyncio/base_futures.py +++ b/Lib/asyncio/base_futures.py @@ -1,18 +1,8 @@ -__all__ = [] +__all__ = () -import concurrent.futures._base import reprlib -from . import events - -Error = concurrent.futures._base.Error -CancelledError = concurrent.futures.CancelledError -TimeoutError = concurrent.futures.TimeoutError - - -class InvalidStateError(Error): - """The operation is not allowed in this state.""" - +from . import format_helpers # States for Future. _PENDING = 'PENDING' @@ -38,17 +28,17 @@ def _format_callbacks(cb): cb = '' def format_cb(callback): - return events._format_callback_source(callback, ()) + return format_helpers._format_callback_source(callback, ()) if size == 1: - cb = format_cb(cb[0]) + cb = format_cb(cb[0][0]) elif size == 2: - cb = '{}, {}'.format(format_cb(cb[0]), format_cb(cb[1])) + cb = '{}, {}'.format(format_cb(cb[0][0]), format_cb(cb[1][0])) elif size > 2: - cb = '{}, <{} more>, {}'.format(format_cb(cb[0]), + cb = '{}, <{} more>, {}'.format(format_cb(cb[0][0]), size - 2, - format_cb(cb[-1])) - return 'cb=[%s]' % cb + format_cb(cb[-1][0])) + return f'cb=[{cb}]' def _future_repr_info(future): @@ -57,15 +47,21 @@ def _future_repr_info(future): info = [future._state.lower()] if future._state == _FINISHED: if future._exception is not None: - info.append('exception={!r}'.format(future._exception)) + info.append(f'exception={future._exception!r}') else: # use reprlib to limit the length of the output, especially # for very long strings result = reprlib.repr(future._result) - info.append('result={}'.format(result)) + info.append(f'result={result}') if future._callbacks: info.append(_format_callbacks(future._callbacks)) if future._source_traceback: frame = future._source_traceback[-1] - info.append('created at %s:%s' % (frame[0], frame[1])) + info.append(f'created at {frame[0]}:{frame[1]}') return info + + +@reprlib.recursive_repr() +def _future_repr(future): + info = ' '.join(_future_repr_info(future)) + return f'<{future.__class__.__name__} {info}>' diff --git a/Lib/asyncio/base_subprocess.py b/Lib/asyncio/base_subprocess.py index a00d9d5732..4c9b0dd565 100644 --- a/Lib/asyncio/base_subprocess.py +++ b/Lib/asyncio/base_subprocess.py @@ -2,10 +2,8 @@ import subprocess import warnings -from . import compat from . import protocols from . import transports -from .coroutines import coroutine from .log import logger @@ -59,9 +57,9 @@ def __repr__(self): if self._closed: info.append('closed') if self._pid is not None: - info.append('pid=%s' % self._pid) + info.append(f'pid={self._pid}') if self._returncode is not None: - info.append('returncode=%s' % self._returncode) + info.append(f'returncode={self._returncode}') elif self._pid is not None: info.append('running') else: @@ -69,19 +67,19 @@ def __repr__(self): stdin = self._pipes.get(0) if stdin is not None: - info.append('stdin=%s' % stdin.pipe) + info.append(f'stdin={stdin.pipe}') stdout = self._pipes.get(1) stderr = self._pipes.get(2) if stdout is not None and stderr is stdout: - info.append('stdout=stderr=%s' % stdout.pipe) + info.append(f'stdout=stderr={stdout.pipe}') else: if stdout is not None: - info.append('stdout=%s' % stdout.pipe) + info.append(f'stdout={stdout.pipe}') if stderr is not None: - info.append('stderr=%s' % stderr.pipe) + info.append(f'stderr={stderr.pipe}') - return '<%s>' % ' '.join(info) + return '<{}>'.format(' '.join(info)) def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): raise NotImplementedError @@ -105,12 +103,13 @@ def close(self): continue proto.pipe.close() - if (self._proc is not None - # the child process finished? - and self._returncode is None - # the child process finished but the transport was not notified yet? - and self._proc.poll() is None - ): + if (self._proc is not None and + # has the child process finished? + self._returncode is None and + # the child process has finished, but the + # transport hasn't been notified yet? + self._proc.poll() is None): + if self._loop.get_debug(): logger.warning('Close running child process: kill %r', self) @@ -121,15 +120,10 @@ def close(self): # Don't clear the _proc reference yet: _post_init() may still run - # On Python 3.3 and older, objects with a destructor part of a reference - # cycle are never destroyed. It's not more the case on Python 3.4 thanks - # to the PEP 442. - if compat.PY34: - def __del__(self): - if not self._closed: - warnings.warn("unclosed transport %r" % self, ResourceWarning, - source=self) - self.close() + def __del__(self, _warn=warnings.warn): + if not self._closed: + _warn(f"unclosed transport {self!r}", ResourceWarning, source=self) + self.close() def get_pid(self): return self._pid @@ -159,26 +153,25 @@ def kill(self): self._check_proc() self._proc.kill() - @coroutine - def _connect_pipes(self, waiter): + async def _connect_pipes(self, waiter): try: proc = self._proc loop = self._loop if proc.stdin is not None: - _, pipe = yield from loop.connect_write_pipe( + _, pipe = await loop.connect_write_pipe( lambda: WriteSubprocessPipeProto(self, 0), proc.stdin) self._pipes[0] = pipe if proc.stdout is not None: - _, pipe = yield from loop.connect_read_pipe( + _, pipe = await loop.connect_read_pipe( lambda: ReadSubprocessPipeProto(self, 1), proc.stdout) self._pipes[1] = pipe if proc.stderr is not None: - _, pipe = yield from loop.connect_read_pipe( + _, pipe = await loop.connect_read_pipe( lambda: ReadSubprocessPipeProto(self, 2), proc.stderr) self._pipes[2] = pipe @@ -189,7 +182,9 @@ def _connect_pipes(self, waiter): for callback, data in self._pending_calls: loop.call_soon(callback, *data) self._pending_calls = None - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: if waiter is not None and not waiter.cancelled(): waiter.set_exception(exc) else: @@ -213,24 +208,17 @@ def _process_exited(self, returncode): assert returncode is not None, returncode assert self._returncode is None, self._returncode if self._loop.get_debug(): - logger.info('%r exited with return code %r', - self, returncode) + logger.info('%r exited with return code %r', self, returncode) self._returncode = returncode if self._proc.returncode is None: # asyncio uses a child watcher: copy the status into the Popen # object. On Python 3.6, it is required to avoid a ResourceWarning. self._proc.returncode = returncode self._call(self._protocol.process_exited) - self._try_finish() - # wake up futures waiting for wait() - for waiter in self._exit_waiters: - if not waiter.cancelled(): - waiter.set_result(returncode) - self._exit_waiters = None + self._try_finish() - @coroutine - def _wait(self): + async def _wait(self): """Wait until the process exit and return the process return code. This method is a coroutine.""" @@ -239,7 +227,7 @@ def _wait(self): waiter = self._loop.create_future() self._exit_waiters.append(waiter) - return (yield from waiter) + return await waiter def _try_finish(self): assert not self._finished @@ -254,6 +242,11 @@ def _call_connection_lost(self, exc): try: self._protocol.connection_lost(exc) finally: + # wake up futures waiting for wait() + for waiter in self._exit_waiters: + if not waiter.cancelled(): + waiter.set_result(self._returncode) + self._exit_waiters = None self._loop = None self._proc = None self._protocol = None @@ -271,8 +264,7 @@ def connection_made(self, transport): self.pipe = transport def __repr__(self): - return ('<%s fd=%s pipe=%r>' - % (self.__class__.__name__, self.fd, self.pipe)) + return f'<{self.__class__.__name__} fd={self.fd} pipe={self.pipe!r}>' def connection_lost(self, exc): self.disconnected = True diff --git a/Lib/asyncio/base_tasks.py b/Lib/asyncio/base_tasks.py index 5f34434c57..c907b68341 100644 --- a/Lib/asyncio/base_tasks.py +++ b/Lib/asyncio/base_tasks.py @@ -1,4 +1,5 @@ import linecache +import reprlib import traceback from . import base_futures @@ -8,25 +9,42 @@ def _task_repr_info(task): info = base_futures._future_repr_info(task) - if task._must_cancel: + if task.cancelling() and not task.done(): # replace status info[0] = 'cancelling' - coro = coroutines._format_coroutine(task._coro) - info.insert(1, 'coro=<%s>' % coro) + info.insert(1, 'name=%r' % task.get_name()) if task._fut_waiter is not None: - info.insert(2, 'wait_for=%r' % task._fut_waiter) + info.insert(2, f'wait_for={task._fut_waiter!r}') + + if task._coro: + coro = coroutines._format_coroutine(task._coro) + info.insert(2, f'coro=<{coro}>') + return info +@reprlib.recursive_repr() +def _task_repr(task): + info = ' '.join(_task_repr_info(task)) + return f'<{task.__class__.__name__} {info}>' + + def _task_get_stack(task, limit): frames = [] - try: - # 'async def' coroutines + if hasattr(task._coro, 'cr_frame'): + # case 1: 'async def' coroutines f = task._coro.cr_frame - except AttributeError: + elif hasattr(task._coro, 'gi_frame'): + # case 2: legacy coroutines f = task._coro.gi_frame + elif hasattr(task._coro, 'ag_frame'): + # case 3: async generators + f = task._coro.ag_frame + else: + # case 4: unknown objects + f = None if f is not None: while f is not None: if limit is not None: @@ -61,15 +79,15 @@ def _task_print_stack(task, limit, file): linecache.checkcache(filename) line = linecache.getline(filename, lineno, f.f_globals) extracted_list.append((filename, lineno, name, line)) + exc = task._exception if not extracted_list: - print('No stack for %r' % task, file=file) + print(f'No stack for {task!r}', file=file) elif exc is not None: - print('Traceback for %r (most recent call last):' % task, - file=file) + print(f'Traceback for {task!r} (most recent call last):', file=file) else: - print('Stack for %r (most recent call last):' % task, - file=file) + print(f'Stack for {task!r} (most recent call last):', file=file) + traceback.print_list(extracted_list, file=file) if exc is not None: for line in traceback.format_exception_only(exc.__class__, exc): diff --git a/Lib/asyncio/compat.py b/Lib/asyncio/compat.py deleted file mode 100644 index 4790bb4a35..0000000000 --- a/Lib/asyncio/compat.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Compatibility helpers for the different Python versions.""" - -import sys - -PY34 = sys.version_info >= (3, 4) -PY35 = sys.version_info >= (3, 5) -PY352 = sys.version_info >= (3, 5, 2) - - -def flatten_list_bytes(list_of_data): - """Concatenate a sequence of bytes-like objects.""" - if not PY34: - # On Python 3.3 and older, bytes.join() doesn't handle - # memoryview. - list_of_data = ( - bytes(data) if isinstance(data, memoryview) else data - for data in list_of_data) - return b''.join(list_of_data) diff --git a/Lib/asyncio/constants.py b/Lib/asyncio/constants.py index f9e123281e..b60c1e4236 100644 --- a/Lib/asyncio/constants.py +++ b/Lib/asyncio/constants.py @@ -1,7 +1,41 @@ -"""Constants.""" +# Contains code from https://github.com/MagicStack/uvloop/tree/v0.16.0 +# SPDX-License-Identifier: PSF-2.0 AND (MIT OR Apache-2.0) +# SPDX-FileCopyrightText: Copyright (c) 2015-2021 MagicStack Inc. http://magic.io + +import enum # After the connection is lost, log warnings after this many write()s. LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5 # Seconds to wait before retrying accept(). ACCEPT_RETRY_DELAY = 1 + +# Number of stack entries to capture in debug mode. +# The larger the number, the slower the operation in debug mode +# (see extract_stack() in format_helpers.py). +DEBUG_STACK_DEPTH = 10 + +# Number of seconds to wait for SSL handshake to complete +# The default timeout matches that of Nginx. +SSL_HANDSHAKE_TIMEOUT = 60.0 + +# Number of seconds to wait for SSL shutdown to complete +# The default timeout mimics lingering_time +SSL_SHUTDOWN_TIMEOUT = 30.0 + +# Used in sendfile fallback code. We use fallback for platforms +# that don't support sendfile, or for TLS connections. +SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 256 + +FLOW_CONTROL_HIGH_WATER_SSL_READ = 256 # KiB +FLOW_CONTROL_HIGH_WATER_SSL_WRITE = 512 # KiB + +# Default timeout for joining the threads in the threadpool +THREAD_JOIN_TIMEOUT = 300 + +# The enum should be here to break circular dependencies between +# base_events and sslproto +class _SendfileMode(enum.Enum): + UNSUPPORTED = enum.auto() + TRY_NATIVE = enum.auto() + FALLBACK = enum.auto() diff --git a/Lib/asyncio/coroutines.py b/Lib/asyncio/coroutines.py index 08e94412b3..ab4f30eb51 100644 --- a/Lib/asyncio/coroutines.py +++ b/Lib/asyncio/coroutines.py @@ -1,249 +1,16 @@ -__all__ = ['coroutine', - 'iscoroutinefunction', 'iscoroutine'] +__all__ = 'iscoroutinefunction', 'iscoroutine' -import functools +import collections.abc import inspect -import opcode import os import sys -import traceback import types -from . import compat -from . import events -from . import base_futures -from .log import logger - -# Opcode of "yield from" instruction -_YIELD_FROM = opcode.opmap['YIELD_FROM'] - -# If you set _DEBUG to true, @coroutine will wrap the resulting -# generator objects in a CoroWrapper instance (defined below). That -# instance will log a message when the generator is never iterated -# over, which may happen when you forget to use "yield from" with a -# coroutine call. Note that the value of the _DEBUG flag is taken -# when the decorator is used, so to be of any use it must be set -# before you define your coroutines. A downside of using this feature -# is that tracebacks show entries for the CoroWrapper.__next__ method -# when _DEBUG is true. -_DEBUG = (not sys.flags.ignore_environment and - bool(os.environ.get('PYTHONASYNCIODEBUG'))) - - -try: - _types_coroutine = types.coroutine - _types_CoroutineType = types.CoroutineType -except AttributeError: - # Python 3.4 - _types_coroutine = None - _types_CoroutineType = None - -try: - _inspect_iscoroutinefunction = inspect.iscoroutinefunction -except AttributeError: - # Python 3.4 - _inspect_iscoroutinefunction = lambda func: False - -try: - from collections.abc import Coroutine as _CoroutineABC, \ - Awaitable as _AwaitableABC -except ImportError: - _CoroutineABC = _AwaitableABC = None - - -# Check for CPython issue #21209 -def has_yield_from_bug(): - class MyGen: - def __init__(self): - self.send_args = None - def __iter__(self): - return self - def __next__(self): - return 42 - def send(self, *what): - self.send_args = what - return None - def yield_from_gen(gen): - yield from gen - value = (1, 2, 3) - gen = MyGen() - coro = yield_from_gen(gen) - next(coro) - coro.send(value) - return gen.send_args != (value,) -_YIELD_FROM_BUG = has_yield_from_bug() -del has_yield_from_bug - - -def debug_wrapper(gen): - # This function is called from 'sys.set_coroutine_wrapper'. - # We only wrap here coroutines defined via 'async def' syntax. - # Generator-based coroutines are wrapped in @coroutine - # decorator. - return CoroWrapper(gen, None) - - -class CoroWrapper: - # Wrapper for coroutine object in _DEBUG mode. - - def __init__(self, gen, func=None): - assert inspect.isgenerator(gen) or inspect.iscoroutine(gen), gen - self.gen = gen - self.func = func # Used to unwrap @coroutine decorator - self._source_traceback = traceback.extract_stack(sys._getframe(1)) - self.__name__ = getattr(gen, '__name__', None) - self.__qualname__ = getattr(gen, '__qualname__', None) - - def __repr__(self): - coro_repr = _format_coroutine(self) - if self._source_traceback: - frame = self._source_traceback[-1] - coro_repr += ', created at %s:%s' % (frame[0], frame[1]) - return '<%s %s>' % (self.__class__.__name__, coro_repr) - - def __iter__(self): - return self - - def __next__(self): - return self.gen.send(None) - - if _YIELD_FROM_BUG: - # For for CPython issue #21209: using "yield from" and a custom - # generator, generator.send(tuple) unpacks the tuple instead of passing - # the tuple unchanged. Check if the caller is a generator using "yield - # from" to decide if the parameter should be unpacked or not. - def send(self, *value): - frame = sys._getframe() - caller = frame.f_back - assert caller.f_lasti >= 0 - if caller.f_code.co_code[caller.f_lasti] != _YIELD_FROM: - value = value[0] - return self.gen.send(value) - else: - def send(self, value): - return self.gen.send(value) - - def throw(self, type, value=None, traceback=None): - return self.gen.throw(type, value, traceback) - - def close(self): - return self.gen.close() - - @property - def gi_frame(self): - return self.gen.gi_frame - - @property - def gi_running(self): - return self.gen.gi_running - - @property - def gi_code(self): - return self.gen.gi_code - - if compat.PY35: - - def __await__(self): - cr_await = getattr(self.gen, 'cr_await', None) - if cr_await is not None: - raise RuntimeError( - "Cannot await on coroutine {!r} while it's " - "awaiting for {!r}".format(self.gen, cr_await)) - return self - - @property - def gi_yieldfrom(self): - return self.gen.gi_yieldfrom - - @property - def cr_await(self): - return self.gen.cr_await - - @property - def cr_running(self): - return self.gen.cr_running - - @property - def cr_code(self): - return self.gen.cr_code - - @property - def cr_frame(self): - return self.gen.cr_frame - - def __del__(self): - # Be careful accessing self.gen.frame -- self.gen might not exist. - gen = getattr(self, 'gen', None) - frame = getattr(gen, 'gi_frame', None) - if frame is None: - frame = getattr(gen, 'cr_frame', None) - if frame is not None and frame.f_lasti == -1: - msg = '%r was never yielded from' % self - tb = getattr(self, '_source_traceback', ()) - if tb: - tb = ''.join(traceback.format_list(tb)) - msg += ('\nCoroutine object created at ' - '(most recent call last):\n') - msg += tb.rstrip() - logger.error(msg) - - -def coroutine(func): - """Decorator to mark coroutines. - - If the coroutine is not yielded from before it is destroyed, - an error message is logged. - """ - if _inspect_iscoroutinefunction(func): - # In Python 3.5 that's all we need to do for coroutines - # defiend with "async def". - # Wrapping in CoroWrapper will happen via - # 'sys.set_coroutine_wrapper' function. - return func - - if inspect.isgeneratorfunction(func): - coro = func - else: - @functools.wraps(func) - def coro(*args, **kw): - res = func(*args, **kw) - if (base_futures.isfuture(res) or inspect.isgenerator(res) or - isinstance(res, CoroWrapper)): - res = yield from res - elif _AwaitableABC is not None: - # If 'func' returns an Awaitable (new in 3.5) we - # want to run it. - try: - await_meth = res.__await__ - except AttributeError: - pass - else: - if isinstance(res, _AwaitableABC): - res = yield from await_meth() - return res - - if not _DEBUG: - if _types_coroutine is None: - wrapper = coro - else: - wrapper = _types_coroutine(coro) - else: - @functools.wraps(func) - def wrapper(*args, **kwds): - w = CoroWrapper(coro(*args, **kwds), func=func) - if w._source_traceback: - del w._source_traceback[-1] - # Python < 3.5 does not implement __qualname__ - # on generator objects, so we set it manually. - # We use getattr as some callables (such as - # functools.partial may lack __qualname__). - w.__name__ = getattr(func, '__name__', None) - w.__qualname__ = getattr(func, '__qualname__', None) - return w - - wrapper._is_coroutine = _is_coroutine # For iscoroutinefunction(). - return wrapper +def _is_debug_mode(): + # See: https://docs.python.org/3/library/asyncio-dev.html#asyncio-debug-mode. + return sys.flags.dev_mode or (not sys.flags.ignore_environment and + bool(os.environ.get('PYTHONASYNCIODEBUG'))) # A marker for iscoroutinefunction. @@ -252,93 +19,91 @@ def wrapper(*args, **kwds): def iscoroutinefunction(func): """Return True if func is a decorated coroutine function.""" - return (getattr(func, '_is_coroutine', None) is _is_coroutine or - _inspect_iscoroutinefunction(func)) + return (inspect.iscoroutinefunction(func) or + getattr(func, '_is_coroutine', None) is _is_coroutine) -_COROUTINE_TYPES = (types.GeneratorType, CoroWrapper) -if _CoroutineABC is not None: - _COROUTINE_TYPES += (_CoroutineABC,) -if _types_CoroutineType is not None: - # Prioritize native coroutine check to speed-up - # asyncio.iscoroutine. - _COROUTINE_TYPES = (_types_CoroutineType,) + _COROUTINE_TYPES +# Prioritize native coroutine check to speed-up +# asyncio.iscoroutine. +_COROUTINE_TYPES = (types.CoroutineType, collections.abc.Coroutine) +_iscoroutine_typecache = set() def iscoroutine(obj): """Return True if obj is a coroutine object.""" - return isinstance(obj, _COROUTINE_TYPES) + if type(obj) in _iscoroutine_typecache: + return True + + if isinstance(obj, _COROUTINE_TYPES): + # Just in case we don't want to cache more than 100 + # positive types. That shouldn't ever happen, unless + # someone stressing the system on purpose. + if len(_iscoroutine_typecache) < 100: + _iscoroutine_typecache.add(type(obj)) + return True + else: + return False def _format_coroutine(coro): assert iscoroutine(coro) - if not hasattr(coro, 'cr_code') and not hasattr(coro, 'gi_code'): - # Most likely a built-in type or a Cython coroutine. - - # Built-in types might not have __qualname__ or __name__. - coro_name = getattr( - coro, '__qualname__', - getattr(coro, '__name__', type(coro).__name__)) - coro_name = '{}()'.format(coro_name) + def get_name(coro): + # Coroutines compiled with Cython sometimes don't have + # proper __qualname__ or __name__. While that is a bug + # in Cython, asyncio shouldn't crash with an AttributeError + # in its __repr__ functions. + if hasattr(coro, '__qualname__') and coro.__qualname__: + coro_name = coro.__qualname__ + elif hasattr(coro, '__name__') and coro.__name__: + coro_name = coro.__name__ + else: + # Stop masking Cython bugs, expose them in a friendly way. + coro_name = f'<{type(coro).__name__} without __name__>' + return f'{coro_name}()' - running = False + def is_running(coro): try: - running = coro.cr_running + return coro.cr_running except AttributeError: try: - running = coro.gi_running + return coro.gi_running except AttributeError: - pass + return False - if running: - return '{} running'.format(coro_name) - else: - return coro_name - - coro_name = None - if isinstance(coro, CoroWrapper): - func = coro.func - coro_name = coro.__qualname__ - if coro_name is not None: - coro_name = '{}()'.format(coro_name) - else: - func = coro + coro_code = None + if hasattr(coro, 'cr_code') and coro.cr_code: + coro_code = coro.cr_code + elif hasattr(coro, 'gi_code') and coro.gi_code: + coro_code = coro.gi_code - if coro_name is None: - coro_name = events._format_callback(func, (), {}) + coro_name = get_name(coro) - try: - coro_code = coro.gi_code - except AttributeError: - coro_code = coro.cr_code + if not coro_code: + # Built-in types might not have __qualname__ or __name__. + if is_running(coro): + return f'{coro_name} running' + else: + return coro_name - try: + coro_frame = None + if hasattr(coro, 'gi_frame') and coro.gi_frame: coro_frame = coro.gi_frame - except AttributeError: + elif hasattr(coro, 'cr_frame') and coro.cr_frame: coro_frame = coro.cr_frame - filename = coro_code.co_filename + # If Cython's coroutine has a fake code object without proper + # co_filename -- expose that. + filename = coro_code.co_filename or '' + lineno = 0 - if (isinstance(coro, CoroWrapper) and - not inspect.isgeneratorfunction(coro.func) and - coro.func is not None): - source = events._get_function_source(coro.func) - if source is not None: - filename, lineno = source - if coro_frame is None: - coro_repr = ('%s done, defined at %s:%s' - % (coro_name, filename, lineno)) - else: - coro_repr = ('%s running, defined at %s:%s' - % (coro_name, filename, lineno)) - elif coro_frame is not None: + + if coro_frame is not None: lineno = coro_frame.f_lineno - coro_repr = ('%s running at %s:%s' - % (coro_name, filename, lineno)) + coro_repr = f'{coro_name} running at {filename}:{lineno}' + else: lineno = coro_code.co_firstlineno - coro_repr = ('%s done, defined at %s:%s' - % (coro_name, filename, lineno)) + coro_repr = f'{coro_name} done, defined at {filename}:{lineno}' return coro_repr diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py index 466db6d9a3..016852880c 100644 --- a/Lib/asyncio/events.py +++ b/Lib/asyncio/events.py @@ -1,96 +1,50 @@ """Event loop and event loop policy.""" -__all__ = ['AbstractEventLoopPolicy', - 'AbstractEventLoop', 'AbstractServer', - 'Handle', 'TimerHandle', - 'get_event_loop_policy', 'set_event_loop_policy', - 'get_event_loop', 'set_event_loop', 'new_event_loop', - 'get_child_watcher', 'set_child_watcher', - '_set_running_loop', 'get_running_loop', - '_get_running_loop', - ] - -import functools -import inspect -import reprlib +# Contains code from https://github.com/MagicStack/uvloop/tree/v0.16.0 +# SPDX-License-Identifier: PSF-2.0 AND (MIT OR Apache-2.0) +# SPDX-FileCopyrightText: Copyright (c) 2015-2021 MagicStack Inc. http://magic.io + +__all__ = ( + 'AbstractEventLoopPolicy', + 'AbstractEventLoop', 'AbstractServer', + 'Handle', 'TimerHandle', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + 'get_child_watcher', 'set_child_watcher', + '_set_running_loop', 'get_running_loop', + '_get_running_loop', +) + +import contextvars +import os +import signal import socket import subprocess import sys import threading -import traceback -from asyncio import compat - - -def _get_function_source(func): - if compat.PY34: - func = inspect.unwrap(func) - elif hasattr(func, '__wrapped__'): - func = func.__wrapped__ - if inspect.isfunction(func): - code = func.__code__ - return (code.co_filename, code.co_firstlineno) - if isinstance(func, functools.partial): - return _get_function_source(func.func) - if compat.PY34 and isinstance(func, functools.partialmethod): - return _get_function_source(func.func) - return None - - -def _format_args_and_kwargs(args, kwargs): - """Format function arguments and keyword arguments. - - Special case for a single parameter: ('hello',) is formatted as ('hello'). - """ - # use reprlib to limit the length of the output - items = [] - if args: - items.extend(reprlib.repr(arg) for arg in args) - if kwargs: - items.extend('{}={}'.format(k, reprlib.repr(v)) - for k, v in kwargs.items()) - return '(' + ', '.join(items) + ')' - - -def _format_callback(func, args, kwargs, suffix=''): - if isinstance(func, functools.partial): - suffix = _format_args_and_kwargs(args, kwargs) + suffix - return _format_callback(func.func, func.args, func.keywords, suffix) - - if hasattr(func, '__qualname__'): - func_repr = getattr(func, '__qualname__') - elif hasattr(func, '__name__'): - func_repr = getattr(func, '__name__') - else: - func_repr = repr(func) - - func_repr += _format_args_and_kwargs(args, kwargs) - if suffix: - func_repr += suffix - return func_repr - -def _format_callback_source(func, args): - func_repr = _format_callback(func, args, None) - source = _get_function_source(func) - if source: - func_repr += ' at %s:%s' % source - return func_repr +from . import format_helpers class Handle: """Object returned by callback registration methods.""" __slots__ = ('_callback', '_args', '_cancelled', '_loop', - '_source_traceback', '_repr', '__weakref__') + '_source_traceback', '_repr', '__weakref__', + '_context') - def __init__(self, callback, args, loop): + def __init__(self, callback, args, loop, context=None): + if context is None: + context = contextvars.copy_context() + self._context = context self._loop = loop self._callback = callback self._args = args self._cancelled = False self._repr = None if self._loop.get_debug(): - self._source_traceback = traceback.extract_stack(sys._getframe(1)) + self._source_traceback = format_helpers.extract_stack( + sys._getframe(1)) else: self._source_traceback = None @@ -99,17 +53,21 @@ def _repr_info(self): if self._cancelled: info.append('cancelled') if self._callback is not None: - info.append(_format_callback_source(self._callback, self._args)) + info.append(format_helpers._format_callback_source( + self._callback, self._args)) if self._source_traceback: frame = self._source_traceback[-1] - info.append('created at %s:%s' % (frame[0], frame[1])) + info.append(f'created at {frame[0]}:{frame[1]}') return info def __repr__(self): if self._repr is not None: return self._repr info = self._repr_info() - return '<%s>' % ' '.join(info) + return '<{}>'.format(' '.join(info)) + + def get_context(self): + return self._context def cancel(self): if not self._cancelled: @@ -122,12 +80,18 @@ def cancel(self): self._callback = None self._args = None + def cancelled(self): + return self._cancelled + def _run(self): try: - self._callback(*self._args) - except Exception as exc: - cb = _format_callback_source(self._callback, self._args) - msg = 'Exception in callback {}'.format(cb) + self._context.run(self._callback, *self._args) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + cb = format_helpers._format_callback_source( + self._callback, self._args) + msg = f'Exception in callback {cb}' context = { 'message': msg, 'exception': exc, @@ -144,9 +108,8 @@ class TimerHandle(Handle): __slots__ = ['_scheduled', '_when'] - def __init__(self, when, callback, args, loop): - assert when is not None - super().__init__(callback, args, loop) + def __init__(self, when, callback, args, loop, context=None): + super().__init__(callback, args, loop, context) if self._source_traceback: del self._source_traceback[-1] self._when = when @@ -155,27 +118,31 @@ def __init__(self, when, callback, args, loop): def _repr_info(self): info = super()._repr_info() pos = 2 if self._cancelled else 1 - info.insert(pos, 'when=%s' % self._when) + info.insert(pos, f'when={self._when}') return info def __hash__(self): return hash(self._when) def __lt__(self, other): - return self._when < other._when + if isinstance(other, TimerHandle): + return self._when < other._when + return NotImplemented def __le__(self, other): - if self._when < other._when: - return True - return self.__eq__(other) + if isinstance(other, TimerHandle): + return self._when < other._when or self.__eq__(other) + return NotImplemented def __gt__(self, other): - return self._when > other._when + if isinstance(other, TimerHandle): + return self._when > other._when + return NotImplemented def __ge__(self, other): - if self._when > other._when: - return True - return self.__eq__(other) + if isinstance(other, TimerHandle): + return self._when > other._when or self.__eq__(other) + return NotImplemented def __eq__(self, other): if isinstance(other, TimerHandle): @@ -185,26 +152,60 @@ def __eq__(self, other): self._cancelled == other._cancelled) return NotImplemented - def __ne__(self, other): - equal = self.__eq__(other) - return NotImplemented if equal is NotImplemented else not equal - def cancel(self): if not self._cancelled: self._loop._timer_handle_cancelled(self) super().cancel() + def when(self): + """Return a scheduled callback time. + + The time is an absolute timestamp, using the same time + reference as loop.time(). + """ + return self._when + class AbstractServer: """Abstract server returned by create_server().""" def close(self): """Stop serving. This leaves existing connections open.""" - return NotImplemented + raise NotImplementedError + + def get_loop(self): + """Get the event loop the Server object is attached to.""" + raise NotImplementedError + + def is_serving(self): + """Return True if the server is accepting connections.""" + raise NotImplementedError + + async def start_serving(self): + """Start accepting connections. + + This method is idempotent, so it can be called when + the server is already being serving. + """ + raise NotImplementedError - def wait_closed(self): + async def serve_forever(self): + """Start accepting connections until the coroutine is cancelled. + + The server is closed when the coroutine is cancelled. + """ + raise NotImplementedError + + async def wait_closed(self): """Coroutine to wait until service is closed.""" - return NotImplemented + raise NotImplementedError + + async def __aenter__(self): + return self + + async def __aexit__(self, *exc): + self.close() + await self.wait_closed() class AbstractEventLoop: @@ -250,23 +251,27 @@ def close(self): """ raise NotImplementedError - def shutdown_asyncgens(self): + async def shutdown_asyncgens(self): """Shutdown all active asynchronous generators.""" raise NotImplementedError + async def shutdown_default_executor(self): + """Schedule the shutdown of the default executor.""" + raise NotImplementedError + # Methods scheduling callbacks. All these return Handles. def _timer_handle_cancelled(self, handle): """Notification that a TimerHandle has been cancelled.""" raise NotImplementedError - def call_soon(self, callback, *args): - return self.call_later(0, callback, *args) + def call_soon(self, callback, *args, context=None): + return self.call_later(0, callback, *args, context=context) - def call_later(self, delay, callback, *args): + def call_later(self, delay, callback, *args, context=None): raise NotImplementedError - def call_at(self, when, callback, *args): + def call_at(self, when, callback, *args, context=None): raise NotImplementedError def time(self): @@ -277,12 +282,12 @@ def create_future(self): # Method scheduling a coroutine object: create a task. - def create_task(self, coro): + def create_task(self, coro, *, name=None, context=None): raise NotImplementedError # Methods for interacting with threads. - def call_soon_threadsafe(self, callback, *args): + def call_soon_threadsafe(self, callback, *args, context=None): raise NotImplementedError def run_in_executor(self, executor, func, *args): @@ -293,21 +298,31 @@ def set_default_executor(self, executor): # Network I/O methods returning Futures. - def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + async def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): raise NotImplementedError - def getnameinfo(self, sockaddr, flags=0): + async def getnameinfo(self, sockaddr, flags=0): raise NotImplementedError - def create_connection(self, protocol_factory, host=None, port=None, *, - ssl=None, family=0, proto=0, flags=0, sock=None, - local_addr=None, server_hostname=None): + async def create_connection( + self, protocol_factory, host=None, port=None, + *, ssl=None, family=0, proto=0, + flags=0, sock=None, local_addr=None, + server_hostname=None, + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, + happy_eyeballs_delay=None, interleave=None): raise NotImplementedError - def create_server(self, protocol_factory, host=None, port=None, *, - family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, - sock=None, backlog=100, ssl=None, reuse_address=None, - reuse_port=None): + async def create_server( + self, protocol_factory, host=None, port=None, + *, family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, sock=None, backlog=100, + ssl=None, reuse_address=None, reuse_port=None, + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, + start_serving=True): """A coroutine which creates a TCP server bound to host and port. The return value is a Server object which can be used to stop @@ -315,8 +330,8 @@ def create_server(self, protocol_factory, host=None, port=None, *, If host is an empty string or None all interfaces are assumed and a list of multiple sockets will be returned (most likely - one for IPv4 and another one for IPv6). The host parameter can also be a - sequence (e.g. list) of hosts to bind to. + one for IPv4 and another one for IPv6). The host parameter can also be + a sequence (e.g. list) of hosts to bind to. family can be set to either AF_INET or AF_INET6 to force the socket to use IPv4 or IPv6. If not set it will be determined @@ -342,22 +357,62 @@ def create_server(self, protocol_factory, host=None, port=None, *, the same port as other existing endpoints are bound to, so long as they all set this flag when being created. This option is not supported on Windows. + + ssl_handshake_timeout is the time in seconds that an SSL server + will wait for completion of the SSL handshake before aborting the + connection. Default is 60s. + + ssl_shutdown_timeout is the time in seconds that an SSL server + will wait for completion of the SSL shutdown procedure + before aborting the connection. Default is 30s. + + start_serving set to True (default) causes the created server + to start accepting connections immediately. When set to False, + the user should await Server.start_serving() or Server.serve_forever() + to make the server to start accepting connections. + """ + raise NotImplementedError + + async def sendfile(self, transport, file, offset=0, count=None, + *, fallback=True): + """Send a file through a transport. + + Return an amount of sent bytes. + """ + raise NotImplementedError + + async def start_tls(self, transport, protocol, sslcontext, *, + server_side=False, + server_hostname=None, + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): + """Upgrade a transport to TLS. + + Return a new transport that *protocol* should start using + immediately. """ raise NotImplementedError - def create_unix_connection(self, protocol_factory, path, *, - ssl=None, sock=None, - server_hostname=None): + async def create_unix_connection( + self, protocol_factory, path=None, *, + ssl=None, sock=None, + server_hostname=None, + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): raise NotImplementedError - def create_unix_server(self, protocol_factory, path, *, - sock=None, backlog=100, ssl=None): + async def create_unix_server( + self, protocol_factory, path=None, *, + sock=None, backlog=100, ssl=None, + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, + start_serving=True): """A coroutine which creates a UNIX Domain Socket server. The return value is a Server object, which can be used to stop the service. - path is a str, representing a file systsem path to bind the + path is a str, representing a file system path to bind the server socket to. sock can optionally be specified in order to use a preexisting @@ -368,14 +423,40 @@ def create_unix_server(self, protocol_factory, path, *, ssl can be set to an SSLContext to enable SSL over the accepted connections. + + ssl_handshake_timeout is the time in seconds that an SSL server + will wait for the SSL handshake to complete (defaults to 60s). + + ssl_shutdown_timeout is the time in seconds that an SSL server + will wait for the SSL shutdown to finish (defaults to 30s). + + start_serving set to True (default) causes the created server + to start accepting connections immediately. When set to False, + the user should await Server.start_serving() or Server.serve_forever() + to make the server to start accepting connections. + """ + raise NotImplementedError + + async def connect_accepted_socket( + self, protocol_factory, sock, + *, ssl=None, + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): + """Handle an accepted connection. + + This is used by servers that accept connections outside of + asyncio, but use asyncio to handle connections. + + This method is a coroutine. When completed, the coroutine + returns a (transport, protocol) pair. """ raise NotImplementedError - def create_datagram_endpoint(self, protocol_factory, - local_addr=None, remote_addr=None, *, - family=0, proto=0, flags=0, - reuse_address=None, reuse_port=None, - allow_broadcast=None, sock=None): + async def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0, + reuse_address=None, reuse_port=None, + allow_broadcast=None, sock=None): """A coroutine which creates a datagram endpoint. This method will try to establish the endpoint in the background. @@ -383,8 +464,8 @@ def create_datagram_endpoint(self, protocol_factory, protocol_factory must be a callable returning a protocol instance. - socket family AF_INET or socket.AF_INET6 depending on host (or - family if specified), socket type SOCK_DGRAM. + socket family AF_INET, socket.AF_INET6 or socket.AF_UNIX depending on + host (or family if specified), socket type SOCK_DGRAM. reuse_address tells the kernel to reuse a local socket in TIME_WAIT state, without waiting for its natural timeout to @@ -408,7 +489,7 @@ def create_datagram_endpoint(self, protocol_factory, # Pipes and subprocesses. - def connect_read_pipe(self, protocol_factory, pipe): + async def connect_read_pipe(self, protocol_factory, pipe): """Register read pipe in event loop. Set the pipe to non-blocking mode. protocol_factory should instantiate object with Protocol interface. @@ -418,10 +499,10 @@ def connect_read_pipe(self, protocol_factory, pipe): # The reason to accept file-like object instead of just file descriptor # is: we need to own pipe and close it at transport finishing # Can got complicated errors if pass f.fileno(), - # close fd in pipe transport then close f and vise versa. + # close fd in pipe transport then close f and vice versa. raise NotImplementedError - def connect_write_pipe(self, protocol_factory, pipe): + async def connect_write_pipe(self, protocol_factory, pipe): """Register write pipe in event loop. protocol_factory should instantiate object with BaseProtocol interface. @@ -431,17 +512,21 @@ def connect_write_pipe(self, protocol_factory, pipe): # The reason to accept file-like object instead of just file descriptor # is: we need to own pipe and close it at transport finishing # Can got complicated errors if pass f.fileno(), - # close fd in pipe transport then close f and vise versa. + # close fd in pipe transport then close f and vice versa. raise NotImplementedError - def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, - stdout=subprocess.PIPE, stderr=subprocess.PIPE, - **kwargs): + async def subprocess_shell(self, protocol_factory, cmd, *, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + **kwargs): raise NotImplementedError - def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, - stdout=subprocess.PIPE, stderr=subprocess.PIPE, - **kwargs): + async def subprocess_exec(self, protocol_factory, *args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + **kwargs): raise NotImplementedError # Ready-based callback registration methods. @@ -463,16 +548,32 @@ def remove_writer(self, fd): # Completion based I/O methods returning Futures. - def sock_recv(self, sock, nbytes): + async def sock_recv(self, sock, nbytes): raise NotImplementedError - def sock_sendall(self, sock, data): + async def sock_recv_into(self, sock, buf): raise NotImplementedError - def sock_connect(self, sock, address): + async def sock_recvfrom(self, sock, bufsize): raise NotImplementedError - def sock_accept(self, sock): + async def sock_recvfrom_into(self, sock, buf, nbytes=0): + raise NotImplementedError + + async def sock_sendall(self, sock, data): + raise NotImplementedError + + async def sock_sendto(self, sock, data, address): + raise NotImplementedError + + async def sock_connect(self, sock, address): + raise NotImplementedError + + async def sock_accept(self, sock): + raise NotImplementedError + + async def sock_sendfile(self, sock, file, offset=0, count=None, + *, fallback=None): raise NotImplementedError # Signal handling. @@ -520,7 +621,7 @@ class AbstractEventLoopPolicy: def get_event_loop(self): """Get the event loop for the current context. - Returns an event loop object implementing the BaseEventLoop interface, + Returns an event loop object implementing the AbstractEventLoop interface, or raises an exception in case no event loop has been set for the current context and the current policy does not specify to create one. @@ -571,23 +672,43 @@ def __init__(self): self._local = self._Local() def get_event_loop(self): - """Get the event loop. + """Get the event loop for the current context. - This may be None or an instance of EventLoop. + Returns an instance of EventLoop or raises an exception. """ if (self._local._loop is None and - not self._local._set_called and - isinstance(threading.current_thread(), threading._MainThread)): + not self._local._set_called and + threading.current_thread() is threading.main_thread()): + stacklevel = 2 + try: + f = sys._getframe(1) + except AttributeError: + pass + else: + # Move up the call stack so that the warning is attached + # to the line outside asyncio itself. + while f: + module = f.f_globals.get('__name__') + if not (module == 'asyncio' or module.startswith('asyncio.')): + break + f = f.f_back + stacklevel += 1 + import warnings + warnings.warn('There is no current event loop', + DeprecationWarning, stacklevel=stacklevel) self.set_event_loop(self.new_event_loop()) + if self._local._loop is None: raise RuntimeError('There is no current event loop in thread %r.' % threading.current_thread().name) + return self._local._loop def set_event_loop(self, loop): """Set the event loop.""" self._local._set_called = True - assert loop is None or isinstance(loop, AbstractEventLoop) + if loop is not None and not isinstance(loop, AbstractEventLoop): + raise TypeError(f"loop must be an instance of AbstractEventLoop or None, not '{type(loop).__name__}'") self._local._loop = loop def new_event_loop(self): @@ -611,7 +732,9 @@ def new_event_loop(self): # A TLS for the running event loop, used by _get_running_loop. class _RunningLoop(threading.local): - _loop = None + loop_pid = (None, None) + + _running_loop = _RunningLoop() @@ -633,7 +756,10 @@ def _get_running_loop(): This is a low-level function intended to be used by event loops. This function is thread-specific. """ - return _running_loop._loop + # NOTE: this function is implemented in C (see _asynciomodule.c) + running_loop, pid = _running_loop.loop_pid + if running_loop is not None and pid == os.getpid(): + return running_loop def _set_running_loop(loop): @@ -642,7 +768,8 @@ def _set_running_loop(loop): This is a low-level function intended to be used by event loops. This function is thread-specific. """ - _running_loop._loop = loop + # NOTE: this function is implemented in C (see _asynciomodule.c) + _running_loop.loop_pid = (loop, os.getpid()) def _init_event_loop_policy(): @@ -665,7 +792,8 @@ def set_event_loop_policy(policy): If policy is None, the default policy is restored.""" global _event_loop_policy - assert policy is None or isinstance(policy, AbstractEventLoopPolicy) + if policy is not None and not isinstance(policy, AbstractEventLoopPolicy): + raise TypeError(f"policy must be an instance of AbstractEventLoopPolicy or None, not '{type(policy).__name__}'") _event_loop_policy = policy @@ -678,6 +806,7 @@ def get_event_loop(): If there is no running event loop set, the function will return the result of `get_event_loop_policy().get_event_loop()` call. """ + # NOTE: this function is implemented in C (see _asynciomodule.c) current_loop = _get_running_loop() if current_loop is not None: return current_loop @@ -703,3 +832,37 @@ def set_child_watcher(watcher): """Equivalent to calling get_event_loop_policy().set_child_watcher(watcher).""" return get_event_loop_policy().set_child_watcher(watcher) + + +# Alias pure-Python implementations for testing purposes. +_py__get_running_loop = _get_running_loop +_py__set_running_loop = _set_running_loop +_py_get_running_loop = get_running_loop +_py_get_event_loop = get_event_loop + + +try: + # get_event_loop() is one of the most frequently called + # functions in asyncio. Pure Python implementation is + # about 4 times slower than C-accelerated. + from _asyncio import (_get_running_loop, _set_running_loop, + get_running_loop, get_event_loop) +except ImportError: + pass +else: + # Alias C implementations for testing purposes. + _c__get_running_loop = _get_running_loop + _c__set_running_loop = _set_running_loop + _c_get_running_loop = get_running_loop + _c_get_event_loop = get_event_loop + + +if hasattr(os, 'fork'): + def on_fork(): + # Reset the loop and wakeupfd in the forked child process. + if _event_loop_policy is not None: + _event_loop_policy._local = BaseDefaultEventLoopPolicy._Local() + _set_running_loop(None) + signal.set_wakeup_fd(-1) + + os.register_at_fork(after_in_child=on_fork) diff --git a/Lib/asyncio/exceptions.py b/Lib/asyncio/exceptions.py new file mode 100644 index 0000000000..5ece595aad --- /dev/null +++ b/Lib/asyncio/exceptions.py @@ -0,0 +1,62 @@ +"""asyncio exceptions.""" + + +__all__ = ('BrokenBarrierError', + 'CancelledError', 'InvalidStateError', 'TimeoutError', + 'IncompleteReadError', 'LimitOverrunError', + 'SendfileNotAvailableError') + + +class CancelledError(BaseException): + """The Future or Task was cancelled.""" + + +TimeoutError = TimeoutError # make local alias for the standard exception + + +class InvalidStateError(Exception): + """The operation is not allowed in this state.""" + + +class SendfileNotAvailableError(RuntimeError): + """Sendfile syscall is not available. + + Raised if OS does not support sendfile syscall for given socket or + file type. + """ + + +class IncompleteReadError(EOFError): + """ + Incomplete read error. Attributes: + + - partial: read bytes string before the end of stream was reached + - expected: total number of expected bytes (or None if unknown) + """ + def __init__(self, partial, expected): + r_expected = 'undefined' if expected is None else repr(expected) + super().__init__(f'{len(partial)} bytes read on a total of ' + f'{r_expected} expected bytes') + self.partial = partial + self.expected = expected + + def __reduce__(self): + return type(self), (self.partial, self.expected) + + +class LimitOverrunError(Exception): + """Reached the buffer limit while looking for a separator. + + Attributes: + - consumed: total number of to be consumed bytes. + """ + def __init__(self, message, consumed): + super().__init__(message) + self.consumed = consumed + + def __reduce__(self): + return type(self), (self.args[0], self.consumed) + + +class BrokenBarrierError(RuntimeError): + """Barrier is broken by barrier.abort() call.""" diff --git a/Lib/asyncio/format_helpers.py b/Lib/asyncio/format_helpers.py new file mode 100644 index 0000000000..27d11fd4fa --- /dev/null +++ b/Lib/asyncio/format_helpers.py @@ -0,0 +1,76 @@ +import functools +import inspect +import reprlib +import sys +import traceback + +from . import constants + + +def _get_function_source(func): + func = inspect.unwrap(func) + if inspect.isfunction(func): + code = func.__code__ + return (code.co_filename, code.co_firstlineno) + if isinstance(func, functools.partial): + return _get_function_source(func.func) + if isinstance(func, functools.partialmethod): + return _get_function_source(func.func) + return None + + +def _format_callback_source(func, args): + func_repr = _format_callback(func, args, None) + source = _get_function_source(func) + if source: + func_repr += f' at {source[0]}:{source[1]}' + return func_repr + + +def _format_args_and_kwargs(args, kwargs): + """Format function arguments and keyword arguments. + + Special case for a single parameter: ('hello',) is formatted as ('hello'). + """ + # use reprlib to limit the length of the output + items = [] + if args: + items.extend(reprlib.repr(arg) for arg in args) + if kwargs: + items.extend(f'{k}={reprlib.repr(v)}' for k, v in kwargs.items()) + return '({})'.format(', '.join(items)) + + +def _format_callback(func, args, kwargs, suffix=''): + if isinstance(func, functools.partial): + suffix = _format_args_and_kwargs(args, kwargs) + suffix + return _format_callback(func.func, func.args, func.keywords, suffix) + + if hasattr(func, '__qualname__') and func.__qualname__: + func_repr = func.__qualname__ + elif hasattr(func, '__name__') and func.__name__: + func_repr = func.__name__ + else: + func_repr = repr(func) + + func_repr += _format_args_and_kwargs(args, kwargs) + if suffix: + func_repr += suffix + return func_repr + + +def extract_stack(f=None, limit=None): + """Replacement for traceback.extract_stack() that only does the + necessary work for asyncio debug mode. + """ + if f is None: + f = sys._getframe().f_back + if limit is None: + # Limit the amount of work to a reasonable amount, as extract_stack() + # can be called for each coroutine and future in debug mode. + limit = constants.DEBUG_STACK_DEPTH + stack = traceback.StackSummary.extract(traceback.walk_stack(f), + limit=limit, + lookup_lines=False) + stack.reverse() + return stack diff --git a/Lib/asyncio/futures.py b/Lib/asyncio/futures.py index 82c03330ad..97fc4e3fcb 100644 --- a/Lib/asyncio/futures.py +++ b/Lib/asyncio/futures.py @@ -1,21 +1,21 @@ """A Future class similar to the one in PEP 3148.""" -__all__ = ['CancelledError', 'TimeoutError', 'InvalidStateError', - 'Future', 'wrap_future', 'isfuture'] +__all__ = ( + 'Future', 'wrap_future', 'isfuture', +) import concurrent.futures +import contextvars import logging import sys -import traceback +from types import GenericAlias from . import base_futures -from . import compat from . import events +from . import exceptions +from . import format_helpers -CancelledError = base_futures.CancelledError -InvalidStateError = base_futures.InvalidStateError -TimeoutError = base_futures.TimeoutError isfuture = base_futures.isfuture @@ -27,96 +27,18 @@ STACK_DEBUG = logging.DEBUG - 1 # heavy-duty debugging -class _TracebackLogger: - """Helper to log a traceback upon destruction if not cleared. - - This solves a nasty problem with Futures and Tasks that have an - exception set: if nobody asks for the exception, the exception is - never logged. This violates the Zen of Python: 'Errors should - never pass silently. Unless explicitly silenced.' - - However, we don't want to log the exception as soon as - set_exception() is called: if the calling code is written - properly, it will get the exception and handle it properly. But - we *do* want to log it if result() or exception() was never called - -- otherwise developers waste a lot of time wondering why their - buggy code fails silently. - - An earlier attempt added a __del__() method to the Future class - itself, but this backfired because the presence of __del__() - prevents garbage collection from breaking cycles. A way out of - this catch-22 is to avoid having a __del__() method on the Future - class itself, but instead to have a reference to a helper object - with a __del__() method that logs the traceback, where we ensure - that the helper object doesn't participate in cycles, and only the - Future has a reference to it. - - The helper object is added when set_exception() is called. When - the Future is collected, and the helper is present, the helper - object is also collected, and its __del__() method will log the - traceback. When the Future's result() or exception() method is - called (and a helper object is present), it removes the helper - object, after calling its clear() method to prevent it from - logging. - - One downside is that we do a fair amount of work to extract the - traceback from the exception, even when it is never logged. It - would seem cheaper to just store the exception object, but that - references the traceback, which references stack frames, which may - reference the Future, which references the _TracebackLogger, and - then the _TracebackLogger would be included in a cycle, which is - what we're trying to avoid! As an optimization, we don't - immediately format the exception; we only do the work when - activate() is called, which call is delayed until after all the - Future's callbacks have run. Since usually a Future has at least - one callback (typically set by 'yield from') and usually that - callback extracts the callback, thereby removing the need to - format the exception. - - PS. I don't claim credit for this solution. I first heard of it - in a discussion about closing files when they are collected. - """ - - __slots__ = ('loop', 'source_traceback', 'exc', 'tb') - - def __init__(self, future, exc): - self.loop = future._loop - self.source_traceback = future._source_traceback - self.exc = exc - self.tb = None - - def activate(self): - exc = self.exc - if exc is not None: - self.exc = None - self.tb = traceback.format_exception(exc.__class__, exc, - exc.__traceback__) - - def clear(self): - self.exc = None - self.tb = None - - def __del__(self): - if self.tb: - msg = 'Future/Task exception was never retrieved\n' - if self.source_traceback: - src = ''.join(traceback.format_list(self.source_traceback)) - msg += 'Future/Task created at (most recent call last):\n' - msg += '%s\n' % src.rstrip() - msg += ''.join(self.tb).rstrip() - self.loop.call_exception_handler({'message': msg}) - - class Future: """This class is *almost* compatible with concurrent.futures.Future. Differences: + - This class is not thread-safe. + - result() and exception() do not take a timeout argument and raise an exception when the future isn't done yet. - Callbacks registered with add_done_callback() are always called - via the event loop's call_soon_threadsafe(). + via the event loop's call_soon(). - This class is not compatible with the wait() and as_completed() methods in the concurrent.futures package. @@ -130,6 +52,9 @@ class Future: _exception = None _loop = None _source_traceback = None + _cancel_message = None + # A saved CancelledError for later chaining as an exception context. + _cancelled_exc = None # This field is used for a dual purpose: # - Its presence is a marker to declare that a class implements @@ -137,12 +62,12 @@ class Future: # The value must also be not-None, to enable a subclass to declare # that it is not compatible by setting this to None. # - It is set by __iter__() below so that Task._step() can tell - # the difference between `yield from Future()` (correct) vs. + # the difference between + # `await Future()` or`yield from Future()` (correct) vs. # `yield Future()` (incorrect). _asyncio_future_blocking = False - _log_traceback = False # Used for Python 3.4 and later - _tb_logger = None # Used for Python 3.3 only + __log_traceback = False def __init__(self, *, loop=None): """Initialize the future. @@ -157,50 +82,83 @@ def __init__(self, *, loop=None): self._loop = loop self._callbacks = [] if self._loop.get_debug(): - self._source_traceback = traceback.extract_stack(sys._getframe(1)) - - _repr_info = base_futures._future_repr_info + self._source_traceback = format_helpers.extract_stack( + sys._getframe(1)) def __repr__(self): - return '<%s %s>' % (self.__class__.__name__, ' '.join(self._repr_info())) - - # On Python 3.3 and older, objects with a destructor part of a reference - # cycle are never destroyed. It's not more the case on Python 3.4 thanks - # to the PEP 442. - if compat.PY34: - def __del__(self): - if not self._log_traceback: - # set_exception() was not called, or result() or exception() - # has consumed the exception - return - exc = self._exception - context = { - 'message': ('%s exception was never retrieved' - % self.__class__.__name__), - 'exception': exc, - 'future': self, - } - if self._source_traceback: - context['source_traceback'] = self._source_traceback - self._loop.call_exception_handler(context) - - def __class_getitem__(cls, type): - return cls - - def cancel(self): + return base_futures._future_repr(self) + + def __del__(self): + if not self.__log_traceback: + # set_exception() was not called, or result() or exception() + # has consumed the exception + return + exc = self._exception + context = { + 'message': + f'{self.__class__.__name__} exception was never retrieved', + 'exception': exc, + 'future': self, + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + + __class_getitem__ = classmethod(GenericAlias) + + @property + def _log_traceback(self): + return self.__log_traceback + + @_log_traceback.setter + def _log_traceback(self, val): + if val: + raise ValueError('_log_traceback can only be set to False') + self.__log_traceback = False + + def get_loop(self): + """Return the event loop the Future is bound to.""" + loop = self._loop + if loop is None: + raise RuntimeError("Future object is not initialized.") + return loop + + def _make_cancelled_error(self): + """Create the CancelledError to raise if the Future is cancelled. + + This should only be called once when handling a cancellation since + it erases the saved context exception value. + """ + if self._cancelled_exc is not None: + exc = self._cancelled_exc + self._cancelled_exc = None + return exc + + if self._cancel_message is None: + exc = exceptions.CancelledError() + else: + exc = exceptions.CancelledError(self._cancel_message) + exc.__context__ = self._cancelled_exc + # Remove the reference since we don't need this anymore. + self._cancelled_exc = None + return exc + + def cancel(self, msg=None): """Cancel the future and schedule callbacks. If the future is already done or cancelled, return False. Otherwise, change the future's state to cancelled, schedule the callbacks and return True. """ + self.__log_traceback = False if self._state != _PENDING: return False self._state = _CANCELLED - self._schedule_callbacks() + self._cancel_message = msg + self.__schedule_callbacks() return True - def _schedule_callbacks(self): + def __schedule_callbacks(self): """Internal: Ask the event loop to call all callbacks. The callbacks are scheduled to be called as soon as possible. Also @@ -211,8 +169,8 @@ def _schedule_callbacks(self): return self._callbacks[:] = [] - for callback in callbacks: - self._loop.call_soon(callback, self) + for callback, ctx in callbacks: + self._loop.call_soon(callback, self, context=ctx) def cancelled(self): """Return True if the future was cancelled.""" @@ -236,15 +194,13 @@ def result(self): the future is done and has an exception set, this exception is raised. """ if self._state == _CANCELLED: - raise CancelledError + exc = self._make_cancelled_error() + raise exc if self._state != _FINISHED: - raise InvalidStateError('Result is not ready.') - self._log_traceback = False - if self._tb_logger is not None: - self._tb_logger.clear() - self._tb_logger = None + raise exceptions.InvalidStateError('Result is not ready.') + self.__log_traceback = False if self._exception is not None: - raise self._exception + raise self._exception.with_traceback(self._exception_tb) return self._result def exception(self): @@ -256,16 +212,14 @@ def exception(self): InvalidStateError. """ if self._state == _CANCELLED: - raise CancelledError + exc = self._make_cancelled_error() + raise exc if self._state != _FINISHED: - raise InvalidStateError('Exception is not set.') - self._log_traceback = False - if self._tb_logger is not None: - self._tb_logger.clear() - self._tb_logger = None + raise exceptions.InvalidStateError('Exception is not set.') + self.__log_traceback = False return self._exception - def add_done_callback(self, fn): + def add_done_callback(self, fn, *, context=None): """Add a callback to be run when the future becomes done. The callback is called with a single argument - the future object. If @@ -273,9 +227,11 @@ def add_done_callback(self, fn): scheduled with call_soon. """ if self._state != _PENDING: - self._loop.call_soon(fn, self) + self._loop.call_soon(fn, self, context=context) else: - self._callbacks.append(fn) + if context is None: + context = contextvars.copy_context() + self._callbacks.append((fn, context)) # New method not in PEP 3148. @@ -284,7 +240,9 @@ def remove_done_callback(self, fn): Returns the number of callbacks removed. """ - filtered_callbacks = [f for f in self._callbacks if f != fn] + filtered_callbacks = [(f, ctx) + for (f, ctx) in self._callbacks + if f != fn] removed_count = len(self._callbacks) - len(filtered_callbacks) if removed_count: self._callbacks[:] = filtered_callbacks @@ -299,10 +257,10 @@ def set_result(self, result): InvalidStateError. """ if self._state != _PENDING: - raise InvalidStateError('{}: {!r}'.format(self._state, self)) + raise exceptions.InvalidStateError(f'{self._state}: {self!r}') self._result = result self._state = _FINISHED - self._schedule_callbacks() + self.__schedule_callbacks() def set_exception(self, exception): """Mark the future done and set an exception. @@ -311,38 +269,45 @@ def set_exception(self, exception): InvalidStateError. """ if self._state != _PENDING: - raise InvalidStateError('{}: {!r}'.format(self._state, self)) + raise exceptions.InvalidStateError(f'{self._state}: {self!r}') if isinstance(exception, type): exception = exception() if type(exception) is StopIteration: raise TypeError("StopIteration interacts badly with generators " "and cannot be raised into a Future") self._exception = exception + self._exception_tb = exception.__traceback__ self._state = _FINISHED - self._schedule_callbacks() - if compat.PY34: - self._log_traceback = True - else: - self._tb_logger = _TracebackLogger(self, exception) - # Arrange for the logger to be activated after all callbacks - # have had a chance to call result() or exception(). - self._loop.call_soon(self._tb_logger.activate) + self.__schedule_callbacks() + self.__log_traceback = True - def __iter__(self): + def __await__(self): if not self.done(): self._asyncio_future_blocking = True yield self # This tells Task to wait for completion. - assert self.done(), "yield from wasn't used with future" + if not self.done(): + raise RuntimeError("await wasn't used with future") return self.result() # May raise too. - if compat.PY35: - __await__ = __iter__ # make compatible with 'await' expression + __iter__ = __await__ # make compatible with 'yield from'. # Needed for testing purposes. _PyFuture = Future +def _get_loop(fut): + # Tries to call Future.get_loop() if it's available. + # Otherwise fallbacks to using the old '_loop' property. + try: + get_loop = fut.get_loop + except AttributeError: + pass + else: + return get_loop() + return fut._loop + + def _set_result_unless_cancelled(fut, result): """Helper setting the result only if the future was not cancelled.""" if fut.cancelled(): @@ -350,6 +315,18 @@ def _set_result_unless_cancelled(fut, result): fut.set_result(result) +def _convert_future_exc(exc): + exc_class = type(exc) + if exc_class is concurrent.futures.CancelledError: + return exceptions.CancelledError(*exc.args) + elif exc_class is concurrent.futures.TimeoutError: + return exceptions.TimeoutError(*exc.args) + elif exc_class is concurrent.futures.InvalidStateError: + return exceptions.InvalidStateError(*exc.args) + else: + return exc + + def _set_concurrent_future_state(concurrent, source): """Copy state from a future to a concurrent.futures.Future.""" assert source.done() @@ -359,7 +336,7 @@ def _set_concurrent_future_state(concurrent, source): return exception = source.exception() if exception is not None: - concurrent.set_exception(exception) + concurrent.set_exception(_convert_future_exc(exception)) else: result = source.result() concurrent.set_result(result) @@ -379,7 +356,7 @@ def _copy_future_state(source, dest): else: exception = source.exception() if exception is not None: - dest.set_exception(exception) + dest.set_exception(_convert_future_exc(exception)) else: result = source.result() dest.set_result(result) @@ -398,8 +375,8 @@ def _chain_future(source, destination): if not isfuture(destination) and not isinstance(destination, concurrent.futures.Future): raise TypeError('A future is required for destination argument') - source_loop = source._loop if isfuture(source) else None - dest_loop = destination._loop if isfuture(destination) else None + source_loop = _get_loop(source) if isfuture(source) else None + dest_loop = _get_loop(destination) if isfuture(destination) else None def _set_state(future, other): if isfuture(future): @@ -415,9 +392,14 @@ def _call_check_cancel(destination): source_loop.call_soon_threadsafe(source.cancel) def _call_set_state(source): + if (destination.cancelled() and + dest_loop is not None and dest_loop.is_closed()): + return if dest_loop is None or dest_loop is source_loop: _set_state(destination, source) else: + if dest_loop.is_closed(): + return dest_loop.call_soon_threadsafe(_set_state, destination, source) destination.add_done_callback(_call_check_cancel) @@ -429,7 +411,7 @@ def wrap_future(future, *, loop=None): if isfuture(future): return future assert isinstance(future, concurrent.futures.Future), \ - 'concurrent.futures.Future is expected, got {!r}'.format(future) + f'concurrent.futures.Future is expected, got {future!r}' if loop is None: loop = events.get_event_loop() new_future = loop.create_future() diff --git a/Lib/asyncio/locks.py b/Lib/asyncio/locks.py index deefc938ec..ce5d8d5bfb 100644 --- a/Lib/asyncio/locks.py +++ b/Lib/asyncio/locks.py @@ -1,92 +1,26 @@ """Synchronization primitives.""" -__all__ = ['Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore'] +__all__ = ('Lock', 'Event', 'Condition', 'Semaphore', + 'BoundedSemaphore', 'Barrier') import collections +import enum -from . import compat -from . import events -from . import futures -from .coroutines import coroutine +from . import exceptions +from . import mixins - -class _ContextManager: - """Context manager. - - This enables the following idiom for acquiring and releasing a - lock around a block: - - with (yield from lock): - - - while failing loudly when accidentally using: - - with lock: - - """ - - def __init__(self, lock): - self._lock = lock - - def __enter__(self): +class _ContextManagerMixin: + async def __aenter__(self): + await self.acquire() # We have no use for the "as ..." clause in the with # statement for locks. return None - def __exit__(self, *args): - try: - self._lock.release() - finally: - self._lock = None # Crudely prevent reuse. - - -class _ContextManagerMixin: - def __enter__(self): - raise RuntimeError( - '"yield from" should be used as context manager expression') + async def __aexit__(self, exc_type, exc, tb): + self.release() - def __exit__(self, *args): - # This must exist because __enter__ exists, even though that - # always raises; that's how the with-statement works. - pass - @coroutine - def __iter__(self): - # This is not a coroutine. It is meant to enable the idiom: - # - # with (yield from lock): - # - # - # as an alternative to: - # - # yield from lock.acquire() - # try: - # - # finally: - # lock.release() - yield from self.acquire() - return _ContextManager(self) - - if compat.PY35: - - def __await__(self): - # To make "with await lock" work. - yield from self.acquire() - return _ContextManager(self) - - @coroutine - def __aenter__(self): - yield from self.acquire() - # We have no use for the "as ..." clause in the with - # statement for locks. - return None - - @coroutine - def __aexit__(self, exc_type, exc, tb): - self.release() - - -class Lock(_ContextManagerMixin): +class Lock(_ContextManagerMixin, mixins._LoopBoundMixin): """Primitive lock objects. A primitive lock is a synchronization primitive that is not owned @@ -108,16 +42,16 @@ class Lock(_ContextManagerMixin): release() call resets the state to unlocked; first coroutine which is blocked in acquire() is being processed. - acquire() is a coroutine and should be called with 'yield from'. + acquire() is a coroutine and should be called with 'await'. - Locks also support the context management protocol. '(yield from lock)' - should be used as the context manager expression. + Locks also support the asynchronous context management protocol. + 'async with lock' statement should be used. Usage: lock = Lock() ... - yield from lock + await lock.acquire() try: ... finally: @@ -127,57 +61,65 @@ class Lock(_ContextManagerMixin): lock = Lock() ... - with (yield from lock): + async with lock: ... Lock objects can be tested for locking state: if not lock.locked(): - yield from lock + await lock.acquire() else: # lock is acquired ... """ - def __init__(self, *, loop=None): - self._waiters = collections.deque() + def __init__(self): + self._waiters = None self._locked = False - if loop is not None: - self._loop = loop - else: - self._loop = events.get_event_loop() def __repr__(self): res = super().__repr__() extra = 'locked' if self._locked else 'unlocked' if self._waiters: - extra = '{},waiters:{}'.format(extra, len(self._waiters)) - return '<{} [{}]>'.format(res[1:-1], extra) + extra = f'{extra}, waiters:{len(self._waiters)}' + return f'<{res[1:-1]} [{extra}]>' def locked(self): """Return True if lock is acquired.""" return self._locked - @coroutine - def acquire(self): + async def acquire(self): """Acquire a lock. This method blocks until the lock is unlocked, then sets it to locked and returns True. """ - if not self._locked and all(w.cancelled() for w in self._waiters): + if (not self._locked and (self._waiters is None or + all(w.cancelled() for w in self._waiters))): self._locked = True return True - fut = self._loop.create_future() + if self._waiters is None: + self._waiters = collections.deque() + fut = self._get_loop().create_future() self._waiters.append(fut) + + # Finally block should be called before the CancelledError + # handling as we don't want CancelledError to call + # _wake_up_first() and attempt to wake up itself. try: - yield from fut - self._locked = True - return True - finally: - self._waiters.remove(fut) + try: + await fut + finally: + self._waiters.remove(fut) + except exceptions.CancelledError: + if not self._locked: + self._wake_up_first() + raise + + self._locked = True + return True def release(self): """Release a lock. @@ -192,16 +134,27 @@ def release(self): """ if self._locked: self._locked = False - # Wake up the first waiter who isn't cancelled. - for fut in self._waiters: - if not fut.done(): - fut.set_result(True) - break + self._wake_up_first() else: raise RuntimeError('Lock is not acquired.') + def _wake_up_first(self): + """Wake up the first waiter if it isn't done.""" + if not self._waiters: + return + try: + fut = next(iter(self._waiters)) + except StopIteration: + return + + # .done() necessarily means that a waiter will wake up later on and + # either take the lock, or, if it was cancelled and lock wasn't + # taken already, will hit this again and wake up a new waiter. + if not fut.done(): + fut.set_result(True) + -class Event: +class Event(mixins._LoopBoundMixin): """Asynchronous equivalent to threading.Event. Class implementing event objects. An event manages a flag that can be set @@ -210,20 +163,16 @@ class Event: false. """ - def __init__(self, *, loop=None): + def __init__(self): self._waiters = collections.deque() self._value = False - if loop is not None: - self._loop = loop - else: - self._loop = events.get_event_loop() def __repr__(self): res = super().__repr__() extra = 'set' if self._value else 'unset' if self._waiters: - extra = '{},waiters:{}'.format(extra, len(self._waiters)) - return '<{} [{}]>'.format(res[1:-1], extra) + extra = f'{extra}, waiters:{len(self._waiters)}' + return f'<{res[1:-1]} [{extra}]>' def is_set(self): """Return True if and only if the internal flag is true.""" @@ -247,8 +196,7 @@ def clear(self): to true again.""" self._value = False - @coroutine - def wait(self): + async def wait(self): """Block until the internal flag is true. If the internal flag is true on entry, return True @@ -258,16 +206,16 @@ def wait(self): if self._value: return True - fut = self._loop.create_future() + fut = self._get_loop().create_future() self._waiters.append(fut) try: - yield from fut + await fut return True finally: self._waiters.remove(fut) -class Condition(_ContextManagerMixin): +class Condition(_ContextManagerMixin, mixins._LoopBoundMixin): """Asynchronous equivalent to threading.Condition. This class implements condition variable objects. A condition variable @@ -277,16 +225,9 @@ class Condition(_ContextManagerMixin): A new Lock object is created and used as the underlying lock. """ - def __init__(self, lock=None, *, loop=None): - if loop is not None: - self._loop = loop - else: - self._loop = events.get_event_loop() - + def __init__(self, lock=None): if lock is None: - lock = Lock(loop=self._loop) - elif lock._loop is not self._loop: - raise ValueError("loop argument must agree with lock") + lock = Lock() self._lock = lock # Export the lock's locked(), acquire() and release() methods. @@ -300,11 +241,10 @@ def __repr__(self): res = super().__repr__() extra = 'locked' if self.locked() else 'unlocked' if self._waiters: - extra = '{},waiters:{}'.format(extra, len(self._waiters)) - return '<{} [{}]>'.format(res[1:-1], extra) + extra = f'{extra}, waiters:{len(self._waiters)}' + return f'<{res[1:-1]} [{extra}]>' - @coroutine - def wait(self): + async def wait(self): """Wait until notified. If the calling coroutine has not acquired the lock when this @@ -320,25 +260,28 @@ def wait(self): self.release() try: - fut = self._loop.create_future() + fut = self._get_loop().create_future() self._waiters.append(fut) try: - yield from fut + await fut return True finally: self._waiters.remove(fut) finally: # Must reacquire lock even if wait is cancelled + cancelled = False while True: try: - yield from self.acquire() + await self.acquire() break - except futures.CancelledError: - pass + except exceptions.CancelledError: + cancelled = True + + if cancelled: + raise exceptions.CancelledError - @coroutine - def wait_for(self, predicate): + async def wait_for(self, predicate): """Wait until a predicate becomes true. The predicate should be a callable which result will be @@ -347,7 +290,7 @@ def wait_for(self, predicate): """ result = predicate() while not result: - yield from self.wait() + await self.wait() result = predicate() return result @@ -384,7 +327,7 @@ def notify_all(self): self.notify(len(self._waiters)) -class Semaphore(_ContextManagerMixin): +class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin): """A Semaphore implementation. A semaphore manages an internal counter which is decremented by each @@ -399,37 +342,25 @@ class Semaphore(_ContextManagerMixin): ValueError is raised. """ - def __init__(self, value=1, *, loop=None): + def __init__(self, value=1): if value < 0: raise ValueError("Semaphore initial value must be >= 0") + self._waiters = None self._value = value - self._waiters = collections.deque() - if loop is not None: - self._loop = loop - else: - self._loop = events.get_event_loop() def __repr__(self): res = super().__repr__() - extra = 'locked' if self.locked() else 'unlocked,value:{}'.format( - self._value) + extra = 'locked' if self.locked() else f'unlocked, value:{self._value}' if self._waiters: - extra = '{},waiters:{}'.format(extra, len(self._waiters)) - return '<{} [{}]>'.format(res[1:-1], extra) - - def _wake_up_next(self): - while self._waiters: - waiter = self._waiters.popleft() - if not waiter.done(): - waiter.set_result(None) - return + extra = f'{extra}, waiters:{len(self._waiters)}' + return f'<{res[1:-1]} [{extra}]>' def locked(self): - """Returns True if semaphore can not be acquired immediately.""" - return self._value == 0 + """Returns True if semaphore cannot be acquired immediately.""" + return self._value == 0 or ( + any(not w.cancelled() for w in (self._waiters or ()))) - @coroutine - def acquire(self): + async def acquire(self): """Acquire a semaphore. If the internal counter is larger than zero on entry, @@ -438,28 +369,53 @@ def acquire(self): called release() to make it larger than 0, and then return True. """ - while self._value <= 0: - fut = self._loop.create_future() - self._waiters.append(fut) + if not self.locked(): + self._value -= 1 + return True + + if self._waiters is None: + self._waiters = collections.deque() + fut = self._get_loop().create_future() + self._waiters.append(fut) + + # Finally block should be called before the CancelledError + # handling as we don't want CancelledError to call + # _wake_up_first() and attempt to wake up itself. + try: try: - yield from fut - except: - # See the similar code in Queue.get. - fut.cancel() - if self._value > 0 and not fut.cancelled(): - self._wake_up_next() - raise - self._value -= 1 + await fut + finally: + self._waiters.remove(fut) + except exceptions.CancelledError: + if not fut.cancelled(): + self._value += 1 + self._wake_up_next() + raise + + if self._value > 0: + self._wake_up_next() return True def release(self): """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to become larger than zero again, wake up that coroutine. """ self._value += 1 self._wake_up_next() + def _wake_up_next(self): + """Wake up the first waiter that isn't done.""" + if not self._waiters: + return + + for fut in self._waiters: + if not fut.done(): + self._value -= 1 + fut.set_result(True) + return + class BoundedSemaphore(Semaphore): """A bounded semaphore implementation. @@ -468,11 +424,163 @@ class BoundedSemaphore(Semaphore): above the initial value. """ - def __init__(self, value=1, *, loop=None): + def __init__(self, value=1): self._bound_value = value - super().__init__(value, loop=loop) + super().__init__(value) def release(self): if self._value >= self._bound_value: raise ValueError('BoundedSemaphore released too many times') super().release() + + + +class _BarrierState(enum.Enum): + FILLING = 'filling' + DRAINING = 'draining' + RESETTING = 'resetting' + BROKEN = 'broken' + + +class Barrier(mixins._LoopBoundMixin): + """Asyncio equivalent to threading.Barrier + + Implements a Barrier primitive. + Useful for synchronizing a fixed number of tasks at known synchronization + points. Tasks block on 'wait()' and are simultaneously awoken once they + have all made their call. + """ + + def __init__(self, parties): + """Create a barrier, initialised to 'parties' tasks.""" + if parties < 1: + raise ValueError('parties must be > 0') + + self._cond = Condition() # notify all tasks when state changes + + self._parties = parties + self._state = _BarrierState.FILLING + self._count = 0 # count tasks in Barrier + + def __repr__(self): + res = super().__repr__() + extra = f'{self._state.value}' + if not self.broken: + extra += f', waiters:{self.n_waiting}/{self.parties}' + return f'<{res[1:-1]} [{extra}]>' + + async def __aenter__(self): + # wait for the barrier reaches the parties number + # when start draining release and return index of waited task + return await self.wait() + + async def __aexit__(self, *args): + pass + + async def wait(self): + """Wait for the barrier. + + When the specified number of tasks have started waiting, they are all + simultaneously awoken. + Returns an unique and individual index number from 0 to 'parties-1'. + """ + async with self._cond: + await self._block() # Block while the barrier drains or resets. + try: + index = self._count + self._count += 1 + if index + 1 == self._parties: + # We release the barrier + await self._release() + else: + await self._wait() + return index + finally: + self._count -= 1 + # Wake up any tasks waiting for barrier to drain. + self._exit() + + async def _block(self): + # Block until the barrier is ready for us, + # or raise an exception if it is broken. + # + # It is draining or resetting, wait until done + # unless a CancelledError occurs + await self._cond.wait_for( + lambda: self._state not in ( + _BarrierState.DRAINING, _BarrierState.RESETTING + ) + ) + + # see if the barrier is in a broken state + if self._state is _BarrierState.BROKEN: + raise exceptions.BrokenBarrierError("Barrier aborted") + + async def _release(self): + # Release the tasks waiting in the barrier. + + # Enter draining state. + # Next waiting tasks will be blocked until the end of draining. + self._state = _BarrierState.DRAINING + self._cond.notify_all() + + async def _wait(self): + # Wait in the barrier until we are released. Raise an exception + # if the barrier is reset or broken. + + # wait for end of filling + # unless a CancelledError occurs + await self._cond.wait_for(lambda: self._state is not _BarrierState.FILLING) + + if self._state in (_BarrierState.BROKEN, _BarrierState.RESETTING): + raise exceptions.BrokenBarrierError("Abort or reset of barrier") + + def _exit(self): + # If we are the last tasks to exit the barrier, signal any tasks + # waiting for the barrier to drain. + if self._count == 0: + if self._state in (_BarrierState.RESETTING, _BarrierState.DRAINING): + self._state = _BarrierState.FILLING + self._cond.notify_all() + + async def reset(self): + """Reset the barrier to the initial state. + + Any tasks currently waiting will get the BrokenBarrier exception + raised. + """ + async with self._cond: + if self._count > 0: + if self._state is not _BarrierState.RESETTING: + #reset the barrier, waking up tasks + self._state = _BarrierState.RESETTING + else: + self._state = _BarrierState.FILLING + self._cond.notify_all() + + async def abort(self): + """Place the barrier into a 'broken' state. + + Useful in case of error. Any currently waiting tasks and tasks + attempting to 'wait()' will have BrokenBarrierError raised. + """ + async with self._cond: + self._state = _BarrierState.BROKEN + self._cond.notify_all() + + @property + def parties(self): + """Return the number of tasks required to trip the barrier.""" + return self._parties + + @property + def n_waiting(self): + """Return the number of tasks currently waiting at the barrier.""" + if self._state is _BarrierState.FILLING: + return self._count + return 0 + + @property + def broken(self): + """Return True if the barrier is in a broken state.""" + return self._state is _BarrierState.BROKEN diff --git a/Lib/asyncio/mixins.py b/Lib/asyncio/mixins.py new file mode 100644 index 0000000000..c6bf97329e --- /dev/null +++ b/Lib/asyncio/mixins.py @@ -0,0 +1,21 @@ +"""Event loop mixins.""" + +import threading +from . import events + +_global_lock = threading.Lock() + + +class _LoopBoundMixin: + _loop = None + + def _get_loop(self): + loop = events._get_running_loop() + + if self._loop is None: + with _global_lock: + if self._loop is None: + self._loop = loop + if loop is not self._loop: + raise RuntimeError(f'{self!r} is bound to a different event loop') + return loop diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index ff12877fae..1e2a730cf3 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -4,20 +4,45 @@ proactor is only implemented on Windows with IOCP. """ -__all__ = ['BaseProactorEventLoop'] +__all__ = 'BaseProactorEventLoop', +import io +import os import socket import warnings +import signal +import threading +import collections from . import base_events -from . import compat from . import constants from . import futures +from . import exceptions +from . import protocols from . import sslproto from . import transports +from . import trsock from .log import logger +def _set_socket_extra(transport, sock): + transport._extra['socket'] = trsock.TransportSocket(sock) + + try: + transport._extra['sockname'] = sock.getsockname() + except socket.error: + if transport._loop.get_debug(): + logger.warning( + "getsockname() failed on %r", sock, exc_info=True) + + if 'peername' not in transport._extra: + try: + transport._extra['peername'] = sock.getpeername() + except socket.error: + # UDP sockets may not have a peer name + transport._extra['peername'] = None + + class _ProactorBasePipeTransport(transports._FlowControlMixin, transports.BaseTransport): """Base class for pipe and socket transports.""" @@ -27,7 +52,7 @@ def __init__(self, loop, sock, protocol, waiter=None, super().__init__(extra, loop) self._set_extra(sock) self._sock = sock - self._protocol = protocol + self.set_protocol(protocol) self._server = server self._buffer = None # None or bytearray. self._read_fut = None @@ -35,6 +60,7 @@ def __init__(self, loop, sock, protocol, waiter=None, self._pending_write = 0 self._conn_lost = 0 self._closing = False # Set when close() called. + self._called_connection_lost = False self._eof_written = False if self._server is not None: self._server._attach() @@ -51,17 +77,16 @@ def __repr__(self): elif self._closing: info.append('closing') if self._sock is not None: - info.append('fd=%s' % self._sock.fileno()) + info.append(f'fd={self._sock.fileno()}') if self._read_fut is not None: - info.append('read=%s' % self._read_fut) + info.append(f'read={self._read_fut!r}') if self._write_fut is not None: - info.append("write=%r" % self._write_fut) + info.append(f'write={self._write_fut!r}') if self._buffer: - bufsize = len(self._buffer) - info.append('write_bufsize=%s' % bufsize) + info.append(f'write_bufsize={len(self._buffer)}') if self._eof_written: info.append('EOF written') - return '<%s>' % ' '.join(info) + return '<{}>'.format(' '.join(info)) def _set_extra(self, sock): self._extra['pipe'] = sock @@ -86,31 +111,33 @@ def close(self): self._read_fut.cancel() self._read_fut = None - # On Python 3.3 and older, objects with a destructor part of a reference - # cycle are never destroyed. It's not more the case on Python 3.4 thanks - # to the PEP 442. - if compat.PY34: - def __del__(self): - if self._sock is not None: - warnings.warn("unclosed transport %r" % self, ResourceWarning, - source=self) - self.close() + def __del__(self, _warn=warnings.warn): + if self._sock is not None: + _warn(f"unclosed transport {self!r}", ResourceWarning, source=self) + self._sock.close() def _fatal_error(self, exc, message='Fatal error on pipe transport'): - if isinstance(exc, base_events._FATAL_ERROR_IGNORE): - if self._loop.get_debug(): - logger.debug("%r: %s", self, message, exc_info=True) - else: - self._loop.call_exception_handler({ - 'message': message, - 'exception': exc, - 'transport': self, - 'protocol': self._protocol, - }) - self._force_close(exc) + try: + if isinstance(exc, OSError): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: + self._loop.call_exception_handler({ + 'message': message, + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + finally: + self._force_close(exc) def _force_close(self, exc): - if self._closing: + if self._empty_waiter is not None and not self._empty_waiter.done(): + if exc is None: + self._empty_waiter.set_result(None) + else: + self._empty_waiter.set_exception(exc) + if self._closing and self._called_connection_lost: return self._closing = True self._conn_lost += 1 @@ -125,6 +152,8 @@ def _force_close(self, exc): self._loop.call_soon(self._call_connection_lost, exc) def _call_connection_lost(self, exc): + if self._called_connection_lost: + return try: self._protocol.connection_lost(exc) finally: @@ -132,7 +161,7 @@ def _call_connection_lost(self, exc): # end then it may fail with ERROR_NETNAME_DELETED if we # just close our end. First calling shutdown() seems to # cure it, but maybe using DisconnectEx() would be better. - if hasattr(self._sock, 'shutdown'): + if hasattr(self._sock, 'shutdown') and self._sock.fileno() != -1: self._sock.shutdown(socket.SHUT_RDWR) self._sock.close() self._sock = None @@ -140,6 +169,7 @@ def _call_connection_lost(self, exc): if server is not None: server._detach() self._server = None + self._called_connection_lost = True def get_write_buffer_size(self): size = self._pending_write @@ -153,53 +183,127 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, """Transport for read pipes.""" def __init__(self, loop, sock, protocol, waiter=None, - extra=None, server=None): + extra=None, server=None, buffer_size=65536): + self._pending_data_length = -1 + self._paused = True super().__init__(loop, sock, protocol, waiter, extra, server) - self._paused = False + + self._data = bytearray(buffer_size) self._loop.call_soon(self._loop_reading) + self._paused = False + + def is_reading(self): + return not self._paused and not self._closing def pause_reading(self): - if self._closing: - raise RuntimeError('Cannot pause_reading() when closing') - if self._paused: - raise RuntimeError('Already paused') + if self._closing or self._paused: + return self._paused = True + + # bpo-33694: Don't cancel self._read_fut because cancelling an + # overlapped WSASend() loss silently data with the current proactor + # implementation. + # + # If CancelIoEx() fails with ERROR_NOT_FOUND, it means that WSASend() + # completed (even if HasOverlappedIoCompleted() returns 0), but + # Overlapped.cancel() currently silently ignores the ERROR_NOT_FOUND + # error. Once the overlapped is ignored, the IOCP loop will ignores the + # completion I/O event and so not read the result of the overlapped + # WSARecv(). + if self._loop.get_debug(): logger.debug("%r pauses reading", self) def resume_reading(self): - if not self._paused: - raise RuntimeError('Not paused') - self._paused = False - if self._closing: + if self._closing or not self._paused: return - self._loop.call_soon(self._loop_reading, self._read_fut) + + self._paused = False + if self._read_fut is None: + self._loop.call_soon(self._loop_reading, None) + + length = self._pending_data_length + self._pending_data_length = -1 + if length > -1: + # Call the protocol method after calling _loop_reading(), + # since the protocol can decide to pause reading again. + self._loop.call_soon(self._data_received, self._data[:length], length) + if self._loop.get_debug(): logger.debug("%r resumes reading", self) - def _loop_reading(self, fut=None): + def _eof_received(self): + if self._loop.get_debug(): + logger.debug("%r received EOF", self) + + try: + keep_open = self._protocol.eof_received() + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._fatal_error( + exc, 'Fatal error: protocol.eof_received() call failed.') + return + + if not keep_open: + self.close() + + def _data_received(self, data, length): if self._paused: + # Don't call any protocol method while reading is paused. + # The protocol will be called on resume_reading(). + assert self._pending_data_length == -1 + self._pending_data_length = length return - data = None + if length == 0: + self._eof_received() + return + + if isinstance(self._protocol, protocols.BufferedProtocol): + try: + protocols._feed_data_to_buffered_proto(self._protocol, data) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._fatal_error(exc, + 'Fatal error: protocol.buffer_updated() ' + 'call failed.') + return + else: + self._protocol.data_received(data) + + def _loop_reading(self, fut=None): + length = -1 + data = None try: if fut is not None: assert self._read_fut is fut or (self._read_fut is None and self._closing) self._read_fut = None - data = fut.result() # deliver data later in "finally" clause + if fut.done(): + # deliver data later in "finally" clause + length = fut.result() + if length == 0: + # we got end-of-file so no need to reschedule a new read + return + + # It's a new slice so make it immutable so protocols upstream don't have problems + data = bytes(memoryview(self._data)[:length]) + else: + # the future will be replaced by next proactor.recv call + fut.cancel() if self._closing: # since close() has been called we ignore any read data - data = None return - if data == b'': - # we got end-of-file so no need to reschedule a new read - return + # bpo-33694: buffer_updated() has currently no fast path because of + # a data loss issue caused by overlapped WSASend() cancellation. - # reschedule a new read - self._read_fut = self._loop._proactor.recv(self._sock, 4096) + if not self._paused: + # reschedule a new read + self._read_fut = self._loop._proactor.recv_into(self._sock, self._data) except ConnectionAbortedError as exc: if not self._closing: self._fatal_error(exc, 'Fatal read error on pipe transport') @@ -210,32 +314,36 @@ def _loop_reading(self, fut=None): self._force_close(exc) except OSError as exc: self._fatal_error(exc, 'Fatal read error on pipe transport') - except futures.CancelledError: + except exceptions.CancelledError: if not self._closing: raise else: - self._read_fut.add_done_callback(self._loop_reading) + if not self._paused: + self._read_fut.add_done_callback(self._loop_reading) finally: - if data: - self._protocol.data_received(data) - elif data is not None: - if self._loop.get_debug(): - logger.debug("%r received EOF", self) - keep_open = self._protocol.eof_received() - if not keep_open: - self.close() + if length > -1: + self._data_received(data, length) class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, transports.WriteTransport): """Transport for write pipes.""" + _start_tls_compatible = True + + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + self._empty_waiter = None + def write(self, data): if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError('data argument must be byte-ish (%r)', - type(data)) + raise TypeError( + f"data argument must be a bytes-like object, " + f"not {type(data).__name__}") if self._eof_written: raise RuntimeError('write_eof() already called') + if self._empty_waiter is not None: + raise RuntimeError('unable to write; sendfile is in progress') if not data: return @@ -267,6 +375,10 @@ def write(self, data): def _loop_writing(self, f=None, data=None): try: + if f is not None and self._write_fut is None and self._closing: + # XXX most likely self._force_close() has been called, and + # it has set self._write_fut to None. + return assert f is self._write_fut self._write_fut = None self._pending_write = 0 @@ -295,6 +407,8 @@ def _loop_writing(self, f=None, data=None): self._maybe_pause_protocol() else: self._write_fut.add_done_callback(self._loop_writing) + if self._empty_waiter is not None and self._write_fut is None: + self._empty_waiter.set_result(None) except ConnectionResetError as exc: self._force_close(exc) except OSError as exc: @@ -309,6 +423,17 @@ def write_eof(self): def abort(self): self._force_close(None) + def _make_empty_waiter(self): + if self._empty_waiter is not None: + raise RuntimeError("Empty waiter is already set") + self._empty_waiter = self._loop.create_future() + if self._write_fut is None: + self._empty_waiter.set_result(None) + return self._empty_waiter + + def _reset_empty_waiter(self): + self._empty_waiter = None + class _ProactorWritePipeTransport(_ProactorBaseWritePipeTransport): def __init__(self, *args, **kw): @@ -332,6 +457,138 @@ def _pipe_closed(self, fut): self.close() +class _ProactorDatagramTransport(_ProactorBasePipeTransport, + transports.DatagramTransport): + max_size = 256 * 1024 + def __init__(self, loop, sock, protocol, address=None, + waiter=None, extra=None): + self._address = address + self._empty_waiter = None + self._buffer_size = 0 + # We don't need to call _protocol.connection_made() since our base + # constructor does it for us. + super().__init__(loop, sock, protocol, waiter=waiter, extra=extra) + + # The base constructor sets _buffer = None, so we set it here + self._buffer = collections.deque() + self._loop.call_soon(self._loop_reading) + + def _set_extra(self, sock): + _set_socket_extra(self, sock) + + def get_write_buffer_size(self): + return self._buffer_size + + def abort(self): + self._force_close(None) + + def sendto(self, data, addr=None): + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be bytes-like object (%r)', + type(data)) + + if not data: + return + + if self._address is not None and addr not in (None, self._address): + raise ValueError( + f'Invalid address: must be None or {self._address}') + + if self._conn_lost and self._address: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('socket.sendto() raised exception.') + self._conn_lost += 1 + return + + # Ensure that what we buffer is immutable. + self._buffer.append((bytes(data), addr)) + self._buffer_size += len(data) + + if self._write_fut is None: + # No current write operations are active, kick one off + self._loop_writing() + # else: A write operation is already kicked off + + self._maybe_pause_protocol() + + def _loop_writing(self, fut=None): + try: + if self._conn_lost: + return + + assert fut is self._write_fut + self._write_fut = None + if fut: + # We are in a _loop_writing() done callback, get the result + fut.result() + + if not self._buffer or (self._conn_lost and self._address): + # The connection has been closed + if self._closing: + self._loop.call_soon(self._call_connection_lost, None) + return + + data, addr = self._buffer.popleft() + self._buffer_size -= len(data) + if self._address is not None: + self._write_fut = self._loop._proactor.send(self._sock, + data) + else: + self._write_fut = self._loop._proactor.sendto(self._sock, + data, + addr=addr) + except OSError as exc: + self._protocol.error_received(exc) + except Exception as exc: + self._fatal_error(exc, 'Fatal write error on datagram transport') + else: + self._write_fut.add_done_callback(self._loop_writing) + self._maybe_resume_protocol() + + def _loop_reading(self, fut=None): + data = None + try: + if self._conn_lost: + return + + assert self._read_fut is fut or (self._read_fut is None and + self._closing) + + self._read_fut = None + if fut is not None: + res = fut.result() + + if self._closing: + # since close() has been called we ignore any read data + data = None + return + + if self._address is not None: + data, addr = res, self._address + else: + data, addr = res + + if self._conn_lost: + return + if self._address is not None: + self._read_fut = self._loop._proactor.recv(self._sock, + self.max_size) + else: + self._read_fut = self._loop._proactor.recvfrom(self._sock, + self.max_size) + except OSError as exc: + self._protocol.error_received(exc) + except exceptions.CancelledError: + if not self._closing: + raise + else: + if self._read_fut is not None: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.datagram_received(data, addr) + + class _ProactorDuplexPipeTransport(_ProactorReadPipeTransport, _ProactorBaseWritePipeTransport, transports.Transport): @@ -349,21 +606,15 @@ class _ProactorSocketTransport(_ProactorReadPipeTransport, transports.Transport): """Transport for connected sockets.""" + _sendfile_compatible = constants._SendfileMode.TRY_NATIVE + + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): + super().__init__(loop, sock, protocol, waiter, extra, server) + base_events._set_nodelay(sock) + def _set_extra(self, sock): - self._extra['socket'] = sock - try: - self._extra['sockname'] = sock.getsockname() - except (socket.error, AttributeError): - if self._loop.get_debug(): - logger.warning("getsockname() failed on %r", - sock, exc_info=True) - if 'peername' not in self._extra: - try: - self._extra['peername'] = sock.getpeername() - except (socket.error, AttributeError): - if self._loop.get_debug(): - logger.warning("getpeername() failed on %r", - sock, exc_info=True) + _set_socket_extra(self, sock) def can_write_eof(self): return True @@ -387,26 +638,35 @@ def __init__(self, proactor): self._accept_futures = {} # socket file descriptor => Future proactor.set_loop(self) self._make_self_pipe() + if threading.current_thread() is threading.main_thread(): + # wakeup fd can only be installed to a file descriptor from the main thread + signal.set_wakeup_fd(self._csock.fileno()) def _make_socket_transport(self, sock, protocol, waiter=None, extra=None, server=None): return _ProactorSocketTransport(self, sock, protocol, waiter, extra, server) - def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, - *, server_side=False, server_hostname=None, - extra=None, server=None): - if not sslproto._is_sslproto_available(): - raise NotImplementedError("Proactor event loop requires Python 3.5" - " or newer (ssl.MemoryBIO) to support " - "SSL") - - ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter, - server_side, server_hostname) + def _make_ssl_transport( + self, rawsock, protocol, sslcontext, waiter=None, + *, server_side=False, server_hostname=None, + extra=None, server=None, + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): + ssl_protocol = sslproto.SSLProtocol( + self, protocol, sslcontext, waiter, + server_side, server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) _ProactorSocketTransport(self, rawsock, ssl_protocol, extra=extra, server=server) return ssl_protocol._app_transport + def _make_datagram_transport(self, sock, protocol, + address=None, waiter=None, extra=None): + return _ProactorDatagramTransport(self, sock, protocol, address, + waiter, extra) + def _make_duplex_pipe_transport(self, sock, protocol, waiter=None, extra=None): return _ProactorDuplexPipeTransport(self, @@ -428,6 +688,8 @@ def close(self): if self.is_closed(): return + if threading.current_thread() is threading.main_thread(): + signal.set_wakeup_fd(-1) # Call these methods before closing the event loop (before calling # BaseEventLoop.close), because they can schedule callbacks with # call_soon(), which is forbidden when the event loop is closed. @@ -440,20 +702,73 @@ def close(self): # Close the event loop super().close() - def sock_recv(self, sock, n): - return self._proactor.recv(sock, n) + async def sock_recv(self, sock, n): + return await self._proactor.recv(sock, n) - def sock_sendall(self, sock, data): - return self._proactor.send(sock, data) + async def sock_recv_into(self, sock, buf): + return await self._proactor.recv_into(sock, buf) - def sock_connect(self, sock, address): - return self._proactor.connect(sock, address) + async def sock_recvfrom(self, sock, bufsize): + return await self._proactor.recvfrom(sock, bufsize) - def sock_accept(self, sock): - return self._proactor.accept(sock) + async def sock_recvfrom_into(self, sock, buf, nbytes=0): + if not nbytes: + nbytes = len(buf) - def _socketpair(self): - raise NotImplementedError + return await self._proactor.recvfrom_into(sock, buf, nbytes) + + async def sock_sendall(self, sock, data): + return await self._proactor.send(sock, data) + + async def sock_sendto(self, sock, data, address): + return await self._proactor.sendto(sock, data, 0, address) + + async def sock_connect(self, sock, address): + return await self._proactor.connect(sock, address) + + async def sock_accept(self, sock): + return await self._proactor.accept(sock) + + async def _sock_sendfile_native(self, sock, file, offset, count): + try: + fileno = file.fileno() + except (AttributeError, io.UnsupportedOperation) as err: + raise exceptions.SendfileNotAvailableError("not a regular file") + try: + fsize = os.fstat(fileno).st_size + except OSError: + raise exceptions.SendfileNotAvailableError("not a regular file") + blocksize = count if count else fsize + if not blocksize: + return 0 # empty file + + blocksize = min(blocksize, 0xffff_ffff) + end_pos = min(offset + count, fsize) if count else fsize + offset = min(offset, fsize) + total_sent = 0 + try: + while True: + blocksize = min(end_pos - offset, blocksize) + if blocksize <= 0: + return total_sent + await self._proactor.sendfile(sock, file, offset, blocksize) + offset += blocksize + total_sent += blocksize + finally: + if total_sent > 0: + file.seek(offset) + + async def _sendfile_native(self, transp, file, offset, count): + resume_reading = transp.is_reading() + transp.pause_reading() + await transp._make_empty_waiter() + try: + return await self.sock_sendfile(transp._sock, file, offset, count, + fallback=False) + finally: + transp._reset_empty_waiter() + if resume_reading: + transp.resume_reading() def _close_self_pipe(self): if self._self_reading_future is not None: @@ -467,21 +782,30 @@ def _close_self_pipe(self): def _make_self_pipe(self): # A self-socket, really. :-) - self._ssock, self._csock = self._socketpair() + self._ssock, self._csock = socket.socketpair() self._ssock.setblocking(False) self._csock.setblocking(False) self._internal_fds += 1 - self.call_soon(self._loop_self_reading) def _loop_self_reading(self, f=None): try: if f is not None: f.result() # may raise + if self._self_reading_future is not f: + # When we scheduled this Future, we assigned it to + # _self_reading_future. If it's not there now, something has + # tried to cancel the loop while this callback was still in the + # queue (see windows_events.ProactorEventLoop.run_forever). In + # that case stop here instead of continuing to schedule a new + # iteration. + return f = self._proactor.recv(self._ssock, 4096) - except futures.CancelledError: + except exceptions.CancelledError: # _close_self_pipe() has been called, stop waiting for data return - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: self.call_exception_handler({ 'message': 'Error on reading from the event loop self pipe', 'exception': exc, @@ -492,10 +816,27 @@ def _loop_self_reading(self, f=None): f.add_done_callback(self._loop_self_reading) def _write_to_self(self): - self._csock.send(b'\0') + # This may be called from a different thread, possibly after + # _close_self_pipe() has been called or even while it is + # running. Guard for self._csock being None or closed. When + # a socket is closed, send() raises OSError (with errno set to + # EBADF, but let's not rely on the exact error code). + csock = self._csock + if csock is None: + return + + try: + csock.send(b'\0') + except OSError: + if self._debug: + logger.debug("Fail to write a null byte into the " + "self-pipe socket", + exc_info=True) def _start_serving(self, protocol_factory, sock, - sslcontext=None, server=None, backlog=100): + sslcontext=None, server=None, backlog=100, + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): def loop(f=None): try: @@ -508,7 +849,9 @@ def loop(f=None): if sslcontext is not None: self._make_ssl_transport( conn, protocol, sslcontext, server_side=True, - extra={'peername': addr}, server=server) + extra={'peername': addr}, server=server, + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) else: self._make_socket_transport( conn, protocol, @@ -521,13 +864,13 @@ def loop(f=None): self.call_exception_handler({ 'message': 'Accept failed on a socket', 'exception': exc, - 'socket': sock, + 'socket': trsock.TransportSocket(sock), }) sock.close() elif self._debug: logger.debug("Accept failed on socket %r", sock, exc_info=True) - except futures.CancelledError: + except exceptions.CancelledError: sock.close() else: self._accept_futures[sock.fileno()] = f @@ -545,6 +888,8 @@ def _stop_accept_futures(self): self._accept_futures.clear() def _stop_serving(self, sock): - self._stop_accept_futures() + future = self._accept_futures.pop(sock.fileno(), None) + if future: + future.cancel() self._proactor._stop_serving(sock) sock.close() diff --git a/Lib/asyncio/protocols.py b/Lib/asyncio/protocols.py index 80fcac9a82..09987b164c 100644 --- a/Lib/asyncio/protocols.py +++ b/Lib/asyncio/protocols.py @@ -1,7 +1,9 @@ -"""Abstract Protocol class.""" +"""Abstract Protocol base classes.""" -__all__ = ['BaseProtocol', 'Protocol', 'DatagramProtocol', - 'SubprocessProtocol'] +__all__ = ( + 'BaseProtocol', 'Protocol', 'DatagramProtocol', + 'SubprocessProtocol', 'BufferedProtocol', +) class BaseProtocol: @@ -14,6 +16,8 @@ class BaseProtocol: write-only transport like write pipe """ + __slots__ = () + def connection_made(self, transport): """Called when a connection is made. @@ -85,6 +89,8 @@ class Protocol(BaseProtocol): * CL: connection_lost() """ + __slots__ = () + def data_received(self, data): """Called when some data is received. @@ -100,9 +106,64 @@ def eof_received(self): """ +class BufferedProtocol(BaseProtocol): + """Interface for stream protocol with manual buffer control. + + Event methods, such as `create_server` and `create_connection`, + accept factories that return protocols that implement this interface. + + The idea of BufferedProtocol is that it allows to manually allocate + and control the receive buffer. Event loops can then use the buffer + provided by the protocol to avoid unnecessary data copies. This + can result in noticeable performance improvement for protocols that + receive big amounts of data. Sophisticated protocols can allocate + the buffer only once at creation time. + + State machine of calls: + + start -> CM [-> GB [-> BU?]]* [-> ER?] -> CL -> end + + * CM: connection_made() + * GB: get_buffer() + * BU: buffer_updated() + * ER: eof_received() + * CL: connection_lost() + """ + + __slots__ = () + + def get_buffer(self, sizehint): + """Called to allocate a new receive buffer. + + *sizehint* is a recommended minimal size for the returned + buffer. When set to -1, the buffer size can be arbitrary. + + Must return an object that implements the + :ref:`buffer protocol `. + It is an error to return a zero-sized buffer. + """ + + def buffer_updated(self, nbytes): + """Called when the buffer was updated with the received data. + + *nbytes* is the total number of bytes that were written to + the buffer. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + If this returns a false value (including None), the transport + will close itself. If it returns a true value, closing the + transport is up to the protocol. + """ + + class DatagramProtocol(BaseProtocol): """Interface for datagram protocol.""" + __slots__ = () + def datagram_received(self, data, addr): """Called when some datagram is received.""" @@ -116,6 +177,8 @@ def error_received(self, exc): class SubprocessProtocol(BaseProtocol): """Interface for protocol for subprocess calls.""" + __slots__ = () + def pipe_data_received(self, fd, data): """Called when the subprocess writes data into stdout/stderr pipe. @@ -132,3 +195,22 @@ def pipe_connection_lost(self, fd, exc): def process_exited(self): """Called when subprocess has exited.""" + + +def _feed_data_to_buffered_proto(proto, data): + data_len = len(data) + while data_len: + buf = proto.get_buffer(data_len) + buf_len = len(buf) + if not buf_len: + raise RuntimeError('get_buffer() returned an empty buffer') + + if buf_len >= data_len: + buf[:data_len] = data + proto.buffer_updated(data_len) + return + else: + buf[:buf_len] = data[:buf_len] + proto.buffer_updated(buf_len) + data = data[buf_len:] + data_len = len(data) diff --git a/Lib/asyncio/queues.py b/Lib/asyncio/queues.py index e16c46ae73..a9656a6df5 100644 --- a/Lib/asyncio/queues.py +++ b/Lib/asyncio/queues.py @@ -1,35 +1,28 @@ -"""Queues""" - -__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty'] +__all__ = ('Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty') import collections import heapq +from types import GenericAlias -from . import compat -from . import events from . import locks -from .coroutines import coroutine +from . import mixins class QueueEmpty(Exception): - """Exception raised when Queue.get_nowait() is called on a Queue object - which is empty. - """ + """Raised when Queue.get_nowait() is called on an empty Queue.""" pass class QueueFull(Exception): - """Exception raised when the Queue.put_nowait() method is called on a Queue - object which is full. - """ + """Raised when the Queue.put_nowait() method is called on a full Queue.""" pass -class Queue: +class Queue(mixins._LoopBoundMixin): """A queue, useful for coordinating producer and consumer coroutines. If maxsize is less than or equal to zero, the queue size is infinite. If it - is an integer greater than 0, then "yield from put()" will block when the + is an integer greater than 0, then "await put()" will block when the queue reaches maxsize, until an item is removed by get(). Unlike the standard library Queue, you can reliably know this Queue's size @@ -37,11 +30,7 @@ class Queue: interrupted between calling qsize() and doing an operation on the Queue. """ - def __init__(self, maxsize=0, *, loop=None): - if loop is None: - self._loop = events.get_event_loop() - else: - self._loop = loop + def __init__(self, maxsize=0): self._maxsize = maxsize # Futures. @@ -49,7 +38,7 @@ def __init__(self, maxsize=0, *, loop=None): # Futures. self._putters = collections.deque() self._unfinished_tasks = 0 - self._finished = locks.Event(loop=self._loop) + self._finished = locks.Event() self._finished.set() self._init(maxsize) @@ -75,25 +64,23 @@ def _wakeup_next(self, waiters): break def __repr__(self): - return '<{} at {:#x} {}>'.format( - type(self).__name__, id(self), self._format()) + return f'<{type(self).__name__} at {id(self):#x} {self._format()}>' def __str__(self): - return '<{} {}>'.format(type(self).__name__, self._format()) + return f'<{type(self).__name__} {self._format()}>' - def __class_getitem__(cls, type): - return cls + __class_getitem__ = classmethod(GenericAlias) def _format(self): - result = 'maxsize={!r}'.format(self._maxsize) + result = f'maxsize={self._maxsize!r}' if getattr(self, '_queue', None): - result += ' _queue={!r}'.format(list(self._queue)) + result += f' _queue={list(self._queue)!r}' if self._getters: - result += ' _getters[{}]'.format(len(self._getters)) + result += f' _getters[{len(self._getters)}]' if self._putters: - result += ' _putters[{}]'.format(len(self._putters)) + result += f' _putters[{len(self._putters)}]' if self._unfinished_tasks: - result += ' tasks={}'.format(self._unfinished_tasks) + result += f' tasks={self._unfinished_tasks}' return result def qsize(self): @@ -120,22 +107,26 @@ def full(self): else: return self.qsize() >= self._maxsize - @coroutine - def put(self, item): + async def put(self, item): """Put an item into the queue. Put an item into the queue. If the queue is full, wait until a free slot is available before adding item. - - This method is a coroutine. """ while self.full(): - putter = self._loop.create_future() + putter = self._get_loop().create_future() self._putters.append(putter) try: - yield from putter + await putter except: putter.cancel() # Just in case putter is not done yet. + try: + # Clean self._putters from canceled putters. + self._putters.remove(putter) + except ValueError: + # The putter could be removed from self._putters by a + # previous get_nowait call. + pass if not self.full() and not putter.cancelled(): # We were woken up by get_nowait(), but can't take # the call. Wake up the next in line. @@ -155,21 +146,25 @@ def put_nowait(self, item): self._finished.clear() self._wakeup_next(self._getters) - @coroutine - def get(self): + async def get(self): """Remove and return an item from the queue. If queue is empty, wait until an item is available. - - This method is a coroutine. """ while self.empty(): - getter = self._loop.create_future() + getter = self._get_loop().create_future() self._getters.append(getter) try: - yield from getter + await getter except: getter.cancel() # Just in case getter is not done yet. + try: + # Clean self._getters from canceled getters. + self._getters.remove(getter) + except ValueError: + # The getter could be removed from self._getters by a + # previous put_nowait call. + pass if not self.empty() and not getter.cancelled(): # We were woken up by put_nowait(), but can't take # the call. Wake up the next in line. @@ -208,8 +203,7 @@ def task_done(self): if self._unfinished_tasks == 0: self._finished.set() - @coroutine - def join(self): + async def join(self): """Block until all items in the queue have been gotten and processed. The count of unfinished tasks goes up whenever an item is added to the @@ -218,7 +212,7 @@ def join(self): When the count of unfinished tasks drops to zero, join() unblocks. """ if self._unfinished_tasks > 0: - yield from self._finished.wait() + await self._finished.wait() class PriorityQueue(Queue): @@ -248,9 +242,3 @@ def _put(self, item): def _get(self): return self._queue.pop() - - -if not compat.PY35: - JoinableQueue = Queue - """Deprecated alias for Queue.""" - __all__.append('JoinableQueue') diff --git a/Lib/asyncio/runners.py b/Lib/asyncio/runners.py index c3a696ef57..1b89236599 100644 --- a/Lib/asyncio/runners.py +++ b/Lib/asyncio/runners.py @@ -1,16 +1,168 @@ -__all__ = ['run'] +__all__ = ('Runner', 'run') +import contextvars +import enum +import functools +import threading +import signal from . import coroutines from . import events +from . import exceptions from . import tasks +from . import constants +class _State(enum.Enum): + CREATED = "created" + INITIALIZED = "initialized" + CLOSED = "closed" -def run(main, *, debug=False): - """Run a coroutine. + +class Runner: + """A context manager that controls event loop life cycle. + + The context manager always creates a new event loop, + allows to run async functions inside it, + and properly finalizes the loop at the context manager exit. + + If debug is True, the event loop will be run in debug mode. + If loop_factory is passed, it is used for new event loop creation. + + asyncio.run(main(), debug=True) + + is a shortcut for + + with asyncio.Runner(debug=True) as runner: + runner.run(main()) + + The run() method can be called multiple times within the runner's context. + + This can be useful for interactive console (e.g. IPython), + unittest runners, console tools, -- everywhere when async code + is called from existing sync framework and where the preferred single + asyncio.run() call doesn't work. + + """ + + # Note: the class is final, it is not intended for inheritance. + + def __init__(self, *, debug=None, loop_factory=None): + self._state = _State.CREATED + self._debug = debug + self._loop_factory = loop_factory + self._loop = None + self._context = None + self._interrupt_count = 0 + self._set_event_loop = False + + def __enter__(self): + self._lazy_init() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def close(self): + """Shutdown and close event loop.""" + if self._state is not _State.INITIALIZED: + return + try: + loop = self._loop + _cancel_all_tasks(loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.run_until_complete( + loop.shutdown_default_executor(constants.THREAD_JOIN_TIMEOUT)) + finally: + if self._set_event_loop: + events.set_event_loop(None) + loop.close() + self._loop = None + self._state = _State.CLOSED + + def get_loop(self): + """Return embedded event loop.""" + self._lazy_init() + return self._loop + + def run(self, coro, *, context=None): + """Run a coroutine inside the embedded event loop.""" + if not coroutines.iscoroutine(coro): + raise ValueError("a coroutine was expected, got {!r}".format(coro)) + + if events._get_running_loop() is not None: + # fail fast with short traceback + raise RuntimeError( + "Runner.run() cannot be called from a running event loop") + + self._lazy_init() + + if context is None: + context = self._context + task = self._loop.create_task(coro, context=context) + + if (threading.current_thread() is threading.main_thread() + and signal.getsignal(signal.SIGINT) is signal.default_int_handler + ): + sigint_handler = functools.partial(self._on_sigint, main_task=task) + try: + signal.signal(signal.SIGINT, sigint_handler) + except ValueError: + # `signal.signal` may throw if `threading.main_thread` does + # not support signals (e.g. embedded interpreter with signals + # not registered - see gh-91880) + sigint_handler = None + else: + sigint_handler = None + + self._interrupt_count = 0 + try: + return self._loop.run_until_complete(task) + except exceptions.CancelledError: + if self._interrupt_count > 0: + uncancel = getattr(task, "uncancel", None) + if uncancel is not None and uncancel() == 0: + raise KeyboardInterrupt() + raise # CancelledError + finally: + if (sigint_handler is not None + and signal.getsignal(signal.SIGINT) is sigint_handler + ): + signal.signal(signal.SIGINT, signal.default_int_handler) + + def _lazy_init(self): + if self._state is _State.CLOSED: + raise RuntimeError("Runner is closed") + if self._state is _State.INITIALIZED: + return + if self._loop_factory is None: + self._loop = events.new_event_loop() + if not self._set_event_loop: + # Call set_event_loop only once to avoid calling + # attach_loop multiple times on child watchers + events.set_event_loop(self._loop) + self._set_event_loop = True + else: + self._loop = self._loop_factory() + if self._debug is not None: + self._loop.set_debug(self._debug) + self._context = contextvars.copy_context() + self._state = _State.INITIALIZED + + def _on_sigint(self, signum, frame, main_task): + self._interrupt_count += 1 + if self._interrupt_count == 1 and not main_task.done(): + main_task.cancel() + # wakeup loop if it is blocked by select() with long timeout + self._loop.call_soon_threadsafe(lambda: None) + return + raise KeyboardInterrupt() + + +def run(main, *, debug=None, loop_factory=None): + """Execute the coroutine and return the result. This function runs the passed coroutine, taking care of - managing the asyncio event loop and finalizing asynchronous - generators. + managing the asyncio event loop, finalizing asynchronous + generators and closing the default executor. This function cannot be called when another asyncio event loop is running in the same thread. @@ -21,6 +173,10 @@ def run(main, *, debug=False): It should be used as a main entry point for asyncio programs, and should ideally only be called once. + The executor is given a timeout duration of 5 minutes to shutdown. + If the executor hasn't finished within that duration, a warning is + emitted and the executor is closed. + Example: async def main(): @@ -30,24 +186,12 @@ async def main(): asyncio.run(main()) """ if events._get_running_loop() is not None: + # fail fast with short traceback raise RuntimeError( "asyncio.run() cannot be called from a running event loop") - if not coroutines.iscoroutine(main): - raise ValueError("a coroutine was expected, got {!r}".format(main)) - - loop = events.new_event_loop() - try: - events.set_event_loop(loop) - loop.set_debug(debug) - return loop.run_until_complete(main) - finally: - try: - _cancel_all_tasks(loop) - loop.run_until_complete(loop.shutdown_asyncgens()) - finally: - events.set_event_loop(None) - loop.close() + with Runner(debug=debug, loop_factory=loop_factory) as runner: + return runner.run(main) def _cancel_all_tasks(loop): @@ -58,8 +202,7 @@ def _cancel_all_tasks(loop): for task in to_cancel: task.cancel() - loop.run_until_complete( - tasks.gather(*to_cancel, loop=loop, return_exceptions=True)) + loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True)) for task in to_cancel: if task.cancelled(): diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 9dbe550b01..790711f834 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -4,11 +4,14 @@ also includes support for signal handling, see the unix_events sub-module. """ -__all__ = ['BaseSelectorEventLoop'] +__all__ = 'BaseSelectorEventLoop', import collections import errno import functools +import itertools +import os +import selectors import socket import warnings import weakref @@ -18,16 +21,23 @@ ssl = None from . import base_events -from . import compat from . import constants from . import events from . import futures -from . import selectors -from . import transports +from . import protocols from . import sslproto -from .coroutines import coroutine +from . import transports +from . import trsock from .log import logger +_HAS_SENDMSG = hasattr(socket.socket, 'sendmsg') + +if _HAS_SENDMSG: + try: + SC_IOV_MAX = os.sysconf('SC_IOV_MAX') + except OSError: + # Fallback to send + _HAS_SENDMSG = False def _test_selector_event(selector, fd, event): # Test if the selector is monitoring 'event' events @@ -40,17 +50,6 @@ def _test_selector_event(selector, fd, event): return bool(key.events & event) -if hasattr(socket, 'TCP_NODELAY'): - def _set_nodelay(sock): - if (sock.family in {socket.AF_INET, socket.AF_INET6} and - sock.type == socket.SOCK_STREAM and - sock.proto == socket.IPPROTO_TCP): - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) -else: - def _set_nodelay(sock): - pass - - class BaseSelectorEventLoop(base_events.BaseEventLoop): """Selector event loop. @@ -69,36 +68,31 @@ def __init__(self, selector=None): def _make_socket_transport(self, sock, protocol, waiter=None, *, extra=None, server=None): + self._ensure_fd_no_transport(sock) return _SelectorSocketTransport(self, sock, protocol, waiter, extra, server) - def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, - *, server_side=False, server_hostname=None, - extra=None, server=None): - if not sslproto._is_sslproto_available(): - return self._make_legacy_ssl_transport( - rawsock, protocol, sslcontext, waiter, - server_side=server_side, server_hostname=server_hostname, - extra=extra, server=server) - - ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter, - server_side, server_hostname) + def _make_ssl_transport( + self, rawsock, protocol, sslcontext, waiter=None, + *, server_side=False, server_hostname=None, + extra=None, server=None, + ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, + ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, + ): + self._ensure_fd_no_transport(rawsock) + ssl_protocol = sslproto.SSLProtocol( + self, protocol, sslcontext, waiter, + server_side, server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout + ) _SelectorSocketTransport(self, rawsock, ssl_protocol, extra=extra, server=server) return ssl_protocol._app_transport - def _make_legacy_ssl_transport(self, rawsock, protocol, sslcontext, - waiter, *, - server_side=False, server_hostname=None, - extra=None, server=None): - # Use the legacy API: SSL_write, SSL_read, etc. The legacy API is used - # on Python 3.4 and older, when ssl.MemoryBIO is not available. - return _SelectorSslTransport( - self, rawsock, protocol, sslcontext, waiter, - server_side, server_hostname, extra, server) - def _make_datagram_transport(self, sock, protocol, address=None, waiter=None, extra=None): + self._ensure_fd_no_transport(sock) return _SelectorDatagramTransport(self, sock, protocol, address, waiter, extra) @@ -113,9 +107,6 @@ def close(self): self._selector.close() self._selector = None - def _socketpair(self): - raise NotImplementedError - def _close_self_pipe(self): self._remove_reader(self._ssock.fileno()) self._ssock.close() @@ -126,7 +117,7 @@ def _close_self_pipe(self): def _make_self_pipe(self): # A self-socket, really. :-) - self._ssock, self._csock = self._socketpair() + self._ssock, self._csock = socket.socketpair() self._ssock.setblocking(False) self._csock.setblocking(False) self._internal_fds += 1 @@ -154,22 +145,30 @@ def _write_to_self(self): # a socket is closed, send() raises OSError (with errno set to # EBADF, but let's not rely on the exact error code). csock = self._csock - if csock is not None: - try: - csock.send(b'\0') - except OSError: - if self._debug: - logger.debug("Fail to write a null byte into the " - "self-pipe socket", - exc_info=True) + if csock is None: + return + + try: + csock.send(b'\0') + except OSError: + if self._debug: + logger.debug("Fail to write a null byte into the " + "self-pipe socket", + exc_info=True) def _start_serving(self, protocol_factory, sock, - sslcontext=None, server=None, backlog=100): + sslcontext=None, server=None, backlog=100, + ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, + ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT): self._add_reader(sock.fileno(), self._accept_connection, - protocol_factory, sock, sslcontext, server, backlog) - - def _accept_connection(self, protocol_factory, sock, - sslcontext=None, server=None, backlog=100): + protocol_factory, sock, sslcontext, server, backlog, + ssl_handshake_timeout, ssl_shutdown_timeout) + + def _accept_connection( + self, protocol_factory, sock, + sslcontext=None, server=None, backlog=100, + ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, + ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT): # This method is only called once for each event loop tick where the # listening socket has triggered an EVENT_READ. There may be multiple # connections waiting for an .accept() so it is called in a loop. @@ -194,24 +193,28 @@ def _accept_connection(self, protocol_factory, sock, self.call_exception_handler({ 'message': 'socket.accept() out of system resource', 'exception': exc, - 'socket': sock, + 'socket': trsock.TransportSocket(sock), }) self._remove_reader(sock.fileno()) self.call_later(constants.ACCEPT_RETRY_DELAY, self._start_serving, protocol_factory, sock, sslcontext, server, - backlog) + backlog, ssl_handshake_timeout, + ssl_shutdown_timeout) else: raise # The event loop will catch, log and ignore it. else: extra = {'peername': addr} - accept = self._accept_connection2(protocol_factory, conn, extra, - sslcontext, server) + accept = self._accept_connection2( + protocol_factory, conn, extra, sslcontext, server, + ssl_handshake_timeout, ssl_shutdown_timeout) self.create_task(accept) - @coroutine - def _accept_connection2(self, protocol_factory, conn, extra, - sslcontext=None, server=None): + async def _accept_connection2( + self, protocol_factory, conn, extra, + sslcontext=None, server=None, + ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, + ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT): protocol = None transport = None try: @@ -220,24 +223,32 @@ def _accept_connection2(self, protocol_factory, conn, extra, if sslcontext: transport = self._make_ssl_transport( conn, protocol, sslcontext, waiter=waiter, - server_side=True, extra=extra, server=server) + server_side=True, extra=extra, server=server, + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) else: transport = self._make_socket_transport( conn, protocol, waiter=waiter, extra=extra, server=server) try: - yield from waiter - except: + await waiter + except BaseException: transport.close() + # gh-109534: When an exception is raised by the SSLProtocol object the + # exception set in this future can keep the protocol object alive and + # cause a reference cycle. + waiter = None raise + # It's now up to the protocol to handle the connection. - # It's now up to the protocol to handle the connection. - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: if self._debug: context = { - 'message': ('Error on transport creation ' - 'for incoming connection'), + 'message': + 'Error on transport creation for incoming connection', 'exception': exc, } if protocol is not None: @@ -247,19 +258,26 @@ def _accept_connection2(self, protocol_factory, conn, extra, self.call_exception_handler(context) def _ensure_fd_no_transport(self, fd): + fileno = fd + if not isinstance(fileno, int): + try: + fileno = int(fileno.fileno()) + except (AttributeError, TypeError, ValueError): + # This code matches selectors._fileobj_to_fd function. + raise ValueError(f"Invalid file object: {fd!r}") from None try: - transport = self._transports[fd] + transport = self._transports[fileno] except KeyError: pass else: if not transport.is_closing(): raise RuntimeError( - 'File descriptor {!r} is used by transport {!r}'.format( - fd, transport)) + f'File descriptor {fd!r} is used by transport ' + f'{transport!r}') def _add_reader(self, fd, callback, *args): self._check_closed() - handle = events.Handle(callback, args, self) + handle = events.Handle(callback, args, self, None) try: key = self._selector.get_key(fd) except KeyError: @@ -271,6 +289,7 @@ def _add_reader(self, fd, callback, *args): (handle, writer)) if reader is not None: reader.cancel() + return handle def _remove_reader(self, fd): if self.is_closed(): @@ -295,7 +314,7 @@ def _remove_reader(self, fd): def _add_writer(self, fd, callback, *args): self._check_closed() - handle = events.Handle(callback, args, self) + handle = events.Handle(callback, args, self, None) try: key = self._selector.get_key(fd) except KeyError: @@ -307,6 +326,7 @@ def _add_writer(self, fd, callback, *args): (reader, handle)) if writer is not None: writer.cancel() + return handle def _remove_writer(self, fd): """Remove a writer callback.""" @@ -334,7 +354,7 @@ def _remove_writer(self, fd): def add_reader(self, fd, callback, *args): """Add a reader callback.""" self._ensure_fd_no_transport(fd) - return self._add_reader(fd, callback, *args) + self._add_reader(fd, callback, *args) def remove_reader(self, fd): """Remove a reader callback.""" @@ -344,111 +364,294 @@ def remove_reader(self, fd): def add_writer(self, fd, callback, *args): """Add a writer callback..""" self._ensure_fd_no_transport(fd) - return self._add_writer(fd, callback, *args) + self._add_writer(fd, callback, *args) def remove_writer(self, fd): """Remove a writer callback.""" self._ensure_fd_no_transport(fd) return self._remove_writer(fd) - def sock_recv(self, sock, n): + async def sock_recv(self, sock, n): """Receive data from the socket. The return value is a bytes object representing the data received. The maximum amount of data to be received at once is specified by nbytes. - - This method is a coroutine. """ + base_events._check_ssl_socket(sock) if self._debug and sock.gettimeout() != 0: raise ValueError("the socket must be non-blocking") + try: + return sock.recv(n) + except (BlockingIOError, InterruptedError): + pass fut = self.create_future() - self._sock_recv(fut, False, sock, n) - return fut + fd = sock.fileno() + self._ensure_fd_no_transport(fd) + handle = self._add_reader(fd, self._sock_recv, fut, sock, n) + fut.add_done_callback( + functools.partial(self._sock_read_done, fd, handle=handle)) + return await fut + + def _sock_read_done(self, fd, fut, handle=None): + if handle is None or not handle.cancelled(): + self.remove_reader(fd) - def _sock_recv(self, fut, registered, sock, n): + def _sock_recv(self, fut, sock, n): # _sock_recv() can add itself as an I/O callback if the operation can't # be done immediately. Don't use it directly, call sock_recv(). - fd = sock.fileno() - if registered: - # Remove the callback early. It should be rare that the - # selector says the fd is ready but the call still returns - # EAGAIN, and I am willing to take a hit in that case in - # order to simplify the common case. - self.remove_reader(fd) - if fut.cancelled(): + if fut.done(): return try: data = sock.recv(n) except (BlockingIOError, InterruptedError): - self.add_reader(fd, self._sock_recv, fut, True, sock, n) - except Exception as exc: + return # try again next time + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: fut.set_exception(exc) else: fut.set_result(data) - def sock_sendall(self, sock, data): - """Send data to the socket. - - The socket must be connected to a remote socket. This method continues - to send data from data until either all data has been sent or an - error occurs. None is returned on success. On error, an exception is - raised, and there is no way to determine how much data, if any, was - successfully processed by the receiving end of the connection. + async def sock_recv_into(self, sock, buf): + """Receive data from the socket. - This method is a coroutine. + The received data is written into *buf* (a writable buffer). + The return value is the number of bytes written. """ + base_events._check_ssl_socket(sock) if self._debug and sock.gettimeout() != 0: raise ValueError("the socket must be non-blocking") + try: + return sock.recv_into(buf) + except (BlockingIOError, InterruptedError): + pass fut = self.create_future() - if data: - self._sock_sendall(fut, False, sock, data) + fd = sock.fileno() + self._ensure_fd_no_transport(fd) + handle = self._add_reader(fd, self._sock_recv_into, fut, sock, buf) + fut.add_done_callback( + functools.partial(self._sock_read_done, fd, handle=handle)) + return await fut + + def _sock_recv_into(self, fut, sock, buf): + # _sock_recv_into() can add itself as an I/O callback if the operation + # can't be done immediately. Don't use it directly, call + # sock_recv_into(). + if fut.done(): + return + try: + nbytes = sock.recv_into(buf) + except (BlockingIOError, InterruptedError): + return # try again next time + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + fut.set_exception(exc) else: - fut.set_result(None) - return fut + fut.set_result(nbytes) + + async def sock_recvfrom(self, sock, bufsize): + """Receive a datagram from a datagram socket. - def _sock_sendall(self, fut, registered, sock, data): + The return value is a tuple of (bytes, address) representing the + datagram received and the address it came from. + The maximum amount of data to be received at once is specified by + nbytes. + """ + base_events._check_ssl_socket(sock) + if self._debug and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + try: + return sock.recvfrom(bufsize) + except (BlockingIOError, InterruptedError): + pass + fut = self.create_future() fd = sock.fileno() + self._ensure_fd_no_transport(fd) + handle = self._add_reader(fd, self._sock_recvfrom, fut, sock, bufsize) + fut.add_done_callback( + functools.partial(self._sock_read_done, fd, handle=handle)) + return await fut + + def _sock_recvfrom(self, fut, sock, bufsize): + # _sock_recvfrom() can add itself as an I/O callback if the operation + # can't be done immediately. Don't use it directly, call + # sock_recvfrom(). + if fut.done(): + return + try: + result = sock.recvfrom(bufsize) + except (BlockingIOError, InterruptedError): + return # try again next time + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + fut.set_exception(exc) + else: + fut.set_result(result) - if registered: - self.remove_writer(fd) - if fut.cancelled(): + async def sock_recvfrom_into(self, sock, buf, nbytes=0): + """Receive data from the socket. + + The received data is written into *buf* (a writable buffer). + The return value is a tuple of (number of bytes written, address). + """ + base_events._check_ssl_socket(sock) + if self._debug and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + if not nbytes: + nbytes = len(buf) + + try: + return sock.recvfrom_into(buf, nbytes) + except (BlockingIOError, InterruptedError): + pass + fut = self.create_future() + fd = sock.fileno() + self._ensure_fd_no_transport(fd) + handle = self._add_reader(fd, self._sock_recvfrom_into, fut, sock, buf, + nbytes) + fut.add_done_callback( + functools.partial(self._sock_read_done, fd, handle=handle)) + return await fut + + def _sock_recvfrom_into(self, fut, sock, buf, bufsize): + # _sock_recv_into() can add itself as an I/O callback if the operation + # can't be done immediately. Don't use it directly, call + # sock_recv_into(). + if fut.done(): return + try: + result = sock.recvfrom_into(buf, bufsize) + except (BlockingIOError, InterruptedError): + return # try again next time + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + fut.set_exception(exc) + else: + fut.set_result(result) + async def sock_sendall(self, sock, data): + """Send data to the socket. + + The socket must be connected to a remote socket. This method continues + to send data from data until either all data has been sent or an + error occurs. None is returned on success. On error, an exception is + raised, and there is no way to determine how much data, if any, was + successfully processed by the receiving end of the connection. + """ + base_events._check_ssl_socket(sock) + if self._debug and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") try: n = sock.send(data) except (BlockingIOError, InterruptedError): n = 0 - except Exception as exc: + + if n == len(data): + # all data sent + return + + fut = self.create_future() + fd = sock.fileno() + self._ensure_fd_no_transport(fd) + # use a trick with a list in closure to store a mutable state + handle = self._add_writer(fd, self._sock_sendall, fut, sock, + memoryview(data), [n]) + fut.add_done_callback( + functools.partial(self._sock_write_done, fd, handle=handle)) + return await fut + + def _sock_sendall(self, fut, sock, view, pos): + if fut.done(): + # Future cancellation can be scheduled on previous loop iteration + return + start = pos[0] + try: + n = sock.send(view[start:]) + except (BlockingIOError, InterruptedError): + return + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: fut.set_exception(exc) return - if n == len(data): + start += n + + if start == len(view): fut.set_result(None) else: - if n: - data = data[n:] - self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + pos[0] = start - @coroutine - def sock_connect(self, sock, address): + async def sock_sendto(self, sock, data, address): + """Send data to the socket. + + The socket must be connected to a remote socket. This method continues + to send data from data until either all data has been sent or an + error occurs. None is returned on success. On error, an exception is + raised, and there is no way to determine how much data, if any, was + successfully processed by the receiving end of the connection. + """ + base_events._check_ssl_socket(sock) + if self._debug and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + try: + return sock.sendto(data, address) + except (BlockingIOError, InterruptedError): + pass + + fut = self.create_future() + fd = sock.fileno() + self._ensure_fd_no_transport(fd) + # use a trick with a list in closure to store a mutable state + handle = self._add_writer(fd, self._sock_sendto, fut, sock, data, + address) + fut.add_done_callback( + functools.partial(self._sock_write_done, fd, handle=handle)) + return await fut + + def _sock_sendto(self, fut, sock, data, address): + if fut.done(): + # Future cancellation can be scheduled on previous loop iteration + return + try: + n = sock.sendto(data, 0, address) + except (BlockingIOError, InterruptedError): + return + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + fut.set_exception(exc) + else: + fut.set_result(n) + + async def sock_connect(self, sock, address): """Connect to a remote socket at address. This method is a coroutine. """ + base_events._check_ssl_socket(sock) if self._debug and sock.gettimeout() != 0: raise ValueError("the socket must be non-blocking") - if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX: - resolved = base_events._ensure_resolved( - address, family=sock.family, proto=sock.proto, loop=self) - if not resolved.done(): - yield from resolved - _, _, _, _, address = resolved.result()[0] + if sock.family == socket.AF_INET or ( + base_events._HAS_IPv6 and sock.family == socket.AF_INET6): + resolved = await self._ensure_resolved( + address, family=sock.family, type=sock.type, proto=sock.proto, + loop=self, + ) + _, _, _, _, address = resolved[0] fut = self.create_future() self._sock_connect(fut, sock, address) - return (yield from fut) + try: + return await fut + finally: + # Needed to break cycles when an exception occurs. + fut = None def _sock_connect(self, fut, sock, address): fd = sock.fileno() @@ -459,66 +662,91 @@ def _sock_connect(self, fut, sock, address): # connection runs in background. We have to wait until the socket # becomes writable to be notified when the connection succeed or # fails. + self._ensure_fd_no_transport(fd) + handle = self._add_writer( + fd, self._sock_connect_cb, fut, sock, address) fut.add_done_callback( - functools.partial(self._sock_connect_done, fd)) - self.add_writer(fd, self._sock_connect_cb, fut, sock, address) - except Exception as exc: + functools.partial(self._sock_write_done, fd, handle=handle)) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: fut.set_exception(exc) else: fut.set_result(None) + finally: + fut = None - def _sock_connect_done(self, fd, fut): - self.remove_writer(fd) + def _sock_write_done(self, fd, fut, handle=None): + if handle is None or not handle.cancelled(): + self.remove_writer(fd) def _sock_connect_cb(self, fut, sock, address): - if fut.cancelled(): + if fut.done(): return try: err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) if err != 0: # Jump to any except clause below. - raise OSError(err, 'Connect call failed %s' % (address,)) + raise OSError(err, f'Connect call failed {address}') except (BlockingIOError, InterruptedError): # socket is still registered, the callback will be retried later pass - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: fut.set_exception(exc) else: fut.set_result(None) + finally: + fut = None - def sock_accept(self, sock): + async def sock_accept(self, sock): """Accept a connection. The socket must be bound to an address and listening for connections. The return value is a pair (conn, address) where conn is a new socket object usable to send and receive data on the connection, and address is the address bound to the socket on the other end of the connection. - - This method is a coroutine. """ + base_events._check_ssl_socket(sock) if self._debug and sock.gettimeout() != 0: raise ValueError("the socket must be non-blocking") fut = self.create_future() - self._sock_accept(fut, False, sock) - return fut + self._sock_accept(fut, sock) + return await fut - def _sock_accept(self, fut, registered, sock): + def _sock_accept(self, fut, sock): fd = sock.fileno() - if registered: - self.remove_reader(fd) - if fut.cancelled(): - return try: conn, address = sock.accept() conn.setblocking(False) except (BlockingIOError, InterruptedError): - self.add_reader(fd, self._sock_accept, fut, True, sock) - except Exception as exc: + self._ensure_fd_no_transport(fd) + handle = self._add_reader(fd, self._sock_accept, fut, sock) + fut.add_done_callback( + functools.partial(self._sock_read_done, fd, handle=handle)) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: fut.set_exception(exc) else: fut.set_result((conn, address)) + async def _sendfile_native(self, transp, file, offset, count): + del self._transports[transp._sock_fd] + resume_reading = transp.is_reading() + transp.pause_reading() + await transp._make_empty_waiter() + try: + return await self.sock_sendfile(transp._sock, file, offset, count, + fallback=False) + finally: + transp._reset_empty_waiter() + if resume_reading: + transp.resume_reading() + self._transports[transp._sock_fd] = transp + def _process_events(self, event_list): for key, mask in event_list: fileobj, (reader, writer) = key.fileobj, key.data @@ -543,8 +771,6 @@ class _SelectorTransport(transports._FlowControlMixin, max_size = 256 * 1024 # Buffer size passed to recv(). - _buffer_factory = bytearray # Constructs initial value for self._buffer. - # Attribute used in the destructor: it must be set even if the constructor # is not called (see _SelectorSslTransport which may start by raising an # exception) @@ -552,8 +778,11 @@ class _SelectorTransport(transports._FlowControlMixin, def __init__(self, loop, sock, protocol, extra=None, server=None): super().__init__(extra, loop) - self._extra['socket'] = sock - self._extra['sockname'] = sock.getsockname() + self._extra['socket'] = trsock.TransportSocket(sock) + try: + self._extra['sockname'] = sock.getsockname() + except OSError: + self._extra['sockname'] = None if 'peername' not in self._extra: try: self._extra['peername'] = sock.getpeername() @@ -561,12 +790,16 @@ def __init__(self, loop, sock, protocol, extra=None, server=None): self._extra['peername'] = None self._sock = sock self._sock_fd = sock.fileno() - self._protocol = protocol - self._protocol_connected = True + + self._protocol_connected = False + self.set_protocol(protocol) + self._server = server - self._buffer = self._buffer_factory() + self._buffer = collections.deque() self._conn_lost = 0 # Set when call to connection_lost scheduled. self._closing = False # Set when close() called. + self._paused = False # Set when pause_reading() called + if self._server is not None: self._server._attach() loop._transports[self._sock_fd] = self @@ -577,7 +810,7 @@ def __repr__(self): info.append('closed') elif self._closing: info.append('closing') - info.append('fd=%s' % self._sock_fd) + info.append(f'fd={self._sock_fd}') # test if the transport was closed if self._loop is not None and not self._loop.is_closed(): polling = _test_selector_event(self._loop._selector, @@ -596,14 +829,15 @@ def __repr__(self): state = 'idle' bufsize = self.get_write_buffer_size() - info.append('write=<%s, bufsize=%s>' % (state, bufsize)) - return '<%s>' % ' '.join(info) + info.append(f'write=<{state}, bufsize={bufsize}>') + return '<{}>'.format(' '.join(info)) def abort(self): self._force_close(None) def set_protocol(self, protocol): self._protocol = protocol + self._protocol_connected = True def get_protocol(self): return self._protocol @@ -611,6 +845,25 @@ def get_protocol(self): def is_closing(self): return self._closing + def is_reading(self): + return not self.is_closing() and not self._paused + + def pause_reading(self): + if not self.is_reading(): + return + self._paused = True + self._loop._remove_reader(self._sock_fd) + if self._loop.get_debug(): + logger.debug("%r pauses reading", self) + + def resume_reading(self): + if self._closing or not self._paused: + return + self._paused = False + self._add_reader(self._sock_fd, self._read_ready) + if self._loop.get_debug(): + logger.debug("%r resumes reading", self) + def close(self): if self._closing: return @@ -621,19 +874,14 @@ def close(self): self._loop._remove_writer(self._sock_fd) self._loop.call_soon(self._call_connection_lost, None) - # On Python 3.3 and older, objects with a destructor part of a reference - # cycle are never destroyed. It's not more the case on Python 3.4 thanks - # to the PEP 442. - if compat.PY34: - def __del__(self): - if self._sock is not None: - warnings.warn("unclosed transport %r" % self, ResourceWarning, - source=self) - self._sock.close() + def __del__(self, _warn=warnings.warn): + if self._sock is not None: + _warn(f"unclosed transport {self!r}", ResourceWarning, source=self) + self._sock.close() def _fatal_error(self, exc, message='Fatal error on transport'): # Should be called from exception handler only. - if isinstance(exc, base_events._FATAL_ERROR_IGNORE): + if isinstance(exc, OSError): if self._loop.get_debug(): logger.debug("%r: %s", self, message, exc_info=True) else: @@ -672,81 +920,146 @@ def _call_connection_lost(self, exc): self._server = None def get_write_buffer_size(self): - return len(self._buffer) + return sum(map(len, self._buffer)) + + def _add_reader(self, fd, callback, *args): + if not self.is_reading(): + return + self._loop._add_reader(fd, callback, *args) class _SelectorSocketTransport(_SelectorTransport): + _start_tls_compatible = True + _sendfile_compatible = constants._SendfileMode.TRY_NATIVE + def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): + + self._read_ready_cb = None super().__init__(loop, sock, protocol, extra, server) self._eof = False - self._paused = False - + self._empty_waiter = None + if _HAS_SENDMSG: + self._write_ready = self._write_sendmsg + else: + self._write_ready = self._write_send # Disable the Nagle algorithm -- small writes will be # sent without waiting for the TCP ACK. This generally # decreases the latency (in some cases significantly.) - _set_nodelay(self._sock) + base_events._set_nodelay(self._sock) self._loop.call_soon(self._protocol.connection_made, self) # only start reading when connection_made() has been called - self._loop.call_soon(self._loop._add_reader, + self._loop.call_soon(self._add_reader, self._sock_fd, self._read_ready) if waiter is not None: # only wake up the waiter when connection_made() has been called self._loop.call_soon(futures._set_result_unless_cancelled, waiter, None) - def pause_reading(self): - if self._closing: - raise RuntimeError('Cannot pause_reading() when closing') - if self._paused: - raise RuntimeError('Already paused') - self._paused = True - self._loop._remove_reader(self._sock_fd) - if self._loop.get_debug(): - logger.debug("%r pauses reading", self) + def set_protocol(self, protocol): + if isinstance(protocol, protocols.BufferedProtocol): + self._read_ready_cb = self._read_ready__get_buffer + else: + self._read_ready_cb = self._read_ready__data_received - def resume_reading(self): - if not self._paused: - raise RuntimeError('Not paused') - self._paused = False - if self._closing: - return - self._loop._add_reader(self._sock_fd, self._read_ready) - if self._loop.get_debug(): - logger.debug("%r resumes reading", self) + super().set_protocol(protocol) def _read_ready(self): + self._read_ready_cb() + + def _read_ready__get_buffer(self): + if self._conn_lost: + return + + try: + buf = self._protocol.get_buffer(-1) + if not len(buf): + raise RuntimeError('get_buffer() returned an empty buffer') + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._fatal_error( + exc, 'Fatal error: protocol.get_buffer() call failed.') + return + + try: + nbytes = self._sock.recv_into(buf) + except (BlockingIOError, InterruptedError): + return + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._fatal_error(exc, 'Fatal read error on socket transport') + return + + if not nbytes: + self._read_ready__on_eof() + return + + try: + self._protocol.buffer_updated(nbytes) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._fatal_error( + exc, 'Fatal error: protocol.buffer_updated() call failed.') + + def _read_ready__data_received(self): if self._conn_lost: return try: data = self._sock.recv(self.max_size) except (BlockingIOError, InterruptedError): - pass - except Exception as exc: + return + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: self._fatal_error(exc, 'Fatal read error on socket transport') + return + + if not data: + self._read_ready__on_eof() + return + + try: + self._protocol.data_received(data) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._fatal_error( + exc, 'Fatal error: protocol.data_received() call failed.') + + def _read_ready__on_eof(self): + if self._loop.get_debug(): + logger.debug("%r received EOF", self) + + try: + keep_open = self._protocol.eof_received() + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._fatal_error( + exc, 'Fatal error: protocol.eof_received() call failed.') + return + + if keep_open: + # We're keeping the connection open so the + # protocol can write more, but we still can't + # receive more, so remove the reader callback. + self._loop._remove_reader(self._sock_fd) else: - if data: - self._protocol.data_received(data) - else: - if self._loop.get_debug(): - logger.debug("%r received EOF", self) - keep_open = self._protocol.eof_received() - if keep_open: - # We're keeping the connection open so the - # protocol can write more, but we still can't - # receive more, so remove the reader callback. - self._loop._remove_reader(self._sock_fd) - else: - self.close() + self.close() def write(self, data): if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError('data argument must be a bytes-like object, ' - 'not %r' % type(data).__name__) + raise TypeError(f'data argument must be a bytes-like object, ' + f'not {type(data).__name__!r}') if self._eof: raise RuntimeError('Cannot call write() after write_eof()') + if self._empty_waiter is not None: + raise RuntimeError('unable to write; sendfile is in progress') if not data: return @@ -762,288 +1075,142 @@ def write(self, data): n = self._sock.send(data) except (BlockingIOError, InterruptedError): pass - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: self._fatal_error(exc, 'Fatal write error on socket transport') return else: - data = data[n:] + data = memoryview(data)[n:] if not data: return # Not all was written; register write handler. self._loop._add_writer(self._sock_fd, self._write_ready) # Add it to the buffer. - self._buffer.extend(data) + self._buffer.append(data) self._maybe_pause_protocol() - def _write_ready(self): - assert self._buffer, 'Data should not be empty' + def _get_sendmsg_buffer(self): + return itertools.islice(self._buffer, SC_IOV_MAX) + def _write_sendmsg(self): + assert self._buffer, 'Data should not be empty' if self._conn_lost: return try: - n = self._sock.send(self._buffer) + nbytes = self._sock.sendmsg(self._get_sendmsg_buffer()) + self._adjust_leftover_buffer(nbytes) except (BlockingIOError, InterruptedError): pass - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: self._loop._remove_writer(self._sock_fd) self._buffer.clear() self._fatal_error(exc, 'Fatal write error on socket transport') + if self._empty_waiter is not None: + self._empty_waiter.set_exception(exc) else: - if n: - del self._buffer[:n] self._maybe_resume_protocol() # May append to buffer. if not self._buffer: self._loop._remove_writer(self._sock_fd) + if self._empty_waiter is not None: + self._empty_waiter.set_result(None) if self._closing: self._call_connection_lost(None) elif self._eof: self._sock.shutdown(socket.SHUT_WR) - def write_eof(self): - if self._eof: - return - self._eof = True - if not self._buffer: - self._sock.shutdown(socket.SHUT_WR) - - def can_write_eof(self): - return True - - -class _SelectorSslTransport(_SelectorTransport): - - _buffer_factory = bytearray - - def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, - server_side=False, server_hostname=None, - extra=None, server=None): - if ssl is None: - raise RuntimeError('stdlib ssl module not available') - - if not sslcontext: - sslcontext = sslproto._create_transport_context(server_side, server_hostname) - - wrap_kwargs = { - 'server_side': server_side, - 'do_handshake_on_connect': False, - } - if server_hostname and not server_side: - wrap_kwargs['server_hostname'] = server_hostname - sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs) - - super().__init__(loop, sslsock, protocol, extra, server) - # the protocol connection is only made after the SSL handshake - self._protocol_connected = False - - self._server_hostname = server_hostname - self._waiter = waiter - self._sslcontext = sslcontext - self._paused = False - - # SSL-specific extra info. (peercert is set later) - self._extra.update(sslcontext=sslcontext) - - if self._loop.get_debug(): - logger.debug("%r starts SSL handshake", self) - start_time = self._loop.time() - else: - start_time = None - self._on_handshake(start_time) - - def _wakeup_waiter(self, exc=None): - if self._waiter is None: - return - if not self._waiter.cancelled(): - if exc is not None: - self._waiter.set_exception(exc) - else: - self._waiter.set_result(None) - self._waiter = None - - def _on_handshake(self, start_time): - try: - self._sock.do_handshake() - except ssl.SSLWantReadError: - self._loop._add_reader(self._sock_fd, - self._on_handshake, start_time) - return - except ssl.SSLWantWriteError: - self._loop._add_writer(self._sock_fd, - self._on_handshake, start_time) - return - except BaseException as exc: - if self._loop.get_debug(): - logger.warning("%r: SSL handshake failed", - self, exc_info=True) - self._loop._remove_reader(self._sock_fd) - self._loop._remove_writer(self._sock_fd) - self._sock.close() - self._wakeup_waiter(exc) - if isinstance(exc, Exception): - return + def _adjust_leftover_buffer(self, nbytes: int) -> None: + buffer = self._buffer + while nbytes: + b = buffer.popleft() + b_len = len(b) + if b_len <= nbytes: + nbytes -= b_len else: - raise - - self._loop._remove_reader(self._sock_fd) - self._loop._remove_writer(self._sock_fd) - - peercert = self._sock.getpeercert() - if not hasattr(self._sslcontext, 'check_hostname'): - # Verify hostname if requested, Python 3.4+ uses check_hostname - # and checks the hostname in do_handshake() - if (self._server_hostname and - self._sslcontext.verify_mode != ssl.CERT_NONE): - try: - ssl.match_hostname(peercert, self._server_hostname) - except Exception as exc: - if self._loop.get_debug(): - logger.warning("%r: SSL handshake failed " - "on matching the hostname", - self, exc_info=True) - self._sock.close() - self._wakeup_waiter(exc) - return - - # Add extra info that becomes available after handshake. - self._extra.update(peercert=peercert, - cipher=self._sock.cipher(), - compression=self._sock.compression(), - ssl_object=self._sock, - ) - - self._read_wants_write = False - self._write_wants_read = False - self._loop._add_reader(self._sock_fd, self._read_ready) - self._protocol_connected = True - self._loop.call_soon(self._protocol.connection_made, self) - # only wake up the waiter when connection_made() has been called - self._loop.call_soon(self._wakeup_waiter) - - if self._loop.get_debug(): - dt = self._loop.time() - start_time - logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3) - - def pause_reading(self): - # XXX This is a bit icky, given the comment at the top of - # _read_ready(). Is it possible to evoke a deadlock? I don't - # know, although it doesn't look like it; write() will still - # accept more data for the buffer and eventually the app will - # call resume_reading() again, and things will flow again. - - if self._closing: - raise RuntimeError('Cannot pause_reading() when closing') - if self._paused: - raise RuntimeError('Already paused') - self._paused = True - self._loop._remove_reader(self._sock_fd) - if self._loop.get_debug(): - logger.debug("%r pauses reading", self) - - def resume_reading(self): - if not self._paused: - raise RuntimeError('Not paused') - self._paused = False - if self._closing: - return - self._loop._add_reader(self._sock_fd, self._read_ready) - if self._loop.get_debug(): - logger.debug("%r resumes reading", self) + buffer.appendleft(b[nbytes:]) + break - def _read_ready(self): + def _write_send(self): + assert self._buffer, 'Data should not be empty' if self._conn_lost: return - if self._write_wants_read: - self._write_wants_read = False - self._write_ready() - - if self._buffer: - self._loop._add_writer(self._sock_fd, self._write_ready) - try: - data = self._sock.recv(self.max_size) - except (BlockingIOError, InterruptedError, ssl.SSLWantReadError): + buffer = self._buffer.popleft() + n = self._sock.send(buffer) + if n != len(buffer): + # Not all data was written + self._buffer.appendleft(buffer[n:]) + except (BlockingIOError, InterruptedError): pass - except ssl.SSLWantWriteError: - self._read_wants_write = True - self._loop._remove_reader(self._sock_fd) - self._loop._add_writer(self._sock_fd, self._write_ready) - except Exception as exc: - self._fatal_error(exc, 'Fatal read error on SSL transport') + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._loop._remove_writer(self._sock_fd) + self._buffer.clear() + self._fatal_error(exc, 'Fatal write error on socket transport') + if self._empty_waiter is not None: + self._empty_waiter.set_exception(exc) else: - if data: - self._protocol.data_received(data) - else: - try: - if self._loop.get_debug(): - logger.debug("%r received EOF", self) - keep_open = self._protocol.eof_received() - if keep_open: - logger.warning('returning true from eof_received() ' - 'has no effect when using ssl') - finally: - self.close() - - def _write_ready(self): - if self._conn_lost: - return - if self._read_wants_write: - self._read_wants_write = False - self._read_ready() - - if not (self._paused or self._closing): - self._loop._add_reader(self._sock_fd, self._read_ready) - - if self._buffer: - try: - n = self._sock.send(self._buffer) - except (BlockingIOError, InterruptedError, ssl.SSLWantWriteError): - n = 0 - except ssl.SSLWantReadError: - n = 0 - self._loop._remove_writer(self._sock_fd) - self._write_wants_read = True - except Exception as exc: + self._maybe_resume_protocol() # May append to buffer. + if not self._buffer: self._loop._remove_writer(self._sock_fd) - self._buffer.clear() - self._fatal_error(exc, 'Fatal write error on SSL transport') - return - - if n: - del self._buffer[:n] - - self._maybe_resume_protocol() # May append to buffer. + if self._empty_waiter is not None: + self._empty_waiter.set_result(None) + if self._closing: + self._call_connection_lost(None) + elif self._eof: + self._sock.shutdown(socket.SHUT_WR) + def write_eof(self): + if self._closing or self._eof: + return + self._eof = True if not self._buffer: - self._loop._remove_writer(self._sock_fd) - if self._closing: - self._call_connection_lost(None) + self._sock.shutdown(socket.SHUT_WR) - def write(self, data): - if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError('data argument must be a bytes-like object, ' - 'not %r' % type(data).__name__) - if not data: + def writelines(self, list_of_data): + if self._eof: + raise RuntimeError('Cannot call writelines() after write_eof()') + if self._empty_waiter is not None: + raise RuntimeError('unable to writelines; sendfile is in progress') + if not list_of_data: return + self._buffer.extend([memoryview(data) for data in list_of_data]) + self._write_ready() + # If the entire buffer couldn't be written, register a write handler + if self._buffer: + self._loop._add_writer(self._sock_fd, self._write_ready) - if self._conn_lost: - if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: - logger.warning('socket.send() raised exception.') - self._conn_lost += 1 - return + def can_write_eof(self): + return True + def _call_connection_lost(self, exc): + super()._call_connection_lost(exc) + if self._empty_waiter is not None: + self._empty_waiter.set_exception( + ConnectionError("Connection is closed by peer")) + + def _make_empty_waiter(self): + if self._empty_waiter is not None: + raise RuntimeError("Empty waiter is already set") + self._empty_waiter = self._loop.create_future() if not self._buffer: - self._loop._add_writer(self._sock_fd, self._write_ready) + self._empty_waiter.set_result(None) + return self._empty_waiter - # Add it to the buffer. - self._buffer.extend(data) - self._maybe_pause_protocol() + def _reset_empty_waiter(self): + self._empty_waiter = None - def can_write_eof(self): - return False + def close(self): + self._read_ready_cb = None + self._write_ready = None + super().close() -class _SelectorDatagramTransport(_SelectorTransport): +class _SelectorDatagramTransport(_SelectorTransport, transports.DatagramTransport): _buffer_factory = collections.deque @@ -1051,9 +1218,10 @@ def __init__(self, loop, sock, protocol, address=None, waiter=None, extra=None): super().__init__(loop, sock, protocol, extra) self._address = address + self._buffer_size = 0 self._loop.call_soon(self._protocol.connection_made, self) # only start reading when connection_made() has been called - self._loop.call_soon(self._loop._add_reader, + self._loop.call_soon(self._add_reader, self._sock_fd, self._read_ready) if waiter is not None: # only wake up the waiter when connection_made() has been called @@ -1061,7 +1229,7 @@ def __init__(self, loop, sock, protocol, address=None, waiter, None) def get_write_buffer_size(self): - return sum(len(data) for data, _ in self._buffer) + return self._buffer_size def _read_ready(self): if self._conn_lost: @@ -1072,21 +1240,25 @@ def _read_ready(self): pass except OSError as exc: self._protocol.error_received(exc) - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: self._fatal_error(exc, 'Fatal read error on datagram transport') else: self._protocol.datagram_received(data, addr) def sendto(self, data, addr=None): if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError('data argument must be a bytes-like object, ' - 'not %r' % type(data).__name__) + raise TypeError(f'data argument must be a bytes-like object, ' + f'not {type(data).__name__!r}') if not data: return - if self._address and addr not in (None, self._address): - raise ValueError('Invalid address: must be None or %s' % - (self._address,)) + if self._address: + if addr not in (None, self._address): + raise ValueError( + f'Invalid address: must be None or {self._address}') + addr = self._address if self._conn_lost and self._address: if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: @@ -1097,7 +1269,7 @@ def sendto(self, data, addr=None): if not self._buffer: # Attempt to send it right away first. try: - if self._address: + if self._extra['peername']: self._sock.send(data) else: self._sock.sendto(data, addr) @@ -1107,32 +1279,39 @@ def sendto(self, data, addr=None): except OSError as exc: self._protocol.error_received(exc) return - except Exception as exc: - self._fatal_error(exc, - 'Fatal write error on datagram transport') + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._fatal_error( + exc, 'Fatal write error on datagram transport') return # Ensure that what we buffer is immutable. self._buffer.append((bytes(data), addr)) + self._buffer_size += len(data) self._maybe_pause_protocol() def _sendto_ready(self): while self._buffer: data, addr = self._buffer.popleft() + self._buffer_size -= len(data) try: - if self._address: + if self._extra['peername']: self._sock.send(data) else: self._sock.sendto(data, addr) except (BlockingIOError, InterruptedError): self._buffer.appendleft((data, addr)) # Try again later. + self._buffer_size += len(data) break except OSError as exc: self._protocol.error_received(exc) return - except Exception as exc: - self._fatal_error(exc, - 'Fatal write error on datagram transport') + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._fatal_error( + exc, 'Fatal write error on datagram transport') return self._maybe_resume_protocol() # May append to buffer. diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py index 7ad28d6aa0..e51669a2ab 100644 --- a/Lib/asyncio/sslproto.py +++ b/Lib/asyncio/sslproto.py @@ -1,16 +1,48 @@ +# Contains code from https://github.com/MagicStack/uvloop/tree/v0.16.0 +# SPDX-License-Identifier: PSF-2.0 AND (MIT OR Apache-2.0) +# SPDX-FileCopyrightText: Copyright (c) 2015-2021 MagicStack Inc. http://magic.io + import collections +import enum import warnings try: import ssl except ImportError: # pragma: no cover ssl = None -from . import base_events -from . import compat +from . import constants +from . import exceptions from . import protocols from . import transports from .log import logger +if ssl is not None: + SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError) + + +class SSLProtocolState(enum.Enum): + UNWRAPPED = "UNWRAPPED" + DO_HANDSHAKE = "DO_HANDSHAKE" + WRAPPED = "WRAPPED" + FLUSHING = "FLUSHING" + SHUTDOWN = "SHUTDOWN" + + +class AppProtocolState(enum.Enum): + # This tracks the state of app protocol (https://git.io/fj59P): + # + # INIT -cm-> CON_MADE [-dr*->] [-er-> EOF?] -cl-> CON_LOST + # + # * cm: connection_made() + # * dr: data_received() + # * er: eof_received() + # * cl: connection_lost() + + STATE_INIT = "STATE_INIT" + STATE_CON_MADE = "STATE_CON_MADE" + STATE_EOF = "STATE_EOF" + STATE_CON_LOST = "STATE_CON_LOST" + def _create_transport_context(server_side, server_hostname): if server_side: @@ -19,286 +51,43 @@ def _create_transport_context(server_side, server_hostname): # Client side may pass ssl=True to use a default # context; in that case the sslcontext passed is None. # The default is secure for client connections. - if hasattr(ssl, 'create_default_context'): - # Python 3.4+: use up-to-date strong settings. - sslcontext = ssl.create_default_context() - if not server_hostname: - sslcontext.check_hostname = False - else: - # Fallback for Python 3.3. - sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - sslcontext.options |= ssl.OP_NO_SSLv2 - sslcontext.options |= ssl.OP_NO_SSLv3 - sslcontext.set_default_verify_paths() - sslcontext.verify_mode = ssl.CERT_REQUIRED + # Python 3.4+: use up-to-date strong settings. + sslcontext = ssl.create_default_context() + if not server_hostname: + sslcontext.check_hostname = False return sslcontext -def _is_sslproto_available(): - return hasattr(ssl, "MemoryBIO") - - -# States of an _SSLPipe. -_UNWRAPPED = "UNWRAPPED" -_DO_HANDSHAKE = "DO_HANDSHAKE" -_WRAPPED = "WRAPPED" -_SHUTDOWN = "SHUTDOWN" - - -class _SSLPipe(object): - """An SSL "Pipe". - - An SSL pipe allows you to communicate with an SSL/TLS protocol instance - through memory buffers. It can be used to implement a security layer for an - existing connection where you don't have access to the connection's file - descriptor, or for some reason you don't want to use it. - - An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode, - data is passed through untransformed. In wrapped mode, application level - data is encrypted to SSL record level data and vice versa. The SSL record - level is the lowest level in the SSL protocol suite and is what travels - as-is over the wire. - - An SslPipe initially is in "unwrapped" mode. To start SSL, call - do_handshake(). To shutdown SSL again, call unwrap(). - """ - - max_size = 256 * 1024 # Buffer size passed to read() - - def __init__(self, context, server_side, server_hostname=None): - """ - The *context* argument specifies the ssl.SSLContext to use. - - The *server_side* argument indicates whether this is a server side or - client side transport. - - The optional *server_hostname* argument can be used to specify the - hostname you are connecting to. You may only specify this parameter if - the _ssl module supports Server Name Indication (SNI). - """ - self._context = context - self._server_side = server_side - self._server_hostname = server_hostname - self._state = _UNWRAPPED - self._incoming = ssl.MemoryBIO() - self._outgoing = ssl.MemoryBIO() - self._sslobj = None - self._need_ssldata = False - self._handshake_cb = None - self._shutdown_cb = None - - @property - def context(self): - """The SSL context passed to the constructor.""" - return self._context - - @property - def ssl_object(self): - """The internal ssl.SSLObject instance. - - Return None if the pipe is not wrapped. - """ - return self._sslobj - - @property - def need_ssldata(self): - """Whether more record level data is needed to complete a handshake - that is currently in progress.""" - return self._need_ssldata - - @property - def wrapped(self): - """ - Whether a security layer is currently in effect. - - Return False during handshake. - """ - return self._state == _WRAPPED - - def do_handshake(self, callback=None): - """Start the SSL handshake. - - Return a list of ssldata. A ssldata element is a list of buffers - - The optional *callback* argument can be used to install a callback that - will be called when the handshake is complete. The callback will be - called with None if successful, else an exception instance. - """ - if self._state != _UNWRAPPED: - raise RuntimeError('handshake in progress or completed') - self._sslobj = self._context.wrap_bio( - self._incoming, self._outgoing, - server_side=self._server_side, - server_hostname=self._server_hostname) - self._state = _DO_HANDSHAKE - self._handshake_cb = callback - ssldata, appdata = self.feed_ssldata(b'', only_handshake=True) - assert len(appdata) == 0 - return ssldata - - def shutdown(self, callback=None): - """Start the SSL shutdown sequence. - - Return a list of ssldata. A ssldata element is a list of buffers - - The optional *callback* argument can be used to install a callback that - will be called when the shutdown is complete. The callback will be - called without arguments. - """ - if self._state == _UNWRAPPED: - raise RuntimeError('no security layer present') - if self._state == _SHUTDOWN: - raise RuntimeError('shutdown in progress') - assert self._state in (_WRAPPED, _DO_HANDSHAKE) - self._state = _SHUTDOWN - self._shutdown_cb = callback - ssldata, appdata = self.feed_ssldata(b'') - assert appdata == [] or appdata == [b''] - return ssldata - - def feed_eof(self): - """Send a potentially "ragged" EOF. - - This method will raise an SSL_ERROR_EOF exception if the EOF is - unexpected. - """ - self._incoming.write_eof() - ssldata, appdata = self.feed_ssldata(b'') - assert appdata == [] or appdata == [b''] - - def feed_ssldata(self, data, only_handshake=False): - """Feed SSL record level data into the pipe. - - The data must be a bytes instance. It is OK to send an empty bytes - instance. This can be used to get ssldata for a handshake initiated by - this endpoint. - - Return a (ssldata, appdata) tuple. The ssldata element is a list of - buffers containing SSL data that needs to be sent to the remote SSL. - - The appdata element is a list of buffers containing plaintext data that - needs to be forwarded to the application. The appdata list may contain - an empty buffer indicating an SSL "close_notify" alert. This alert must - be acknowledged by calling shutdown(). - """ - if self._state == _UNWRAPPED: - # If unwrapped, pass plaintext data straight through. - if data: - appdata = [data] - else: - appdata = [] - return ([], appdata) - - self._need_ssldata = False - if data: - self._incoming.write(data) - - ssldata = [] - appdata = [] - try: - if self._state == _DO_HANDSHAKE: - # Call do_handshake() until it doesn't raise anymore. - self._sslobj.do_handshake() - self._state = _WRAPPED - if self._handshake_cb: - self._handshake_cb(None) - if only_handshake: - return (ssldata, appdata) - # Handshake done: execute the wrapped block - - if self._state == _WRAPPED: - # Main state: read data from SSL until close_notify - while True: - chunk = self._sslobj.read(self.max_size) - appdata.append(chunk) - if not chunk: # close_notify - break +def add_flowcontrol_defaults(high, low, kb): + if high is None: + if low is None: + hi = kb * 1024 + else: + lo = low + hi = 4 * lo + else: + hi = high + if low is None: + lo = hi // 4 + else: + lo = low - elif self._state == _SHUTDOWN: - # Call shutdown() until it doesn't raise anymore. - self._sslobj.unwrap() - self._sslobj = None - self._state = _UNWRAPPED - if self._shutdown_cb: - self._shutdown_cb() - - elif self._state == _UNWRAPPED: - # Drain possible plaintext data after close_notify. - appdata.append(self._incoming.read()) - except (ssl.SSLError, ssl.CertificateError) as exc: - if getattr(exc, 'errno', None) not in ( - ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, - ssl.SSL_ERROR_SYSCALL): - if self._state == _DO_HANDSHAKE and self._handshake_cb: - self._handshake_cb(exc) - raise - self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ) - - # Check for record level data that needs to be sent back. - # Happens for the initial handshake and renegotiations. - if self._outgoing.pending: - ssldata.append(self._outgoing.read()) - return (ssldata, appdata) - - def feed_appdata(self, data, offset=0): - """Feed plaintext data into the pipe. - - Return an (ssldata, offset) tuple. The ssldata element is a list of - buffers containing record level data that needs to be sent to the - remote SSL instance. The offset is the number of plaintext bytes that - were processed, which may be less than the length of data. - - NOTE: In case of short writes, this call MUST be retried with the SAME - buffer passed into the *data* argument (i.e. the id() must be the - same). This is an OpenSSL requirement. A further particularity is that - a short write will always have offset == 0, because the _ssl module - does not enable partial writes. And even though the offset is zero, - there will still be encrypted data in ssldata. - """ - assert 0 <= offset <= len(data) - if self._state == _UNWRAPPED: - # pass through data in unwrapped mode - if offset < len(data): - ssldata = [data[offset:]] - else: - ssldata = [] - return (ssldata, len(data)) + if not hi >= lo >= 0: + raise ValueError('high (%r) must be >= low (%r) must be >= 0' % + (hi, lo)) - ssldata = [] - view = memoryview(data) - while True: - self._need_ssldata = False - try: - if offset < len(view): - offset += self._sslobj.write(view[offset:]) - except ssl.SSLError as exc: - # It is not allowed to call write() after unwrap() until the - # close_notify is acknowledged. We return the condition to the - # caller as a short write. - if exc.reason == 'PROTOCOL_IS_SHUTDOWN': - exc.errno = ssl.SSL_ERROR_WANT_READ - if exc.errno not in (ssl.SSL_ERROR_WANT_READ, - ssl.SSL_ERROR_WANT_WRITE, - ssl.SSL_ERROR_SYSCALL): - raise - self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ) - - # See if there's any record level data back for us. - if self._outgoing.pending: - ssldata.append(self._outgoing.read()) - if offset == len(view) or self._need_ssldata: - break - return (ssldata, offset) + return hi, lo class _SSLProtocolTransport(transports._FlowControlMixin, transports.Transport): - def __init__(self, loop, ssl_protocol, app_protocol): + _start_tls_compatible = True + _sendfile_compatible = constants._SendfileMode.FALLBACK + + def __init__(self, loop, ssl_protocol): self._loop = loop - # SSLProtocol instance self._ssl_protocol = ssl_protocol - self._app_protocol = app_protocol self._closed = False def get_extra_info(self, name, default=None): @@ -306,10 +95,10 @@ def get_extra_info(self, name, default=None): return self._ssl_protocol._get_extra_info(name, default) def set_protocol(self, protocol): - self._app_protocol = protocol + self._ssl_protocol._set_app_protocol(protocol) def get_protocol(self): - return self._app_protocol + return self._ssl_protocol._app_protocol def is_closing(self): return self._closed @@ -322,18 +111,21 @@ def close(self): protocol's connection_lost() method will (eventually) called with None as its argument. """ - self._closed = True - self._ssl_protocol._start_shutdown() - - # On Python 3.3 and older, objects with a destructor part of a reference - # cycle are never destroyed. It's not more the case on Python 3.4 thanks - # to the PEP 442. - if compat.PY34: - def __del__(self): - if not self._closed: - warnings.warn("unclosed transport %r" % self, ResourceWarning, - source=self) - self.close() + if not self._closed: + self._closed = True + self._ssl_protocol._start_shutdown() + else: + self._ssl_protocol = None + + def __del__(self, _warnings=warnings): + if not self._closed: + self._closed = True + _warnings.warn( + "unclosed transport ", ResourceWarning) + + def is_reading(self): + return not self._ssl_protocol._app_reading_paused def pause_reading(self): """Pause the receiving end. @@ -341,7 +133,7 @@ def pause_reading(self): No data will be passed to the protocol's data_received() method until resume_reading() is called. """ - self._ssl_protocol._transport.pause_reading() + self._ssl_protocol._pause_reading() def resume_reading(self): """Resume the receiving end. @@ -349,7 +141,7 @@ def resume_reading(self): Data received will once again be passed to the protocol's data_received() method. """ - self._ssl_protocol._transport.resume_reading() + self._ssl_protocol._resume_reading() def set_write_buffer_limits(self, high=None, low=None): """Set the high- and low-water limits for write flow control. @@ -370,11 +162,51 @@ def set_write_buffer_limits(self, high=None, low=None): reduces opportunities for doing I/O and computation concurrently. """ - self._ssl_protocol._transport.set_write_buffer_limits(high, low) + self._ssl_protocol._set_write_buffer_limits(high, low) + self._ssl_protocol._control_app_writing() + + def get_write_buffer_limits(self): + return (self._ssl_protocol._outgoing_low_water, + self._ssl_protocol._outgoing_high_water) def get_write_buffer_size(self): - """Return the current size of the write buffer.""" - return self._ssl_protocol._transport.get_write_buffer_size() + """Return the current size of the write buffers.""" + return self._ssl_protocol._get_write_buffer_size() + + def set_read_buffer_limits(self, high=None, low=None): + """Set the high- and low-water limits for read flow control. + + These two values control when to call the upstream transport's + pause_reading() and resume_reading() methods. If specified, + the low-water limit must be less than or equal to the + high-water limit. Neither value can be negative. + + The defaults are implementation-specific. If only the + high-water limit is given, the low-water limit defaults to an + implementation-specific value less than or equal to the + high-water limit. Setting high to zero forces low to zero as + well, and causes pause_reading() to be called whenever the + buffer becomes non-empty. Setting low to zero causes + resume_reading() to be called only once the buffer is empty. + Use of zero for either limit is generally sub-optimal as it + reduces opportunities for doing I/O and computation + concurrently. + """ + self._ssl_protocol._set_read_buffer_limits(high, low) + self._ssl_protocol._control_ssl_reading() + + def get_read_buffer_limits(self): + return (self._ssl_protocol._incoming_low_water, + self._ssl_protocol._incoming_high_water) + + def get_read_buffer_size(self): + """Return the current size of the read buffer.""" + return self._ssl_protocol._get_read_buffer_size() + + @property + def _protocol_paused(self): + # Required for sendfile fallback pause_writing/resume_writing logic + return self._ssl_protocol._app_writing_paused def write(self, data): """Write some data bytes to the transport. @@ -383,11 +215,26 @@ def write(self, data): to be sent out asynchronously. """ if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError("data: expecting a bytes-like instance, got {!r}" - .format(type(data).__name__)) + raise TypeError(f"data: expecting a bytes-like instance, " + f"got {type(data).__name__}") if not data: return - self._ssl_protocol._write_appdata(data) + self._ssl_protocol._write_appdata((data,)) + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation concatenates the arguments and + calls write() on the result. + """ + self._ssl_protocol._write_appdata(list_of_data) + + def write_eof(self): + """Close the write end after flushing buffered data. + + This raises :exc:`NotImplementedError` right now. + """ + raise NotImplementedError def can_write_eof(self): """Return True if this transport supports write_eof(), False if not.""" @@ -400,24 +247,53 @@ def abort(self): The protocol's connection_lost() method will (eventually) be called with None as its argument. """ - self._ssl_protocol._abort() + self._force_close(None) + def _force_close(self, exc): + self._closed = True + if self._ssl_protocol is not None: + self._ssl_protocol._abort(exc) + + def _test__append_write_backlog(self, data): + # for test only + self._ssl_protocol._write_backlog.append(data) + self._ssl_protocol._write_buffer_size += len(data) -class SSLProtocol(protocols.Protocol): - """SSL protocol. - Implementation of SSL on top of a socket using incoming and outgoing - buffers which are ssl.MemoryBIO objects. - """ +class SSLProtocol(protocols.BufferedProtocol): + max_size = 256 * 1024 # Buffer size passed to read() + + _handshake_start_time = None + _handshake_timeout_handle = None + _shutdown_timeout_handle = None def __init__(self, loop, app_protocol, sslcontext, waiter, server_side=False, server_hostname=None, - call_connection_made=True): + call_connection_made=True, + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): if ssl is None: - raise RuntimeError('stdlib ssl module not available') + raise RuntimeError("stdlib ssl module not available") + + self._ssl_buffer = bytearray(self.max_size) + self._ssl_buffer_view = memoryview(self._ssl_buffer) + + if ssl_handshake_timeout is None: + ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT + elif ssl_handshake_timeout <= 0: + raise ValueError( + f"ssl_handshake_timeout should be a positive number, " + f"got {ssl_handshake_timeout}") + if ssl_shutdown_timeout is None: + ssl_shutdown_timeout = constants.SSL_SHUTDOWN_TIMEOUT + elif ssl_shutdown_timeout <= 0: + raise ValueError( + f"ssl_shutdown_timeout should be a positive number, " + f"got {ssl_shutdown_timeout}") if not sslcontext: - sslcontext = _create_transport_context(server_side, server_hostname) + sslcontext = _create_transport_context( + server_side, server_hostname) self._server_side = server_side if server_hostname and not server_side: @@ -435,17 +311,55 @@ def __init__(self, loop, app_protocol, sslcontext, waiter, self._waiter = waiter self._loop = loop - self._app_protocol = app_protocol - self._app_transport = _SSLProtocolTransport(self._loop, - self, self._app_protocol) - # _SSLPipe instance (None until the connection is made) - self._sslpipe = None - self._session_established = False - self._in_handshake = False - self._in_shutdown = False + self._set_app_protocol(app_protocol) + self._app_transport = None + self._app_transport_created = False # transport, ex: SelectorSocketTransport self._transport = None - self._call_connection_made = call_connection_made + self._ssl_handshake_timeout = ssl_handshake_timeout + self._ssl_shutdown_timeout = ssl_shutdown_timeout + # SSL and state machine + self._incoming = ssl.MemoryBIO() + self._outgoing = ssl.MemoryBIO() + self._state = SSLProtocolState.UNWRAPPED + self._conn_lost = 0 # Set when connection_lost called + if call_connection_made: + self._app_state = AppProtocolState.STATE_INIT + else: + self._app_state = AppProtocolState.STATE_CON_MADE + self._sslobj = self._sslcontext.wrap_bio( + self._incoming, self._outgoing, + server_side=self._server_side, + server_hostname=self._server_hostname) + + # Flow Control + + self._ssl_writing_paused = False + + self._app_reading_paused = False + + self._ssl_reading_paused = False + self._incoming_high_water = 0 + self._incoming_low_water = 0 + self._set_read_buffer_limits() + self._eof_received = False + + self._app_writing_paused = False + self._outgoing_high_water = 0 + self._outgoing_low_water = 0 + self._set_write_buffer_limits() + self._get_app_transport() + + def _set_app_protocol(self, app_protocol): + self._app_protocol = app_protocol + # Make fast hasattr check first + if (hasattr(app_protocol, 'get_buffer') and + isinstance(app_protocol, protocols.BufferedProtocol)): + self._app_protocol_get_buffer = app_protocol.get_buffer + self._app_protocol_buffer_updated = app_protocol.buffer_updated + self._app_protocol_is_buffer = True + else: + self._app_protocol_is_buffer = False def _wakeup_waiter(self, exc=None): if self._waiter is None: @@ -457,15 +371,20 @@ def _wakeup_waiter(self, exc=None): self._waiter.set_result(None) self._waiter = None + def _get_app_transport(self): + if self._app_transport is None: + if self._app_transport_created: + raise RuntimeError('Creating _SSLProtocolTransport twice') + self._app_transport = _SSLProtocolTransport(self._loop, self) + self._app_transport_created = True + return self._app_transport + def connection_made(self, transport): """Called when the low-level connection is made. Start the SSL handshake. """ self._transport = transport - self._sslpipe = _SSLPipe(self._sslcontext, - self._server_side, - self._server_hostname) self._start_handshake() def connection_lost(self, exc): @@ -475,48 +394,58 @@ def connection_lost(self, exc): meaning a regular EOF is received or the connection was aborted or closed). """ - if self._session_established: - self._session_established = False - self._loop.call_soon(self._app_protocol.connection_lost, exc) + self._write_backlog.clear() + self._outgoing.read() + self._conn_lost += 1 + + # Just mark the app transport as closed so that its __dealloc__ + # doesn't complain. + if self._app_transport is not None: + self._app_transport._closed = True + + if self._state != SSLProtocolState.DO_HANDSHAKE: + if ( + self._app_state == AppProtocolState.STATE_CON_MADE or + self._app_state == AppProtocolState.STATE_EOF + ): + self._app_state = AppProtocolState.STATE_CON_LOST + self._loop.call_soon(self._app_protocol.connection_lost, exc) + self._set_state(SSLProtocolState.UNWRAPPED) self._transport = None self._app_transport = None + self._app_protocol = None self._wakeup_waiter(exc) - def pause_writing(self): - """Called when the low-level transport's buffer goes over - the high-water mark. - """ - self._app_protocol.pause_writing() + if self._shutdown_timeout_handle: + self._shutdown_timeout_handle.cancel() + self._shutdown_timeout_handle = None + if self._handshake_timeout_handle: + self._handshake_timeout_handle.cancel() + self._handshake_timeout_handle = None - def resume_writing(self): - """Called when the low-level transport's buffer drains below - the low-water mark. - """ - self._app_protocol.resume_writing() + def get_buffer(self, n): + want = n + if want <= 0 or want > self.max_size: + want = self.max_size + if len(self._ssl_buffer) < want: + self._ssl_buffer = bytearray(want) + self._ssl_buffer_view = memoryview(self._ssl_buffer) + return self._ssl_buffer_view - def data_received(self, data): - """Called when some SSL data is received. + def buffer_updated(self, nbytes): + self._incoming.write(self._ssl_buffer_view[:nbytes]) - The argument is a bytes object. - """ - try: - ssldata, appdata = self._sslpipe.feed_ssldata(data) - except ssl.SSLError as e: - if self._loop.get_debug(): - logger.warning('%r: SSL error %s (reason %s)', - self, e.errno, e.reason) - self._abort() - return + if self._state == SSLProtocolState.DO_HANDSHAKE: + self._do_handshake() - for chunk in ssldata: - self._transport.write(chunk) + elif self._state == SSLProtocolState.WRAPPED: + self._do_read() - for chunk in appdata: - if chunk: - self._app_protocol.data_received(chunk) - else: - self._start_shutdown() - break + elif self._state == SSLProtocolState.FLUSHING: + self._do_flush() + + elif self._state == SSLProtocolState.SHUTDOWN: + self._do_shutdown() def eof_received(self): """Called when the other end of the low-level stream @@ -526,36 +455,80 @@ def eof_received(self): will close itself. If it returns a true value, closing the transport is up to the protocol. """ + self._eof_received = True try: if self._loop.get_debug(): logger.debug("%r received EOF", self) - self._wakeup_waiter(ConnectionResetError) + if self._state == SSLProtocolState.DO_HANDSHAKE: + self._on_handshake_complete(ConnectionResetError) - if not self._in_handshake: - keep_open = self._app_protocol.eof_received() - if keep_open: - logger.warning('returning true from eof_received() ' - 'has no effect when using ssl') - finally: + elif self._state == SSLProtocolState.WRAPPED: + self._set_state(SSLProtocolState.FLUSHING) + if self._app_reading_paused: + return True + else: + self._do_flush() + + elif self._state == SSLProtocolState.FLUSHING: + self._do_write() + self._set_state(SSLProtocolState.SHUTDOWN) + self._do_shutdown() + + elif self._state == SSLProtocolState.SHUTDOWN: + self._do_shutdown() + + except Exception: self._transport.close() + raise def _get_extra_info(self, name, default=None): if name in self._extra: return self._extra[name] - else: + elif self._transport is not None: return self._transport.get_extra_info(name, default) + else: + return default - def _start_shutdown(self): - if self._in_shutdown: - return - self._in_shutdown = True - self._write_appdata(b'') + def _set_state(self, new_state): + allowed = False + + if new_state == SSLProtocolState.UNWRAPPED: + allowed = True + + elif ( + self._state == SSLProtocolState.UNWRAPPED and + new_state == SSLProtocolState.DO_HANDSHAKE + ): + allowed = True + + elif ( + self._state == SSLProtocolState.DO_HANDSHAKE and + new_state == SSLProtocolState.WRAPPED + ): + allowed = True - def _write_appdata(self, data): - self._write_backlog.append((data, 0)) - self._write_buffer_size += len(data) - self._process_write_backlog() + elif ( + self._state == SSLProtocolState.WRAPPED and + new_state == SSLProtocolState.FLUSHING + ): + allowed = True + + elif ( + self._state == SSLProtocolState.FLUSHING and + new_state == SSLProtocolState.SHUTDOWN + ): + allowed = True + + if allowed: + self._state = new_state + + else: + raise RuntimeError( + 'cannot switch state from {} to {}'.format( + self._state, new_state)) + + # Handshake flow def _start_handshake(self): if self._loop.get_debug(): @@ -563,42 +536,58 @@ def _start_handshake(self): self._handshake_start_time = self._loop.time() else: self._handshake_start_time = None - self._in_handshake = True - # (b'', 1) is a special value in _process_write_backlog() to do - # the SSL handshake - self._write_backlog.append((b'', 1)) - self._loop.call_soon(self._process_write_backlog) + + self._set_state(SSLProtocolState.DO_HANDSHAKE) + + # start handshake timeout count down + self._handshake_timeout_handle = \ + self._loop.call_later(self._ssl_handshake_timeout, + lambda: self._check_handshake_timeout()) + + self._do_handshake() + + def _check_handshake_timeout(self): + if self._state == SSLProtocolState.DO_HANDSHAKE: + msg = ( + f"SSL handshake is taking longer than " + f"{self._ssl_handshake_timeout} seconds: " + f"aborting the connection" + ) + self._fatal_error(ConnectionAbortedError(msg)) + + def _do_handshake(self): + try: + self._sslobj.do_handshake() + except SSLAgainErrors: + self._process_outgoing() + except ssl.SSLError as exc: + self._on_handshake_complete(exc) + else: + self._on_handshake_complete(None) def _on_handshake_complete(self, handshake_exc): - self._in_handshake = False + if self._handshake_timeout_handle is not None: + self._handshake_timeout_handle.cancel() + self._handshake_timeout_handle = None - sslobj = self._sslpipe.ssl_object + sslobj = self._sslobj try: - if handshake_exc is not None: + if handshake_exc is None: + self._set_state(SSLProtocolState.WRAPPED) + else: raise handshake_exc peercert = sslobj.getpeercert() - if not hasattr(self._sslcontext, 'check_hostname'): - # Verify hostname if requested, Python 3.4+ uses check_hostname - # and checks the hostname in do_handshake() - if (self._server_hostname - and self._sslcontext.verify_mode != ssl.CERT_NONE): - ssl.match_hostname(peercert, self._server_hostname) - except BaseException as exc: - if self._loop.get_debug(): - if isinstance(exc, ssl.CertificateError): - logger.warning("%r: SSL handshake failed " - "on verifying the certificate", - self, exc_info=True) - else: - logger.warning("%r: SSL handshake failed", - self, exc_info=True) - self._transport.close() - if isinstance(exc, Exception): - self._wakeup_waiter(exc) - return + except Exception as exc: + handshake_exc = None + self._set_state(SSLProtocolState.UNWRAPPED) + if isinstance(exc, ssl.CertificateError): + msg = 'SSL handshake failed on verifying the certificate' else: - raise + msg = 'SSL handshake failed' + self._fatal_error(exc, msg) + self._wakeup_waiter(exc) + return if self._loop.get_debug(): dt = self._loop.time() - self._handshake_start_time @@ -608,85 +597,330 @@ def _on_handshake_complete(self, handshake_exc): self._extra.update(peercert=peercert, cipher=sslobj.cipher(), compression=sslobj.compression(), - ssl_object=sslobj, - ) - if self._call_connection_made: - self._app_protocol.connection_made(self._app_transport) + ssl_object=sslobj) + if self._app_state == AppProtocolState.STATE_INIT: + self._app_state = AppProtocolState.STATE_CON_MADE + self._app_protocol.connection_made(self._get_app_transport()) self._wakeup_waiter() - self._session_established = True - # In case transport.write() was already called. Don't call - # immediately _process_write_backlog(), but schedule it: - # _on_handshake_complete() can be called indirectly from - # _process_write_backlog(), and _process_write_backlog() is not - # reentrant. - self._loop.call_soon(self._process_write_backlog) - - def _process_write_backlog(self): - # Try to make progress on the write backlog. - if self._transport is None: + self._do_read() + + # Shutdown flow + + def _start_shutdown(self): + if ( + self._state in ( + SSLProtocolState.FLUSHING, + SSLProtocolState.SHUTDOWN, + SSLProtocolState.UNWRAPPED + ) + ): return + if self._app_transport is not None: + self._app_transport._closed = True + if self._state == SSLProtocolState.DO_HANDSHAKE: + self._abort(None) + else: + self._set_state(SSLProtocolState.FLUSHING) + self._shutdown_timeout_handle = self._loop.call_later( + self._ssl_shutdown_timeout, + lambda: self._check_shutdown_timeout() + ) + self._do_flush() + + def _check_shutdown_timeout(self): + if ( + self._state in ( + SSLProtocolState.FLUSHING, + SSLProtocolState.SHUTDOWN + ) + ): + self._transport._force_close( + exceptions.TimeoutError('SSL shutdown timed out')) + + def _do_flush(self): + self._do_read() + self._set_state(SSLProtocolState.SHUTDOWN) + self._do_shutdown() + + def _do_shutdown(self): + try: + if not self._eof_received: + self._sslobj.unwrap() + except SSLAgainErrors: + self._process_outgoing() + except ssl.SSLError as exc: + self._on_shutdown_complete(exc) + else: + self._process_outgoing() + self._call_eof_received() + self._on_shutdown_complete(None) + def _on_shutdown_complete(self, shutdown_exc): + if self._shutdown_timeout_handle is not None: + self._shutdown_timeout_handle.cancel() + self._shutdown_timeout_handle = None + + if shutdown_exc: + self._fatal_error(shutdown_exc) + else: + self._loop.call_soon(self._transport.close) + + def _abort(self, exc): + self._set_state(SSLProtocolState.UNWRAPPED) + if self._transport is not None: + self._transport._force_close(exc) + + # Outgoing flow + + def _write_appdata(self, list_of_data): + if ( + self._state in ( + SSLProtocolState.FLUSHING, + SSLProtocolState.SHUTDOWN, + SSLProtocolState.UNWRAPPED + ) + ): + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('SSL connection is closed') + self._conn_lost += 1 + return + + for data in list_of_data: + self._write_backlog.append(data) + self._write_buffer_size += len(data) + + try: + if self._state == SSLProtocolState.WRAPPED: + self._do_write() + + except Exception as ex: + self._fatal_error(ex, 'Fatal error on SSL protocol') + + def _do_write(self): + try: + while self._write_backlog: + data = self._write_backlog[0] + count = self._sslobj.write(data) + data_len = len(data) + if count < data_len: + self._write_backlog[0] = data[count:] + self._write_buffer_size -= count + else: + del self._write_backlog[0] + self._write_buffer_size -= data_len + except SSLAgainErrors: + pass + self._process_outgoing() + + def _process_outgoing(self): + if not self._ssl_writing_paused: + data = self._outgoing.read() + if len(data): + self._transport.write(data) + self._control_app_writing() + + # Incoming flow + + def _do_read(self): + if ( + self._state not in ( + SSLProtocolState.WRAPPED, + SSLProtocolState.FLUSHING, + ) + ): + return try: - for i in range(len(self._write_backlog)): - data, offset = self._write_backlog[0] - if data: - ssldata, offset = self._sslpipe.feed_appdata(data, offset) - elif offset: - ssldata = self._sslpipe.do_handshake( - self._on_handshake_complete) - offset = 1 + if not self._app_reading_paused: + if self._app_protocol_is_buffer: + self._do_read__buffered() else: - ssldata = self._sslpipe.shutdown(self._finalize) - offset = 1 - - for chunk in ssldata: - self._transport.write(chunk) - - if offset < len(data): - self._write_backlog[0] = (data, offset) - # A short write means that a write is blocked on a read - # We need to enable reading if it is paused! - assert self._sslpipe.need_ssldata - if self._transport._paused: - self._transport.resume_reading() + self._do_read__copied() + if self._write_backlog: + self._do_write() + else: + self._process_outgoing() + self._control_ssl_reading() + except Exception as ex: + self._fatal_error(ex, 'Fatal error on SSL protocol') + + def _do_read__buffered(self): + offset = 0 + count = 1 + + buf = self._app_protocol_get_buffer(self._get_read_buffer_size()) + wants = len(buf) + + try: + count = self._sslobj.read(wants, buf) + + if count > 0: + offset = count + while offset < wants: + count = self._sslobj.read(wants - offset, buf[offset:]) + if count > 0: + offset += count + else: + break + else: + self._loop.call_soon(lambda: self._do_read()) + except SSLAgainErrors: + pass + if offset > 0: + self._app_protocol_buffer_updated(offset) + if not count: + # close_notify + self._call_eof_received() + self._start_shutdown() + + def _do_read__copied(self): + chunk = b'1' + zero = True + one = False + + try: + while True: + chunk = self._sslobj.read(self.max_size) + if not chunk: break + if zero: + zero = False + one = True + first = chunk + elif one: + one = False + data = [first, chunk] + else: + data.append(chunk) + except SSLAgainErrors: + pass + if one: + self._app_protocol.data_received(first) + elif not zero: + self._app_protocol.data_received(b''.join(data)) + if not chunk: + # close_notify + self._call_eof_received() + self._start_shutdown() + + def _call_eof_received(self): + try: + if self._app_state == AppProtocolState.STATE_CON_MADE: + self._app_state = AppProtocolState.STATE_EOF + keep_open = self._app_protocol.eof_received() + if keep_open: + logger.warning('returning true from eof_received() ' + 'has no effect when using ssl') + except (KeyboardInterrupt, SystemExit): + raise + except BaseException as ex: + self._fatal_error(ex, 'Error calling eof_received()') - # An entire chunk from the backlog was processed. We can - # delete it and reduce the outstanding buffer size. - del self._write_backlog[0] - self._write_buffer_size -= len(data) - except BaseException as exc: - if self._in_handshake: - # BaseExceptions will be re-raised in _on_handshake_complete. - self._on_handshake_complete(exc) - else: - self._fatal_error(exc, 'Fatal error on SSL transport') - if not isinstance(exc, Exception): - # BaseException + # Flow control for writes from APP socket + + def _control_app_writing(self): + size = self._get_write_buffer_size() + if size >= self._outgoing_high_water and not self._app_writing_paused: + self._app_writing_paused = True + try: + self._app_protocol.pause_writing() + except (KeyboardInterrupt, SystemExit): raise + except BaseException as exc: + self._loop.call_exception_handler({ + 'message': 'protocol.pause_writing() failed', + 'exception': exc, + 'transport': self._app_transport, + 'protocol': self, + }) + elif size <= self._outgoing_low_water and self._app_writing_paused: + self._app_writing_paused = False + try: + self._app_protocol.resume_writing() + except (KeyboardInterrupt, SystemExit): + raise + except BaseException as exc: + self._loop.call_exception_handler({ + 'message': 'protocol.resume_writing() failed', + 'exception': exc, + 'transport': self._app_transport, + 'protocol': self, + }) + + def _get_write_buffer_size(self): + return self._outgoing.pending + self._write_buffer_size + + def _set_write_buffer_limits(self, high=None, low=None): + high, low = add_flowcontrol_defaults( + high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_WRITE) + self._outgoing_high_water = high + self._outgoing_low_water = low + + # Flow control for reads to APP socket + + def _pause_reading(self): + self._app_reading_paused = True + + def _resume_reading(self): + if self._app_reading_paused: + self._app_reading_paused = False + + def resume(): + if self._state == SSLProtocolState.WRAPPED: + self._do_read() + elif self._state == SSLProtocolState.FLUSHING: + self._do_flush() + elif self._state == SSLProtocolState.SHUTDOWN: + self._do_shutdown() + self._loop.call_soon(resume) + + # Flow control for reads from SSL socket + + def _control_ssl_reading(self): + size = self._get_read_buffer_size() + if size >= self._incoming_high_water and not self._ssl_reading_paused: + self._ssl_reading_paused = True + self._transport.pause_reading() + elif size <= self._incoming_low_water and self._ssl_reading_paused: + self._ssl_reading_paused = False + self._transport.resume_reading() + + def _set_read_buffer_limits(self, high=None, low=None): + high, low = add_flowcontrol_defaults( + high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_READ) + self._incoming_high_water = high + self._incoming_low_water = low + + def _get_read_buffer_size(self): + return self._incoming.pending + + # Flow control for writes to SSL socket + + def pause_writing(self): + """Called when the low-level transport's buffer goes over + the high-water mark. + """ + assert not self._ssl_writing_paused + self._ssl_writing_paused = True + + def resume_writing(self): + """Called when the low-level transport's buffer drains below + the low-water mark. + """ + assert self._ssl_writing_paused + self._ssl_writing_paused = False + self._process_outgoing() def _fatal_error(self, exc, message='Fatal error on transport'): - # Should be called from exception handler only. - if isinstance(exc, base_events._FATAL_ERROR_IGNORE): + if self._transport: + self._transport._force_close(exc) + + if isinstance(exc, OSError): if self._loop.get_debug(): logger.debug("%r: %s", self, message, exc_info=True) - else: + elif not isinstance(exc, exceptions.CancelledError): self._loop.call_exception_handler({ 'message': message, 'exception': exc, 'transport': self._transport, 'protocol': self, }) - if self._transport: - self._transport._force_close(exc) - - def _finalize(self): - if self._transport is not None: - self._transport.close() - - def _abort(self): - if self._transport is not None: - try: - self._transport.abort() - finally: - self._finalize() diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py new file mode 100644 index 0000000000..451a53a16f --- /dev/null +++ b/Lib/asyncio/staggered.py @@ -0,0 +1,149 @@ +"""Support for running coroutines in parallel with staggered start times.""" + +__all__ = 'staggered_race', + +import contextlib +import typing + +from . import events +from . import exceptions as exceptions_mod +from . import locks +from . import tasks + + +async def staggered_race( + coro_fns: typing.Iterable[typing.Callable[[], typing.Awaitable]], + delay: typing.Optional[float], + *, + loop: events.AbstractEventLoop = None, +) -> typing.Tuple[ + typing.Any, + typing.Optional[int], + typing.List[typing.Optional[Exception]] +]: + """Run coroutines with staggered start times and take the first to finish. + + This method takes an iterable of coroutine functions. The first one is + started immediately. From then on, whenever the immediately preceding one + fails (raises an exception), or when *delay* seconds has passed, the next + coroutine is started. This continues until one of the coroutines complete + successfully, in which case all others are cancelled, or until all + coroutines fail. + + The coroutines provided should be well-behaved in the following way: + + * They should only ``return`` if completed successfully. + + * They should always raise an exception if they did not complete + successfully. In particular, if they handle cancellation, they should + probably reraise, like this:: + + try: + # do work + except asyncio.CancelledError: + # undo partially completed work + raise + + Args: + coro_fns: an iterable of coroutine functions, i.e. callables that + return a coroutine object when called. Use ``functools.partial`` or + lambdas to pass arguments. + + delay: amount of time, in seconds, between starting coroutines. If + ``None``, the coroutines will run sequentially. + + loop: the event loop to use. + + Returns: + tuple *(winner_result, winner_index, exceptions)* where + + - *winner_result*: the result of the winning coroutine, or ``None`` + if no coroutines won. + + - *winner_index*: the index of the winning coroutine in + ``coro_fns``, or ``None`` if no coroutines won. If the winning + coroutine may return None on success, *winner_index* can be used + to definitively determine whether any coroutine won. + + - *exceptions*: list of exceptions returned by the coroutines. + ``len(exceptions)`` is equal to the number of coroutines actually + started, and the order is the same as in ``coro_fns``. The winning + coroutine's entry is ``None``. + + """ + # TODO: when we have aiter() and anext(), allow async iterables in coro_fns. + loop = loop or events.get_running_loop() + enum_coro_fns = enumerate(coro_fns) + winner_result = None + winner_index = None + exceptions = [] + running_tasks = [] + + async def run_one_coro( + previous_failed: typing.Optional[locks.Event]) -> None: + # Wait for the previous task to finish, or for delay seconds + if previous_failed is not None: + with contextlib.suppress(exceptions_mod.TimeoutError): + # Use asyncio.wait_for() instead of asyncio.wait() here, so + # that if we get cancelled at this point, Event.wait() is also + # cancelled, otherwise there will be a "Task destroyed but it is + # pending" later. + await tasks.wait_for(previous_failed.wait(), delay) + # Get the next coroutine to run + try: + this_index, coro_fn = next(enum_coro_fns) + except StopIteration: + return + # Start task that will run the next coroutine + this_failed = locks.Event() + next_task = loop.create_task(run_one_coro(this_failed)) + running_tasks.append(next_task) + assert len(running_tasks) == this_index + 2 + # Prepare place to put this coroutine's exceptions if not won + exceptions.append(None) + assert len(exceptions) == this_index + 1 + + try: + result = await coro_fn() + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as e: + exceptions[this_index] = e + this_failed.set() # Kickstart the next coroutine + else: + # Store winner's results + nonlocal winner_index, winner_result + assert winner_index is None + winner_index = this_index + winner_result = result + # Cancel all other tasks. We take care to not cancel the current + # task as well. If we do so, then since there is no `await` after + # here and CancelledError are usually thrown at one, we will + # encounter a curious corner case where the current task will end + # up as done() == True, cancelled() == False, exception() == + # asyncio.CancelledError. This behavior is specified in + # https://bugs.python.org/issue30048 + for i, t in enumerate(running_tasks): + if i != this_index: + t.cancel() + + first_task = loop.create_task(run_one_coro(None)) + running_tasks.append(first_task) + try: + # Wait for a growing list of tasks to all finish: poor man's version of + # curio's TaskGroup or trio's nursery + done_count = 0 + while done_count != len(running_tasks): + done, _ = await tasks.wait(running_tasks) + done_count = len(done) + # If run_one_coro raises an unhandled exception, it's probably a + # programming error, and I want to see it. + if __debug__: + for d in done: + if d.done() and not d.cancelled() and d.exception(): + raise d.exception() + return winner_result, winner_index, exceptions + finally: + # Make sure no tasks are left running if we leave this function + for t in running_tasks: + t.cancel() diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index a82cc79aca..f310aa2f36 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -1,55 +1,30 @@ -"""Stream-related things.""" - -__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol', - 'open_connection', 'start_server', - 'IncompleteReadError', - 'LimitOverrunError', - ] +__all__ = ( + 'StreamReader', 'StreamWriter', 'StreamReaderProtocol', + 'open_connection', 'start_server') +import collections import socket +import sys +import warnings +import weakref if hasattr(socket, 'AF_UNIX'): - __all__.extend(['open_unix_connection', 'start_unix_server']) + __all__ += ('open_unix_connection', 'start_unix_server') from . import coroutines -from . import compat from . import events +from . import exceptions +from . import format_helpers from . import protocols -from .coroutines import coroutine from .log import logger +from .tasks import sleep -_DEFAULT_LIMIT = 2 ** 16 - - -class IncompleteReadError(EOFError): - """ - Incomplete read error. Attributes: - - - partial: read bytes string before the end of stream was reached - - expected: total number of expected bytes (or None if unknown) - """ - def __init__(self, partial, expected): - super().__init__("%d bytes read on a total of %r expected bytes" - % (len(partial), expected)) - self.partial = partial - self.expected = expected - - -class LimitOverrunError(Exception): - """Reached the buffer limit while looking for a separator. - - Attributes: - - consumed: total number of to be consumed bytes. - """ - def __init__(self, message, consumed): - super().__init__(message) - self.consumed = consumed +_DEFAULT_LIMIT = 2 ** 16 # 64 KiB -@coroutine -def open_connection(host=None, port=None, *, - loop=None, limit=_DEFAULT_LIMIT, **kwds): +async def open_connection(host=None, port=None, *, + limit=_DEFAULT_LIMIT, **kwds): """A wrapper for create_connection() returning a (reader, writer) pair. The reader returned is a StreamReader instance; the writer is a @@ -67,19 +42,17 @@ def open_connection(host=None, port=None, *, StreamReaderProtocol classes, just copy the code -- there's really nothing special here except some convenience.) """ - if loop is None: - loop = events.get_event_loop() + loop = events.get_running_loop() reader = StreamReader(limit=limit, loop=loop) protocol = StreamReaderProtocol(reader, loop=loop) - transport, _ = yield from loop.create_connection( + transport, _ = await loop.create_connection( lambda: protocol, host, port, **kwds) writer = StreamWriter(transport, protocol, reader, loop) return reader, writer -@coroutine -def start_server(client_connected_cb, host=None, port=None, *, - loop=None, limit=_DEFAULT_LIMIT, **kwds): +async def start_server(client_connected_cb, host=None, port=None, *, + limit=_DEFAULT_LIMIT, **kwds): """Start a socket server, call back for each client connected. The first parameter, `client_connected_cb`, takes two parameters: @@ -94,15 +67,13 @@ def start_server(client_connected_cb, host=None, port=None, *, positional host and port, with various optional keyword arguments following. The return value is the same as loop.create_server(). - Additional optional keyword arguments are loop (to set the event loop - instance to use) and limit (to set the buffer limit passed to the - StreamReader). + Additional optional keyword argument is limit (to set the buffer + limit passed to the StreamReader). The return value is the same as loop.create_server(), i.e. a Server object which can be used to stop the service. """ - if loop is None: - loop = events.get_event_loop() + loop = events.get_running_loop() def factory(): reader = StreamReader(limit=limit, loop=loop) @@ -110,31 +81,28 @@ def factory(): loop=loop) return protocol - return (yield from loop.create_server(factory, host, port, **kwds)) + return await loop.create_server(factory, host, port, **kwds) if hasattr(socket, 'AF_UNIX'): # UNIX Domain Sockets are supported on this platform - @coroutine - def open_unix_connection(path=None, *, - loop=None, limit=_DEFAULT_LIMIT, **kwds): + async def open_unix_connection(path=None, *, + limit=_DEFAULT_LIMIT, **kwds): """Similar to `open_connection` but works with UNIX Domain Sockets.""" - if loop is None: - loop = events.get_event_loop() + loop = events.get_running_loop() + reader = StreamReader(limit=limit, loop=loop) protocol = StreamReaderProtocol(reader, loop=loop) - transport, _ = yield from loop.create_unix_connection( + transport, _ = await loop.create_unix_connection( lambda: protocol, path, **kwds) writer = StreamWriter(transport, protocol, reader, loop) return reader, writer - @coroutine - def start_unix_server(client_connected_cb, path=None, *, - loop=None, limit=_DEFAULT_LIMIT, **kwds): + async def start_unix_server(client_connected_cb, path=None, *, + limit=_DEFAULT_LIMIT, **kwds): """Similar to `start_server` but works with UNIX Domain Sockets.""" - if loop is None: - loop = events.get_event_loop() + loop = events.get_running_loop() def factory(): reader = StreamReader(limit=limit, loop=loop) @@ -142,14 +110,14 @@ def factory(): loop=loop) return protocol - return (yield from loop.create_unix_server(factory, path, **kwds)) + return await loop.create_unix_server(factory, path, **kwds) class FlowControlMixin(protocols.Protocol): """Reusable flow control logic for StreamWriter.drain(). This implements the protocol methods pause_writing(), - resume_reading() and connection_lost(). If the subclass overrides + resume_writing() and connection_lost(). If the subclass overrides these it must call the super methods. StreamWriter.drain() must wait for _drain_helper() coroutine. @@ -161,7 +129,7 @@ def __init__(self, loop=None): else: self._loop = loop self._paused = False - self._drain_waiter = None + self._drain_waiters = collections.deque() self._connection_lost = False def pause_writing(self): @@ -176,39 +144,37 @@ def resume_writing(self): if self._loop.get_debug(): logger.debug("%r resumes writing", self) - waiter = self._drain_waiter - if waiter is not None: - self._drain_waiter = None + for waiter in self._drain_waiters: if not waiter.done(): waiter.set_result(None) def connection_lost(self, exc): self._connection_lost = True - # Wake up the writer if currently paused. + # Wake up the writer(s) if currently paused. if not self._paused: return - waiter = self._drain_waiter - if waiter is None: - return - self._drain_waiter = None - if waiter.done(): - return - if exc is None: - waiter.set_result(None) - else: - waiter.set_exception(exc) - @coroutine - def _drain_helper(self): + for waiter in self._drain_waiters: + if not waiter.done(): + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + async def _drain_helper(self): if self._connection_lost: raise ConnectionResetError('Connection lost') if not self._paused: return - waiter = self._drain_waiter - assert waiter is None or waiter.cancelled() waiter = self._loop.create_future() - self._drain_waiter = waiter - yield from waiter + self._drain_waiters.append(waiter) + try: + await waiter + finally: + self._drain_waiters.remove(waiter) + + def _get_close_waiter(self, stream): + raise NotImplementedError class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): @@ -220,40 +186,110 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): call inappropriate methods of the protocol.) """ + _source_traceback = None + def __init__(self, stream_reader, client_connected_cb=None, loop=None): super().__init__(loop=loop) - self._stream_reader = stream_reader + if stream_reader is not None: + self._stream_reader_wr = weakref.ref(stream_reader) + self._source_traceback = stream_reader._source_traceback + else: + self._stream_reader_wr = None + if client_connected_cb is not None: + # This is a stream created by the `create_server()` function. + # Keep a strong reference to the reader until a connection + # is established. + self._strong_reader = stream_reader + self._reject_connection = False self._stream_writer = None + self._task = None + self._transport = None self._client_connected_cb = client_connected_cb self._over_ssl = False + self._closed = self._loop.create_future() + + @property + def _stream_reader(self): + if self._stream_reader_wr is None: + return None + return self._stream_reader_wr() + + def _replace_writer(self, writer): + loop = self._loop + transport = writer.transport + self._stream_writer = writer + self._transport = transport + self._over_ssl = transport.get_extra_info('sslcontext') is not None def connection_made(self, transport): - self._stream_reader.set_transport(transport) + if self._reject_connection: + context = { + 'message': ('An open stream was garbage collected prior to ' + 'establishing network connection; ' + 'call "stream.close()" explicitly.') + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + transport.abort() + return + self._transport = transport + reader = self._stream_reader + if reader is not None: + reader.set_transport(transport) self._over_ssl = transport.get_extra_info('sslcontext') is not None if self._client_connected_cb is not None: self._stream_writer = StreamWriter(transport, self, - self._stream_reader, + reader, self._loop) - res = self._client_connected_cb(self._stream_reader, + res = self._client_connected_cb(reader, self._stream_writer) if coroutines.iscoroutine(res): - self._loop.create_task(res) + def callback(task): + if task.cancelled(): + transport.close() + return + exc = task.exception() + if exc is not None: + self._loop.call_exception_handler({ + 'message': 'Unhandled exception in client_connected_cb', + 'exception': exc, + 'transport': transport, + }) + transport.close() + + self._task = self._loop.create_task(res) + self._task.add_done_callback(callback) + + self._strong_reader = None def connection_lost(self, exc): - if self._stream_reader is not None: + reader = self._stream_reader + if reader is not None: + if exc is None: + reader.feed_eof() + else: + reader.set_exception(exc) + if not self._closed.done(): if exc is None: - self._stream_reader.feed_eof() + self._closed.set_result(None) else: - self._stream_reader.set_exception(exc) + self._closed.set_exception(exc) super().connection_lost(exc) - self._stream_reader = None + self._stream_reader_wr = None self._stream_writer = None + self._task = None + self._transport = None def data_received(self, data): - self._stream_reader.feed_data(data) + reader = self._stream_reader + if reader is not None: + reader.feed_data(data) def eof_received(self): - self._stream_reader.feed_eof() + reader = self._stream_reader + if reader is not None: + reader.feed_eof() if self._over_ssl: # Prevent a warning in SSLProtocol.eof_received: # "returning true from eof_received() @@ -261,6 +297,20 @@ def eof_received(self): return False return True + def _get_close_waiter(self, stream): + return self._closed + + def __del__(self): + # Prevent reports about unhandled exceptions. + # Better than self._closed._log_traceback = False hack + try: + closed = self._closed + except AttributeError: + pass # failed constructor + else: + if closed.done() and not closed.cancelled(): + closed.exception() + class StreamWriter: """Wraps a Transport. @@ -279,12 +329,14 @@ def __init__(self, transport, protocol, reader, loop): assert reader is None or isinstance(reader, StreamReader) self._reader = reader self._loop = loop + self._complete_fut = self._loop.create_future() + self._complete_fut.set_result(None) def __repr__(self): - info = [self.__class__.__name__, 'transport=%r' % self._transport] + info = [self.__class__.__name__, f'transport={self._transport!r}'] if self._reader is not None: - info.append('reader=%r' % self._reader) - return '<%s>' % ' '.join(info) + info.append(f'reader={self._reader!r}') + return '<{}>'.format(' '.join(info)) @property def transport(self): @@ -305,36 +357,68 @@ def can_write_eof(self): def close(self): return self._transport.close() + def is_closing(self): + return self._transport.is_closing() + + async def wait_closed(self): + await self._protocol._get_close_waiter(self) + def get_extra_info(self, name, default=None): return self._transport.get_extra_info(name, default) - @coroutine - def drain(self): + async def drain(self): """Flush the write buffer. The intended use is to write w.write(data) - yield from w.drain() + await w.drain() """ if self._reader is not None: exc = self._reader.exception() if exc is not None: raise exc - if self._transport is not None: - if self._transport.is_closing(): - # Yield to the event loop so connection_lost() may be - # called. Without this, _drain_helper() would return - # immediately, and code that calls - # write(...); yield from drain() - # in a loop would never call connection_lost(), so it - # would not see an error when the socket is closed. - yield - yield from self._protocol._drain_helper() - + if self._transport.is_closing(): + # Wait for protocol.connection_lost() call + # Raise connection closing error if any, + # ConnectionResetError otherwise + # Yield to the event loop so connection_lost() may be + # called. Without this, _drain_helper() would return + # immediately, and code that calls + # write(...); await drain() + # in a loop would never call connection_lost(), so it + # would not see an error when the socket is closed. + await sleep(0) + await self._protocol._drain_helper() + + async def start_tls(self, sslcontext, *, + server_hostname=None, + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): + """Upgrade an existing stream-based connection to TLS.""" + server_side = self._protocol._client_connected_cb is not None + protocol = self._protocol + await self.drain() + new_transport = await self._loop.start_tls( # type: ignore + self._transport, protocol, sslcontext, + server_side=server_side, server_hostname=server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) + self._transport = new_transport + protocol._replace_writer(self) + + def __del__(self): + if not self._transport.is_closing(): + if self._loop.is_closed(): + warnings.warn("loop is closed", ResourceWarning) + else: + self.close() + warnings.warn(f"unclosed {self!r}", ResourceWarning) class StreamReader: + _source_traceback = None + def __init__(self, limit=_DEFAULT_LIMIT, loop=None): # The line length limit is a security feature; # it also doubles as half the buffer limit. @@ -353,24 +437,27 @@ def __init__(self, limit=_DEFAULT_LIMIT, loop=None): self._exception = None self._transport = None self._paused = False + if self._loop.get_debug(): + self._source_traceback = format_helpers.extract_stack( + sys._getframe(1)) def __repr__(self): info = ['StreamReader'] if self._buffer: - info.append('%d bytes' % len(self._buffer)) + info.append(f'{len(self._buffer)} bytes') if self._eof: info.append('eof') if self._limit != _DEFAULT_LIMIT: - info.append('l=%d' % self._limit) + info.append(f'limit={self._limit}') if self._waiter: - info.append('w=%r' % self._waiter) + info.append(f'waiter={self._waiter!r}') if self._exception: - info.append('e=%r' % self._exception) + info.append(f'exception={self._exception!r}') if self._transport: - info.append('t=%r' % self._transport) + info.append(f'transport={self._transport!r}') if self._paused: info.append('paused') - return '<%s>' % ' '.join(info) + return '<{}>'.format(' '.join(info)) def exception(self): return self._exception @@ -431,8 +518,7 @@ def feed_data(self, data): else: self._paused = True - @coroutine - def _wait_for_data(self, func_name): + async def _wait_for_data(self, func_name): """Wait until feed_data() or feed_eof() is called. If stream was paused, automatically resume it. @@ -442,8 +528,9 @@ def _wait_for_data(self, func_name): # would have an unexpected behaviour. It would not possible to know # which coroutine would get the next data. if self._waiter is not None: - raise RuntimeError('%s() called while another coroutine is ' - 'already waiting for incoming data' % func_name) + raise RuntimeError( + f'{func_name}() called while another coroutine is ' + f'already waiting for incoming data') assert not self._eof, '_wait_for_data after EOF' @@ -455,12 +542,11 @@ def _wait_for_data(self, func_name): self._waiter = self._loop.create_future() try: - yield from self._waiter + await self._waiter finally: self._waiter = None - @coroutine - def readline(self): + async def readline(self): """Read chunk of data from the stream until newline (b'\n') is found. On success, return chunk that ends with newline. If only partial @@ -479,10 +565,10 @@ def readline(self): sep = b'\n' seplen = len(sep) try: - line = yield from self.readuntil(sep) - except IncompleteReadError as e: + line = await self.readuntil(sep) + except exceptions.IncompleteReadError as e: return e.partial - except LimitOverrunError as e: + except exceptions.LimitOverrunError as e: if self._buffer.startswith(sep, e.consumed): del self._buffer[:e.consumed + seplen] else: @@ -491,8 +577,7 @@ def readline(self): raise ValueError(e.args[0]) return line - @coroutine - def readuntil(self, separator=b'\n'): + async def readuntil(self, separator=b'\n'): """Read data from the stream until ``separator`` is found. On success, the data and separator will be removed from the @@ -558,7 +643,7 @@ def readuntil(self, separator=b'\n'): # see upper comment for explanation. offset = buflen + 1 - seplen if offset > self._limit: - raise LimitOverrunError( + raise exceptions.LimitOverrunError( 'Separator is not found, and chunk exceed the limit', offset) @@ -569,13 +654,13 @@ def readuntil(self, separator=b'\n'): if self._eof: chunk = bytes(self._buffer) self._buffer.clear() - raise IncompleteReadError(chunk, None) + raise exceptions.IncompleteReadError(chunk, None) # _wait_for_data() will resume reading if stream was paused. - yield from self._wait_for_data('readuntil') + await self._wait_for_data('readuntil') if isep > self._limit: - raise LimitOverrunError( + raise exceptions.LimitOverrunError( 'Separator is found, but chunk is longer than limit', isep) chunk = self._buffer[:isep + seplen] @@ -583,20 +668,20 @@ def readuntil(self, separator=b'\n'): self._maybe_resume_transport() return bytes(chunk) - @coroutine - def read(self, n=-1): + async def read(self, n=-1): """Read up to `n` bytes from the stream. - If n is not provided, or set to -1, read until EOF and return all read - bytes. If the EOF was received and the internal buffer is empty, return - an empty bytes object. + If `n` is not provided or set to -1, + read until EOF, then return all read bytes. + If EOF was received and the internal buffer is empty, + return an empty bytes object. - If n is zero, return empty bytes object immediately. + If `n` is 0, return an empty bytes object immediately. - If n is positive, this function try to read `n` bytes, and may return - less or equal bytes than requested, but at least one byte. If EOF was - received before any byte is read, this function returns empty byte - object. + If `n` is positive, return at most `n` available bytes + as soon as at least 1 byte is available in the internal buffer. + If EOF is received before any byte is read, return an empty + bytes object. Returned value is not limited with limit, configured at stream creation. @@ -618,24 +703,23 @@ def read(self, n=-1): # bytes. So just call self.read(self._limit) until EOF. blocks = [] while True: - block = yield from self.read(self._limit) + block = await self.read(self._limit) if not block: break blocks.append(block) return b''.join(blocks) if not self._buffer and not self._eof: - yield from self._wait_for_data('read') + await self._wait_for_data('read') # This will work right even if buffer is less than n bytes - data = bytes(self._buffer[:n]) + data = bytes(memoryview(self._buffer)[:n]) del self._buffer[:n] self._maybe_resume_transport() return data - @coroutine - def readexactly(self, n): + async def readexactly(self, n): """Read exactly `n` bytes. Raise an IncompleteReadError if EOF is reached before `n` bytes can be @@ -663,33 +747,24 @@ def readexactly(self, n): if self._eof: incomplete = bytes(self._buffer) self._buffer.clear() - raise IncompleteReadError(incomplete, n) + raise exceptions.IncompleteReadError(incomplete, n) - yield from self._wait_for_data('readexactly') + await self._wait_for_data('readexactly') if len(self._buffer) == n: data = bytes(self._buffer) self._buffer.clear() else: - data = bytes(self._buffer[:n]) + data = bytes(memoryview(self._buffer)[:n]) del self._buffer[:n] self._maybe_resume_transport() return data - if compat.PY35: - @coroutine - def __aiter__(self): - return self - - @coroutine - def __anext__(self): - val = yield from self.readline() - if val == b'': - raise StopAsyncIteration - return val - - if compat.PY352: - # In Python 3.5.2 and greater, __aiter__ should return - # the asynchronous iterator directly. - def __aiter__(self): - return self + def __aiter__(self): + return self + + async def __anext__(self): + val = await self.readline() + if val == b'': + raise StopAsyncIteration + return val diff --git a/Lib/asyncio/subprocess.py b/Lib/asyncio/subprocess.py index b2f5304f77..043359bbd0 100644 --- a/Lib/asyncio/subprocess.py +++ b/Lib/asyncio/subprocess.py @@ -1,4 +1,4 @@ -__all__ = ['create_subprocess_exec', 'create_subprocess_shell'] +__all__ = 'create_subprocess_exec', 'create_subprocess_shell' import subprocess @@ -6,7 +6,6 @@ from . import protocols from . import streams from . import tasks -from .coroutines import coroutine from .log import logger @@ -24,16 +23,19 @@ def __init__(self, limit, loop): self._limit = limit self.stdin = self.stdout = self.stderr = None self._transport = None + self._process_exited = False + self._pipe_fds = [] + self._stdin_closed = self._loop.create_future() def __repr__(self): info = [self.__class__.__name__] if self.stdin is not None: - info.append('stdin=%r' % self.stdin) + info.append(f'stdin={self.stdin!r}') if self.stdout is not None: - info.append('stdout=%r' % self.stdout) + info.append(f'stdout={self.stdout!r}') if self.stderr is not None: - info.append('stderr=%r' % self.stderr) - return '<%s>' % ' '.join(info) + info.append(f'stderr={self.stderr!r}') + return '<{}>'.format(' '.join(info)) def connection_made(self, transport): self._transport = transport @@ -43,12 +45,14 @@ def connection_made(self, transport): self.stdout = streams.StreamReader(limit=self._limit, loop=self._loop) self.stdout.set_transport(stdout_transport) + self._pipe_fds.append(1) stderr_transport = transport.get_pipe_transport(2) if stderr_transport is not None: self.stderr = streams.StreamReader(limit=self._limit, loop=self._loop) self.stderr.set_transport(stderr_transport) + self._pipe_fds.append(2) stdin_transport = transport.get_pipe_transport(0) if stdin_transport is not None: @@ -73,6 +77,13 @@ def pipe_connection_lost(self, fd, exc): if pipe is not None: pipe.close() self.connection_lost(exc) + if exc is None: + self._stdin_closed.set_result(None) + else: + self._stdin_closed.set_exception(exc) + # Since calling `wait_closed()` is not mandatory, + # we shouldn't log the traceback if this is not awaited. + self._stdin_closed._log_traceback = False return if fd == 1: reader = self.stdout @@ -80,15 +91,28 @@ def pipe_connection_lost(self, fd, exc): reader = self.stderr else: reader = None - if reader != None: + if reader is not None: if exc is None: reader.feed_eof() else: reader.set_exception(exc) + if fd in self._pipe_fds: + self._pipe_fds.remove(fd) + self._maybe_close_transport() + def process_exited(self): - self._transport.close() - self._transport = None + self._process_exited = True + self._maybe_close_transport() + + def _maybe_close_transport(self): + if len(self._pipe_fds) == 0 and self._process_exited: + self._transport.close() + self._transport = None + + def _get_close_waiter(self, stream): + if stream is self.stdin: + return self._stdin_closed class Process: @@ -102,18 +126,15 @@ def __init__(self, transport, protocol, loop): self.pid = transport.get_pid() def __repr__(self): - return '<%s %s>' % (self.__class__.__name__, self.pid) + return f'<{self.__class__.__name__} {self.pid}>' @property def returncode(self): return self._transport.get_returncode() - @coroutine - def wait(self): - """Wait until the process exit and return the process return code. - - This method is a coroutine.""" - return (yield from self._transport._wait()) + async def wait(self): + """Wait until the process exit and return the process return code.""" + return await self._transport._wait() def send_signal(self, signal): self._transport.send_signal(signal) @@ -124,17 +145,19 @@ def terminate(self): def kill(self): self._transport.kill() - @coroutine - def _feed_stdin(self, input): + async def _feed_stdin(self, input): debug = self._loop.get_debug() - self.stdin.write(input) - if debug: - logger.debug('%r communicate: feed stdin (%s bytes)', - self, len(input)) try: - yield from self.stdin.drain() + if input is not None: + self.stdin.write(input) + if debug: + logger.debug( + '%r communicate: feed stdin (%s bytes)', self, len(input)) + + await self.stdin.drain() except (BrokenPipeError, ConnectionResetError) as exc: - # communicate() ignores BrokenPipeError and ConnectionResetError + # communicate() ignores BrokenPipeError and ConnectionResetError. + # write() and drain() can raise these exceptions. if debug: logger.debug('%r communicate: stdin got %r', self, exc) @@ -142,12 +165,10 @@ def _feed_stdin(self, input): logger.debug('%r communicate: close stdin', self) self.stdin.close() - @coroutine - def _noop(self): + async def _noop(self): return None - @coroutine - def _read_stream(self, fd): + async def _read_stream(self, fd): transport = self._transport.get_pipe_transport(fd) if fd == 2: stream = self.stderr @@ -157,16 +178,15 @@ def _read_stream(self, fd): if self._loop.get_debug(): name = 'stdout' if fd == 1 else 'stderr' logger.debug('%r communicate: read %s', self, name) - output = yield from stream.read() + output = await stream.read() if self._loop.get_debug(): name = 'stdout' if fd == 1 else 'stderr' logger.debug('%r communicate: close %s', self, name) transport.close() return output - @coroutine - def communicate(self, input=None): - if input is not None: + async def communicate(self, input=None): + if self.stdin is not None: stdin = self._feed_stdin(input) else: stdin = self._noop() @@ -178,36 +198,32 @@ def communicate(self, input=None): stderr = self._read_stream(2) else: stderr = self._noop() - stdin, stdout, stderr = yield from tasks.gather(stdin, stdout, stderr, - loop=self._loop) - yield from self.wait() + stdin, stdout, stderr = await tasks.gather(stdin, stdout, stderr) + await self.wait() return (stdout, stderr) -@coroutine -def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None, - loop=None, limit=streams._DEFAULT_LIMIT, **kwds): - if loop is None: - loop = events.get_event_loop() +async def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None, + limit=streams._DEFAULT_LIMIT, **kwds): + loop = events.get_running_loop() protocol_factory = lambda: SubprocessStreamProtocol(limit=limit, loop=loop) - transport, protocol = yield from loop.subprocess_shell( - protocol_factory, - cmd, stdin=stdin, stdout=stdout, - stderr=stderr, **kwds) + transport, protocol = await loop.subprocess_shell( + protocol_factory, + cmd, stdin=stdin, stdout=stdout, + stderr=stderr, **kwds) return Process(transport, protocol, loop) -@coroutine -def create_subprocess_exec(program, *args, stdin=None, stdout=None, - stderr=None, loop=None, - limit=streams._DEFAULT_LIMIT, **kwds): - if loop is None: - loop = events.get_event_loop() + +async def create_subprocess_exec(program, *args, stdin=None, stdout=None, + stderr=None, limit=streams._DEFAULT_LIMIT, + **kwds): + loop = events.get_running_loop() protocol_factory = lambda: SubprocessStreamProtocol(limit=limit, loop=loop) - transport, protocol = yield from loop.subprocess_exec( - protocol_factory, - program, *args, - stdin=stdin, stdout=stdout, - stderr=stderr, **kwds) + transport, protocol = await loop.subprocess_exec( + protocol_factory, + program, *args, + stdin=stdin, stdout=stdout, + stderr=stderr, **kwds) return Process(transport, protocol, loop) diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py new file mode 100644 index 0000000000..d264e51f1f --- /dev/null +++ b/Lib/asyncio/taskgroups.py @@ -0,0 +1,240 @@ +# Adapted with permission from the EdgeDB project; +# license: PSFL. + + +__all__ = ("TaskGroup",) + +from . import events +from . import exceptions +from . import tasks + + +class TaskGroup: + """Asynchronous context manager for managing groups of tasks. + + Example use: + + async with asyncio.TaskGroup() as group: + task1 = group.create_task(some_coroutine(...)) + task2 = group.create_task(other_coroutine(...)) + print("Both tasks have completed now.") + + All tasks are awaited when the context manager exits. + + Any exceptions other than `asyncio.CancelledError` raised within + a task will cancel all remaining tasks and wait for them to exit. + The exceptions are then combined and raised as an `ExceptionGroup`. + """ + def __init__(self): + self._entered = False + self._exiting = False + self._aborting = False + self._loop = None + self._parent_task = None + self._parent_cancel_requested = False + self._tasks = set() + self._errors = [] + self._base_error = None + self._on_completed_fut = None + + def __repr__(self): + info = [''] + if self._tasks: + info.append(f'tasks={len(self._tasks)}') + if self._errors: + info.append(f'errors={len(self._errors)}') + if self._aborting: + info.append('cancelling') + elif self._entered: + info.append('entered') + + info_str = ' '.join(info) + return f'' + + async def __aenter__(self): + if self._entered: + raise RuntimeError( + f"TaskGroup {self!r} has already been entered") + if self._loop is None: + self._loop = events.get_running_loop() + self._parent_task = tasks.current_task(self._loop) + if self._parent_task is None: + raise RuntimeError( + f'TaskGroup {self!r} cannot determine the parent task') + self._entered = True + + return self + + async def __aexit__(self, et, exc, tb): + self._exiting = True + + if (exc is not None and + self._is_base_error(exc) and + self._base_error is None): + self._base_error = exc + + propagate_cancellation_error = \ + exc if et is exceptions.CancelledError else None + if self._parent_cancel_requested: + # If this flag is set we *must* call uncancel(). + if self._parent_task.uncancel() == 0: + # If there are no pending cancellations left, + # don't propagate CancelledError. + propagate_cancellation_error = None + + if et is not None: + if not self._aborting: + # Our parent task is being cancelled: + # + # async with TaskGroup() as g: + # g.create_task(...) + # await ... # <- CancelledError + # + # or there's an exception in "async with": + # + # async with TaskGroup() as g: + # g.create_task(...) + # 1 / 0 + # + self._abort() + + # We use while-loop here because "self._on_completed_fut" + # can be cancelled multiple times if our parent task + # is being cancelled repeatedly (or even once, when + # our own cancellation is already in progress) + while self._tasks: + if self._on_completed_fut is None: + self._on_completed_fut = self._loop.create_future() + + try: + await self._on_completed_fut + except exceptions.CancelledError as ex: + if not self._aborting: + # Our parent task is being cancelled: + # + # async def wrapper(): + # async with TaskGroup() as g: + # g.create_task(foo) + # + # "wrapper" is being cancelled while "foo" is + # still running. + propagate_cancellation_error = ex + self._abort() + + self._on_completed_fut = None + + assert not self._tasks + + if self._base_error is not None: + raise self._base_error + + # Propagate CancelledError if there is one, except if there + # are other errors -- those have priority. + if propagate_cancellation_error and not self._errors: + raise propagate_cancellation_error + + if et is not None and et is not exceptions.CancelledError: + self._errors.append(exc) + + if self._errors: + # Exceptions are heavy objects that can have object + # cycles (bad for GC); let's not keep a reference to + # a bunch of them. + try: + me = BaseExceptionGroup('unhandled errors in a TaskGroup', self._errors) + raise me from None + finally: + self._errors = None + + def create_task(self, coro, *, name=None, context=None): + """Create a new task in this group and return it. + + Similar to `asyncio.create_task`. + """ + if not self._entered: + raise RuntimeError(f"TaskGroup {self!r} has not been entered") + if self._exiting and not self._tasks: + raise RuntimeError(f"TaskGroup {self!r} is finished") + if self._aborting: + raise RuntimeError(f"TaskGroup {self!r} is shutting down") + if context is None: + task = self._loop.create_task(coro) + else: + task = self._loop.create_task(coro, context=context) + tasks._set_task_name(task, name) + # optimization: Immediately call the done callback if the task is + # already done (e.g. if the coro was able to complete eagerly), + # and skip scheduling a done callback + if task.done(): + self._on_task_done(task) + else: + self._tasks.add(task) + task.add_done_callback(self._on_task_done) + return task + + # Since Python 3.8 Tasks propagate all exceptions correctly, + # except for KeyboardInterrupt and SystemExit which are + # still considered special. + + def _is_base_error(self, exc: BaseException) -> bool: + assert isinstance(exc, BaseException) + return isinstance(exc, (SystemExit, KeyboardInterrupt)) + + def _abort(self): + self._aborting = True + + for t in self._tasks: + if not t.done(): + t.cancel() + + def _on_task_done(self, task): + self._tasks.discard(task) + + if self._on_completed_fut is not None and not self._tasks: + if not self._on_completed_fut.done(): + self._on_completed_fut.set_result(True) + + if task.cancelled(): + return + + exc = task.exception() + if exc is None: + return + + self._errors.append(exc) + if self._is_base_error(exc) and self._base_error is None: + self._base_error = exc + + if self._parent_task.done(): + # Not sure if this case is possible, but we want to handle + # it anyways. + self._loop.call_exception_handler({ + 'message': f'Task {task!r} has errored out but its parent ' + f'task {self._parent_task} is already completed', + 'exception': exc, + 'task': task, + }) + return + + if not self._aborting and not self._parent_cancel_requested: + # If parent task *is not* being cancelled, it means that we want + # to manually cancel it to abort whatever is being run right now + # in the TaskGroup. But we want to mark parent task as + # "not cancelled" later in __aexit__. Example situation that + # we need to handle: + # + # async def foo(): + # try: + # async with TaskGroup() as g: + # g.create_task(crash_soon()) + # await something # <- this needs to be canceled + # # by the TaskGroup, e.g. + # # foo() needs to be cancelled + # except Exception: + # # Ignore any exceptions raised in the TaskGroup + # pass + # await something_else # this line has to be called + # # after TaskGroup is finished. + self._abort() + self._parent_cancel_requested = True + self._parent_task.cancel() diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py index 8a8427fe68..0b22e28d8e 100644 --- a/Lib/asyncio/tasks.py +++ b/Lib/asyncio/tasks.py @@ -1,24 +1,36 @@ """Support for tasks, coroutines and the scheduler.""" -__all__ = ['Task', 'create_task', - 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', - 'wait', 'wait_for', 'as_completed', 'sleep', 'async', - 'gather', 'shield', 'ensure_future', 'run_coroutine_threadsafe', - 'all_tasks' - ] +__all__ = ( + 'Task', 'create_task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'wait_for', 'as_completed', 'sleep', + 'gather', 'shield', 'ensure_future', 'run_coroutine_threadsafe', + 'current_task', 'all_tasks', + 'create_eager_task_factory', 'eager_task_factory', + '_register_task', '_unregister_task', '_enter_task', '_leave_task', +) import concurrent.futures +import contextvars import functools import inspect +import itertools +import types import warnings import weakref +from types import GenericAlias from . import base_tasks -from . import compat from . import coroutines from . import events +from . import exceptions from . import futures -from .coroutines import coroutine +from . import timeouts + +# Helper to generate new task names +# This uses itertools.count() instead of a "+= 1" operation because the latter +# is not thread safe. See bpo-11866 for a longer explanation. +_task_name_counter = itertools.count(1).__next__ def current_task(loop=None): @@ -32,102 +44,134 @@ def all_tasks(loop=None): """Return a set of all tasks for the loop.""" if loop is None: loop = events.get_running_loop() - return {t for t in _all_tasks + # capturing the set of eager tasks first, so if an eager task "graduates" + # to a regular task in another thread, we don't risk missing it. + eager_tasks = list(_eager_tasks) + # Looping over the WeakSet isn't safe as it can be updated from another + # thread, therefore we cast it to list prior to filtering. The list cast + # itself requires iteration, so we repeat it several times ignoring + # RuntimeErrors (which are not very likely to occur). + # See issues 34970 and 36607 for details. + scheduled_tasks = None + i = 0 + while True: + try: + scheduled_tasks = list(_scheduled_tasks) + except RuntimeError: + i += 1 + if i >= 1000: + raise + else: + break + return {t for t in itertools.chain(scheduled_tasks, eager_tasks) if futures._get_loop(t) is loop and not t.done()} -def _all_tasks_compat(loop=None): - # Different from "all_task()" by returning *all* Tasks, including - # the completed ones. Used to implement deprecated "Tasks.all_task()" - # method. - if loop is None: - loop = events.get_event_loop() - return {t for t in _all_tasks if futures._get_loop(t) is loop} - - def _set_task_name(task, name): if name is not None: try: set_name = task.set_name except AttributeError: - pass + warnings.warn("Task.set_name() was added in Python 3.8, " + "the method support will be mandatory for third-party " + "task implementations since 3.13.", + DeprecationWarning, stacklevel=3) else: set_name(name) -class Task(futures.Future): +class Task(futures._PyFuture): # Inherit Python Task implementation + # from a Python Future implementation. + """A coroutine wrapped in a Future.""" # An important invariant maintained while a Task not done: + # _fut_waiter is either None or a Future. The Future + # can be either done() or not done(). + # The task can be in any of 3 states: # - # - Either _fut_waiter is None, and _step() is scheduled; - # - or _fut_waiter is some Future, and _step() is *not* scheduled. + # - 1: _fut_waiter is not None and not _fut_waiter.done(): + # __step() is *not* scheduled and the Task is waiting for _fut_waiter. + # - 2: (_fut_waiter is None or _fut_waiter.done()) and __step() is scheduled: + # the Task is waiting for __step() to be executed. + # - 3: _fut_waiter is None and __step() is *not* scheduled: + # the Task is currently executing (in __step()). # - # The only transition from the latter to the former is through - # _wakeup(). When _fut_waiter is not None, one of its callbacks - # must be _wakeup(). + # * In state 1, one of the callbacks of __fut_waiter must be __wakeup(). + # * The transition from 1 to 2 happens when _fut_waiter becomes done(), + # as it schedules __wakeup() to be called (which calls __step() so + # we way that __step() is scheduled). + # * It transitions from 2 to 3 when __step() is executed, and it clears + # _fut_waiter to None. + + # If False, don't log a message if the task is destroyed while its + # status is still pending + _log_destroy_pending = True - # Weak set containing all tasks alive. - _all_tasks = weakref.WeakSet() + def __init__(self, coro, *, loop=None, name=None, context=None, + eager_start=False): + super().__init__(loop=loop) + if self._source_traceback: + del self._source_traceback[-1] + if not coroutines.iscoroutine(coro): + # raise after Future.__init__(), attrs are required for __del__ + # prevent logging for pending task in __del__ + self._log_destroy_pending = False + raise TypeError(f"a coroutine was expected, got {coro!r}") + + if name is None: + self._name = f'Task-{_task_name_counter()}' + else: + self._name = str(name) - # Dictionary containing tasks that are currently active in - # all running event loops. {EventLoop: Task} - _current_tasks = {} + self._num_cancels_requested = 0 + self._must_cancel = False + self._fut_waiter = None + self._coro = coro + if context is None: + self._context = contextvars.copy_context() + else: + self._context = context - # If False, don't log a message if the task is destroyed whereas its - # status is still pending - _log_destroy_pending = True + if eager_start and self._loop.is_running(): + self.__eager_start() + else: + self._loop.call_soon(self.__step, context=self._context) + _register_task(self) - @classmethod - def current_task(cls, loop=None): - """Return the currently running task in an event loop or None. + def __del__(self): + if self._state == futures._PENDING and self._log_destroy_pending: + context = { + 'task': self, + 'message': 'Task was destroyed but it is pending!', + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + super().__del__() - By default the current task for the current event loop is returned. + __class_getitem__ = classmethod(GenericAlias) - None is returned when called not in the context of a Task. - """ - if loop is None: - loop = events.get_event_loop() - return cls._current_tasks.get(loop) + def __repr__(self): + return base_tasks._task_repr(self) - @classmethod - def all_tasks(cls, loop=None): - """Return a set of all tasks for an event loop. + def get_coro(self): + return self._coro - By default all tasks for the current event loop are returned. - """ - if loop is None: - loop = events.get_event_loop() - return {t for t in cls._all_tasks if t._loop is loop} + def get_context(self): + return self._context - def __init__(self, coro, *, loop=None): - assert coroutines.iscoroutine(coro), repr(coro) - super().__init__(loop=loop) - if self._source_traceback: - del self._source_traceback[-1] - self._coro = coro - self._fut_waiter = None - self._must_cancel = False - self._loop.call_soon(self._step) - self.__class__._all_tasks.add(self) - - # On Python 3.3 or older, objects with a destructor that are part of a - # reference cycle are never destroyed. That's not the case any more on - # Python 3.4 thanks to the PEP 442. - if compat.PY34: - def __del__(self): - if self._state == futures._PENDING and self._log_destroy_pending: - context = { - 'task': self, - 'message': 'Task was destroyed but it is pending!', - } - if self._source_traceback: - context['source_traceback'] = self._source_traceback - self._loop.call_exception_handler(context) - futures.Future.__del__(self) - - def _repr_info(self): - return base_tasks._task_repr_info(self) + def get_name(self): + return self._name + + def set_name(self, value): + self._name = str(value) + + def set_result(self, result): + raise RuntimeError('Task does not support set_result operation') + + def set_exception(self, exception): + raise RuntimeError('Task does not support set_exception operation') def get_stack(self, *, limit=None): """Return the list of stack frames for this task's coroutine. @@ -163,7 +207,7 @@ def print_stack(self, *, limit=None, file=None): """ return base_tasks._task_print_stack(self, limit, file) - def cancel(self): + def cancel(self, msg=None): """Request that this task cancel itself. This arranges for a CancelledError to be thrown into the @@ -182,31 +226,87 @@ def cancel(self): task will be marked as cancelled when the wrapped coroutine terminates with a CancelledError exception (even if cancel() was not called). + + This also increases the task's count of cancellation requests. """ + self._log_traceback = False if self.done(): return False + self._num_cancels_requested += 1 + # These two lines are controversial. See discussion starting at + # https://github.com/python/cpython/pull/31394#issuecomment-1053545331 + # Also remember that this is duplicated in _asynciomodule.c. + # if self._num_cancels_requested > 1: + # return False if self._fut_waiter is not None: - if self._fut_waiter.cancel(): + if self._fut_waiter.cancel(msg=msg): # Leave self._fut_waiter; it may be a Task that # catches and ignores the cancellation so we may have # to cancel it again later. return True - # It must be the case that self._step is already scheduled. + # It must be the case that self.__step is already scheduled. self._must_cancel = True + self._cancel_message = msg return True - def _step(self, exc=None): - assert not self.done(), \ - '_step(): already done: {!r}, {!r}'.format(self, exc) + def cancelling(self): + """Return the count of the task's cancellation requests. + + This count is incremented when .cancel() is called + and may be decremented using .uncancel(). + """ + return self._num_cancels_requested + + def uncancel(self): + """Decrement the task's count of cancellation requests. + + This should be called by the party that called `cancel()` on the task + beforehand. + + Returns the remaining number of cancellation requests. + """ + if self._num_cancels_requested > 0: + self._num_cancels_requested -= 1 + return self._num_cancels_requested + + def __eager_start(self): + prev_task = _swap_current_task(self._loop, self) + try: + _register_eager_task(self) + try: + self._context.run(self.__step_run_and_handle_result, None) + finally: + _unregister_eager_task(self) + finally: + try: + curtask = _swap_current_task(self._loop, prev_task) + assert curtask is self + finally: + if self.done(): + self._coro = None + self = None # Needed to break cycles when an exception occurs. + else: + _register_task(self) + + def __step(self, exc=None): + if self.done(): + raise exceptions.InvalidStateError( + f'_step(): already done: {self!r}, {exc!r}') if self._must_cancel: - if not isinstance(exc, futures.CancelledError): - exc = futures.CancelledError() + if not isinstance(exc, exceptions.CancelledError): + exc = self._make_cancelled_error() self._must_cancel = False - coro = self._coro self._fut_waiter = None - self.__class__._current_tasks[self._loop] = self - # Call either coro.throw(exc) or coro.send(None). + _enter_task(self._loop, self) + try: + self.__step_run_and_handle_result(exc) + finally: + _leave_task(self._loop, self) + self = None # Needed to break cycles when an exception occurs. + + def __step_run_and_handle_result(self, exc): + coro = self._coro try: if exc is None: # We use the `send` method directly, because coroutines @@ -215,71 +315,77 @@ def _step(self, exc=None): else: result = coro.throw(exc) except StopIteration as exc: - self.set_result(exc.value) - except futures.CancelledError: + if self._must_cancel: + # Task is cancelled right before coro stops. + self._must_cancel = False + super().cancel(msg=self._cancel_message) + else: + super().set_result(exc.value) + except exceptions.CancelledError as exc: + # Save the original exception so we can chain it later. + self._cancelled_exc = exc super().cancel() # I.e., Future.cancel(self). - except Exception as exc: - self.set_exception(exc) - except BaseException as exc: - self.set_exception(exc) + except (KeyboardInterrupt, SystemExit) as exc: + super().set_exception(exc) raise + except BaseException as exc: + super().set_exception(exc) else: blocking = getattr(result, '_asyncio_future_blocking', None) if blocking is not None: # Yielded Future must come from Future.__iter__(). - if result._loop is not self._loop: + if futures._get_loop(result) is not self._loop: + new_exc = RuntimeError( + f'Task {self!r} got Future ' + f'{result!r} attached to a different loop') self._loop.call_soon( - self._step, - RuntimeError( - 'Task {!r} got Future {!r} attached to a ' - 'different loop'.format(self, result))) + self.__step, new_exc, context=self._context) elif blocking: if result is self: + new_exc = RuntimeError( + f'Task cannot await on itself: {self!r}') self._loop.call_soon( - self._step, - RuntimeError( - 'Task cannot await on itself: {!r}'.format( - self))) + self.__step, new_exc, context=self._context) else: result._asyncio_future_blocking = False - result.add_done_callback(self._wakeup) + result.add_done_callback( + self.__wakeup, context=self._context) self._fut_waiter = result if self._must_cancel: - if self._fut_waiter.cancel(): + if self._fut_waiter.cancel( + msg=self._cancel_message): self._must_cancel = False else: + new_exc = RuntimeError( + f'yield was used instead of yield from ' + f'in task {self!r} with {result!r}') self._loop.call_soon( - self._step, - RuntimeError( - 'yield was used instead of yield from ' - 'in task {!r} with {!r}'.format(self, result))) + self.__step, new_exc, context=self._context) + elif result is None: # Bare yield relinquishes control for one event loop iteration. - self._loop.call_soon(self._step) + self._loop.call_soon(self.__step, context=self._context) elif inspect.isgenerator(result): # Yielding a generator is just wrong. + new_exc = RuntimeError( + f'yield was used instead of yield from for ' + f'generator in task {self!r} with {result!r}') self._loop.call_soon( - self._step, - RuntimeError( - 'yield was used instead of yield from for ' - 'generator in task {!r} with {}'.format( - self, result))) + self.__step, new_exc, context=self._context) else: # Yielding something else is an error. + new_exc = RuntimeError(f'Task got bad yield: {result!r}') self._loop.call_soon( - self._step, - RuntimeError( - 'Task got bad yield: {!r}'.format(result))) + self.__step, new_exc, context=self._context) finally: - self.__class__._current_tasks.pop(self._loop) self = None # Needed to break cycles when an exception occurs. - def _wakeup(self, future): + def __wakeup(self, future): try: future.result() - except Exception as exc: + except BaseException as exc: # This may also be a cancellation. - self._step(exc) + self.__step(exc) else: # Don't pass the value of `future.result()` explicitly, # as `Future.__iter__` and `Future.__await__` don't need it. @@ -287,7 +393,7 @@ def _wakeup(self, future): # Python eval loop would use `.send(value)` method call, # instead of `__next__()`, which is slower for futures # that return non-generator iterators from their `__iter__`. - self._step() + self.__step() self = None # Needed to break cycles when an exception occurs. @@ -303,13 +409,18 @@ def _wakeup(self, future): Task = _CTask = _asyncio.Task -def create_task(coro, *, name=None): +def create_task(coro, *, name=None, context=None): """Schedule the execution of a coroutine object in a spawn task. Return a Task object. """ loop = events.get_running_loop() - task = loop.create_task(coro) + if context is None: + # Use legacy API if context is not needed + task = loop.create_task(coro) + else: + task = loop.create_task(coro, context=context) + _set_task_name(task, name) return task @@ -321,11 +432,10 @@ def create_task(coro, *, name=None): ALL_COMPLETED = concurrent.futures.ALL_COMPLETED -@coroutine -def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): - """Wait for the Futures and coroutines given by fs to complete. +async def wait(fs, *, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures or Tasks given by fs to complete. - The sequence futures must not be empty. + The fs iterable must not be empty. Coroutines will be wrapped in Tasks. @@ -333,24 +443,25 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): Usage: - done, pending = yield from asyncio.wait(fs) + done, pending = await asyncio.wait(fs) Note: This does not raise TimeoutError! Futures that aren't done when the timeout occurs are returned in the second set. """ if futures.isfuture(fs) or coroutines.iscoroutine(fs): - raise TypeError("expect a list of futures, not %s" % type(fs).__name__) + raise TypeError(f"expect a list of futures, not {type(fs).__name__}") if not fs: - raise ValueError('Set of coroutines/Futures is empty.') + raise ValueError('Set of Tasks/Futures is empty.') if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): - raise ValueError('Invalid return_when value: {}'.format(return_when)) + raise ValueError(f'Invalid return_when value: {return_when}') - if loop is None: - loop = events.get_event_loop() + fs = set(fs) - fs = {ensure_future(f, loop=loop) for f in set(fs)} + if any(coroutines.iscoroutine(f) for f in fs): + raise TypeError("Passing coroutines is forbidden, use tasks explicitly.") - return (yield from _wait(fs, timeout, return_when, loop)) + loop = events.get_running_loop() + return await _wait(fs, timeout, return_when, loop) def _release_waiter(waiter, *args): @@ -358,8 +469,7 @@ def _release_waiter(waiter, *args): waiter.set_result(None) -@coroutine -def wait_for(fut, timeout, *, loop=None): +async def wait_for(fut, timeout): """Wait for the single Future or coroutine to complete, with timeout. Coroutine will be wrapped in Task. @@ -370,43 +480,47 @@ def wait_for(fut, timeout, *, loop=None): If the wait is cancelled, the task is also cancelled. + If the task suppresses the cancellation and returns a value instead, + that value is returned. + This function is a coroutine. """ - if loop is None: - loop = events.get_event_loop() + # The special case for timeout <= 0 is for the following case: + # + # async def test_waitfor(): + # func_started = False + # + # async def func(): + # nonlocal func_started + # func_started = True + # + # try: + # await asyncio.wait_for(func(), 0) + # except asyncio.TimeoutError: + # assert not func_started + # else: + # assert False + # + # asyncio.run(test_waitfor()) - if timeout is None: - return (yield from fut) - waiter = loop.create_future() - timeout_handle = loop.call_later(timeout, _release_waiter, waiter) - cb = functools.partial(_release_waiter, waiter) + if timeout is not None and timeout <= 0: + fut = ensure_future(fut) - fut = ensure_future(fut, loop=loop) - fut.add_done_callback(cb) + if fut.done(): + return fut.result() - try: - # wait until the future completes or the timeout + await _cancel_and_wait(fut) try: - yield from waiter - except futures.CancelledError: - fut.remove_done_callback(cb) - fut.cancel() - raise - - if fut.done(): return fut.result() - else: - fut.remove_done_callback(cb) - fut.cancel() - raise futures.TimeoutError() - finally: - timeout_handle.cancel() + except exceptions.CancelledError as exc: + raise TimeoutError from exc + async with timeouts.timeout(timeout): + return await fut -@coroutine -def _wait(fs, timeout, return_when, loop): - """Internal helper for wait() and wait_for(). +async def _wait(fs, timeout, return_when, loop): + """Internal helper for wait(). The fs argument must be a collection of Futures. """ @@ -433,14 +547,15 @@ def _on_completion(f): f.add_done_callback(_on_completion) try: - yield from waiter + await waiter finally: if timeout_handle is not None: timeout_handle.cancel() + for f in fs: + f.remove_done_callback(_on_completion) done, pending = set(), set() for f in fs: - f.remove_done_callback(_on_completion) if f.done(): done.add(f) else: @@ -448,8 +563,25 @@ def _on_completion(f): return done, pending +async def _cancel_and_wait(fut): + """Cancel the *fut* future or task and wait until it completes.""" + + loop = events.get_running_loop() + waiter = loop.create_future() + cb = functools.partial(_release_waiter, waiter) + fut.add_done_callback(cb) + + try: + fut.cancel() + # We cannot wait on *fut* directly to make + # sure _cancel_and_wait itself is reliably cancellable. + await waiter + finally: + fut.remove_done_callback(cb) + + # This is *not* a @coroutine! It is just an iterator (yielding Futures). -def as_completed(fs, *, loop=None, timeout=None): +def as_completed(fs, *, timeout=None): """Return an iterator whose values are coroutines. When waiting for the yielded coroutines you'll get the results (or @@ -459,20 +591,22 @@ def as_completed(fs, *, loop=None, timeout=None): This differs from PEP 3148; the proper way to use this is: for f in as_completed(fs): - result = yield from f # The 'yield from' may raise. + result = await f # The 'await' may raise. # Use result. - If a timeout is specified, the 'yield from' will raise + If a timeout is specified, the 'await' will raise TimeoutError when the timeout occurs before all Futures are done. Note: The futures 'f' are not necessarily members of fs. """ if futures.isfuture(fs) or coroutines.iscoroutine(fs): - raise TypeError("expect a list of futures, not %s" % type(fs).__name__) - loop = loop if loop is not None else events.get_event_loop() - todo = {ensure_future(f, loop=loop) for f in set(fs)} + raise TypeError(f"expect an iterable of futures, not {type(fs).__name__}") + from .queues import Queue # Import here to avoid circular import problem. - done = Queue(loop=loop) + done = Queue() + + loop = events.get_event_loop() + todo = {ensure_future(f, loop=loop) for f in set(fs)} timeout_handle = None def _on_timeout(): @@ -489,12 +623,11 @@ def _on_completion(f): if not todo and timeout_handle is not None: timeout_handle.cancel() - @coroutine - def _wait_for_one(): - f = yield from done.get() + async def _wait_for_one(): + f = await done.get() if f is None: # Dummy value from _on_timeout(). - raise futures.TimeoutError + raise exceptions.TimeoutError return f.result() # May raise f.exception(). for f in todo: @@ -505,74 +638,65 @@ def _wait_for_one(): yield _wait_for_one() -@coroutine -def sleep(delay, result=None, *, loop=None): +@types.coroutine +def __sleep0(): + """Skip one event loop run cycle. + + This is a private helper for 'asyncio.sleep()', used + when the 'delay' is set to 0. It uses a bare 'yield' + expression (which Task.__step knows how to handle) + instead of creating a Future object. + """ + yield + + +async def sleep(delay, result=None): """Coroutine that completes after a given time (in seconds).""" - if delay == 0: - yield + if delay <= 0: + await __sleep0() return result - if loop is None: - loop = events.get_event_loop() + loop = events.get_running_loop() future = loop.create_future() - h = future._loop.call_later(delay, - futures._set_result_unless_cancelled, - future, result) + h = loop.call_later(delay, + futures._set_result_unless_cancelled, + future, result) try: - return (yield from future) + return await future finally: h.cancel() -def async_(coro_or_future, *, loop=None): - """Wrap a coroutine in a future. - - If the argument is a Future, it is returned directly. - - This function is deprecated in 3.5. Use asyncio.ensure_future() instead. - """ - - warnings.warn("asyncio.async() function is deprecated, use ensure_future()", - DeprecationWarning) - - return ensure_future(coro_or_future, loop=loop) - -# Silence DeprecationWarning: -globals()['async'] = async_ -async_.__name__ = 'async' -del async_ - - def ensure_future(coro_or_future, *, loop=None): """Wrap a coroutine or an awaitable in a future. If the argument is a Future, it is returned directly. """ if futures.isfuture(coro_or_future): - if loop is not None and loop is not coro_or_future._loop: - raise ValueError('loop argument must agree with Future') + if loop is not None and loop is not futures._get_loop(coro_or_future): + raise ValueError('The future belongs to a different loop than ' + 'the one specified as the loop argument') return coro_or_future - elif coroutines.iscoroutine(coro_or_future): - if loop is None: - loop = events.get_event_loop() - task = loop.create_task(coro_or_future) - if task._source_traceback: - del task._source_traceback[-1] - return task - elif compat.PY35 and inspect.isawaitable(coro_or_future): - return ensure_future(_wrap_awaitable(coro_or_future), loop=loop) - else: - raise TypeError('A Future, a coroutine or an awaitable is required') - - -@coroutine -def _wrap_awaitable(awaitable): - """Helper for asyncio.ensure_future(). + should_close = True + if not coroutines.iscoroutine(coro_or_future): + if inspect.isawaitable(coro_or_future): + async def _wrap_awaitable(awaitable): + return await awaitable + + coro_or_future = _wrap_awaitable(coro_or_future) + should_close = False + else: + raise TypeError('An asyncio.Future, a coroutine or an awaitable ' + 'is required') - Wraps awaitable (an object with __await__) into a coroutine - that will later be wrapped in a Task by ensure_future(). - """ - return (yield from awaitable.__await__()) + if loop is None: + loop = events.get_event_loop() + try: + return loop.create_task(coro_or_future) + except RuntimeError: + if should_close: + coro_or_future.close() + raise class _GatheringFuture(futures.Future): @@ -583,23 +707,29 @@ class _GatheringFuture(futures.Future): cancelled. """ - def __init__(self, children, *, loop=None): + def __init__(self, children, *, loop): + assert loop is not None super().__init__(loop=loop) self._children = children + self._cancel_requested = False - def cancel(self): + def cancel(self, msg=None): if self.done(): return False ret = False for child in self._children: - if child.cancel(): + if child.cancel(msg=msg): ret = True + if ret: + # If any child tasks were actually cancelled, we should + # propagate the cancellation request regardless of + # *return_exceptions* argument. See issue 32684. + self._cancel_requested = True return ret -def gather(*coros_or_futures, loop=None, return_exceptions=False): - """Return a future aggregating results from the given coroutines - or futures. +def gather(*coros_or_futures, return_exceptions=False): + """Return a future aggregating results from the given coroutines/futures. Coroutines will be wrapped in a future and scheduled in the event loop. They will not necessarily be scheduled in the same order as @@ -620,77 +750,129 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): the outer Future is *not* cancelled in this case. (This is to prevent the cancellation of one child to cause other children to be cancelled.) + + If *return_exceptions* is False, cancelling gather() after it + has been marked done won't cancel any submitted awaitables. + For instance, gather can be marked done after propagating an + exception to the caller, therefore, calling ``gather.cancel()`` + after catching an exception (raised by one of the awaitables) from + gather won't cancel any other awaitables. """ if not coros_or_futures: - if loop is None: - loop = events.get_event_loop() + loop = events.get_event_loop() outer = loop.create_future() outer.set_result([]) return outer - arg_to_fut = {} - for arg in set(coros_or_futures): - if not futures.isfuture(arg): - fut = ensure_future(arg, loop=loop) - if loop is None: - loop = fut._loop - # The caller cannot control this future, the "destroy pending task" - # warning should not be emitted. - fut._log_destroy_pending = False - else: - fut = arg - if loop is None: - loop = fut._loop - elif fut._loop is not loop: - raise ValueError("futures are tied to different event loops") - arg_to_fut[arg] = fut - - children = [arg_to_fut[arg] for arg in coros_or_futures] - nchildren = len(children) - outer = _GatheringFuture(children, loop=loop) - nfinished = 0 - results = [None] * nchildren - - def _done_callback(i, fut): + def _done_callback(fut): nonlocal nfinished - if outer.done(): + nfinished += 1 + + if outer is None or outer.done(): if not fut.cancelled(): # Mark exception retrieved. fut.exception() return - if fut.cancelled(): - res = futures.CancelledError() - if not return_exceptions: - outer.set_exception(res) - return - elif fut._exception is not None: - res = fut.exception() # Mark exception retrieved. - if not return_exceptions: - outer.set_exception(res) + if not return_exceptions: + if fut.cancelled(): + # Check if 'fut' is cancelled first, as + # 'fut.exception()' will *raise* a CancelledError + # instead of returning it. + exc = fut._make_cancelled_error() + outer.set_exception(exc) return + else: + exc = fut.exception() + if exc is not None: + outer.set_exception(exc) + return + + if nfinished == nfuts: + # All futures are done; create a list of results + # and set it to the 'outer' future. + results = [] + + for fut in children: + if fut.cancelled(): + # Check if 'fut' is cancelled first, as 'fut.exception()' + # will *raise* a CancelledError instead of returning it. + # Also, since we're adding the exception return value + # to 'results' instead of raising it, don't bother + # setting __context__. This also lets us preserve + # calling '_make_cancelled_error()' at most once. + res = exceptions.CancelledError( + '' if fut._cancel_message is None else + fut._cancel_message) + else: + res = fut.exception() + if res is None: + res = fut.result() + results.append(res) + + if outer._cancel_requested: + # If gather is being cancelled we must propagate the + # cancellation regardless of *return_exceptions* argument. + # See issue 32684. + exc = fut._make_cancelled_error() + outer.set_exception(exc) + else: + outer.set_result(results) + + arg_to_fut = {} + children = [] + nfuts = 0 + nfinished = 0 + done_futs = [] + loop = None + outer = None # bpo-46672 + for arg in coros_or_futures: + if arg not in arg_to_fut: + fut = ensure_future(arg, loop=loop) + if loop is None: + loop = futures._get_loop(fut) + if fut is not arg: + # 'arg' was not a Future, therefore, 'fut' is a new + # Future created specifically for 'arg'. Since the caller + # can't control it, disable the "destroy pending task" + # warning. + fut._log_destroy_pending = False + + nfuts += 1 + arg_to_fut[arg] = fut + if fut.done(): + done_futs.append(fut) + else: + fut.add_done_callback(_done_callback) + else: - res = fut._result - results[i] = res - nfinished += 1 - if nfinished == nchildren: - outer.set_result(results) + # There's a duplicate Future object in coros_or_futures. + fut = arg_to_fut[arg] + + children.append(fut) - for i, fut in enumerate(children): - fut.add_done_callback(functools.partial(_done_callback, i)) + outer = _GatheringFuture(children, loop=loop) + # Run done callbacks after GatheringFuture created so any post-processing + # can be performed at this point + # optimization: in the special case that *all* futures finished eagerly, + # this will effectively complete the gather eagerly, with the last + # callback setting the result (or exception) on outer before returning it + for fut in done_futs: + _done_callback(fut) return outer -def shield(arg, *, loop=None): +def shield(arg): """Wait for a future, shielding it from cancellation. The statement - res = yield from shield(something()) + task = asyncio.create_task(something()) + res = await shield(task) is exactly equivalent to the statement - res = yield from something() + res = await something() *except* that if the coroutine containing it is cancelled, the task running in something() is not cancelled. From the POV of @@ -702,19 +884,25 @@ def shield(arg, *, loop=None): If you want to completely ignore cancellation (not recommended) you can combine shield() with a try/except clause, as follows: + task = asyncio.create_task(something()) try: - res = yield from shield(something()) + res = await shield(task) except CancelledError: res = None + + Save a reference to tasks passed to this function, to avoid + a task disappearing mid-execution. The event loop only keeps + weak references to tasks. A task that isn't referenced elsewhere + may get garbage collected at any time, even before it's done. """ - inner = ensure_future(arg, loop=loop) + inner = ensure_future(arg) if inner.done(): # Shortcut. return inner - loop = inner._loop + loop = futures._get_loop(inner) outer = loop.create_future() - def _done_callback(inner): + def _inner_done_callback(inner): if outer.cancelled(): if not inner.cancelled(): # Mark inner's result as retrieved. @@ -730,7 +918,13 @@ def _done_callback(inner): else: outer.set_result(inner.result()) - inner.add_done_callback(_done_callback) + + def _outer_done_callback(outer): + if not inner.done(): + inner.remove_done_callback(_inner_done_callback) + + inner.add_done_callback(_inner_done_callback) + outer.add_done_callback(_outer_done_callback) return outer @@ -746,7 +940,9 @@ def run_coroutine_threadsafe(coro, loop): def callback(): try: futures._chain_future(ensure_future(coro, loop=loop), future) - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: if future.set_running_or_notify_cancel(): future.set_exception(exc) raise @@ -755,8 +951,40 @@ def callback(): return future -# WeakSet containing all alive tasks. -_all_tasks = weakref.WeakSet() +def create_eager_task_factory(custom_task_constructor): + """Create a function suitable for use as a task factory on an event-loop. + + Example usage: + + loop.set_task_factory( + asyncio.create_eager_task_factory(my_task_constructor)) + + Now, tasks created will be started immediately (rather than being first + scheduled to an event loop). The constructor argument can be any callable + that returns a Task-compatible object and has a signature compatible + with `Task.__init__`; it must have the `eager_start` keyword argument. + + Most applications will use `Task` for `custom_task_constructor` and in + this case there's no need to call `create_eager_task_factory()` + directly. Instead the global `eager_task_factory` instance can be + used. E.g. `loop.set_task_factory(asyncio.eager_task_factory)`. + """ + + def factory(loop, coro, *, name=None, context=None): + return custom_task_constructor( + coro, loop=loop, name=name, context=context, eager_start=True) + + return factory + + +eager_task_factory = create_eager_task_factory(Task) + + +# Collectively these two sets hold references to the complete set of active +# tasks. Eagerly executed tasks use a faster regular set as an optimization +# but may graduate to a WeakSet if the task blocks on IO. +_scheduled_tasks = weakref.WeakSet() +_eager_tasks = set() # Dictionary containing tasks that are currently active in # all running event loops. {EventLoop: Task} @@ -764,8 +992,13 @@ def callback(): def _register_task(task): - """Register a new task in asyncio as executed by loop.""" - _all_tasks.add(task) + """Register an asyncio Task scheduled to run on an event loop.""" + _scheduled_tasks.add(task) + + +def _register_eager_task(task): + """Register an asyncio Task about to be eagerly executed.""" + _eager_tasks.add(task) def _enter_task(loop, task): @@ -784,25 +1017,49 @@ def _leave_task(loop, task): del _current_tasks[loop] +def _swap_current_task(loop, task): + prev_task = _current_tasks.get(loop) + if task is None: + del _current_tasks[loop] + else: + _current_tasks[loop] = task + return prev_task + + def _unregister_task(task): - """Unregister a task.""" - _all_tasks.discard(task) + """Unregister a completed, scheduled Task.""" + _scheduled_tasks.discard(task) + + +def _unregister_eager_task(task): + """Unregister a task which finished its first eager step.""" + _eager_tasks.discard(task) +_py_current_task = current_task _py_register_task = _register_task +_py_register_eager_task = _register_eager_task _py_unregister_task = _unregister_task +_py_unregister_eager_task = _unregister_eager_task _py_enter_task = _enter_task _py_leave_task = _leave_task +_py_swap_current_task = _swap_current_task try: - from _asyncio import (_register_task, _unregister_task, - _enter_task, _leave_task, - _all_tasks, _current_tasks) + from _asyncio import (_register_task, _register_eager_task, + _unregister_task, _unregister_eager_task, + _enter_task, _leave_task, _swap_current_task, + _scheduled_tasks, _eager_tasks, _current_tasks, + current_task) except ImportError: pass else: + _c_current_task = current_task _c_register_task = _register_task + _c_register_eager_task = _register_eager_task _c_unregister_task = _unregister_task + _c_unregister_eager_task = _unregister_eager_task _c_enter_task = _enter_task _c_leave_task = _leave_task + _c_swap_current_task = _swap_current_task diff --git a/Lib/asyncio/test_utils.py b/Lib/asyncio/test_utils.py deleted file mode 100644 index 99e3839f45..0000000000 --- a/Lib/asyncio/test_utils.py +++ /dev/null @@ -1,503 +0,0 @@ -"""Utilities shared by tests.""" - -import collections -import contextlib -import io -import logging -import os -import re -import socket -import socketserver -import sys -import tempfile -import threading -import time -import unittest -import weakref - -from unittest import mock - -from http.server import HTTPServer -from wsgiref.simple_server import WSGIRequestHandler, WSGIServer - -try: - import ssl -except ImportError: # pragma: no cover - ssl = None - -from . import base_events -from . import compat -from . import events -from . import futures -from . import selectors -from . import tasks -from .coroutines import coroutine -from .log import logger - - -if sys.platform == 'win32': # pragma: no cover - from .windows_utils import socketpair -else: - from socket import socketpair # pragma: no cover - - -def dummy_ssl_context(): - if ssl is None: - return None - else: - return ssl.SSLContext(ssl.PROTOCOL_SSLv23) - - -def run_briefly(loop): - @coroutine - def once(): - pass - gen = once() - t = loop.create_task(gen) - # Don't log a warning if the task is not done after run_until_complete(). - # It occurs if the loop is stopped or if a task raises a BaseException. - t._log_destroy_pending = False - try: - loop.run_until_complete(t) - finally: - gen.close() - - -def run_until(loop, pred, timeout=30): - deadline = time.time() + timeout - while not pred(): - if timeout is not None: - timeout = deadline - time.time() - if timeout <= 0: - raise futures.TimeoutError() - loop.run_until_complete(tasks.sleep(0.001, loop=loop)) - - -def run_once(loop): - """Legacy API to run once through the event loop. - - This is the recommended pattern for test code. It will poll the - selector once and run all callbacks scheduled in response to I/O - events. - """ - loop.call_soon(loop.stop) - loop.run_forever() - - -class SilentWSGIRequestHandler(WSGIRequestHandler): - - def get_stderr(self): - return io.StringIO() - - def log_message(self, format, *args): - pass - - -class SilentWSGIServer(WSGIServer): - - request_timeout = 2 - - def get_request(self): - request, client_addr = super().get_request() - request.settimeout(self.request_timeout) - return request, client_addr - - def handle_error(self, request, client_address): - pass - - -class SSLWSGIServerMixin: - - def finish_request(self, request, client_address): - # The relative location of our test directory (which - # contains the ssl key and certificate files) differs - # between the stdlib and stand-alone asyncio. - # Prefer our own if we can find it. - here = os.path.join(os.path.dirname(__file__), '..', 'tests') - if not os.path.isdir(here): - here = os.path.join(os.path.dirname(os.__file__), - 'test', 'test_asyncio') - keyfile = os.path.join(here, 'ssl_key.pem') - certfile = os.path.join(here, 'ssl_cert.pem') - context = ssl.SSLContext() - context.load_cert_chain(certfile, keyfile) - - ssock = context.wrap_socket(request, server_side=True) - try: - self.RequestHandlerClass(ssock, client_address, self) - ssock.close() - except OSError: - # maybe socket has been closed by peer - pass - - -class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer): - pass - - -def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): - - def app(environ, start_response): - status = '200 OK' - headers = [('Content-type', 'text/plain')] - start_response(status, headers) - return [b'Test message'] - - # Run the test WSGI server in a separate thread in order not to - # interfere with event handling in the main thread - server_class = server_ssl_cls if use_ssl else server_cls - httpd = server_class(address, SilentWSGIRequestHandler) - httpd.set_app(app) - httpd.address = httpd.server_address - server_thread = threading.Thread( - target=lambda: httpd.serve_forever(poll_interval=0.05)) - server_thread.start() - try: - yield httpd - finally: - httpd.shutdown() - httpd.server_close() - server_thread.join() - - -if hasattr(socket, 'AF_UNIX'): - - class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): - - def server_bind(self): - socketserver.UnixStreamServer.server_bind(self) - self.server_name = '127.0.0.1' - self.server_port = 80 - - - class UnixWSGIServer(UnixHTTPServer, WSGIServer): - - request_timeout = 2 - - def server_bind(self): - UnixHTTPServer.server_bind(self) - self.setup_environ() - - def get_request(self): - request, client_addr = super().get_request() - request.settimeout(self.request_timeout) - # Code in the stdlib expects that get_request - # will return a socket and a tuple (host, port). - # However, this isn't true for UNIX sockets, - # as the second return value will be a path; - # hence we return some fake data sufficient - # to get the tests going - return request, ('127.0.0.1', '') - - - class SilentUnixWSGIServer(UnixWSGIServer): - - def handle_error(self, request, client_address): - pass - - - class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer): - pass - - - def gen_unix_socket_path(): - with tempfile.NamedTemporaryFile() as file: - return file.name - - - @contextlib.contextmanager - def unix_socket_path(): - path = gen_unix_socket_path() - try: - yield path - finally: - try: - os.unlink(path) - except OSError: - pass - - - @contextlib.contextmanager - def run_test_unix_server(*, use_ssl=False): - with unix_socket_path() as path: - yield from _run_test_server(address=path, use_ssl=use_ssl, - server_cls=SilentUnixWSGIServer, - server_ssl_cls=UnixSSLWSGIServer) - - -@contextlib.contextmanager -def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): - yield from _run_test_server(address=(host, port), use_ssl=use_ssl, - server_cls=SilentWSGIServer, - server_ssl_cls=SSLWSGIServer) - - -def make_test_protocol(base): - dct = {} - for name in dir(base): - if name.startswith('__') and name.endswith('__'): - # skip magic names - continue - dct[name] = MockCallback(return_value=None) - return type('TestProtocol', (base,) + base.__bases__, dct)() - - -class TestSelector(selectors.BaseSelector): - - def __init__(self): - self.keys = {} - - def register(self, fileobj, events, data=None): - key = selectors.SelectorKey(fileobj, 0, events, data) - self.keys[fileobj] = key - return key - - def unregister(self, fileobj): - return self.keys.pop(fileobj) - - def select(self, timeout): - return [] - - def get_map(self): - return self.keys - - -class TestLoop(base_events.BaseEventLoop): - """Loop for unittests. - - It manages self time directly. - If something scheduled to be executed later then - on next loop iteration after all ready handlers done - generator passed to __init__ is calling. - - Generator should be like this: - - def gen(): - ... - when = yield ... - ... = yield time_advance - - Value returned by yield is absolute time of next scheduled handler. - Value passed to yield is time advance to move loop's time forward. - """ - - def __init__(self, gen=None): - super().__init__() - - if gen is None: - def gen(): - yield - self._check_on_close = False - else: - self._check_on_close = True - - self._gen = gen() - next(self._gen) - self._time = 0 - self._clock_resolution = 1e-9 - self._timers = [] - self._selector = TestSelector() - - self.readers = {} - self.writers = {} - self.reset_counters() - - self._transports = weakref.WeakValueDictionary() - - def time(self): - return self._time - - def advance_time(self, advance): - """Move test time forward.""" - if advance: - self._time += advance - - def close(self): - super().close() - if self._check_on_close: - try: - self._gen.send(0) - except StopIteration: - pass - else: # pragma: no cover - raise AssertionError("Time generator is not finished") - - def _add_reader(self, fd, callback, *args): - self.readers[fd] = events.Handle(callback, args, self) - - def _remove_reader(self, fd): - self.remove_reader_count[fd] += 1 - if fd in self.readers: - del self.readers[fd] - return True - else: - return False - - def assert_reader(self, fd, callback, *args): - assert fd in self.readers, 'fd {} is not registered'.format(fd) - handle = self.readers[fd] - assert handle._callback == callback, '{!r} != {!r}'.format( - handle._callback, callback) - assert handle._args == args, '{!r} != {!r}'.format( - handle._args, args) - - def _add_writer(self, fd, callback, *args): - self.writers[fd] = events.Handle(callback, args, self) - - def _remove_writer(self, fd): - self.remove_writer_count[fd] += 1 - if fd in self.writers: - del self.writers[fd] - return True - else: - return False - - def assert_writer(self, fd, callback, *args): - assert fd in self.writers, 'fd {} is not registered'.format(fd) - handle = self.writers[fd] - assert handle._callback == callback, '{!r} != {!r}'.format( - handle._callback, callback) - assert handle._args == args, '{!r} != {!r}'.format( - handle._args, args) - - def _ensure_fd_no_transport(self, fd): - try: - transport = self._transports[fd] - except KeyError: - pass - else: - raise RuntimeError( - 'File descriptor {!r} is used by transport {!r}'.format( - fd, transport)) - - def add_reader(self, fd, callback, *args): - """Add a reader callback.""" - self._ensure_fd_no_transport(fd) - return self._add_reader(fd, callback, *args) - - def remove_reader(self, fd): - """Remove a reader callback.""" - self._ensure_fd_no_transport(fd) - return self._remove_reader(fd) - - def add_writer(self, fd, callback, *args): - """Add a writer callback..""" - self._ensure_fd_no_transport(fd) - return self._add_writer(fd, callback, *args) - - def remove_writer(self, fd): - """Remove a writer callback.""" - self._ensure_fd_no_transport(fd) - return self._remove_writer(fd) - - def reset_counters(self): - self.remove_reader_count = collections.defaultdict(int) - self.remove_writer_count = collections.defaultdict(int) - - def _run_once(self): - super()._run_once() - for when in self._timers: - advance = self._gen.send(when) - self.advance_time(advance) - self._timers = [] - - def call_at(self, when, callback, *args): - self._timers.append(when) - return super().call_at(when, callback, *args) - - def _process_events(self, event_list): - return - - def _write_to_self(self): - pass - - -def MockCallback(**kwargs): - return mock.Mock(spec=['__call__'], **kwargs) - - -class MockPattern(str): - """A regex based str with a fuzzy __eq__. - - Use this helper with 'mock.assert_called_with', or anywhere - where a regex comparison between strings is needed. - - For instance: - mock_call.assert_called_with(MockPattern('spam.*ham')) - """ - def __eq__(self, other): - return bool(re.search(str(self), other, re.S)) - - -def get_function_source(func): - source = events._get_function_source(func) - if source is None: - raise ValueError("unable to get the source of %r" % (func,)) - return source - - -class TestCase(unittest.TestCase): - def set_event_loop(self, loop, *, cleanup=True): - assert loop is not None - # ensure that the event loop is passed explicitly in asyncio - events.set_event_loop(None) - if cleanup: - self.addCleanup(loop.close) - - def new_test_loop(self, gen=None): - loop = TestLoop(gen) - self.set_event_loop(loop) - return loop - - def setUp(self): - self._get_running_loop = events._get_running_loop - events._get_running_loop = lambda: None - - def tearDown(self): - events._get_running_loop = self._get_running_loop - - events.set_event_loop(None) - - # Detect CPython bug #23353: ensure that yield/yield-from is not used - # in an except block of a generator - self.assertEqual(sys.exc_info(), (None, None, None)) - - if not compat.PY34: - # Python 3.3 compatibility - def subTest(self, *args, **kwargs): - class EmptyCM: - def __enter__(self): - pass - def __exit__(self, *exc): - pass - return EmptyCM() - - -@contextlib.contextmanager -def disable_logger(): - """Context manager to disable asyncio logger. - - For example, it can be used to ignore warnings in debug mode. - """ - old_level = logger.level - try: - logger.setLevel(logging.CRITICAL+1) - yield - finally: - logger.setLevel(old_level) - - -def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM, - family=socket.AF_INET): - """Create a mock of a non-blocking socket.""" - sock = mock.MagicMock(socket.socket) - sock.proto = proto - sock.type = type - sock.family = family - sock.gettimeout.return_value = 0.0 - return sock - - -def force_legacy_ssl_support(): - return mock.patch('asyncio.sslproto._is_sslproto_available', - return_value=False) diff --git a/Lib/asyncio/threads.py b/Lib/asyncio/threads.py new file mode 100644 index 0000000000..db048a8231 --- /dev/null +++ b/Lib/asyncio/threads.py @@ -0,0 +1,25 @@ +"""High-level support for working with threads in asyncio""" + +import functools +import contextvars + +from . import events + + +__all__ = "to_thread", + + +async def to_thread(func, /, *args, **kwargs): + """Asynchronously run function *func* in a separate thread. + + Any *args and **kwargs supplied for this function are directly passed + to *func*. Also, the current :class:`contextvars.Context` is propagated, + allowing context variables from the main thread to be accessed in the + separate thread. + + Return a coroutine that can be awaited to get the eventual result of *func*. + """ + loop = events.get_running_loop() + ctx = contextvars.copy_context() + func_call = functools.partial(ctx.run, func, *args, **kwargs) + return await loop.run_in_executor(None, func_call) diff --git a/Lib/asyncio/timeouts.py b/Lib/asyncio/timeouts.py new file mode 100644 index 0000000000..30042abb3a --- /dev/null +++ b/Lib/asyncio/timeouts.py @@ -0,0 +1,168 @@ +import enum + +from types import TracebackType +from typing import final, Optional, Type + +from . import events +from . import exceptions +from . import tasks + + +__all__ = ( + "Timeout", + "timeout", + "timeout_at", +) + + +class _State(enum.Enum): + CREATED = "created" + ENTERED = "active" + EXPIRING = "expiring" + EXPIRED = "expired" + EXITED = "finished" + + +@final +class Timeout: + """Asynchronous context manager for cancelling overdue coroutines. + + Use `timeout()` or `timeout_at()` rather than instantiating this class directly. + """ + + def __init__(self, when: Optional[float]) -> None: + """Schedule a timeout that will trigger at a given loop time. + + - If `when` is `None`, the timeout will never trigger. + - If `when < loop.time()`, the timeout will trigger on the next + iteration of the event loop. + """ + self._state = _State.CREATED + + self._timeout_handler: Optional[events.TimerHandle] = None + self._task: Optional[tasks.Task] = None + self._when = when + + def when(self) -> Optional[float]: + """Return the current deadline.""" + return self._when + + def reschedule(self, when: Optional[float]) -> None: + """Reschedule the timeout.""" + if self._state is not _State.ENTERED: + if self._state is _State.CREATED: + raise RuntimeError("Timeout has not been entered") + raise RuntimeError( + f"Cannot change state of {self._state.value} Timeout", + ) + + self._when = when + + if self._timeout_handler is not None: + self._timeout_handler.cancel() + + if when is None: + self._timeout_handler = None + else: + loop = events.get_running_loop() + if when <= loop.time(): + self._timeout_handler = loop.call_soon(self._on_timeout) + else: + self._timeout_handler = loop.call_at(when, self._on_timeout) + + def expired(self) -> bool: + """Is timeout expired during execution?""" + return self._state in (_State.EXPIRING, _State.EXPIRED) + + def __repr__(self) -> str: + info = [''] + if self._state is _State.ENTERED: + when = round(self._when, 3) if self._when is not None else None + info.append(f"when={when}") + info_str = ' '.join(info) + return f"" + + async def __aenter__(self) -> "Timeout": + if self._state is not _State.CREATED: + raise RuntimeError("Timeout has already been entered") + task = tasks.current_task() + if task is None: + raise RuntimeError("Timeout should be used inside a task") + self._state = _State.ENTERED + self._task = task + self._cancelling = self._task.cancelling() + self.reschedule(self._when) + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + assert self._state in (_State.ENTERED, _State.EXPIRING) + + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._timeout_handler = None + + if self._state is _State.EXPIRING: + self._state = _State.EXPIRED + + if self._task.uncancel() <= self._cancelling and exc_type is exceptions.CancelledError: + # Since there are no new cancel requests, we're + # handling this. + raise TimeoutError from exc_val + elif self._state is _State.ENTERED: + self._state = _State.EXITED + + return None + + def _on_timeout(self) -> None: + assert self._state is _State.ENTERED + self._task.cancel() + self._state = _State.EXPIRING + # drop the reference early + self._timeout_handler = None + + +def timeout(delay: Optional[float]) -> Timeout: + """Timeout async context manager. + + Useful in cases when you want to apply timeout logic around block + of code or in cases when asyncio.wait_for is not suitable. For example: + + >>> async with asyncio.timeout(10): # 10 seconds timeout + ... await long_running_task() + + + delay - value in seconds or None to disable timeout logic + + long_running_task() is interrupted by raising asyncio.CancelledError, + the top-most affected timeout() context manager converts CancelledError + into TimeoutError. + """ + loop = events.get_running_loop() + return Timeout(loop.time() + delay if delay is not None else None) + + +def timeout_at(when: Optional[float]) -> Timeout: + """Schedule the timeout at absolute time. + + Like timeout() but argument gives absolute time in the same clock system + as loop.time(). + + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + + >>> async with asyncio.timeout_at(loop.time() + 10): + ... await long_running_task() + + + when - a deadline when timeout occurs or None to disable timeout logic + + long_running_task() is interrupted by raising asyncio.CancelledError, + the top-most affected timeout() context manager converts CancelledError + into TimeoutError. + """ + return Timeout(when) diff --git a/Lib/asyncio/transports.py b/Lib/asyncio/transports.py index 0db0875715..30fd41d49a 100644 --- a/Lib/asyncio/transports.py +++ b/Lib/asyncio/transports.py @@ -1,15 +1,16 @@ """Abstract Transport class.""" -from asyncio import compat - -__all__ = ['BaseTransport', 'ReadTransport', 'WriteTransport', - 'Transport', 'DatagramTransport', 'SubprocessTransport', - ] +__all__ = ( + 'BaseTransport', 'ReadTransport', 'WriteTransport', + 'Transport', 'DatagramTransport', 'SubprocessTransport', +) class BaseTransport: """Base class for transports.""" + __slots__ = ('_extra',) + def __init__(self, extra=None): if extra is None: extra = {} @@ -28,8 +29,8 @@ def close(self): Buffered data will be flushed asynchronously. No more data will be received. After all buffered data is flushed, the - protocol's connection_lost() method will (eventually) called - with None as its argument. + protocol's connection_lost() method will (eventually) be + called with None as its argument. """ raise NotImplementedError @@ -45,6 +46,12 @@ def get_protocol(self): class ReadTransport(BaseTransport): """Interface for read-only transports.""" + __slots__ = () + + def is_reading(self): + """Return True if the transport is receiving.""" + raise NotImplementedError + def pause_reading(self): """Pause the receiving end. @@ -65,6 +72,8 @@ def resume_reading(self): class WriteTransport(BaseTransport): """Interface for write-only transports.""" + __slots__ = () + def set_write_buffer_limits(self, high=None, low=None): """Set the high- and low-water limits for write flow control. @@ -90,6 +99,12 @@ def get_write_buffer_size(self): """Return the current size of the write buffer.""" raise NotImplementedError + def get_write_buffer_limits(self): + """Get the high and low watermarks for write flow control. + Return a tuple (low, high) where low and high are + positive number of bytes.""" + raise NotImplementedError + def write(self, data): """Write some data bytes to the transport. @@ -104,7 +119,7 @@ def writelines(self, list_of_data): The default implementation concatenates the arguments and calls write() on the result. """ - data = compat.flatten_list_bytes(list_of_data) + data = b''.join(list_of_data) self.write(data) def write_eof(self): @@ -151,10 +166,14 @@ class Transport(ReadTransport, WriteTransport): except writelines(), which calls write() in a loop. """ + __slots__ = () + class DatagramTransport(BaseTransport): """Interface for datagram (UDP) transports.""" + __slots__ = () + def sendto(self, data, addr=None): """Send data to the transport. @@ -177,6 +196,8 @@ def abort(self): class SubprocessTransport(BaseTransport): + __slots__ = () + def get_pid(self): """Get subprocess id.""" raise NotImplementedError @@ -244,6 +265,8 @@ class _FlowControlMixin(Transport): resume_writing() may be called. """ + __slots__ = ('_loop', '_protocol_paused', '_high_water', '_low_water') + def __init__(self, extra=None, loop=None): super().__init__(extra) assert loop is not None @@ -259,7 +282,9 @@ def _maybe_pause_protocol(self): self._protocol_paused = True try: self._protocol.pause_writing() - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: self._loop.call_exception_handler({ 'message': 'protocol.pause_writing() failed', 'exception': exc, @@ -269,11 +294,13 @@ def _maybe_pause_protocol(self): def _maybe_resume_protocol(self): if (self._protocol_paused and - self.get_write_buffer_size() <= self._low_water): + self.get_write_buffer_size() <= self._low_water): self._protocol_paused = False try: self._protocol.resume_writing() - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: self._loop.call_exception_handler({ 'message': 'protocol.resume_writing() failed', 'exception': exc, @@ -287,14 +314,16 @@ def get_write_buffer_limits(self): def _set_write_buffer_limits(self, high=None, low=None): if high is None: if low is None: - high = 64*1024 + high = 64 * 1024 else: - high = 4*low + high = 4 * low if low is None: low = high // 4 + if not high >= low >= 0: - raise ValueError('high (%r) must be >= low (%r) must be >= 0' % - (high, low)) + raise ValueError( + f'high ({high!r}) must be >= low ({low!r}) must be >= 0') + self._high_water = high self._low_water = low diff --git a/Lib/asyncio/trsock.py b/Lib/asyncio/trsock.py new file mode 100644 index 0000000000..c1f20473b3 --- /dev/null +++ b/Lib/asyncio/trsock.py @@ -0,0 +1,98 @@ +import socket + + +class TransportSocket: + + """A socket-like wrapper for exposing real transport sockets. + + These objects can be safely returned by APIs like + `transport.get_extra_info('socket')`. All potentially disruptive + operations (like "socket.close()") are banned. + """ + + __slots__ = ('_sock',) + + def __init__(self, sock: socket.socket): + self._sock = sock + + @property + def family(self): + return self._sock.family + + @property + def type(self): + return self._sock.type + + @property + def proto(self): + return self._sock.proto + + def __repr__(self): + s = ( + f"" + + def __getstate__(self): + raise TypeError("Cannot serialize asyncio.TransportSocket object") + + def fileno(self): + return self._sock.fileno() + + def dup(self): + return self._sock.dup() + + def get_inheritable(self): + return self._sock.get_inheritable() + + def shutdown(self, how): + # asyncio doesn't currently provide a high-level transport API + # to shutdown the connection. + self._sock.shutdown(how) + + def getsockopt(self, *args, **kwargs): + return self._sock.getsockopt(*args, **kwargs) + + def setsockopt(self, *args, **kwargs): + self._sock.setsockopt(*args, **kwargs) + + def getpeername(self): + return self._sock.getpeername() + + def getsockname(self): + return self._sock.getsockname() + + def getsockbyname(self): + return self._sock.getsockbyname() + + def settimeout(self, value): + if value == 0: + return + raise ValueError( + 'settimeout(): only 0 timeout is allowed on transport sockets') + + def gettimeout(self): + return 0 + + def setblocking(self, flag): + if not flag: + return + raise ValueError( + 'setblocking(): transport sockets cannot be blocking') diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py index 9db09b9d9b..f2e920ada4 100644 --- a/Lib/asyncio/unix_events.py +++ b/Lib/asyncio/unix_events.py @@ -1,7 +1,10 @@ """Selector event loop for Unix with signal handling.""" import errno +import io +import itertools import os +import selectors import signal import socket import stat @@ -10,25 +13,27 @@ import threading import warnings - from . import base_events from . import base_subprocess -from . import compat from . import constants from . import coroutines from . import events +from . import exceptions from . import futures from . import selector_events -from . import selectors +from . import tasks from . import transports -from .coroutines import coroutine from .log import logger -__all__ = ['SelectorEventLoop', - 'AbstractChildWatcher', 'SafeChildWatcher', - 'FastChildWatcher', 'DefaultEventLoopPolicy', - ] +__all__ = ( + 'SelectorEventLoop', + 'AbstractChildWatcher', 'SafeChildWatcher', + 'FastChildWatcher', 'PidfdChildWatcher', + 'MultiLoopChildWatcher', 'ThreadedChildWatcher', + 'DefaultEventLoopPolicy', +) + if sys.platform == 'win32': # pragma: no cover raise ImportError('Signals are not really supported on Windows') @@ -39,11 +44,14 @@ def _sighandler_noop(signum, frame): pass -try: - _fspath = os.fspath -except AttributeError: - # Python 3.5 or earlier - _fspath = lambda path: path +def waitstatus_to_exitcode(status): + try: + return os.waitstatus_to_exitcode(status) + except ValueError: + # The child exited, but we don't understand its status. + # This shouldn't happen, but if it does, let's just + # return that status; perhaps that helps debug it. + return status class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): @@ -56,13 +64,19 @@ def __init__(self, selector=None): super().__init__(selector) self._signal_handlers = {} - def _socketpair(self): - return socket.socketpair() - def close(self): super().close() - for sig in list(self._signal_handlers): - self.remove_signal_handler(sig) + if not sys.is_finalizing(): + for sig in list(self._signal_handlers): + self.remove_signal_handler(sig) + else: + if self._signal_handlers: + warnings.warn(f"Closing the loop {self!r} " + f"on interpreter shutdown " + f"stage, skipping signal handlers removal", + ResourceWarning, + source=self) + self._signal_handlers.clear() def _process_self_data(self, data): for signum in data: @@ -77,8 +91,8 @@ def add_signal_handler(self, sig, callback, *args): Raise ValueError if the signal number is invalid or uncatchable. Raise RuntimeError if there is a problem setting up the handler. """ - if (coroutines.iscoroutine(callback) - or coroutines.iscoroutinefunction(callback)): + if (coroutines.iscoroutine(callback) or + coroutines.iscoroutinefunction(callback)): raise TypeError("coroutines cannot be used " "with add_signal_handler()") self._check_signal(sig) @@ -92,12 +106,12 @@ def add_signal_handler(self, sig, callback, *args): except (ValueError, OSError) as exc: raise RuntimeError(str(exc)) - handle = events.Handle(callback, args, self) + handle = events.Handle(callback, args, self, None) self._signal_handlers[sig] = handle try: # Register a dummy signal handler to ask Python to write the signal - # number in the wakup file descriptor. _process_self_data() will + # number in the wakeup file descriptor. _process_self_data() will # read signal numbers from this file descriptor to handle signals. signal.signal(sig, _sighandler_noop) @@ -112,7 +126,7 @@ def add_signal_handler(self, sig, callback, *args): logger.info('set_wakeup_fd(-1) failed: %s', nexc) if exc.errno == errno.EINVAL: - raise RuntimeError('sig {} cannot be caught'.format(sig)) + raise RuntimeError(f'sig {sig} cannot be caught') else: raise @@ -146,7 +160,7 @@ def remove_signal_handler(self, sig): signal.signal(sig, handler) except OSError as exc: if exc.errno == errno.EINVAL: - raise RuntimeError('sig {} cannot be caught'.format(sig)) + raise RuntimeError(f'sig {sig} cannot be caught') else: raise @@ -165,11 +179,10 @@ def _check_signal(self, sig): Raise RuntimeError if there is a problem setting up the handler. """ if not isinstance(sig, int): - raise TypeError('sig must be an int, not {!r}'.format(sig)) + raise TypeError(f'sig must be an int, not {sig!r}') - if not (1 <= sig < signal.NSIG): - raise ValueError( - 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + if sig not in signal.valid_signals(): + raise ValueError(f'invalid signal number {sig}') def _make_read_pipe_transport(self, pipe, protocol, waiter=None, extra=None): @@ -179,43 +192,48 @@ def _make_write_pipe_transport(self, pipe, protocol, waiter=None, extra=None): return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) - @coroutine - def _make_subprocess_transport(self, protocol, args, shell, - stdin, stdout, stderr, bufsize, - extra=None, **kwargs): - with events.get_child_watcher() as watcher: + async def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + watcher = events.get_child_watcher() + + with watcher: + if not watcher.is_active(): + # Check early. + # Raising exception before process creation + # prevents subprocess execution if the watcher + # is not ready to handle it. + raise RuntimeError("asyncio.get_child_watcher() is not activated, " + "subprocess support is not installed.") waiter = self.create_future() transp = _UnixSubprocessTransport(self, protocol, args, shell, - stdin, stdout, stderr, bufsize, - waiter=waiter, extra=extra, - **kwargs) - + stdin, stdout, stderr, bufsize, + waiter=waiter, extra=extra, + **kwargs) watcher.add_child_handler(transp.get_pid(), - self._child_watcher_callback, transp) + self._child_watcher_callback, transp) try: - yield from waiter - except Exception as exc: - # Workaround CPython bug #23353: using yield/yield-from in an - # except block of a generator doesn't clear properly - # sys.exc_info() - err = exc - else: - err = None - - if err is not None: + await waiter + except (SystemExit, KeyboardInterrupt): + raise + except BaseException: transp.close() - yield from transp._wait() - raise err + await transp._wait() + raise return transp def _child_watcher_callback(self, pid, returncode, transp): self.call_soon_threadsafe(transp._process_exited, returncode) - @coroutine - def create_unix_connection(self, protocol_factory, path, *, - ssl=None, sock=None, - server_hostname=None): + async def create_unix_connection( + self, protocol_factory, path=None, *, + ssl=None, sock=None, + server_hostname=None, + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): assert server_hostname is None or isinstance(server_hostname, str) if ssl: if server_hostname is None: @@ -224,16 +242,23 @@ def create_unix_connection(self, protocol_factory, path, *, else: if server_hostname is not None: raise ValueError('server_hostname is only meaningful with ssl') + if ssl_handshake_timeout is not None: + raise ValueError( + 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') if path is not None: if sock is not None: raise ValueError( 'path and sock can not be specified at the same time') + path = os.fspath(path) sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) try: sock.setblocking(False) - yield from self.sock_connect(sock, path) + await self.sock_connect(sock, path) except: sock.close() raise @@ -242,28 +267,40 @@ def create_unix_connection(self, protocol_factory, path, *, if sock is None: raise ValueError('no path and sock were specified') if (sock.family != socket.AF_UNIX or - not base_events._is_stream_socket(sock)): + sock.type != socket.SOCK_STREAM): raise ValueError( - 'A UNIX Domain Stream Socket was expected, got {!r}' - .format(sock)) + f'A UNIX Domain Stream Socket was expected, got {sock!r}') sock.setblocking(False) - transport, protocol = yield from self._create_connection_transport( - sock, protocol_factory, ssl, server_hostname) + transport, protocol = await self._create_connection_transport( + sock, protocol_factory, ssl, server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) return transport, protocol - @coroutine - def create_unix_server(self, protocol_factory, path=None, *, - sock=None, backlog=100, ssl=None): + async def create_unix_server( + self, protocol_factory, path=None, *, + sock=None, backlog=100, ssl=None, + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, + start_serving=True): if isinstance(ssl, bool): raise TypeError('ssl argument must be an SSLContext or None') + if ssl_handshake_timeout is not None and not ssl: + raise ValueError( + 'ssl_handshake_timeout is only meaningful with ssl') + + if ssl_shutdown_timeout is not None and not ssl: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') + if path is not None: if sock is not None: raise ValueError( 'path and sock can not be specified at the same time') - path = _fspath(path) + path = os.fspath(path) sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) # Check for abstract socket. `str` and `bytes` paths are supported. @@ -275,7 +312,8 @@ def create_unix_server(self, protocol_factory, path=None, *, pass except OSError as err: # Directory may have permissions only to create socket. - logger.error('Unable to check or remove stale UNIX socket %r: %r', path, err) + logger.error('Unable to check or remove stale UNIX socket ' + '%r: %r', path, err) try: sock.bind(path) @@ -284,7 +322,7 @@ def create_unix_server(self, protocol_factory, path=None, *, if exc.errno == errno.EADDRINUSE: # Let's improve the error message by adding # with what exact address it occurs. - msg = 'Address {!r} is already in use'.format(path) + msg = f'Address {path!r} is already in use' raise OSError(errno.EADDRINUSE, msg) from None else: raise @@ -297,28 +335,126 @@ def create_unix_server(self, protocol_factory, path=None, *, 'path was not specified, and no sock specified') if (sock.family != socket.AF_UNIX or - not base_events._is_stream_socket(sock)): + sock.type != socket.SOCK_STREAM): raise ValueError( - 'A UNIX Domain Stream Socket was expected, got {!r}' - .format(sock)) + f'A UNIX Domain Stream Socket was expected, got {sock!r}') - server = base_events.Server(self, [sock]) - sock.listen(backlog) sock.setblocking(False) - self._start_serving(protocol_factory, sock, ssl, server) - return server + server = base_events.Server(self, [sock], protocol_factory, + ssl, backlog, ssl_handshake_timeout, + ssl_shutdown_timeout) + if start_serving: + server._start_serving() + # Skip one loop iteration so that all 'loop.add_reader' + # go through. + await tasks.sleep(0) + return server -#if hasattr(os, 'set_blocking'): -# def _set_nonblocking(fd): -# os.set_blocking(fd, False) -#else: -# import fcntl + async def _sock_sendfile_native(self, sock, file, offset, count): + try: + os.sendfile + except AttributeError: + raise exceptions.SendfileNotAvailableError( + "os.sendfile() is not available") + try: + fileno = file.fileno() + except (AttributeError, io.UnsupportedOperation) as err: + raise exceptions.SendfileNotAvailableError("not a regular file") + try: + fsize = os.fstat(fileno).st_size + except OSError: + raise exceptions.SendfileNotAvailableError("not a regular file") + blocksize = count if count else fsize + if not blocksize: + return 0 # empty file + + fut = self.create_future() + self._sock_sendfile_native_impl(fut, None, sock, fileno, + offset, count, blocksize, 0) + return await fut + + def _sock_sendfile_native_impl(self, fut, registered_fd, sock, fileno, + offset, count, blocksize, total_sent): + fd = sock.fileno() + if registered_fd is not None: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_writer(registered_fd) + if fut.cancelled(): + self._sock_sendfile_update_filepos(fileno, offset, total_sent) + return + if count: + blocksize = count - total_sent + if blocksize <= 0: + self._sock_sendfile_update_filepos(fileno, offset, total_sent) + fut.set_result(total_sent) + return -# def _set_nonblocking(fd): -# flags = fcntl.fcntl(fd, fcntl.F_GETFL) -# flags = flags | os.O_NONBLOCK -# fcntl.fcntl(fd, fcntl.F_SETFL, flags) + try: + sent = os.sendfile(fd, fileno, offset, blocksize) + except (BlockingIOError, InterruptedError): + if registered_fd is None: + self._sock_add_cancellation_callback(fut, sock) + self.add_writer(fd, self._sock_sendfile_native_impl, fut, + fd, sock, fileno, + offset, count, blocksize, total_sent) + except OSError as exc: + if (registered_fd is not None and + exc.errno == errno.ENOTCONN and + type(exc) is not ConnectionError): + # If we have an ENOTCONN and this isn't a first call to + # sendfile(), i.e. the connection was closed in the middle + # of the operation, normalize the error to ConnectionError + # to make it consistent across all Posix systems. + new_exc = ConnectionError( + "socket is not connected", errno.ENOTCONN) + new_exc.__cause__ = exc + exc = new_exc + if total_sent == 0: + # We can get here for different reasons, the main + # one being 'file' is not a regular mmap(2)-like + # file, in which case we'll fall back on using + # plain send(). + err = exceptions.SendfileNotAvailableError( + "os.sendfile call failed") + self._sock_sendfile_update_filepos(fileno, offset, total_sent) + fut.set_exception(err) + else: + self._sock_sendfile_update_filepos(fileno, offset, total_sent) + fut.set_exception(exc) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._sock_sendfile_update_filepos(fileno, offset, total_sent) + fut.set_exception(exc) + else: + if sent == 0: + # EOF + self._sock_sendfile_update_filepos(fileno, offset, total_sent) + fut.set_result(total_sent) + else: + offset += sent + total_sent += sent + if registered_fd is None: + self._sock_add_cancellation_callback(fut, sock) + self.add_writer(fd, self._sock_sendfile_native_impl, fut, + fd, sock, fileno, + offset, count, blocksize, total_sent) + + def _sock_sendfile_update_filepos(self, fileno, offset, total_sent): + if total_sent > 0: + os.lseek(fileno, offset, os.SEEK_SET) + + def _sock_add_cancellation_callback(self, fut, sock): + def cb(fut): + if fut.cancelled(): + fd = sock.fileno() + if fd != -1: + self.remove_writer(fd) + fut.add_done_callback(cb) class _UnixReadPipeTransport(transports.ReadTransport): @@ -333,6 +469,7 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): self._fileno = pipe.fileno() self._protocol = protocol self._closing = False + self._paused = False mode = os.fstat(self._fileno).st_mode if not (stat.S_ISFIFO(mode) or @@ -343,29 +480,36 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): self._protocol = None raise ValueError("Pipe transport is for pipes/sockets only.") - _set_nonblocking(self._fileno) + os.set_blocking(self._fileno, False) self._loop.call_soon(self._protocol.connection_made, self) # only start reading when connection_made() has been called - self._loop.call_soon(self._loop._add_reader, + self._loop.call_soon(self._add_reader, self._fileno, self._read_ready) if waiter is not None: # only wake up the waiter when connection_made() has been called self._loop.call_soon(futures._set_result_unless_cancelled, waiter, None) + def _add_reader(self, fd, callback): + if not self.is_reading(): + return + self._loop._add_reader(fd, callback) + + def is_reading(self): + return not self._paused and not self._closing + def __repr__(self): info = [self.__class__.__name__] if self._pipe is None: info.append('closed') elif self._closing: info.append('closing') - info.append('fd=%s' % self._fileno) + info.append(f'fd={self._fileno}') selector = getattr(self._loop, '_selector', None) if self._pipe is not None and selector is not None: polling = selector_events._test_selector_event( - selector, - self._fileno, selectors.EVENT_READ) + selector, self._fileno, selectors.EVENT_READ) if polling: info.append('polling') else: @@ -374,7 +518,7 @@ def __repr__(self): info.append('open') else: info.append('closed') - return '<%s>' % ' '.join(info) + return '<{}>'.format(' '.join(info)) def _read_ready(self): try: @@ -395,10 +539,20 @@ def _read_ready(self): self._loop.call_soon(self._call_connection_lost, None) def pause_reading(self): + if not self.is_reading(): + return + self._paused = True self._loop._remove_reader(self._fileno) + if self._loop.get_debug(): + logger.debug("%r pauses reading", self) def resume_reading(self): + if self._closing or not self._paused: + return + self._paused = False self._loop._add_reader(self._fileno, self._read_ready) + if self._loop.get_debug(): + logger.debug("%r resumes reading", self) def set_protocol(self, protocol): self._protocol = protocol @@ -413,15 +567,10 @@ def close(self): if not self._closing: self._close(None) - # On Python 3.3 and older, objects with a destructor part of a reference - # cycle are never destroyed. It's not more the case on Python 3.4 thanks - # to the PEP 442. - if compat.PY34: - def __del__(self): - if self._pipe is not None: - warnings.warn("unclosed transport %r" % self, ResourceWarning, - source=self) - self._pipe.close() + def __del__(self, _warn=warnings.warn): + if self._pipe is not None: + _warn(f"unclosed transport {self!r}", ResourceWarning, source=self) + self._pipe.close() def _fatal_error(self, exc, message='Fatal error on pipe transport'): # should be called by exception handler only @@ -476,7 +625,7 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): raise ValueError("Pipe transport is only for " "pipes, sockets and character devices") - _set_nonblocking(self._fileno) + os.set_blocking(self._fileno, False) self._loop.call_soon(self._protocol.connection_made, self) # On AIX, the reader trick (to be notified when the read end of the @@ -498,24 +647,23 @@ def __repr__(self): info.append('closed') elif self._closing: info.append('closing') - info.append('fd=%s' % self._fileno) + info.append(f'fd={self._fileno}') selector = getattr(self._loop, '_selector', None) if self._pipe is not None and selector is not None: polling = selector_events._test_selector_event( - selector, - self._fileno, selectors.EVENT_WRITE) + selector, self._fileno, selectors.EVENT_WRITE) if polling: info.append('polling') else: info.append('idle') bufsize = self.get_write_buffer_size() - info.append('bufsize=%s' % bufsize) + info.append(f'bufsize={bufsize}') elif self._pipe is not None: info.append('open') else: info.append('closed') - return '<%s>' % ' '.join(info) + return '<{}>'.format(' '.join(info)) def get_write_buffer_size(self): return len(self._buffer) @@ -549,7 +697,9 @@ def write(self, data): n = os.write(self._fileno, data) except (BlockingIOError, InterruptedError): n = 0 - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: self._conn_lost += 1 self._fatal_error(exc, 'Fatal write error on pipe transport') return @@ -569,7 +719,9 @@ def _write_ready(self): n = os.write(self._fileno, self._buffer) except (BlockingIOError, InterruptedError): pass - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: self._buffer.clear() self._conn_lost += 1 # Remove writer here, _fatal_error() doesn't it @@ -614,22 +766,17 @@ def close(self): # write_eof is all what we needed to close the write pipe self.write_eof() - # On Python 3.3 and older, objects with a destructor part of a reference - # cycle are never destroyed. It's not more the case on Python 3.4 thanks - # to the PEP 442. - if compat.PY34: - def __del__(self): - if self._pipe is not None: - warnings.warn("unclosed transport %r" % self, ResourceWarning, - source=self) - self._pipe.close() + def __del__(self, _warn=warnings.warn): + if self._pipe is not None: + _warn(f"unclosed transport {self!r}", ResourceWarning, source=self) + self._pipe.close() def abort(self): self._close(None) def _fatal_error(self, exc, message='Fatal error on pipe transport'): # should be called by exception handler only - if isinstance(exc, base_events._FATAL_ERROR_IGNORE): + if isinstance(exc, OSError): if self._loop.get_debug(): logger.debug("%r: %s", self, message, exc_info=True) else: @@ -659,45 +806,28 @@ def _call_connection_lost(self, exc): self._loop = None -#if hasattr(os, 'set_inheritable'): -# # Python 3.4 and newer -# _set_inheritable = os.set_inheritable -#else: -# import fcntl -# -# def _set_inheritable(fd, inheritable): -# cloexec_flag = getattr(fcntl, 'FD_CLOEXEC', 1) -# -# old = fcntl.fcntl(fd, fcntl.F_GETFD) -# if not inheritable: -# fcntl.fcntl(fd, fcntl.F_SETFD, old | cloexec_flag) -# else: -# fcntl.fcntl(fd, fcntl.F_SETFD, old & ~cloexec_flag) - - class _UnixSubprocessTransport(base_subprocess.BaseSubprocessTransport): def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): stdin_w = None - if stdin == subprocess.PIPE: - # Use a socket pair for stdin, since not all platforms + if stdin == subprocess.PIPE and sys.platform.startswith('aix'): + # Use a socket pair for stdin on AIX, since it does not # support selecting read events on the write end of a # socket (which we use in order to detect closing of the - # other end). Notably this is needed on AIX, and works - # just fine on other platforms. - stdin, stdin_w = self._loop._socketpair() - - # Mark the write end of the stdin pipe as non-inheritable, - # needed by close_fds=False on Python 3.3 and older - # (Python 3.4 implements the PEP 446, socketpair returns - # non-inheritable sockets) - _set_inheritable(stdin_w.fileno(), False) - self._proc = subprocess.Popen( - args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, - universal_newlines=False, bufsize=bufsize, **kwargs) - if stdin_w is not None: - stdin.close() - self._proc.stdin = open(stdin_w.detach(), 'wb', buffering=bufsize) + # other end). + stdin, stdin_w = socket.socketpair() + try: + self._proc = subprocess.Popen( + args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, + universal_newlines=False, bufsize=bufsize, **kwargs) + if stdin_w is not None: + stdin.close() + self._proc.stdin = open(stdin_w.detach(), 'wb', buffering=bufsize) + stdin_w = None + finally: + if stdin_w is not None: + stdin.close() + stdin_w.close() class AbstractChildWatcher: @@ -723,6 +853,13 @@ class AbstractChildWatcher: waitpid(-1), there should be only one active object per process. """ + def __init_subclass__(cls) -> None: + if cls.__module__ != __name__: + warnings._deprecated("AbstractChildWatcher", + "{name!r} is deprecated as of Python 3.12 and will be " + "removed in Python {remove}.", + remove=(3, 14)) + def add_child_handler(self, pid, callback, *args): """Register a new child handler. @@ -759,6 +896,15 @@ def close(self): """ raise NotImplementedError() + def is_active(self): + """Return ``True`` if the watcher is active and is used by the event loop. + + Return True if the watcher is installed and ready to handle process exit + notifications. + + """ + raise NotImplementedError() + def __enter__(self): """Enter the watcher's context and allow starting new processes @@ -770,6 +916,64 @@ def __exit__(self, a, b, c): raise NotImplementedError() +class PidfdChildWatcher(AbstractChildWatcher): + """Child watcher implementation using Linux's pid file descriptors. + + This child watcher polls process file descriptors (pidfds) to await child + process termination. In some respects, PidfdChildWatcher is a "Goldilocks" + child watcher implementation. It doesn't require signals or threads, doesn't + interfere with any processes launched outside the event loop, and scales + linearly with the number of subprocesses launched by the event loop. The + main disadvantage is that pidfds are specific to Linux, and only work on + recent (5.3+) kernels. + """ + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + pass + + def is_active(self): + return True + + def close(self): + pass + + def attach_loop(self, loop): + pass + + def add_child_handler(self, pid, callback, *args): + loop = events.get_running_loop() + pidfd = os.pidfd_open(pid) + loop._add_reader(pidfd, self._do_wait, pid, pidfd, callback, args) + + def _do_wait(self, pid, pidfd, callback, args): + loop = events.get_running_loop() + loop._remove_reader(pidfd) + try: + _, status = os.waitpid(pid, 0) + except ChildProcessError: + # The child process is already reaped + # (may happen if waitpid() is called elsewhere). + returncode = 255 + logger.warning( + "child process pid %d exit status already read: " + " will report returncode 255", + pid) + else: + returncode = waitstatus_to_exitcode(status) + + os.close(pidfd) + callback(pid, returncode, *args) + + def remove_child_handler(self, pid): + # asyncio never calls remove_child_handler() !!! + # The method is no-op but is implemented because + # abstract base classes require it. + return True + + class BaseChildWatcher(AbstractChildWatcher): def __init__(self): @@ -779,6 +983,9 @@ def __init__(self): def close(self): self.attach_loop(None) + def is_active(self): + return self._loop is not None and self._loop.is_running() + def _do_waitpid(self, expected_pid): raise NotImplementedError() @@ -808,7 +1015,9 @@ def attach_loop(self, loop): def _sig_chld(self): try: self._do_waitpid_all() - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: # self._loop should always be available here # as '_sig_chld' is added as a signal handler # in 'attach_loop' @@ -817,19 +1026,6 @@ def _sig_chld(self): 'exception': exc, }) - def _compute_returncode(self, status): - if os.WIFSIGNALED(status): - # The child process died because of a signal. - return -os.WTERMSIG(status) - elif os.WIFEXITED(status): - # The child process exited (e.g sys.exit()). - return os.WEXITSTATUS(status) - else: - # The child exited, but we don't understand its status. - # This shouldn't happen, but if it does, let's just - # return that status; perhaps that helps debug it. - return status - class SafeChildWatcher(BaseChildWatcher): """'Safe' child watcher implementation. @@ -842,6 +1038,13 @@ class SafeChildWatcher(BaseChildWatcher): big number of children (O(n) each time SIGCHLD is raised) """ + def __init__(self): + super().__init__() + warnings._deprecated("SafeChildWatcher", + "{name!r} is deprecated as of Python 3.12 and will be " + "removed in Python {remove}.", + remove=(3, 14)) + def close(self): self._callbacks.clear() super().close() @@ -853,11 +1056,6 @@ def __exit__(self, a, b, c): pass def add_child_handler(self, pid, callback, *args): - if self._loop is None: - raise RuntimeError( - "Cannot add child handler, " - "the child watcher does not have a loop attached") - self._callbacks[pid] = (callback, args) # Prevent a race condition in case the child is already terminated. @@ -893,7 +1091,7 @@ def _do_waitpid(self, expected_pid): # The child process is still alive. return - returncode = self._compute_returncode(status) + returncode = waitstatus_to_exitcode(status) if self._loop.get_debug(): logger.debug('process %s exited with returncode %s', expected_pid, returncode) @@ -925,6 +1123,10 @@ def __init__(self): self._lock = threading.Lock() self._zombies = {} self._forks = 0 + warnings._deprecated("FastChildWatcher", + "{name!r} is deprecated as of Python 3.12 and will be " + "removed in Python {remove}.", + remove=(3, 14)) def close(self): self._callbacks.clear() @@ -954,11 +1156,6 @@ def __exit__(self, a, b, c): def add_child_handler(self, pid, callback, *args): assert self._forks, "Must use the context manager" - if self._loop is None: - raise RuntimeError( - "Cannot add child handler, " - "the child watcher does not have a loop attached") - with self._lock: try: returncode = self._zombies.pop(pid) @@ -991,7 +1188,7 @@ def _do_waitpid_all(self): # A child process is still alive. return - returncode = self._compute_returncode(status) + returncode = waitstatus_to_exitcode(status) with self._lock: try: @@ -1020,6 +1217,228 @@ def _do_waitpid_all(self): callback(pid, returncode, *args) +class MultiLoopChildWatcher(AbstractChildWatcher): + """A watcher that doesn't require running loop in the main thread. + + This implementation registers a SIGCHLD signal handler on + instantiation (which may conflict with other code that + install own handler for this signal). + + The solution is safe but it has a significant overhead when + handling a big number of processes (*O(n)* each time a + SIGCHLD is received). + """ + + # Implementation note: + # The class keeps compatibility with AbstractChildWatcher ABC + # To achieve this it has empty attach_loop() method + # and doesn't accept explicit loop argument + # for add_child_handler()/remove_child_handler() + # but retrieves the current loop by get_running_loop() + + def __init__(self): + self._callbacks = {} + self._saved_sighandler = None + warnings._deprecated("MultiLoopChildWatcher", + "{name!r} is deprecated as of Python 3.12 and will be " + "removed in Python {remove}.", + remove=(3, 14)) + + def is_active(self): + return self._saved_sighandler is not None + + def close(self): + self._callbacks.clear() + if self._saved_sighandler is None: + return + + handler = signal.getsignal(signal.SIGCHLD) + if handler != self._sig_chld: + logger.warning("SIGCHLD handler was changed by outside code") + else: + signal.signal(signal.SIGCHLD, self._saved_sighandler) + self._saved_sighandler = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def add_child_handler(self, pid, callback, *args): + loop = events.get_running_loop() + self._callbacks[pid] = (loop, callback, args) + + # Prevent a race condition in case the child is already terminated. + self._do_waitpid(pid) + + def remove_child_handler(self, pid): + try: + del self._callbacks[pid] + return True + except KeyError: + return False + + def attach_loop(self, loop): + # Don't save the loop but initialize itself if called first time + # The reason to do it here is that attach_loop() is called from + # unix policy only for the main thread. + # Main thread is required for subscription on SIGCHLD signal + if self._saved_sighandler is not None: + return + + self._saved_sighandler = signal.signal(signal.SIGCHLD, self._sig_chld) + if self._saved_sighandler is None: + logger.warning("Previous SIGCHLD handler was set by non-Python code, " + "restore to default handler on watcher close.") + self._saved_sighandler = signal.SIG_DFL + + # Set SA_RESTART to limit EINTR occurrences. + signal.siginterrupt(signal.SIGCHLD, False) + + def _do_waitpid_all(self): + for pid in list(self._callbacks): + self._do_waitpid(pid) + + def _do_waitpid(self, expected_pid): + assert expected_pid > 0 + + try: + pid, status = os.waitpid(expected_pid, os.WNOHANG) + except ChildProcessError: + # The child process is already reaped + # (may happen if waitpid() is called elsewhere). + pid = expected_pid + returncode = 255 + logger.warning( + "Unknown child process pid %d, will report returncode 255", + pid) + debug_log = False + else: + if pid == 0: + # The child process is still alive. + return + + returncode = waitstatus_to_exitcode(status) + debug_log = True + try: + loop, callback, args = self._callbacks.pop(pid) + except KeyError: # pragma: no cover + # May happen if .remove_child_handler() is called + # after os.waitpid() returns. + logger.warning("Child watcher got an unexpected pid: %r", + pid, exc_info=True) + else: + if loop.is_closed(): + logger.warning("Loop %r that handles pid %r is closed", loop, pid) + else: + if debug_log and loop.get_debug(): + logger.debug('process %s exited with returncode %s', + expected_pid, returncode) + loop.call_soon_threadsafe(callback, pid, returncode, *args) + + def _sig_chld(self, signum, frame): + try: + self._do_waitpid_all() + except (SystemExit, KeyboardInterrupt): + raise + except BaseException: + logger.warning('Unknown exception in SIGCHLD handler', exc_info=True) + + +class ThreadedChildWatcher(AbstractChildWatcher): + """Threaded child watcher implementation. + + The watcher uses a thread per process + for waiting for the process finish. + + It doesn't require subscription on POSIX signal + but a thread creation is not free. + + The watcher has O(1) complexity, its performance doesn't depend + on amount of spawn processes. + """ + + def __init__(self): + self._pid_counter = itertools.count(0) + self._threads = {} + + def is_active(self): + return True + + def close(self): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def __del__(self, _warn=warnings.warn): + threads = [thread for thread in list(self._threads.values()) + if thread.is_alive()] + if threads: + _warn(f"{self.__class__} has registered but not finished child processes", + ResourceWarning, + source=self) + + def add_child_handler(self, pid, callback, *args): + loop = events.get_running_loop() + thread = threading.Thread(target=self._do_waitpid, + name=f"asyncio-waitpid-{next(self._pid_counter)}", + args=(loop, pid, callback, args), + daemon=True) + self._threads[pid] = thread + thread.start() + + def remove_child_handler(self, pid): + # asyncio never calls remove_child_handler() !!! + # The method is no-op but is implemented because + # abstract base classes require it. + return True + + def attach_loop(self, loop): + pass + + def _do_waitpid(self, loop, expected_pid, callback, args): + assert expected_pid > 0 + + try: + pid, status = os.waitpid(expected_pid, 0) + except ChildProcessError: + # The child process is already reaped + # (may happen if waitpid() is called elsewhere). + pid = expected_pid + returncode = 255 + logger.warning( + "Unknown child process pid %d, will report returncode 255", + pid) + else: + returncode = waitstatus_to_exitcode(status) + if loop.get_debug(): + logger.debug('process %s exited with returncode %s', + expected_pid, returncode) + + if loop.is_closed(): + logger.warning("Loop %r that handles pid %r is closed", loop, pid) + else: + loop.call_soon_threadsafe(callback, pid, returncode, *args) + + self._threads.pop(expected_pid) + +def can_use_pidfd(): + if not hasattr(os, 'pidfd_open'): + return False + try: + pid = os.getpid() + os.close(os.pidfd_open(pid, 0)) + except OSError: + # blocked by security policy like SECCOMP + return False + return True + + class _UnixDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy): """UNIX event loop policy with a watcher for child processes.""" _loop_factory = _UnixSelectorEventLoop @@ -1031,10 +1450,10 @@ def __init__(self): def _init_watcher(self): with events._lock: if self._watcher is None: # pragma: no branch - self._watcher = SafeChildWatcher() - if isinstance(threading.current_thread(), - threading._MainThread): - self._watcher.attach_loop(self._local._loop) + if can_use_pidfd(): + self._watcher = PidfdChildWatcher() + else: + self._watcher = ThreadedChildWatcher() def set_event_loop(self, loop): """Set the event loop. @@ -1046,18 +1465,21 @@ def set_event_loop(self, loop): super().set_event_loop(loop) - if self._watcher is not None and \ - isinstance(threading.current_thread(), threading._MainThread): + if (self._watcher is not None and + threading.current_thread() is threading.main_thread()): self._watcher.attach_loop(loop) def get_child_watcher(self): """Get the watcher for child processes. - If not yet set, a SafeChildWatcher object is automatically created. + If not yet set, a ThreadedChildWatcher object is automatically created. """ if self._watcher is None: self._init_watcher() + warnings._deprecated("get_child_watcher", + "{name!r} is deprecated as of Python 3.12 and will be " + "removed in Python {remove}.", remove=(3, 14)) return self._watcher def set_child_watcher(self, watcher): @@ -1069,6 +1491,10 @@ def set_child_watcher(self, watcher): self._watcher.close() self._watcher = watcher + warnings._deprecated("set_child_watcher", + "{name!r} is deprecated as of Python 3.12 and will be " + "removed in Python {remove}.", remove=(3, 14)) + SelectorEventLoop = _UnixSelectorEventLoop DefaultEventLoopPolicy = _UnixDefaultEventLoopPolicy diff --git a/Lib/asyncio/windows_events.py b/Lib/asyncio/windows_events.py index 2c68bc526a..cb613451a5 100644 --- a/Lib/asyncio/windows_events.py +++ b/Lib/asyncio/windows_events.py @@ -1,32 +1,41 @@ """Selector and proactor event loops for Windows.""" +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('win32 only') + +import _overlapped import _winapi import errno +from functools import partial import math +import msvcrt import socket import struct +import time import weakref from . import events from . import base_subprocess from . import futures +from . import exceptions from . import proactor_events from . import selector_events from . import tasks from . import windows_utils -# XXX RustPython TODO: _overlapped -# from . import _overlapped -from .coroutines import coroutine from .log import logger -__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor', - 'DefaultEventLoopPolicy', - ] +__all__ = ( + 'SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor', + 'DefaultEventLoopPolicy', 'WindowsSelectorEventLoopPolicy', + 'WindowsProactorEventLoopPolicy', +) -NULL = 0 -INFINITE = 0xffffffff +NULL = _winapi.NULL +INFINITE = _winapi.INFINITE ERROR_CONNECTION_REFUSED = 1225 ERROR_CONNECTION_ABORTED = 1236 @@ -53,7 +62,7 @@ def _repr_info(self): info = super()._repr_info() if self._ov is not None: state = 'pending' if self._ov.pending else 'completed' - info.insert(1, 'overlapped=<%s, %#x>' % (state, self._ov.address)) + info.insert(1, f'overlapped=<{state}, {self._ov.address:#x}>') return info def _cancel_overlapped(self): @@ -72,9 +81,9 @@ def _cancel_overlapped(self): self._loop.call_exception_handler(context) self._ov = None - def cancel(self): + def cancel(self, msg=None): self._cancel_overlapped() - return super().cancel() + return super().cancel(msg=msg) def set_exception(self, exception): super().set_exception(exception) @@ -109,12 +118,12 @@ def _poll(self): def _repr_info(self): info = super()._repr_info() - info.append('handle=%#x' % self._handle) + info.append(f'handle={self._handle:#x}') if self._handle is not None: state = 'signaled' if self._poll() else 'waiting' info.append(state) if self._wait_handle is not None: - info.append('wait_handle=%#x' % self._wait_handle) + info.append(f'wait_handle={self._wait_handle:#x}') return info def _unregister_wait_cb(self, fut): @@ -146,9 +155,9 @@ def _unregister_wait(self): self._unregister_wait_cb(None) - def cancel(self): + def cancel(self, msg=None): self._unregister_wait() - return super().cancel() + return super().cancel(msg=msg) def set_exception(self, exception): self._unregister_wait() @@ -297,9 +306,6 @@ def close(self): class _WindowsSelectorEventLoop(selector_events.BaseSelectorEventLoop): """Windows version of selector event loop.""" - def _socketpair(self): - return windows_utils.socketpair() - class ProactorEventLoop(proactor_events.BaseProactorEventLoop): """Windows version of proactor event loop using IOCP.""" @@ -309,20 +315,34 @@ def __init__(self, proactor=None): proactor = IocpProactor() super().__init__(proactor) - def _socketpair(self): - return windows_utils.socketpair() - - @coroutine - def create_pipe_connection(self, protocol_factory, address): + def run_forever(self): + try: + assert self._self_reading_future is None + self.call_soon(self._loop_self_reading) + super().run_forever() + finally: + if self._self_reading_future is not None: + ov = self._self_reading_future._ov + self._self_reading_future.cancel() + # self_reading_future always uses IOCP, so even though it's + # been cancelled, we need to make sure that the IOCP message + # is received so that the kernel is not holding on to the + # memory, possibly causing memory corruption later. Only + # unregister it if IO is complete in all respects. Otherwise + # we need another _poll() later to complete the IO. + if ov is not None and not ov.pending: + self._proactor._unregister(ov) + self._self_reading_future = None + + async def create_pipe_connection(self, protocol_factory, address): f = self._proactor.connect_pipe(address) - pipe = yield from f + pipe = await f protocol = protocol_factory() trans = self._make_duplex_pipe_transport(pipe, protocol, extra={'addr': address}) return trans, protocol - @coroutine - def start_serving_pipe(self, protocol_factory, address): + async def start_serving_pipe(self, protocol_factory, address): server = PipeServer(address) def loop_accept_pipe(f=None): @@ -347,6 +367,10 @@ def loop_accept_pipe(f=None): return f = self._proactor.accept_pipe(pipe) + except BrokenPipeError: + if pipe and pipe.fileno() != -1: + pipe.close() + self.call_soon(loop_accept_pipe) except OSError as exc: if pipe and pipe.fileno() != -1: self.call_exception_handler({ @@ -358,7 +382,8 @@ def loop_accept_pipe(f=None): elif self._debug: logger.warning("Accept pipe failed on pipe %r", pipe, exc_info=True) - except futures.CancelledError: + self.call_soon(loop_accept_pipe) + except exceptions.CancelledError: if pipe: pipe.close() else: @@ -368,28 +393,22 @@ def loop_accept_pipe(f=None): self.call_soon(loop_accept_pipe) return [server] - @coroutine - def _make_subprocess_transport(self, protocol, args, shell, - stdin, stdout, stderr, bufsize, - extra=None, **kwargs): + async def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): waiter = self.create_future() transp = _WindowsSubprocessTransport(self, protocol, args, shell, stdin, stdout, stderr, bufsize, waiter=waiter, extra=extra, **kwargs) try: - yield from waiter - except Exception as exc: - # Workaround CPython bug #23353: using yield/yield-from in an - # except block of a generator doesn't clear properly sys.exc_info() - err = exc - else: - err = None - - if err is not None: + await waiter + except (SystemExit, KeyboardInterrupt): + raise + except BaseException: transp.close() - yield from transp._wait() - raise err + await transp._wait() + raise return transp @@ -397,7 +416,7 @@ def _make_subprocess_transport(self, protocol, args, shell, class IocpProactor: """Proactor implementation using IOCP.""" - def __init__(self, concurrency=0xffffffff): + def __init__(self, concurrency=INFINITE): self._loop = None self._results = [] self._iocp = _overlapped.CreateIoCompletionPort( @@ -407,10 +426,16 @@ def __init__(self, concurrency=0xffffffff): self._unregistered = [] self._stopped_serving = weakref.WeakSet() + def _check_closed(self): + if self._iocp is None: + raise RuntimeError('IocpProactor is closed') + def __repr__(self): - return ('<%s overlapped#=%s result#=%s>' - % (self.__class__.__name__, len(self._cache), - len(self._results))) + info = ['overlapped#=%s' % len(self._cache), + 'result#=%s' % len(self._results)] + if self._iocp is None: + info.append('closed') + return '<%s %s>' % (self.__class__.__name__, " ".join(info)) def set_loop(self, loop): self._loop = loop @@ -420,13 +445,40 @@ def select(self, timeout=None): self._poll(timeout) tmp = self._results self._results = [] - return tmp + try: + return tmp + finally: + # Needed to break cycles when an exception occurs. + tmp = None def _result(self, value): fut = self._loop.create_future() fut.set_result(value) return fut + @staticmethod + def finish_socket_func(trans, key, ov): + try: + return ov.getresult() + except OSError as exc: + if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED, + _overlapped.ERROR_OPERATION_ABORTED): + raise ConnectionResetError(*exc.args) + else: + raise + + @classmethod + def _finish_recvfrom(cls, trans, key, ov, *, empty_result): + try: + return cls.finish_socket_func(trans, key, ov) + except OSError as exc: + # WSARecvFrom will report ERROR_PORT_UNREACHABLE when the same + # socket is used to send to an address that is not listening. + if exc.winerror == _overlapped.ERROR_PORT_UNREACHABLE: + return empty_result, None + else: + raise + def recv(self, conn, nbytes, flags=0): self._register_with_iocp(conn) ov = _overlapped.Overlapped(NULL) @@ -438,16 +490,50 @@ def recv(self, conn, nbytes, flags=0): except BrokenPipeError: return self._result(b'') - def finish_recv(trans, key, ov): - try: - return ov.getresult() - except OSError as exc: - if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: - raise ConnectionResetError(*exc.args) - else: - raise + return self._register(ov, conn, self.finish_socket_func) + + def recv_into(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + try: + if isinstance(conn, socket.socket): + ov.WSARecvInto(conn.fileno(), buf, flags) + else: + ov.ReadFileInto(conn.fileno(), buf) + except BrokenPipeError: + return self._result(0) - return self._register(ov, conn, finish_recv) + return self._register(ov, conn, self.finish_socket_func) + + def recvfrom(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + try: + ov.WSARecvFrom(conn.fileno(), nbytes, flags) + except BrokenPipeError: + return self._result((b'', None)) + + return self._register(ov, conn, partial(self._finish_recvfrom, + empty_result=b'')) + + def recvfrom_into(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + try: + ov.WSARecvFromInto(conn.fileno(), buf, flags) + except BrokenPipeError: + return self._result((0, None)) + + return self._register(ov, conn, partial(self._finish_recvfrom, + empty_result=0)) + + def sendto(self, conn, buf, flags=0, addr=None): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + + ov.WSASendTo(conn.fileno(), buf, flags, addr) + + return self._register(ov, conn, self.finish_socket_func) def send(self, conn, buf, flags=0): self._register_with_iocp(conn) @@ -457,16 +543,7 @@ def send(self, conn, buf, flags=0): else: ov.WriteFile(conn.fileno(), buf) - def finish_send(trans, key, ov): - try: - return ov.getresult() - except OSError as exc: - if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: - raise ConnectionResetError(*exc.args) - else: - raise - - return self._register(ov, conn, finish_send) + return self._register(ov, conn, self.finish_socket_func) def accept(self, listener): self._register_with_iocp(listener) @@ -483,12 +560,11 @@ def finish_accept(trans, key, ov): conn.settimeout(listener.gettimeout()) return conn, conn.getpeername() - @coroutine - def accept_coro(future, conn): + async def accept_coro(future, conn): # Coroutine closing the accept socket if the future is cancelled try: - yield from future - except futures.CancelledError: + await future + except exceptions.CancelledError: conn.close() raise @@ -498,6 +574,14 @@ def accept_coro(future, conn): return future def connect(self, conn, address): + if conn.type == socket.SOCK_DGRAM: + # WSAConnect will complete immediately for UDP sockets so we don't + # need to register any IOCP operation + _overlapped.WSAConnect(conn.fileno(), address) + fut = self._loop.create_future() + fut.set_result(None) + return fut + self._register_with_iocp(conn) # The socket needs to be locally bound before we call ConnectEx(). try: @@ -520,6 +604,18 @@ def finish_connect(trans, key, ov): return self._register(ov, conn, finish_connect) + def sendfile(self, sock, file, offset, count): + self._register_with_iocp(sock) + ov = _overlapped.Overlapped(NULL) + offset_low = offset & 0xffff_ffff + offset_high = (offset >> 32) & 0xffff_ffff + ov.TransmitFile(sock.fileno(), + msvcrt.get_osfhandle(file.fileno()), + offset_low, offset_high, + count, 0, 0) + + return self._register(ov, sock, self.finish_socket_func) + def accept_pipe(self, pipe): self._register_with_iocp(pipe) ov = _overlapped.Overlapped(NULL) @@ -537,13 +633,12 @@ def finish_accept_pipe(trans, key, ov): return self._register(ov, pipe, finish_accept_pipe) - @coroutine - def connect_pipe(self, address): + async def connect_pipe(self, address): delay = CONNECT_PIPE_INIT_DELAY while True: - # Unfortunately there is no way to do an overlapped connect to a pipe. - # Call CreateFile() in a loop until it doesn't fail with - # ERROR_PIPE_BUSY + # Unfortunately there is no way to do an overlapped connect to + # a pipe. Call CreateFile() in a loop until it doesn't fail with + # ERROR_PIPE_BUSY. try: handle = _overlapped.ConnectPipe(address) break @@ -553,7 +648,7 @@ def connect_pipe(self, address): # ConnectPipe() failed with ERROR_PIPE_BUSY: retry later delay = min(delay * 2, CONNECT_PIPE_MAX_DELAY) - yield from tasks.sleep(delay, loop=self._loop) + await tasks.sleep(delay) return windows_utils.PipeHandle(handle) @@ -573,6 +668,8 @@ def _wait_cancel(self, event, done_callback): return fut def _wait_for_handle(self, handle, timeout, _is_cancel): + self._check_closed() + if timeout is None: ms = _winapi.INFINITE else: @@ -615,6 +712,8 @@ def _register_with_iocp(self, obj): # that succeed immediately. def _register(self, ov, obj, callback): + self._check_closed() + # Return a future which will be set with the result of the # operation when it completes. The future's value is actually # the value returned by callback(). @@ -651,6 +750,7 @@ def _unregister(self, ov): already be signalled (pending in the proactor event queue). It is also safe if the event is never signalled (because it was cancelled). """ + self._check_closed() self._unregistered.append(ov) def _get_accept_socket(self, family): @@ -707,8 +807,10 @@ def _poll(self, timeout=None): else: f.set_result(value) self._results.append(f) + finally: + f = None - # Remove unregisted futures + # Remove unregistered futures for ov in self._unregistered: self._cache.pop(ov.address, None) self._unregistered.clear() @@ -720,8 +822,12 @@ def _stop_serving(self, obj): self._stopped_serving.add(obj) def close(self): + if self._iocp is None: + # already closed + return + # Cancel remaining registered operations. - for address, (fut, ov, obj, callback) in list(self._cache.items()): + for fut, ov, obj, callback in list(self._cache.values()): if fut.cancelled(): # Nothing to do with cancelled futures pass @@ -742,14 +848,25 @@ def close(self): context['source_traceback'] = fut._source_traceback self._loop.call_exception_handler(context) + # Wait until all cancelled overlapped complete: don't exit with running + # overlapped to prevent a crash. Display progress every second if the + # loop is still running. + msg_update = 1.0 + start_time = time.monotonic() + next_msg = start_time + msg_update while self._cache: - if not self._poll(1): - logger.debug('taking long time to close proactor') + if next_msg <= time.monotonic(): + logger.debug('%r is running after closing for %.1f seconds', + self, time.monotonic() - start_time) + next_msg = time.monotonic() + msg_update + + # handle a few events, or timeout + self._poll(msg_update) self._results = [] - if self._iocp is not None: - _winapi.CloseHandle(self._iocp) - self._iocp = None + + _winapi.CloseHandle(self._iocp) + self._iocp = None def __del__(self): self.close() @@ -773,8 +890,12 @@ def callback(f): SelectorEventLoop = _WindowsSelectorEventLoop -class _WindowsDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy): +class WindowsSelectorEventLoopPolicy(events.BaseDefaultEventLoopPolicy): _loop_factory = SelectorEventLoop -DefaultEventLoopPolicy = _WindowsDefaultEventLoopPolicy +class WindowsProactorEventLoopPolicy(events.BaseDefaultEventLoopPolicy): + _loop_factory = ProactorEventLoop + + +DefaultEventLoopPolicy = WindowsProactorEventLoopPolicy diff --git a/Lib/asyncio/windows_utils.py b/Lib/asyncio/windows_utils.py index 7c63fb904b..ef277fac3e 100644 --- a/Lib/asyncio/windows_utils.py +++ b/Lib/asyncio/windows_utils.py @@ -1,6 +1,4 @@ -""" -Various Windows specific bits and pieces -""" +"""Various Windows specific bits and pieces.""" import sys @@ -11,13 +9,12 @@ import itertools import msvcrt import os -import socket import subprocess import tempfile import warnings -__all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle'] +__all__ = 'pipe', 'Popen', 'PIPE', 'PipeHandle' # Constants/globals @@ -29,61 +26,14 @@ _mmap_counter = itertools.count() -if hasattr(socket, 'socketpair'): - # Since Python 3.5, socket.socketpair() is now also available on Windows - socketpair = socket.socketpair -else: - # Replacement for socket.socketpair() - def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): - """A socket pair usable as a self-pipe, for Windows. - - Origin: https://gist.github.com/4325783, by Geert Jansen. - Public domain. - """ - if family == socket.AF_INET: - host = '127.0.0.1' - elif family == socket.AF_INET6: - host = '::1' - else: - raise ValueError("Only AF_INET and AF_INET6 socket address " - "families are supported") - if type != socket.SOCK_STREAM: - raise ValueError("Only SOCK_STREAM socket type is supported") - if proto != 0: - raise ValueError("Only protocol zero is supported") - - # We create a connected TCP socket. Note the trick with setblocking(0) - # that prevents us from having to create a thread. - lsock = socket.socket(family, type, proto) - try: - lsock.bind((host, 0)) - lsock.listen(1) - # On IPv6, ignore flow_info and scope_id - addr, port = lsock.getsockname()[:2] - csock = socket.socket(family, type, proto) - try: - csock.setblocking(False) - try: - csock.connect((addr, port)) - except (BlockingIOError, InterruptedError): - pass - csock.setblocking(True) - ssock, _ = lsock.accept() - except: - csock.close() - raise - finally: - lsock.close() - return (ssock, csock) - - # Replacement for os.pipe() using handles instead of fds def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): """Like os.pipe() but with overlapped support and using handles not fds.""" - address = tempfile.mktemp(prefix=r'\\.\pipe\python-pipe-%d-%d-' % - (os.getpid(), next(_mmap_counter))) + address = tempfile.mktemp( + prefix=r'\\.\pipe\python-pipe-{:d}-{:d}-'.format( + os.getpid(), next(_mmap_counter))) if duplex: openmode = _winapi.PIPE_ACCESS_DUPLEX @@ -138,10 +88,10 @@ def __init__(self, handle): def __repr__(self): if self._handle is not None: - handle = 'handle=%r' % self._handle + handle = f'handle={self._handle!r}' else: handle = 'closed' - return '<%s %s>' % (self.__class__.__name__, handle) + return f'<{self.__class__.__name__} {handle}>' @property def handle(self): @@ -149,7 +99,7 @@ def handle(self): def fileno(self): if self._handle is None: - raise ValueError("I/O operatioon on closed pipe") + raise ValueError("I/O operation on closed pipe") return self._handle def close(self, *, CloseHandle=_winapi.CloseHandle): @@ -157,10 +107,9 @@ def close(self, *, CloseHandle=_winapi.CloseHandle): CloseHandle(self._handle) self._handle = None - def __del__(self): + def __del__(self, _warn=warnings.warn): if self._handle is not None: - warnings.warn("unclosed %r" % self, ResourceWarning, - source=self) + _warn(f"unclosed {self!r}", ResourceWarning, source=self) self.close() def __enter__(self):