diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py index 68499e5aeeae7e..41773337c2a1c8 100644 --- a/Lib/asyncio/sslproto.py +++ b/Lib/asyncio/sslproto.py @@ -135,7 +135,7 @@ def do_handshake(self, callback=None): assert len(appdata) == 0 return ssldata - def shutdown(self, callback=None): + def shutdown(self): """Start the SSL shutdown sequence. Return a list of ssldata. A ssldata element is a list of buffers @@ -150,7 +150,6 @@ def shutdown(self, callback=None): raise RuntimeError('shutdown in progress') assert self._state in (_WRAPPED, _DO_HANDSHAKE) self._state = _SHUTDOWN - self._shutdown_cb = callback ssldata, appdata = self.feed_ssldata(b'') assert appdata == [] or appdata == [b''] return ssldata @@ -218,8 +217,6 @@ def feed_ssldata(self, data, only_handshake=False): self._sslobj.unwrap() self._sslobj = None self._state = _UNWRAPPED - if self._shutdown_cb: - self._shutdown_cb() elif self._state == _UNWRAPPED: # Drain possible plaintext data after close_notify. @@ -636,11 +633,14 @@ def _process_write_backlog(self): self._on_handshake_complete) offset = 1 else: - ssldata = self._sslpipe.shutdown(self._finalize) + try: + ssldata = self._sslpipe.shutdown() + self._feed_ssl_data(ssldata) + finally: + self._finalize() offset = 1 - for chunk in ssldata: - self._transport.write(chunk) + self._feed_ssl_data(ssldata) if offset < len(data): self._write_backlog[0] = (data, offset) @@ -665,6 +665,10 @@ def _process_write_backlog(self): # BaseException raise + def _feed_ssl_data(self, ssldata): + for chunk in ssldata: + self._transport.write(chunk) + def _fatal_error(self, exc, message='Fatal error on transport'): # Should be called from exception handler only. if isinstance(exc, base_events._FATAL_ERROR_IGNORE): diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py index bcd236ea2632ed..b0e465a393892c 100644 --- a/Lib/test/test_asyncio/test_sslproto.py +++ b/Lib/test/test_asyncio/test_sslproto.py @@ -1,8 +1,13 @@ """Tests for asyncio/sslproto.py.""" - +import contextlib import logging +import os +import socket +import threading import unittest from unittest import mock + + try: import ssl except ImportError: @@ -13,6 +18,55 @@ from asyncio import sslproto from asyncio import test_utils +HOST = '127.0.0.1' + + +def data_file(name): + return os.path.join(os.path.dirname(__file__), name) + + +class DummySSLServer(threading.Thread): + class Protocol(asyncio.Protocol): + transport = None + + def connection_lost(self, exc): + self.transport.close() + + def connection_made(self, transport): + self.transport = transport + + def __init__(self): + super().__init__() + + self.loop = asyncio.new_event_loop() + context = ssl.SSLContext() + context.load_cert_chain(data_file('keycert3.pem')) + server_future = self.loop.create_server( + self.Protocol, *(HOST, 0), + ssl=context) + self.server = self.loop.run_until_complete(server_future) + + def run(self): + self.loop.run_forever() + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) + self.loop.close() + + def stop(self): + self.loop.call_soon_threadsafe(self.loop.stop) + + +@contextlib.contextmanager +def run_test_ssl_server(): + th = DummySSLServer() + th.start() + try: + yield th.server + finally: + th.stop() + th.join() + + @unittest.skipIf(ssl is None, 'No ssl module') class SslProtoHandshakeTests(test_utils.TestCase): @@ -121,6 +175,15 @@ def test_get_extra_info_on_closed_connection(self): ssl_proto.connection_lost(None) self.assertIsNone(ssl_proto._get_extra_info('socket')) + def test_ssl_shutdown(self): + # bpo-30698 Shutdown the ssl layer cleanly + with run_test_ssl_server() as server: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect(server.sockets[0].getsockname()) + ssl_socket = test_utils.dummy_ssl_context().wrap_socket(sock) + with contextlib.closing(ssl_socket): + ssl_socket.unwrap() + if __name__ == '__main__': unittest.main() diff --git a/Misc/NEWS b/Misc/NEWS index 824d7fde0621ae..e160da3c74ce28 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -365,6 +365,8 @@ Extension Modules Library ------- +- bpo-30698: Asyncio sslproto shutdown the ssl layer cleanly + - bpo-30038: Fix race condition between signal delivery and wakeup file descriptor. Patch by Nathaniel Smith.