8000 extmod/asyncio: Add ssl support with SSLContext. · micropython/micropython@d283028 · GitHub
[go: up one dir, main page]

Skip to content

Commit d283028

Browse files
committed
extmod/asyncio: Add ssl support with SSLContext.
This adds asyncio ssl support with SSLContext and the corresponding tests in `tests/net_inet` and `tests/multi_net` Note that not doing the handshake on connect will delegate the handshake to the following `mbedtls_ssl_read/write` calls. However if the handshake fails when a client certificate is required and not presented by the peer, it needs to be notified of this handshake error (otherwise it will hang until timeout if any). Finally at MicroPython side raise the proper mbedtls error code and message. Signed-off-by: Carlos Gil <carlosgilglez@gmail.com>
1 parent 4c9c285 commit d283028

18 files changed

+971
-5
lines changed

extmod/asyncio/stream.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def readline(self):
6363
while True:
6464
yield core._io_queue.queue_read(self.s)
6565
l2 = self.s.readline() # may do multiple reads but won't block
66+
if l2 is None:
67+
continue
6668
l += l2
6769
if not l2 or l[-1] == 10: # \n (check l in case l2 is str)
6870
return l
@@ -100,19 +102,29 @@ def drain(self):
100102
# Create a TCP stream connection to a remote host
101103
#
102104
# async
103-
def open_connection(host, port):
105+
def open_connection(host, port, ssl=None, server_hostname=None):
104106
from errno import EINPROGRESS
105107
import socket
106108

107109
ai = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)[0] # TODO this is blocking!
108110
s = socket.socket(ai[0], ai[1], ai[2])
109111
s.set 6D40 blocking(False)
110-
ss = Stream(s)
111112
try:
112113
s.connect(ai[-1])
113114
except OSError as er:
114115
if er.errno != EINPROGRESS:
115116
raise er
117+
# wrap with SSL, if requested
118+
if ssl:
119+
if ssl is True:
120+
import ssl as _ssl
121+
122+
ssl = _ssl.SSLContext(_ssl.PROTOCOL_TLS_CLIENT)
123+
if not server_hostname:
124+
server_hostname = host
125+
s = ssl.wrap_socket(s, server_hostname=server_hostname, do_handshake_on_connect=False)
126+
s.setblocking(False)
127+
ss = Stream(s)
116128
yield core._io_queue.queue_write(s)
117129
return ss, ss
118130

@@ -135,7 +147,7 @@ def close(self):
135147
async def wait_closed(self):
136148
await self.task
137149

138-
async def _serve(self, s, cb):
150+
async def _serve(self, s, cb, ssl):
139151
self.state = False
140152
# Accept incoming connections
141153
while True:
@@ -156,14 +168,21 @@ async def _serve(self, s, cb):
156168
except:
157169
# Ignore a failed accept
158170
continue
171+
if ssl:
172+
try:
173+
s2 = ssl.wrap_socket(s2, server_side=True, do_handshake_on_connect=False)
174+
except OSError as e:
175+
core.sys.print_exception(e)
176+
s2.close()
177+
continue
159178
s2.setblocking(False)
160179
s2s = Stream(s2, {"peername": addr})
161180
core.create_task(cb(s2s, s2s))
162181

163182

164183
# Helper function to start a TCP stream server, running as a new task
165184
# TODO could use an accept-callback on socket read activity instead of creating a task
166-
async def start_server(cb, host, port, backlog=5):
185+
async def start_server(cb, host, port, backlog=5, ssl=None):
167186
import socket
168187

169188
# Create and bind server socket.
@@ -176,7 +195,7 @@ async def start_server(cb, host, port, backlog=5):
176195

177196
# Create and return server object and task.
178197
srv = Server()
179-
srv.task = core.create_task(srv._serve(s, cb))
198+
srv.task = core.create_task(srv._serve(s, cb, ssl))
180199
try:
181200
# Ensure that the _serve task has been scheduled so that it gets to
182201
# handle cancellation.

extmod/modssl_mbedtls.c

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,50 @@ STATIC NORETURN void mbedtls_raise_error(int err) {
165165
#endif
166166
}
167167

168+
STATIC void ssl_check_async_handshake_failure(mp_obj_ssl_socket_t *sslsock, int *errcode) {
169+
170+
if (
171+
#if MBEDTLS_VERSION_NUMBER >= 0x03000000
172+
(*errcode < 0) && (mbedtls_ssl_is_handshake_over(&sslsock->ssl) == 0) && (*errcode != MBEDTLS_ERR_SSL_CONN_EOF)
173+
#else
174+
(*errcode < 0) && (*errcode != MBEDTLS_ERR_SSL_CONN_EOF)
175+
#endif
176+
) {
177+
// Async handshake is done by mbdetls_ssl_read/write
178+
// if return code is MBEDTLS_ERR_XX (i.e < 0) and handshake is not done due to
179+
// handshake failure notify peer
180+
// with proper error code and raise mp error with mbedtls_raise_error
181+
182+
if (*errcode == MBEDTLS_ERR_SSL_NO_CLIENT_CERTIFICATE) {
183+
// Check if TLSv1.3 and use proper alert for this case (to be implemented)
184+
// uint8_t alert = MBEDTLS_SSL_ALERT_MSG_CERT_REQUIRED; tlsv1.3
185+
// uint8_t alert = MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE; tlsv1.2
186+
mbedtls_ssl_send_alert_message(&sslsock->ssl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
187+
MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE);
188+
}
189+
190+
if (*errcode == MBEDTLS_ERR_X509_CERT_VERIFY_FAILED) {
191+
// The certificate may have been rejected for several reasons.
192+
uint32_t flags = 0;
193+
uint ret = 0;
194+
char xcbuf[512];
195+
flags = mbedtls_ssl_get_verify_result(&sslsock->ssl);
196+
ret = mbedtls_x509_crt_verify_info(xcbuf, sizeof(xcbuf), "\n", flags);
197+
// The length of the string written (not including the terminated nul byte),
198+
// or a negative err code.
199+
if (ret > 0) {
200+
sslsock->sock = MP_OBJ_NULL;
201+
mbedtls_ssl_free(&sslsock->ssl);
202+
mp_raise_msg_varg(&mp_type_ValueError, MP_ERROR_TEXT("%s"), xcbuf);
203+
}
204+
}
205+
206+
sslsock->sock = MP_OBJ_NULL;
207+
mbedtls_ssl_free(&sslsock->ssl);
208+
mbedtls_raise_error(*errcode);
209+
210+
}
211+
}
168212
/******************************************************************************/
169213
// SSLContext type.
170214

@@ -630,6 +674,7 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
630674
} else {
631675
o->last_error = ret;
632676
}
677+
ssl_check_async_handshake_failure(o, &ret);
633678
*errcode = ret;
634679
return MP_STREAM_ERROR;
635680
}
@@ -658,6 +703,7 @@ STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in
658703
} else {
659704
o->last_error = ret;
660705
}
706+
ssl_check_async_handshake_failure(o, &ret);
661707
*errcode = ret;
662708
return MP_STREAM_ERROR;
663709
}
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Test asyncio TCP server and client using start_server() and open_connection()
2+
3+
try:
4+
import asyncio
5+
import ssl
6+
except ImportError:
7+
print("SKIP")
8+
raise SystemExit
9+
10+
PORT = 8000
11+
PORT_TLS = 8004
12+
13+
14+
async def handle_tcp_connection(reader, writer):
15+
# Test that peername exists (but don't check its value, it changes)
16+
writer.get_extra_info("peername")
17+
18+
data = await reader.read(100)
19+
print("echo:", data)
20+
writer.write(data)
21+
await writer.drain()
22+
23+
print("close")
24+
writer.close()
25+
await writer.wait_closed()
26+
27+
print("done")
28+
ev.set()
29+
30+
31+
async def handle_tls_connection(reader, writer):
32+
# Test that peername exists (but don't check its value, it changes)
33+
writer.get_extra_info("peername")
34+
35+
data = await reader.read(100)
36+
print("echo:", data)
37+
writer.write(data)
38+
await writer.drain()
39+
40+
print("close")
41+
writer.close()
42+
await writer.wait_closed()
43+
44+
print("done")
45+
ev1.set()
46+
47+
48+
async def tcp_server():
49+
global ev
50+
ev = asyncio.Event()
51+
server = await asyncio.start_server(handle_tcp_connection, "0.0.0.0", PORT)
52+
print("tcp server running")
53+
54+
multitest.next()
55+
async with server:
56+
await asyncio.wait_for(ev.wait(), 10)
57+
58+
59+
async def tcp_client(message):
60+
await asyncio.sleep(1)
61+
reader, writer = await asyncio.open_connection(IP, PORT)
62+
print("write:", message)
63+
writer.write(message)
64+
await writer.drain()
65+
data = await reader.read(100)
66+
print("read:", data)
67+
68+
69+
# These are test certificates. See tests/README.md for details.
70+
cert = cafile = "multi_net/rsa_cert.der"
71+
72+
key = "multi_net/rsa_key.der"
73+
74+
75+
async def tls_server():
76+
global ev1
77+
78+
server_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
79+
server_ctx.load_cert_chain(cert, keyfile=key)
80+
ev1 = asyncio.Event()
81+
server = await asyncio.start_server(handle_tls_connection, "0.0.0.0", PORT_TLS, ssl=server_ctx)
82+
print("tls server running")
83+
multitest.next()
84+
async with server:
85+
await asyncio.wait_for(ev1.wait(), 10)
86+
87+
88+
async def tls_client(message):
89+
client_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
90+
client_ctx.verify_mode = ssl.CERT_REQUIRED
91+
client_ctx.load_verify_locations(cafile=cafile)
92+
await asyncio.sleep(1)
93+
reader, writer = await asyncio.open_connection(
94+
IP, PORT_TLS, ssl=client_ctx, server_hostname="micropython.local"
95+
)
96+
print("write:", message)
97+
writer.write(message)
98+
await writer.drain()
99+
data = await reader.read(100)
100+
print("read:", data)
101+
102+
103+
async def tcp_tls_server():
104+
print("TCP SERVER")
105+
await tcp_server()
106+
107+
print("TLS SERVER")
108+
await tls_server()
109+
110+
111+
async def tcp_tls_client():
112+
print("TCP CLIENT")
113+
await tcp_client(b"client data")
114+
multitest.next()
115+
print("TLS CLIENT")
116+
await tls_client(b"client data")
117+
118+
119+
def instance0():
120+
multitest.globals(IP=multitest.get_network_ip())
121+
asyncio.run(tcp_tls_server())
122+
123+
124+
def instance1():
125+
multitest.next()
126+
asyncio.run(tcp_tls_client())
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
--- instance0 ---
2+
TCP SERVER
3+
tcp server running
4+
echo: b'client data'
5+
close
6+
done
7+
TLS SERVER
8+
tls server running
9+
NEXT
10+
echo: b'client data'
11+
close
12+
done
13+
--- instance1 ---
14+
TCP CLIENT
15+
write: b'client data'
16+
read: b'client data'
17+
NEXT
18+
TLS CLIENT
19+
write: b'client data'
20+
read: b'client data'
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Test uasyncio TCP server and client using start_server() and open_connection()
2+
3+
try:
4+
import asyncio
5+
import ssl
6+
except ImportError:
7+
print("SKIP")
8+
raise SystemExit
9+
10+
PORT = 8000
11+
12+
# These are test certificates. See tests/README.md for details.
13+
cert = cafile = "multi_net/rsa_cert.der"
14+
15+
key = "multi_net/rsa_key.der"
16+
17+
18+
async def handle_connection(reader, writer):
19+
# Test that peername exists (but don't check its value, it changes)
20+
writer.get_extra_info("peername")
21+
22+
data = await reader.read(100)
23+
print("echo:", data)
24+
writer.write(data)
25+
await writer.drain()
26+
27+
print("close")
28+
writer.close()
29+
await writer.wait_closed()
30+
31+
print("done")
32+
ev.set()
33+
34+
35+
async def tcp_server():
36+
global ev
37+
38+
server_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
39+
server_ctx.load_cert_chain(cert, keyfile=key)
40+
ev = asyncio.Event()
41+
server = await asyncio.start_server(handle_connection, "0.0.0.0", PORT, ssl=server_ctx)
42+
print("server running")
43+
multitest.next()
44+
async with server:
45+
await asyncio.wait_for(ev.wait(), 10)
46+
47+
48+
async def tcp_client(message):
49+
client_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
50+
client_ctx.verify_mode = ssl.CERT_REQUIRED
51+
client_ctx.load_verify_locations(cafile=cafile)
52+
reader, writer = await asyncio.open_connection(
53+
IP, PORT, ssl=client_ctx, server_hostname="micropython.local"
54+
)
55+
print("write:", message)
56+
writer.write(message)
57+
await writer.drain()
58+
data = await reader.read(100)
59+
print("read:", data)
60+
61+
62+
def instance0():
63+
multitest.globals(IP=multitest.get_network_ip())
64+
asyncio.run(tcp_server())
65+
66+
67+
def instance1():
68+
multitest.next()
69+
asyncio.run(tcp_client(b"client data"))
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
--- instance0 ---
2+
server running
3+
echo: b'client data'
4+
close
5+
done
6+
--- instance1 ---
7+
write: b'client data'
8+
read: b'client data'

0 commit comments

Comments
 (0)
0