8000 bpo-33654: Support protocol type switching in SSLTransport.set_protoc… · python/cpython@2179022 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2179022

Browse files
1st1asvetlov
authored andcommitted
bpo-33654: Support protocol type switching in SSLTransport.set_protocol() (#7194)
1 parent f295587 commit 2179022

File tree

3 files changed

+44
-15
lines changed

3 files changed

+44
-15
lines changed

Lib/asyncio/sslproto.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def get_extra_info(self, name, default=None):
295295
return self._ssl_protocol._get_extra_info(name, default)
296296

297297
def set_protocol(self, protocol):
298-
self._ssl_protocol._app_protocol = protocol
298+
self._ssl_protocol._set_app_protocol(protocol)
299299

300300
def get_protocol(self):
301301
return self._ssl_protocol._app_protocol
@@ -440,9 +440,7 @@ def __init__(self, loop, app_protocol, sslcontext, waiter,
440440

441441
self._waiter = waiter
442442
self._loop = loop
443-
self._app_protocol = app_protocol
444-
self._app_protocol_is_buffer = \
445-
isinstance(app_protocol, protocols.BufferedProtocol)
443+
self._set_app_protocol(app_protocol)
446444
self._app_transport = _SSLProtocolTransport(self._loop, self)
447445
# _SSLPipe instance (None until the connection is made)
448446
self._sslpipe = None
@@ -454,6 +452,11 @@ def __init__(self, loop, app_protocol, sslcontext, waiter,
454452
self._call_connection_made = call_connection_made
455453
self._ssl_handshake_timeout = ssl_handshake_timeout
456454

455+
def _set_app_protocol(self, app_protocol):
456+
self._app_protocol = app_protocol
457+
self._app_protocol_is_buffer = \
458+
isinstance(app_protocol, protocols.BufferedProtocol)
459+
457460
def _wakeup_waiter(self, exc=None):
458461
if self._waiter is None:
459462
return

Lib/test/test_asyncio/test_sslproto.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ def test_start_tls_client_buf_proto_1(self):
302302

303303
server_context = test_utils.simple_server_sslcontext()
304304
client_context = test_utils.simple_client_sslcontext()
305+
client_con_made_calls = 0
305306

306307
def serve(sock):
307308
sock.settimeout(self.TIMEOUT)
@@ -315,20 +316,21 @@ def serve(sock):
315316
data = sock.recv_all(len(HELLO_MSG))
316317
self.assertEqual(len(data), len(HELLO_MSG))
317318

319+
sock.sendall(b'2')
320+
data = sock.recv_all(len(HELLO_MSG))
321+
self.assertEqual(len(data), len(HELLO_MSG))
322+
318323
sock.shutdown(socket.SHUT_RDWR)
319324
sock.close()
320325

321-
class ClientProto(asyncio.BufferedProtocol):
322-
def __init__(self, on_data, on_eof):
326+
class ClientProtoFirst(asyncio.BufferedProtocol):
327+
def __init__(self, on_data):
323328
self.on_data = on_data
324-
self.on_eof = on_eof
325-
self.con_made_cnt = 0
326329
self.buf = bytearray(1)
327330

328-
def connection_made(proto, tr):
329-
proto.con_made_cnt += 1
330-
# Ensure connection_made gets called only once.
331-
self.assertEqual(proto.con_made_cnt, 1)
331+
def connection_made(self, tr):
332+
nonlocal client_con_made_calls
333+
client_con_made_calls += 1
332334

333335
def get_buffer(self, sizehint):
334336
return self.buf
@@ -337,27 +339,50 @@ def buffer_updated(self, nsize):
337339
assert nsize == 1
338340
self.on_data.set_result(bytes(self.buf[:nsize]))
339341

342+
class ClientProtoSecond(asyncio.Protocol):
343+
def __init__(self, on_data, on_eof):
344+
self.on_data = on_data
345+
self.on_eof = on_eof
346+
self.con_made_cnt = 0
347+
348+
def connection_made(self, tr):
349+
nonlocal client_con_made_calls
350+
client_con_made_calls += 1
351+
352+
def data_received(self, data):
353+
self.on_data.set_result(data)
354+
340355
def eof_received(self):
341356
self.on_eof.set_result(True)
342357

343358
async def client(addr):
344359
await asyncio.sleep(0.5, loop=self.loop)
345360

346-
on_data = self.loop.create_future()
361+
on_data1 = self.loop.create_future()
362+
on_data2 = self.loop.create_future()
347363
on_eof = self.loop.create_future()
348364

349365
tr, proto = await self.loop.create_connection(
350-
lambda: ClientProto(on_data, on_eof), *addr)
366+
lambda: ClientProtoFirst(on_data1), *addr)
351367

352368
tr.write(HELLO_MSG)
353369
new_tr = await self.loop.start_tls(tr, proto, client_context)
354370

355-
self.assertEqual(await on_data, b'O')
371+
self.assertEqual(await on_data1, b'O')
372+
new_tr.write(HELLO_MSG)
373+
374+
new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
375+
self.assertEqual(await on_data2, b'2')
356376
new_tr.write(HELLO_MSG)
357377
await on_eof
358378

359379
new_tr.close()
360380

381+
# connection_made() should be called only once -- when
382+
# we establish connection for the first time. Start TLS
383+
# doesn't call connection_made() on application protocols.
384+
self.assertEqual(client_con_made_calls, 1)
385+
361386
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
362387
self.loop.run_until_complete(
363388
asyncio.wait_for(client(srv.addr),
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Support protocol type switching in SSLTransport.set_protocol().

0 commit comments

Comments
 (0)
0