46
46
#include "mbedtls/debug.h"
47
47
#include "mbedtls/error.h"
48
48
49
+ // flags for _mp_obj_ssl_socket_t.poll_flag that control the poll ioctl
50
+ // the issue is that when using ipoll we may be polling only for reading, and the socket may never
51
+ // become readable because mbedtls needs to write soemthing (like a handshake or renegotiation) and
52
+ // so poll never returns "it's readable" or "it's writable" and so nothing ever makes progress.
53
+ // See also the commit message for
54
+ // https://github.com/micropython/micropython/commit/9c7c082396f717a8a8eb845a0af407e78d38165f
55
+ #define READ_NEEDS_WRITE 0x1 // mbedtls_ssl_read said "I need a write"
56
+ #define WRITE_NEEDS_READ 0x2 // mbedtls_ssl_write said "I need a read"
57
+
49
58
typedef struct _mp_obj_ssl_socket_t {
50
59
mp_obj_base_t base ;
51
60
mp_obj_t sock ;
@@ -56,6 +65,8 @@ typedef struct _mp_obj_ssl_socket_t {
56
65
mbedtls_x509_crt cacert ;
57
66
mbedtls_x509_crt cert ;
58
67
mbedtls_pk_context pkey ;
68
+ uint8_t poll_flag ;
69
+ uint8_t poll_by_read ; // true: at next poll try to read first
59
70
} mp_obj_ssl_socket_t ;
60
71
61
72
struct ssl_args {
@@ -116,6 +127,27 @@ STATIC NORETURN void mbedtls_raise_error(int err) {
116
127
#endif
117
128
}
118
129
130
+ STATIC mp_obj_t mod_ssl_errstr (mp_obj_t err_in ) {
131
+ size_t err = mp_obj_get_int (err_in );
132
+ vstr_t vstr ;
133
+ vstr_init_len (& vstr , 80 );
134
+
135
+ // Including mbedtls_strerror takes about 16KB on the esp32 due to all the strings
136
+ #if 1
137
+ vstr .buf [0 ] = 0 ;
138
+ mbedtls_strerror (err , vstr .buf , vstr .alloc );
139
+ vstr .len = strlen (vstr .buf );
140
+ if (vstr .len == 0 ) {
141
+ return MP_OBJ_NULL ;
142
+ }
143
+ #else
144
+ vstr_printf (vstr , "mbedtls error -0x%x\n" , - err );
145
+ #endif
146
+ return mp_obj_new_str_from_vstr (& mp_type_bytes , & vstr );
147
+ }
148
+ STATIC MP_DEFINE_CONST_FUN_OBJ_1 (mod_ssl_errstr_obj , mod_ssl_errstr );
149
+
150
+ // _mbedtls_ssl_send is called by mbedtls to send bytes onto the underlying socket
119
151
STATIC int _mbedtls_ssl_send (void * ctx , const byte * buf , size_t len ) {
120
152
mp_obj_t sock = * (mp_obj_t * )ctx ;
121
153
@@ -237,6 +269,8 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
237
269
}
238
270
}
239
271
272
+ o -> poll_flag = 0 ;
273
+ o -> poll_by_read = 0 ;
240
274
if (args -> do_handshake .u_bool ) {
241
275
while ((ret = mbedtls_ssl_handshake (& o -> ssl )) != 0 ) {
242
276
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE ) {
@@ -289,12 +323,16 @@ STATIC void socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kin
289
323
STATIC mp_uint_t socket_read (mp_obj_t o_in , void * buf , mp_uint_t size , int * errcode ) {
290
324
mp_obj_ssl_socket_t * o = MP_OBJ_TO_PTR (o_in );
291
325
326
+ o -> poll_flag &= ~READ_NEEDS_WRITE ; // clear flag
292
327
int ret = mbedtls_ssl_read (& o -> ssl , buf , size );
293
328
if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY ) {
294
329
// end of stream
295
330
return 0 ;
296
331
}
297
332
if (ret >= 0 ) {
333
+ // if we got all we wanted, for the next poll try a read first 'cause
334
+ // there may be data in the mbedtls record buffer
335
+ o -> poll_by_read = ret == size ;
298
336
return ret ;
299
337
}
300
338
if (ret == MBEDTLS_ERR_SSL_WANT_READ ) {
@@ -303,6 +341,7 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
303
341
// If handshake is not finished, read attempt may end up in protocol
304
342
// wanting to write next handshake message. The same may happen with
305
343
// renegotation.
344
+ o -> poll_flag |= READ_NEEDS_WRITE ; // set flag
306
345
ret = MP_EWOULDBLOCK ;
307
346
}
308
347
* errcode = ret ;
@@ -312,6 +351,7 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
312
351
STATIC mp_uint_t socket_write (mp_obj_t o_in , const void * buf , mp_uint_t size , int * errcode ) {
313
352
mp_obj_ssl_socket_t * o = MP_OBJ_TO_PTR (o_in );
314
353
354
+ o -> poll_flag &= ~WRITE_NEEDS_READ ; // clear flag
315
355
int ret = mbedtls_ssl_write (& o -> ssl , buf , size );
316
356
if (ret >= 0 ) {
317
357
return ret ;
@@ -322,6 +362,7 @@ STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in
322
362
// If handshake is not finished, write attempt may end up in protocol
323
363
// wanting to read next handshake message. The same may happen with
324
364
// renegotation.
365
+ o -> poll_flag |= WRITE_NEEDS_READ ; // set flag
325
366
ret = MP_EWOULDBLOCK ;
326
367
}
327
368
* errcode = ret ;
@@ -348,6 +389,41 @@ STATIC mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, i
348
389
mbedtls_ssl_config_free (& self -> conf );
349
390
mbedtls_ctr_drbg_free (& self -> ctr_drbg );
350
391
mbedtls_entropy_free (& self -> entropy );
392
+ } else if (request == MP_STREAM_POLL ) {
393
+ mp_uint_t ret = 0 ;
394
+ // If the last read returned everything asked for there may be more in the mbedtls buffer,
395
+ // so find out. (There doesn't seem to be an equivalent issue with writes.)
396
+ if ((arg & MP_STREAM_POLL_RD ) && self -> poll_by_read ) {
397
+ size_t avail = mbedtls_ssl_get_bytes_avail (& self -> ssl );
398
+ if (avail > 0 ) ret = MP_STREAM_POLL_RD ;
399
+ }
400
+ // If we're polling to read but not write but mbedtls previously said it needs to write in
401
+ // order to be able to read then poll for both and if either is available pretend the socket
402
+ // is readable. When the app then performs a read, mbedtls is happy to perform the writes as
403
+ // well. Essentially, what we're ensuring is that one of mbedtls' read/write functions is
404
+ // called as soon as the socket can do something.
405
+ if ((arg & MP_STREAM_POLL_RD ) && !(arg & MP_STREAM_POLL_WR ) &&
406
+ self -> poll_flag & READ_NEEDS_WRITE ) {
407
+ arg |= MP_STREAM_POLL_WR ;
408
+ ret |= mp_get_stream (self -> sock )-> ioctl (self -> sock , request , arg , errcode );
409
+ if (ret & MP_STREAM_POLL_WR ) {
410
+ ret |= MP_STREAM_POLL_RD ;
411
+ ret &= ~MP_STREAM_POLL_WR ;
412
+ }
413
+ return ret ;
414
+ // Now comes the same logic flipped around for write
415
+ } else if ((arg & MP_STREAM_POLL_WR ) && !(arg & MP_STREAM_POLL_RD ) &&
416
+ self -> poll_flag & WRITE_NEEDS_READ ) {
417
+ arg |= MP_STREAM_POLL_RD ;
418
+ ret |= mp_get_stream (self -> sock )-> ioctl (self -> sock , request , arg , errcode );
419
+ if (ret & MP_STREAM_POLL_RD ) {
420
+ ret |= MP_STREAM_POLL_WR ;
421
+ ret &= ~MP_STREAM_POLL_RD ;
422
+ }
423
+ return ret ;
424
+ }
425
+ // Pass down to underlying socket
426
+ return ret | mp_get_stream (self -> sock )-> ioctl (self -> sock , request , arg , errcode );
351
427
}
352
428
// Pass all requests down to the underlying socket
353
429
return mp_get_stream (self -> sock )-> ioctl (self -> sock , request , arg , errcode );
@@ -409,6 +485,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_KW(mod_ssl_wrap_socket_obj, 1, mod_ssl_wrap_socke
409
485
STATIC const mp_rom_map_elem_t mp_module_ssl_globals_table [] = {
410
486
{ MP_ROM_QSTR (MP_QSTR___name__ ), MP_ROM_QSTR (MP_QSTR_ussl ) },
411
487
{ MP_ROM_QSTR (MP_QSTR_wrap_socket ), MP_ROM_PTR (& mod_ssl_wrap_socket_obj ) },
488
+ { MP_ROM_QSTR (MP_QSTR_errstr ), MP_ROM_PTR (& mod_ssl_errstr_obj ) },
412
489
};
413
490
414
491
STATIC MP_DEFINE_CONST_DICT (mp_module_ssl_globals , mp_module_ssl_globals_table );
0 commit comments