8000 gh-79156: Add start_tls() method to streams API (#91453) · python/cpython@6217864 · GitHub
[go: up one dir, main page]

Skip to content

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 6217864

Browse files
arhadthedevicgood
andauthored
gh-79156: Add start_tls() method to streams API (#91453)
The existing event loop `start_tls()` method is not sufficient for connections using the streams API. The existing StreamReader works because the new transport passes received data to the original protocol. The StreamWriter must then write data to the new transport, and the StreamReaderProtocol must be updated to close the new transport correctly. The new StreamWriter `start_tls()` updates itself and the reader protocol to the new SSL transport. Co-authored-by: Ian Good <icgood@gmail.com>
1 parent bd26ef5 commit 6217864

File tree

5 files changed

+109
-0
lines changed

5 files changed

+109
-0
lines changed

Doc/library/asyncio-stream.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,24 @@ StreamWriter
295295
be resumed. When there is nothing to wait for, the :meth:`drain`
296296
returns immediately.
297297

298+
.. coroutinemethod:: start_tls(sslcontext, \*, server_hostname=None, \
299+
ssl_handshake_timeout=None)
300+
301+
Upgrade an existing stream-based connection to TLS.
302+
303+
Parameters:
304+
305+
* *sslcontext*: a configured instance of :class:`~ssl.SSLContext`.
306+
307+
* *server_hostname*: sets or overrides the host name that the target
308+
server's certificate will be matched against.
309+
310+
* *ssl_handshake_timeout* is the time in seconds to wait for the TLS
311+
handshake to complete before aborting the connection. ``60.0`` seconds
312+
if ``None`` (default).
313+
314+
.. versionadded:: 3.8
315+
298316
.. method:: is_closing()
299317

300318
Return ``True`` if the stream is closed or in the process of

Doc/whatsnew/3.11.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,10 @@ asyncio
246246
:meth:`~asyncio.AbstractEventLoop.sock_recvfrom_into`.
247247
(Contributed by Alex Grönholm in :issue:`46805`.)
248248

249+
* Add :meth:`~asyncio.streams.StreamWriter.start_tls` method for upgrading
250+
existing stream-based connections to TLS. (Contributed by Ian Good in
251+
:issue:`34975`.)
252+
249253
fractions
250254
---------
251255

Lib/asyncio/streams.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,13 @@ def _stream_reader(self):
217217
return None
218218
return self._stream_reader_wr()
219219

220+
def _replace_writer(self, writer):
221+
loop = self._loop
222+
transport = writer.transport
223+
self._stream_writer = writer
224+
self._transport = transport
225+
self._over_ssl = transport.get_extra_info('sslcontext') is not None
226+
220227
def connection_made(self, transport):
221228
if self._reject_connection:
222229
context = {
@@ -371,6 +378,20 @@ async def drain(self):
371378
await sleep(0)
372379
await self._protocol._drain_helper()
373380

381+
async def start_tls(self, sslcontext, *,
382+
server_hostname=None,
383+
ssl_handshake_timeout=None):
384+
"""Upgrade an existing stream-based connection to TLS."""
385+
server_side = self._protocol._client_connected_cb is not None
386+
protocol = self._protocol
387+
await self.drain()
388+
new_transport = await self._loop.start_tls( # type: ignore
389+
self._transport, protocol, sslcontext,
390+
server_side=server_side, server_hostname=server_hostname,
391+
ssl_handshake_timeout=ssl_handshake_timeout)
392+
self._transport = new_transport
393+
protocol._replace_writer(self)
394+
374395

375396
class StreamReader:
376397

Lib/test/test_asyncio/test_streams.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,69 @@ async def client(path):
706706

707707
self.assertEqual(messages, [])
708708

709+
@unittest.skipIf(ssl is None, 'No ssl module')
710+
def test_start_tls(self):
711+
712+
class MyServer:
713+
714+
def __init__(self, loop):
715+
self.server = None
716+
self.loop = loop
717+
718+
async def handle_client(self, client_reader, client_writer):
719+
data1 = await client_reader.readline()
720+
client_writer.write(data1)
721+
await client_writer.drain()
722+
assert client_writer.get_extra_info('sslcontext') is None
723+
await client_writer.start_tls(
724+
test_utils.simple_server_sslcontext())
725+
assert client_writer.get_extra_info('sslcontext') is not None
726+
data2 = await client_reader.readline()
727+
client_writer.write(data2)
728+
await client_writer.drain()
729+
client_writer.close()
730+
await client_writer.wait_closed()
731+
732+
def start(self):
733+
sock = socket.create_server(('127.0.0.1', 0))
734+
self.server = self.loop.run_until_complete(
735+
asyncio.start_server(self.handle_client,
736+
sock=sock))
737+
return sock.getsockname()
738+
739+
def stop(self):
740+
if self.server is not None:
741+
self.server.close()
742+
self.loop.run_until_complete(self.server.wait_closed())
743+
self.server = None
744+
745+
async def client(addr):
746+
reader, writer = await asyncio.open_connection(*addr)
747+
writer.write(b"hello world 1!\n")
748+
await writer.drain()
749+
msgback1 = await reader.readline()
750+
assert writer.get_extra_info('sslcontext') is None
751+
await writer.start_tls(test_utils.simple_client_sslcontext())
752+
assert writer.get_extra_info('sslcontext') is not None
753+
writer.write(b"hello world 2!\n")
754+
await writer.drain()
755+
msgback2 = await reader.readline()
756+
writer.close()
757+
await writer.wait_closed()
758+
return msgback1, msgback2
759+
760+
messages = []
761+
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
762+
763+
server = MyServer(self.loop)
764+
addr = server.start()
765+
msg1, msg2 = self.loop.run_until_complete(client(addr))
766+
server.stop()
767+
768+
self.assertEqual(messages, [])
769+
self.assertEqual(msg1, b"hello world 1!\n")
770+
self.assertEqual(msg2, b"hello world 2!\n")
771+
709772
@unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
710773
def test_read_all_from_pipe_reader(self):
711774
# See asyncio issue 168. This test is derived from the example
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Adds a ``start_tls()`` method to :class:`~asyncio.streams.StreamWriter`,
2+
which upgrades the connection with TLS using the given
3+
:class:`~ssl.SSLContext`.

0 commit comments

Comments
 (0)
0