8000 bpo-23749: Implement loop.start_tls() · python/cpython@bf71f87 · GitHub
[go: up one dir, main page]

Skip to content

Commit bf71f87

Browse files
committed
bpo-23749: Implement loop.start_tls()
1 parent e5f7dcc commit bf71f87

File tree

9 files changed

+586
-57
lines changed

9 files changed

+586
-57
lines changed

Lib/asyncio/base_events.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929
import warnings
3030
import weakref
3131

32+
try:
33+
import ssl
34+
except ImportError: # pragma: no cover
35+
ssl = None
36+
3237
from . import coroutines
3338
from . import events
3439
from . import futures
@@ -795,6 +800,49 @@ async def _create_connection_transport(
795800

796801
return transport, protocol
797802

803+
async def start_tls(self, transport, protocol, sslcontext, *,
804+
server_side=False,
805+
ssl_handshake_timeout=None):
806+
"""Upgrade transport to TLS.
807+
808+
Return a new transport that *protocol* should start using
809+
immediately.
810+
"""
811+
if ssl is None:
812+
raise RuntimeError('Python ssl module is not available')
813+
814+
if not isinstance(sslcontext, ssl.SSLContext):
815+
raise TypeError(
816+
f'sslcontext is expected to be an instance of ssl.SSLContext, '
817+
f'got {sslcontext!r}')
818+
819+
if not getattr(transport, '_start_tls_compatible', False):
820+
raise TypeError(
821+
f'transport {self!r} is not supported by start_tls()')
822+
823+
if transport.get_write_buffer_size():
824+
if transport._buffer_drained_fut is not None:
825+
raise RuntimeError(
826+
f'cannot start_tls(); another operation is awaiting '
827+
f'for the write buffer to drain')
828+
829+
waiter = self.create_future()
830+
transport._buffer_drained_fut = waiter
831+
await waiter
832+
833+
assert transport._buffer_drained_fut is None
834+
835+
waiter = self.create_future()
836+
app_transport = self._make_ssl_transport(
837+
transport._sock, protocol, sslcontext, waiter,
838+
server_side=server_side,
839+
ssl_handshake_timeout=ssl_handshake_timeout,
840+
server=transport._server,
841+
call_connection_made=False)
842+
843+
await waiter
844+
return app_transport
845+
798846
async def create_datagram_endpoint(self, protocol_factory,
799847
local_addr=None, remote_addr=None, *,
800848
family=0, proto=0, flags=0,

Lib/asyncio/events.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,16 @@ async def create_server(
305305
"""
306306
raise NotImplementedError
307307

308+
async def start_tls(self, transport, protocol, sslcontext, *,
309+
server_side=False,
310+
ssl_handshake_timeout=None):
311+
"""Upgrade a transport to TLS.
312+
313+
Return a new transport that *protocol* should start using
314+
immediately.
315+
"""
316+
raise NotImplementedError
317+
308318
async def create_unix_connection(
309319
self, protocol_factory, path=None, *,
310320
ssl=None, sock=None,

Lib/asyncio/proactor_events.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,18 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport,
223223
transports.WriteTransport):
224224
"""Transport for write pipes."""
225225

226+
def __init__(self, *args, **kwargs):
227+
super().__init__(*args, **kwargs)
228+
self._buffer_drained_fut = None
229+
226230
def write(self, data):
231+
if self._buffer_drained_fut is not None:
232+
raise RuntimeError(
233+
f'cannot write: transport {self!r} is waiting until its '
234+
f'write buffer is drained')
235+
return self._write(data)
236+
237+
def _write(self, data):
227238
if not isinstance(data, (bytes, bytearray, memoryview)):
228239
raise TypeError(
229240
f"data argument must be a bytes-like object, "
@@ -280,6 +291,11 @@ def _loop_writing(self, f=None, data=None):
280291
# and it may add more data to the buffer (even causing the
281292
# protocol to be paused again).
282293
self._maybe_resume_protocol()
294+
295+
if self._buffer_drained_fut is not None:
296+
self._buffer_drained_fut.set_result(None)
297+
self._buffer_drained_fut = None
298+
283299
else:
284300
self._write_fut = self._loop._proactor.send(self._sock, data)
285301
if not self._write_fut.done():
@@ -343,6 +359,8 @@ class _ProactorSocketTransport(_ProactorReadPipeTransport,
343359
transports.Transport):
344360
"""Transport for connected sockets."""
345361

362+
_start_tls_compatible = True
363+
346364
def _set_extra(self, sock):
347365
self._extra['socket'] = sock
348366

@@ -393,11 +411,13 @@ def _make_ssl_transport(
393411
self, rawsock, protocol, sslcontext, waiter=None,
394412
*, server_side=False, server_hostname=None,
395413
extra=None, server=None,
396-
ssl_handshake_timeout=None):
414+
ssl_handshake_timeout=None,
415+
call_connection_made=True):
397416
ssl_protocol = sslproto.SSLProtocol(
398417
self, protocol, sslcontext, waiter,
399418
server_side, server_hostname,
400-
ssl_handshake_timeout=ssl_handshake_timeout)
419+
ssl_handshake_timeout=ssl_handshake_timeout,
420+
call_connection_made=call_connection_made)
401421
_ProactorSocketTransport(self, rawsock, ssl_protocol,
402422
extra=extra, server=server)
403423
return ssl_protocol._app_transport

Lib/asyncio/selector_events.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,13 @@ def _make_ssl_transport(
7474
self, rawsock, protocol, sslcontext, waiter=None,
7575
*, server_side=False, server_hostname=None,
7676
extra=None, server=None,
77-
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
77+
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
78+
call_connection_made=False):
7879
ssl_protocol = sslproto.SSLProtocol(
7980
self, protocol, sslcontext, waiter,
8081
server_side, server_hostname,
81-
ssl_handshake_timeout=ssl_handshake_timeout)
82+
ssl_handshake_timeout=ssl_handshake_timeout,
83+
call_connection_made=call_connection_made)
8284
_SelectorSocketTransport(self, rawsock, ssl_protocol,
8385
extra=extra, server=server)
8486
return ssl_protocol._app_transport
@@ -694,12 +696,16 @@ def get_write_buffer_size(self):
694696

695697
class _SelectorSocketTransport(_SelectorTransport):
696698

699+
_start_tls_compatible = True
700+
697701
def __init__(self, loop, sock, protocol, waiter=None,
698702
extra=None, server=None):
699703
super().__init__(loop, sock, protocol, extra, server)
700704
self._eof = False
701705
self._paused = False
702706

707+
self._buffer_drained_fut = None
708+
703709
# Disable the Nagle algorithm -- small writes will be
704710
# sent without waiting for the TCP ACK. This generally
705711
# decreases the latency (in some cases significantly.)
@@ -758,6 +764,13 @@ def _read_ready(self):
758764
self.close()
759765

760766
def write(self, data):
767+
if self._buffer_drained_fut is not None:
768+
raise RuntimeError(
769+
f'cannot write: transport {self!r} is waiting until its '
770+
f'write buffer is drained')
771+
return self._write(data)
772+
773+
def _write(self, data):
761774
if not isinstance(data, (bytes, bytearray, memoryview)):
762775
raise TypeError(f'data argument must be a bytes-like object, '
763776
f'not {type(data).__name__!r}')
@@ -816,6 +829,10 @@ def _write_ready(self):
816829
elif self._eof:
817830
self._sock.shutdown(socket.SHUT_WR)
818831

832+
if self._buffer_drained_fut is not None:
833+
self._buffer_drained_fut.set_result(None)
834+
self._buffer_drained_fut = None
835+
819836
def write_eof(self):
820837
if self._eof:
821838
return

0 commit comments

Comments
 (0)
0