8000 bpo-33654: Support protocol type switching in SSLTransport.set_protocol() by 1st1 · Pull Request #7194 · python/cpython · GitHub
[go: up one dir, main page]

Skip to content

bpo-33654: Support protocol type switching in SSLTransport.set_protocol() #7194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 29, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions Lib/asyncio/sslproto.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def get_extra_info(self, name, default=None):
return self._ssl_protocol._get_extra_info(name, default)

def set_protocol(self, protocol):
self._ssl_protocol._app_protocol = protocol
self._ssl_protocol._set_app_protocol(protocol)

def get_protocol(self):
return self._ssl_protocol._app_protocol
Expand Down Expand Up @@ -440,9 +440,7 @@ def __init__(self, loop, app_protocol, sslcontext, waiter,

self._waiter = waiter
self._loop = loop
self._app_protocol = app_protocol
self._app_protocol_is_buffer = \
isinstance(app_protocol, protocols.BufferedProtocol)
self._set_app_protocol(app_protocol)
self._app_transport = _SSLProtocolTransport(self._loop, self)
# _SSLPipe instance (None until the connection is made)
self._sslpipe = None
Expand All @@ -454,6 +452,11 @@ def __init__(self, loop 8000 , app_protocol, sslcontext, waiter,
self._call_connection_made = call_connection_made
self._ssl_handshake_timeout = ssl_handshake_timeout

def _set_app_protocol(self, app_protocol):
self._app_protocol = app_protocol
self._app_protocol_is_buffer = \
isinstance(app_protocol, protocols.BufferedProtocol)

def _wakeup_waiter(self, exc=None):
if self._waiter is None:
return
Expand Down
47 changes: 36 additions & 11 deletions Lib/test/test_asyncio/test_sslproto.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def test_start_tls_client_buf_proto_1(self):

server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext()
client_con_made_calls = 0

def serve(sock):
sock.settimeout(self.TIMEOUT)
Expand All @@ -315,20 +316,21 @@ def serve(sock):
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))

sock.sendall(b'2')
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))

sock.shutdown(socket.SHUT_RDWR)
sock.close()

class ClientProto(asyncio.BufferedProtocol):
def __init__(self, on_data, on_eof):
class ClientProtoFirst(asyncio.BufferedProtocol):
def __init__(self, on_data):
self.on_data = on_data
self.on_eof = on_eof
self.con_made_cnt = 0
self.buf = bytearray(1)

def connection_made(proto, tr):
proto.con_made_cnt += 1
# Ensure connection_made gets called only once.
self.assertEqual(proto.con_made_cnt, 1)
def connection_made(self, tr):
nonlocal client_con_made_calls
client_con_made_calls += 1

def get_buffer(self, sizehint):
return self.buf
Expand All @@ -337,27 +339,50 @@ def buffer_updated(self, nsize):
assert nsize == 1
self.on_data.set_result(bytes(self.buf[:nsize]))

class ClientProtoSecond(asyncio.Protocol):
def __init__(self, on_data, on_eof):
self.on_data = on_data
self.on_eof = on_eof
self.con_made_cnt = 0

def connection_made(self, tr):
nonlocal client_con_made_calls
client_con_made_calls += 1

def data_received(self, data):
self.on_data.set_result(data)

def eof_received(self):
self.on_eof.set_result(True)

async def client(addr):
await asyncio.sleep(0.5, loop=self.loop)

on_data = self.loop.create_future()
on_data1 = self.loop.create_future()
on_data2 8000 = self.loop.create_future()
on_eof = self.loop.create_future()

tr, proto = await self.loop.create_connection(
lambda: ClientProto(on_data, on_eof), *addr)
lambda: ClientProtoFirst(on_data1), *addr)

tr.write(HELLO_MSG)
new_tr = await self.loop.start_tls(tr, proto, client_context)

self.assertEqual(await on_data, b'O')
self.assertEqual(await on_data1, b'O')
new_tr.write(HELLO_MSG)

new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
self.assertEqual(await on_data2, b'2')
new_tr.write(HELLO_MSG)
await on_eof

new_tr.close()

# connection_made() should be called only once -- when
# we establish connection for the first time. Start TLS
# doesn't call connection_made() on application protocols.
self.assertEqual(client_con_made_calls, 1)

with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
self.loop.run_until_complete(
asyncio.wait_for(client(srv.addr),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support protocol type switching in SSLTransport.set_protocol().
0