8000 fix uasyncio ssl; add poll patch to modussl_mbedtls · micropython/micropython@523a2f7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 523a2f7

Browse files
committed
fix uasyncio ssl; add poll patch to modussl_mbedtls
1 parent b83a62c commit 523a2f7

File tree

4 files changed

+114
-5
lines changed

4 files changed

+114
-5
lines changed

extmod/modussl_axtls.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
< 8000 div class="d-flex flex-row">
@@ -176,6 +176,7 @@ STATIC mp_uint_t ussl_socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int
176176

177177
while (o->bytes_left == 0) {
178178
mp_int_t r = ssl_read(o->ssl_sock, &o->buf);
179+
printf("R{%d}", r);
179180
if (r == SSL_OK) {
180181
// SSL_OK from ssl_read() means "everything is ok, but there's
181182
// no user data yet". It may happen e.g. if handshake is not
@@ -242,6 +243,7 @@ STATIC mp_uint_t ussl_socket_write(mp_obj_t o_in, const void *buf, mp_uint_t siz
242243
mp_int_t r;
243244
eagain:
244245
r = ssl_write(o->ssl_sock, buf, size);
246+
printf("W{%d}", r);
245247
if (r == SSL_OK) {
246248
// see comment in read method
247249
if (o->blocking) {
@@ -356,6 +358,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_KW(mod_ssl_wrap_socket_obj, 1, mod_ssl_wrap_socke
356358
STATIC const mp_rom_map_elem_t mp_module_ssl_globals_table[] = {
357359
{ MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ussl) },
358360
{ MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&mod_ssl_wrap_socket_obj) },
361+
{ MP_ROM_QSTR(MP_QSTR_errstr), MP_ROM_PTR(&mod_ssl_errstr_obj) },
359362
};
360363

361364
STATIC MP_DEFINE_CONST_DICT(mp_module_ssl_globals, mp_module_ssl_globals_table);

extmod/modussl_mbedtls.c

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@
4545
#include "mbedtls/debug.h"
4646
#include "mbedtls/error.h"
4747

48+
// flags for _mp_obj_ssl_socket_t.poll_flag that control the poll ioctl
49+
// the issue is that when using ipoll we may be polling only for reading, and the socket may never
50+
// become readable because mbedtls needs to write soemthing (like a handshake or renegotiation) and
51+
// so poll never returns "it's readable" or "it's writable" and so nothing ever makes progress.
52+
// See also the commit message for
53+
// https://github.com/micropython/micropython/commit/9c7c082396f717a8a8eb845a0af407e78d38165f
54+
#define READ_NEEDS_WRITE 0x1 // mbedtls_ssl_read said "I need a write"
55+
#define WRITE_NEEDS_READ 0x2 // mbedtls_ssl_write said "I need a read"
56+
4857
typedef struct _mp_obj_ssl_socket_t {
4958
mp_obj_base_t base;
5059
mp_obj_t sock;
@@ -55,6 +64,7 @@ typedef struct _mp_obj_ssl_socket_t {
5564
mbedtls_x509_crt cacert;
5665
mbedtls_x509_crt cert;
5766
mbedtls_pk_context pkey;
67+
uint8_t poll_flag;
5868
} mp_obj_ssl_socket_t;
5969

6070
struct ssl_args {
@@ -92,6 +102,26 @@ STATIC NORETURN void mbedtls_raise_error(int err) {
92102
#endif
93103
}
94104

105+
STATIC mp_obj_t mod_ssl_errstr(mp_obj_t err_in) {
106+
size_t err = mp_obj_get_int(err_in);
107+
vstr_t vstr;
108+
vstr_init_len(&vstr, 80);
109+
110+
// Including mbedtls_strerror takes about 16KB on the esp32 due to all the strings
111+
#if 1
112+
vstr.buf[0] = 0;
113+
mbedtls_strerror(err, vstr.buf, vstr.alloc);
114+
vstr.len = strlen(vstr.buf);
115+
if (vstr.len == 0) {
116+
return MP_OBJ_NULL;
117+
}
118+
#else
119+
vstr_printf(vstr, "mbedtls error -0x%x\n", -err);
120+
#endif
121+
return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr);
122+
}
123+
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_ssl_errstr_obj, mod_ssl_errstr);
124+
95125
// _mbedtls_ssl_send is called by mbedtls to send bytes onto the underlying socket
96126
STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
97127
mp_obj_t sock = *(mp_obj_t *)ctx;
@@ -214,10 +244,10 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
214244
}
215245
}
216246

247+
o->poll_flag = 0;
217248
if (args->do_handshake.u_bool) {
218249
while ((ret = mbedtls_ssl_handshake(&o->ssl)) != 0) {
219250
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
220< 8000 code class="diff-text syntax-highlighted-line deletion">-
printf("mbedtls_ssl_handshake error: %d/-0x%x\n", ret, -ret);
221251
goto cleanup;
222252
}
223253
}
@@ -267,6 +297,7 @@ STATIC void socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kin
267297
STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errcode) {
268298
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);
269299

300+
o->poll_flag &= ~READ_NEEDS_WRITE; // clear flag
270301
int ret = mbedtls_ssl_read(&o->ssl, buf, size);
271302
if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
272303
// end of stream
@@ -281,6 +312,7 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
281312
// If handshake is not finished, read attempt may end up in protocol
282313
// wanting to write next handshake message. The same may happen with
283314
// renegotation.
315+
o->poll_flag |= READ_NEEDS_WRITE; // set flag
284316
ret = MP_EWOULDBLOCK;
285317
}
286318
*errcode = ret;
@@ -306,6 +338,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_recv_obj, socket_recv);
306338
STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, int *errcode) {
307339
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);
308340

341+
o->poll_flag &= ~WRITE_NEEDS_READ; // clear flag
309342
int ret = mbedtls_ssl_write(&o->ssl, buf, size);
310343
if (ret >= 0) {
311344
return ret;
@@ -316,6 +349,7 @@ STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in
316349
// If handshake is not finished, write attempt may end up in protocol
317350
// wanting to read next handshake message. The same may happen with
318351
// renegotation.
352+
o->poll_flag |= WRITE_NEEDS_READ; // set flag
319353
ret = MP_EWOULDBLOCK;
320354
}
321355
*errcode = ret;
@@ -355,6 +389,33 @@ STATIC mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, i
355389
mbedtls_ssl_config_free(&self->conf);
356390
mbedtls_ctr_drbg_free(&self->ctr_drbg);
357391
mbedtls_entropy_free(&self->entropy);
392+
} else if (request == MP_STREAM_POLL) {
393+
// If we're polling to read but not write but mbedtls previously said it needs to write in
394+
// order to be able to read then poll for both and if either is available pretend the socket
395+
// is readable. When the app then performs a read, mbedtls is happy to perform the writes as
396+
// well. Essentially, what we're ensuring is that one of mbedtls' read/write functions is
397+
// called as soon as the socket can do something.
398+
if ((arg & MP_STREAM_POLL_RD) && !(arg & MP_STREAM_POLL_WR) &&
399+
self->poll_flag & READ_NEEDS_WRITE) {
400+
arg |= MP_STREAM_POLL_WR;
401+
mp_uint_t ret = mp_get_stream(self->sock)->ioctl(self->sock, request, arg, errcode);
402+
if (ret & MP_STREAM_POLL_WR) {
403+
ret |= MP_STREAM_POLL_RD;
404+
ret &= ~MP_STREAM_POLL_WR;
405+
}
406+
return ret;
407+
// Now comes the same logic flipped around for write
408+
} else if ((arg & MP_STREAM_POLL_WR) && !(arg & MP_STREAM_POLL_RD) &&
409+
self->poll_flag & WRITE_NEEDS_READ) {
410+
arg |= MP_STREAM_POLL_RD;
411+
mp_uint_t ret = mp_get_stream(self->sock)->ioctl(self->sock, request, arg, errcode);
412+
if (ret & MP_STREAM_POLL_RD) {
413+
ret |= MP_STREAM_POLL_WR;
414+
ret &= ~MP_STREAM_POLL_RD;
415+
}
416+
return ret;
417+
}
418+
// fall-through if there's no wonky XX_NEEDS_YY situation
358419
}
359420
// Pass all requests down to the underlying socket
360421
return mp_get_stream(self->sock)->ioctl(self->sock, request, arg, errcode);
@@ -418,6 +479,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_KW(mod_ssl_wrap_socket_obj, 1, mod_ssl_wrap_socke
418479
STATIC const mp_rom_map_elem_t mp_module_ssl_globals_table[] = {
419480
{ MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ussl) },
420481
{ MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&mod_ssl_wrap_socket_obj) },
482+
{ MP_ROM_QSTR(MP_QSTR_errstr), MP_ROM_PTR(&mod_ssl_errstr_obj) },
421483
};
422484

423485
STATIC MP_DEFINE_CONST_DICT(mp_module_ssl_globals, mp_module_ssl_globals_table);

extmod/uasyncio/stream.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import ssl as modssl # module is used in function that has an ssl parameter
88
except:
99
modssl = None
10-
# Note a major issue with SSL in this commit message:
11-
# https://github.com/micropython/micropython/commit/9c7c082396f717a8a8eb845a0af407e78d38165f
1210

1311
class Stream:
1412
def __init__(self, s, e={}):
@@ -67,13 +65,11 @@ async def open_connection(host, port, ssl=None, server_hostname=None):
6765
ai = socket.getaddrinfo(host, port)[0] # TODO this is blocking!
6866
s = socket.socket()
6967
s.setblocking(False)
70-
ss = Stream(s)
7168
try:
7269
s.connect(ai[-1])
7370
except OSError as er:
7471
if er.args[0] != EINPROGRESS:
7572
raise er
76-
yield core._io_queue.queue_write(s)
7773
# wrap with SSL, if requested
7874
if ssl:
7975
if not modssl:
@@ -91,6 +87,7 @@ async def open_connection(host, port, ssl=None, server_hostname=None):
9187
raise ValueError("invalid ssl param")
9288
ssl["do_handshake"] = False # as non-blocking as possible
9389
s = modssl.wrap_socket(s, **ssl)
90+
ss = Stream(s)
9491
return ss, ss
9592

9693

tests/net_inet/uasyncio_ssl.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Attempt to test the funky poll patch in modussl_mbedtls, but sadly this works
2+
# without the patch...
3+
4+
try:
5+
import uasyncio as asyncio, ussl as ssl
6+
except ImportError:
7+
try:
8+
import asyncio, ssl
9+
except ImportError:
10+
print("SKIP")
11+
raise SystemExit
12+
13+
14+
# open a connection and start by reading, not writing
15+
async def read_first(host, port):
16+
reader, writer = await asyncio.open_connection(host, port, ssl=True)
17+
18+
print("read something")
19+
inbuf = b''
20+
while len(inbuf) < 20:
21+
try:
22+
b = await reader.read(100)
23+
except OSError as e:
24+
if e.args[0] < -120:
25+
print("read SSL error -%x : %s" % (-e.args[0], ssl.errstr(e.args[0])))
26+
raise OSError(e.args[0], bytes.decode(ssl.errstr(e.args[0])))
27+
else:
28+
print("read OSError: %d / -%x" % (e.args[0], -e.args[0]))
29+
raise
30+
print("read:", b)
31+
if b is None:
32+
continue
33+
elif len(b) == 0:
34+
print("EOF")
35+
break
36+
elif len(b) > 0:
37+
inbuf += b
38+
else:
39+
raise ValueError("negative length returned by recv")
40+
41+
print("close")
42+
writer.close()
43+
await writer.wait_closed()
44+
print("done")
45+
46+
47+
asyncio.run(read_first("aspmx.l.google.com", 25))

0 commit comments

Comments
 (0)
0