8000 extmod/uasyncio: add SSL support and fix SSL errors (esp32 primarily) by tve · Pull Request #5815 · micropython/micropython · GitHub
[go: up one dir, main page]

Skip to content

extmod/uasyncio: add SSL support and fix SSL errors (esp32 primarily) #5815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions extmod/modussl_mbedtls.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include "mbedtls/entropy.h"
#include "mbedtls/ctr_drbg.h"
#include "mbedtls/debug.h"
#include "mbedtls/error.h"

typedef struct _mp_obj_ssl_socket_t {
mp_obj_base_t base;
Expand Down Expand Up @@ -74,6 +75,14 @@ STATIC void mbedtls_debug(void *ctx, int level, const char *file, int line, cons
}
#endif

STATIC NORETURN void mbedtls_raise_error(int err) {
char error_buf[80];
mbedtls_strerror(err, error_buf, sizeof(error_buf));
//printf("mbedtls error -0x%x : %s\n", -err, error_buf);
mp_raise_msg_varg(&mp_type_OSError, "MBEDTLS -0x%x: %s", -err, error_buf);
}

// _mbedtls_ssl_send is called my mbedtls to send bytes onto the underlying socket
STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
mp_obj_t sock = *(mp_obj_t *)ctx;

Expand All @@ -85,12 +94,13 @@ STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
if (mp_is_nonblocking_error(err)) {
return MBEDTLS_ERR_SSL_WANT_WRITE;
}
return -err;
return -err; // convert an MP_ERRNO to something mbedtls passes through as error
} else {
return out_sz;
}
}

// _mbedtls_ssl_recv is called my mbedtls to receive bytes from the underlying socket
STATIC int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
mp_obj_t sock = *(mp_obj_t *)ctx;

Expand Down Expand Up @@ -129,7 +139,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
mbedtls_pk_init(&o->pkey);
mbedtls_ctr_drbg_init(&o->ctr_drbg);
#ifdef MBEDTLS_DEBUG_C
// Debug level (0-4)
// Debug level (0-4) 1=warning, 2=info, 3=debug, 4=verbose
mbedtls_debug_set_threshold(0);
#endif

Expand Down Expand Up @@ -197,7 +207,6 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
if (args->do_handshake.u_bool) {
while ((ret = mbedtls_ssl_handshake(&o->ssl)) != 0) {
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
printf("mbedtls_ssl_handshake error: -%x\n", -ret);
goto cleanup;
}
}
Expand All @@ -221,7 +230,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
} else if (ret == MBEDTLS_ERR_X509_BAD_INPUT_DATA) {
mp_raise_ValueError("invalid cert");
} else {
mp_raise_OSError(MP_EIO);
mbedtls_raise_error(ret);
}
}

Expand Down Expand Up @@ -267,6 +276,22 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
return MP_STREAM_ERROR;
}

STATIC mp_obj_t socket_recv(mp_obj_t self_in, mp_obj_t len_in) {
size_t len = mp_obj_get_int(len_in);
vstr_t vstr;
vstr_init_len(&vstr, len);

int errcode;
int ret = socket_read(self_in, vstr.buf, len, &errcode);
if (ret == MP_STREAM_ERROR) {
mp_raise_OSError(errcode);
}

vstr.len = ret;
return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_recv_obj, socket_recv);

STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, int *errcode) {
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);

Expand All @@ -286,6 +311,19 @@ STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in
return MP_STREAM_ERROR;
}

STATIC mp_obj_t socket_send(mp_obj_t self_in, mp_obj_t buf_in) {
mp_buffer_info_t bufinfo;
mp_get_buffer_raise(buf_in, &bufinfo, MP_BUFFER_READ);

int errcode;
int r = socket_write(self_in, bufinfo.buf, bufinfo.len, &errcode);
if (r == MP_STREAM_ERROR) {
mp_raise_OSError(errcode);
}
return mp_obj_new_int(r);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_send_obj, socket_send);

STATIC mp_obj_t socket_setblocking(mp_obj_t self_in, mp_obj_t flag_in) {
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(self_in);
mp_obj_t sock = o->sock;
Expand Down Expand Up @@ -315,7 +353,9 @@ STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = {
{ MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mp_stream_read_obj) },
{ MP_ROM_QSTR(MP_QSTR_readinto), MP_ROM_PTR(&mp_stream_readinto_obj) },
{ MP_ROM_QSTR(MP_QSTR_readline), MP_ROM_PTR(&mp_stream_unbuffered_readline_obj) },
{ MP_ROM_QSTR(MP_QSTR_recv), MP_ROM_PTR(&socket_recv_obj) },
{ MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mp_stream_write_obj) },
{ MP_ROM_QSTR(MP_QSTR_send), MP_ROM_PTR(&socket_send_obj) },
{ MP_ROM_QSTR(MP_QSTR_setblocking), MP_ROM_PTR(&socket_setblocking_obj) },
{ MP_ROM_QSTR(MP_QSTR_close), MP_ROM_PTR(&mp_stream_close_obj) },
#if MICROPY_PY_USSL_FINALISER
Expand Down
22 changes: 21 additions & 1 deletion extmod/uasyncio/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

from . import core

try:
import ssl as modssl # module is used in function that has an ssl parameter
except:
ssl_class = None

class Stream:
def __init__(self, s, e={}):
Expand Down Expand Up @@ -54,7 +58,7 @@ async def drain(self):


# Create a TCP stream connection to a remote host
async def open_connection(host, port):
async def open_connection(host, port, ssl=None, server_hostname=None):
from uerrno import EINPROGRESS
import usocket as socket

Expand All @@ -68,6 +72,22 @@ async def open_connection(host, port):
if er.args[0] != EINPROGRESS:
raise er
yield core._io_queue.queue_write(s)
# wrap with SSL, if requested
if ssl:
if not modssl:
raise ValueError("SSL not supported")
if ssl is True:
ssl = {} # spec says to use ssl.create_default_context() but we don't have that
elif isinstance(ssl, dict):
# non-standard: accept dict with KW args suitable to call ssl.wrap_socket()
if server_hostname:
# spec: server_hostname sets or overrides the hostname that the target server’s
# certificate will be matched against.
ssl["server_hostname"] = server_hostname
else:
# spec says we should handle ssl.SSLContext object here, but ain't got that
raise ValueError("invalid ssl param")
s = modssl.wrap_socket(s, **ssl)
return ss, ss


Expand Down
9 changes: 6 additions & 3 deletions ports/esp32/modsocket.c
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,8 @@ int _socket_send(socket_obj_t *sock, const char *data, size_t datalen) {
MP_THREAD_GIL_EXIT();
int r = lwip_write(sock->fd, data + sentlen, datalen - sentlen);
MP_THREAD_GIL_ENTER();
if (r < 0 && errno != EWOULDBLOCK) {
// lwip returns EINPROGRESS when trying to send right after a non-blocking connect
if (r < 0 && errno != EWOULDBLOCK && errno != EINPROGRESS) {
exception_from_errno(errno);
}
if (r > 0) {
Expand All @@ -580,7 +581,7 @@ int _socket_send(socket_obj_t *sock, const char *data, size_t datalen) {
check_for_exceptions();
}
if (sentlen == 0) {
mp_raise_OSError(MP_ETIMEDOUT);
mp_raise_OSError(sock->retries == 0 ? MP_EWOULDBLOCK : MP_ETIMEDOUT);
}
return sentlen;
}
Expand Down Expand Up @@ -663,7 +664,9 @@ STATIC mp_uint_t socket_stream_write(mp_obj_t self_in, const void *buf, mp_uint_
if (r > 0) {
return r;
}
if (r < 0 && errno != EWOULDBLOCK) {
// lwip returns EINPROGRESS when trying to write right after a non-blocking connect
if (r < 0 && errno != EWOULDBLOCK && errno != EINPROGRESS) {
printf("socket_stream_write error %d\n", errno);
*errcode = errno;
return MP_STREAM_ERROR;
}
Expand Down
75 changes: 69 additions & 6 deletions tests/net_hosted/connect_nonblock.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,83 @@
# test that socket.connect() on a non-blocking socket raises EINPROGRESS
# and that an immediate write/send/read/recv does the right thing

try:
import usocket as socket
import usocket as socket, ussl as ssl, sys
except:
import socket
import socket, ssl, sys


def test(peer_addr):
def do_connect(peer_addr, tls):
s = socket.socket()
s.setblocking(False)
try:
s.connect(peer_addr)
except OSError as er:
print(er.args[0] == 115) # 115 is EINPROGRESS
print('connect:', er.args[0] == 115) # 115 is EINPROGRESS
# wrap with ssl/tls if desired
if tls:
try:
if sys.implementation.name == 'micropython':
s = ssl.wrap_socket(s)
else:
s = ssl.wrap_socket(s, do_handshake_on_connect=False)
print("wrap: True")
except Exception as e:
print("wrap:", e)
return s

def test(peer_addr, tls=False):
# a fresh socket is opened for each combination because MP on linux is too fast

# hasRW is used to force socket.read and socket.write to None in CPython since it doesn't
# have these methods and we expect a None result
hasRW = sys.implementation.name == 'micropython' or tls or None

# connect + send
s = do_connect(peer_addr, tls)
# send -> EAGAIN
try:
print('send ret:', s.send(b'1234'))
except OSError as er:
print('send:', er, er.args[0] == 11) # 11 is EAGAIN
s.close()

# connect + write
s = do_connect(peer_addr, tls)
import time
time.sleep(0.2)
# write -> None
try:
ret = hasRW and s.write(b'1234') # None in CPython
print('write ret:', ret, ret is None)
except OSError as er:
print('write:', er, False) # should not raise
except ValueError as er: # CPython
print('write:', er, er.args[0] == 'Write on closed or unwrapped SSL socket.')
s.close()

# connect + recv
s = do_connect(peer_addr, tls)
# recv -> EAGAIN
try:
print('recv ret:', s.recv(10))
except OSError as er:
print('recv:', er, er.args[0] == 11) # 11 is EAGAIN
s.close()

# connect + read
s = do_connect(peer_addr, tls)
# read -> None
try:
ret = hasRW and s.read(10)
print('read ret:', ret, ret is None)
except OSError as er:
print('read:', er, False) # should not raise
except ValueError as er: # CPython
print('read:', er, er.args[0] == 'Read on closed or unwrapped SSL socket.')
s.close()

if __name__ == "__main__":
test(socket.getaddrinfo('micropython.org', 80)[0][-1])
print("--- Plain sockets ---")
test(socket.getaddrinfo('micropython.org', 80)[0][-1], False)
print("--- SSL sockets ---")
test(socket.getaddrinfo('micropython.org', 443)[0][-1], True)
43 changes: 43 additions & 0 deletions tests/net_inet/ssl_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# test that socket.connect() on a non-blocking socket raises EINPROGRESS
# and that an immediate write/send/read/recv does the right thing

try:
import usocket as socket, ussl as ssl, sys
except:
import socket, ssl, sys

def test(addr, hostname, block=True):
print("---", hostname or addr)
s = socket.socket()
s.setblocking(block)
try:
s.connect(addr)
print('connected')
except OSError as e:
if e.args[0] != 115: # 115 == EINPROGRESS
raise

try:
s = ssl.wrap_socket(s)
print("wrap: True")
except OSError as e:
print("wrap:", e)

if not block:
try:
while s.write(b'0') is None:
pass
except OSError as e:
print("write:", e)
s.close()

if __name__ == "__main__":
# connect to plain HTTP port, oops!
addr = socket.getaddrinfo('micropython.org', 80)[0][-1]
test(addr, None)
# connect to plain HTTP port, oops!
addr = socket.getaddrinfo('micropython.org', 80)[0][-1]
test(addr, None, False)
# connect to server with self-signed cert, oops!
addr = socket.getaddrinfo('test.mosquitto.org', 8883)[0][-1]
test(addr, 'test.mosquitto.org')
0