45
45
#include "mbedtls/debug.h"
46
46
#include "mbedtls/error.h"
47
47
48
+ // flags for _mp_obj_ssl_socket_t.poll_flag that control the poll ioctl
49
+ // the issue is that when using ipoll we may be polling only for reading, and the socket may never
50
+ // become readable because mbedtls needs to write soemthing (like a handshake or renegotiation) and
51
+ // so poll never returns "it's readable" or "it's writable" and so nothing ever makes progress.
52
+ // See also the commit message for
53
+ // https://github.com/micropython/micropython/commit/9c7c082396f717a8a8eb845a0af407e78d38165f
54
+ #define READ_NEEDS_WRITE 0x1 // mbedtls_ssl_read said "I need a write"
55
+ #define WRITE_NEEDS_READ 0x2 // mbedtls_ssl_write said "I need a read"
56
+
48
57
typedef struct _mp_obj_ssl_socket_t {
49
58
mp_obj_base_t base ;
50
59
mp_obj_t sock ;
@@ -55,6 +64,7 @@ typedef struct _mp_obj_ssl_socket_t {
55
64
mbedtls_x509_crt cacert ;
56
65
mbedtls_x509_crt cert ;
57
66
mbedtls_pk_context pkey ;
67
+ uint8_t poll_flag ;
58
68
} mp_obj_ssl_socket_t ;
59
69
60
70
struct ssl_args {
@@ -92,6 +102,26 @@ STATIC NORETURN void mbedtls_raise_error(int err) {
92
102
#endif
93
103
}
94
104
105
+ STATIC mp_obj_t mod_ssl_errstr (mp_obj_t err_in ) {
106
+ size_t err = mp_obj_get_int (err_in );
107
+ vstr_t vstr ;
108
+ vstr_init_len (& vstr , 80 );
109
+
110
+ // Including mbedtls_strerror takes about 16KB on the esp32 due to all the strings
111
+ #if 1
112
+ vstr .buf [0 ] = 0 ;
113
+ mbedtls_strerror (err , vstr .buf , vstr .alloc );
114
+ vstr .len = strlen (vstr .buf );
115
+ if (vstr .len == 0 ) {
116
+ return MP_OBJ_NULL ;
117
+ }
118
+ #else
119
+ vstr_printf (vstr , "mbedtls error -0x%x\n" , - err );
120
+ #endif
121
+ return mp_obj_new_str_from_vstr (& mp_type_bytes , & vstr );
122
+ }
123
+ STATIC MP_DEFINE_CONST_FUN_OBJ_1 (mod_ssl_errstr_obj , mod_ssl_errstr );
124
+
95
125
// _mbedtls_ssl_send is called by mbedtls to send bytes onto the underlying socket
96
126
STATIC int _mbedtls_ssl_send (void * ctx , const byte * buf , size_t len ) {
97
127
mp_obj_t sock = * (mp_obj_t * )ctx ;
@@ -214,10 +244,10 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
214
244
}
215
245
}
216
246
247
+ o -> poll_flag = 0 ;
217
248
if (args -> do_handshake .u_bool ) {
218
249
while ((ret = mbedtls_ssl_handshake (& o -> ssl )) != 0 ) {
219
250
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE ) {
220
<
8000
code class="diff-text syntax-highlighted-line deletion">- printf ("mbedtls_ssl_handshake error: %d/-0x%x\n" , ret , - ret );
221
251
goto cleanup ;
222
252
}
223
253
}
@@ -267,6 +297,7 @@ STATIC void socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kin
267
297
STATIC mp_uint_t socket_read (mp_obj_t o_in , void * buf , mp_uint_t size , int * errcode ) {
268
298
mp_obj_ssl_socket_t * o = MP_OBJ_TO_PTR (o_in );
269
299
300
+ o -> poll_flag &= ~READ_NEEDS_WRITE ; // clear flag
270
301
int ret = mbedtls_ssl_read (& o -> ssl , buf , size );
271
302
if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY ) {
272
303
// end of stream
@@ -281,6 +312,7 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
281
312
// If handshake is not finished, read attempt may end up in protocol
282
313
// wanting to write next handshake message. The same may happen with
283
314
// renegotation.
315
+ o -> poll_flag |= READ_NEEDS_WRITE ; // set flag
284
316
ret = MP_EWOULDBLOCK ;
285
317
}
286
318
* errcode = ret ;
@@ -306,6 +338,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_recv_obj, socket_recv);
306
338
STATIC mp_uint_t socket_write (mp_obj_t o_in , const void * buf , mp_uint_t size , int * errcode ) {
307
339
mp_obj_ssl_socket_t * o = MP_OBJ_TO_PTR (o_in );
308
340
341
+ o -> poll_flag &= ~WRITE_NEEDS_READ ; // clear flag
309
342
int ret = mbedtls_ssl_write (& o -> ssl , buf , size );
310
343
if (ret >= 0 ) {
311
344
return ret ;
@@ -316,6 +349,7 @@ STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in
316
349
// If handshake is not finished, write attempt may end up in protocol
317
350
// wanting to read next handshake message. The same may happen with
318
351
// renegotation.
352
+ o -> poll_flag |= WRITE_NEEDS_READ ; // set flag
319
353
ret = MP_EWOULDBLOCK ;
320
354
}
321
355
* errcode = ret ;
@@ -355,6 +389,33 @@ STATIC mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, i
355
389
mbedtls_ssl_config_free (& self -> conf );
356
390
mbedtls_ctr_drbg_free (& self -> ctr_drbg );
357
391
mbedtls_entropy_free (& self -> entropy );
392
+ } else if (request == MP_STREAM_POLL ) {
393
+ // If we're polling to read but not write but mbedtls previously said it needs to write in
394
+ // order to be able to read then poll for both and if either is available pretend the socket
395
+ // is readable. When the app then performs a read, mbedtls is happy to perform the writes as
396
+ // well. Essentially, what we're ensuring is that one of mbedtls' read/write functions is
397
+ // called as soon as the socket can do something.
398
+ if ((arg & MP_STREAM_POLL_RD ) && !(arg & MP_STREAM_POLL_WR ) &&
399
+ self -> poll_flag & READ_NEEDS_WRITE ) {
400
+ arg |= MP_STREAM_POLL_WR ;
401
+ mp_uint_t ret = mp_get_stream (self -> sock )-> ioctl (self -> sock , request , arg , errcode );
402
+ if (ret & MP_STREAM_POLL_WR ) {
403
+ ret |= MP_STREAM_POLL_RD ;
404
+ ret &= ~MP_STREAM_POLL_WR ;
405
+ }
406
+ return ret ;
407
+ // Now comes the same logic flipped around for write
408
+ } else if ((arg & MP_STREAM_POLL_WR ) && !(arg & MP_STREAM_POLL_RD ) &&
409
+ self -> poll_flag & WRITE_NEEDS_READ ) {
410
+ arg |= MP_STREAM_POLL_RD ;
411
+ mp_uint_t ret = mp_get_stream (self -> sock )-> ioctl (self -> sock , request , arg , errcode );
412
+ if (ret & MP_STREAM_POLL_RD ) {
413
+ ret |= MP_STREAM_POLL_WR ;
414
+ ret &= ~MP_STREAM_POLL_RD ;
415
+ }
416
+ return ret ;
417
+ }
418
+ // fall-through if there's no wonky XX_NEEDS_YY situation
358
419
}
359
420
// Pass all requests down to the underlying socket
360
421
return mp_get_stream (self -> sock )-> ioctl (self -> sock , request , arg , errcode );
@@ -418,6 +479,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_KW(mod_ssl_wrap_socket_obj, 1, mod_ssl_wrap_socke
418
479
STATIC const mp_rom_map_elem_t mp_module_ssl_globals_table [] = {
419
480
{ MP_ROM_QSTR (MP_QSTR___name__ ), MP_ROM_QSTR (MP_QSTR_ussl ) },
420
481
{ MP_ROM_QSTR (MP_QSTR_wrap_socket ), MP_ROM_PTR (& mod_ssl_wrap_socket_obj ) },
482
+ { MP_ROM_QSTR (MP_QSTR_errstr ), MP_ROM_PTR (& mod_ssl_errstr_obj ) },
421
483
};
422
484
423
485
STATIC MP_DEFINE_CONST_DICT (mp_module_ssl_globals , mp_module_ssl_globals_table );
0 commit comments