8000 extmod/modussl: fix socket and ussl read/recv/send/write errors for n… · micropython/micropython@a370f53 · GitHub
[go: up one dir, main page]

Skip to content

Commit a370f53

Browse files
committed
extmod/modussl: fix socket and ussl read/recv/send/write errors for non-blocking sockets
1 parent 388d419 commit a370f53

File tree

9 files changed

+420
-29
lines changed

9 files changed

+420
-29
lines changed

docs/library/ussl.rst

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,23 @@ facilities for network sockets, both client-side and server-side.
1313
Functions
1414
---------
1515

16-
.. function:: ussl.wrap_socket(sock, server_side=False, keyfile=None, certfile=None, cert_reqs=CERT_NONE, ca_certs=None)
17-
16+
.. function:: ussl.wrap_socket(sock, server_side=False, keyfile=None, certfile=None, cert_reqs=CERT_NONE, ca_certs=None, do_handshake=True)
1817
Takes a `stream` *sock* (usually usocket.socket instance of ``SOCK_STREAM`` type),
1918
and returns an instance of ssl.SSLSocket, which wraps the underlying stream in
2019
an SSL context. Returned object has the usual `stream` interface methods like
21-
``read()``, ``write()``, etc. In MicroPython, the returned object does not expose
22-
socket interface and methods like ``recv()``, ``send()``. In particular, a
20+
``read()``, ``write()``, etc. as well as ``recv()``, ``send()``. In particular, a
2321
server-side SSL socket should be created from a normal socket returned from
2422
:meth:`~usocket.socket.accept()` on a non-SSL listening server socket.
2523

24+
- *do_handshake* determines whether the handshake is done as part of the ``wrap_socket``
25+
or whether it is deferred to be done as part of the initial reads or writes
26+
(there is no ``do_handshake`` method as in CPython).
27+
For blocking sockets doing the handshake immediately is standard. For non-blocking
28+
sockets (i.e. when the *sock* passed into ``wrap_socket`` is in non-blocking mode)
29+
the handshake should generally be deferred because otherwise ``wrap_socket`` blocks
30+
until it completes. Note that in AXTLS the handshake can be deferred until the first
31+
read or write but it then blocks until completion.
32+
2633
Depending on the underlying module implementation in a particular
2734
`MicroPython port`, some or all keyword arguments above may be not supported.
2835

extmod/modussl_axtls.c

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,16 @@ STATIC mp_obj_ssl_socket_t *ussl_socket_new(mp_obj_t sock, struct ssl_args *args
104104
o->ssl_sock = ssl_client_new(o->ssl_ctx, (long)sock, NULL, 0, ext);
105105

106106
if (args->do_handshake.u_bool) {
107-
int res = ssl_handshake_status(o->ssl_sock);
108-
109-
if (res != SSL_OK) {
110-
printf("ssl_handshake_status: %d\n", res);
111-
ssl_display_error(res);
112-
mp_raise_OSError(MP_EIO);
107+
int r = ssl_handshake_status(o->ssl_sock);
108+
109+
if (r != SSL_OK) {
110+
ssl_display_error(r);
111+
if (r == SSL_CLOSE_NOTIFY || r == SSL_ERROR_CONN_LOST) { // EOF
112+
r = MP_ENOTCONN;
113+
} else if (r == SSL_EAGAIN) {
114+
r = MP_EAGAIN;
115+
}
116+
ussl_raise_error(r);
113117
}
114118
}
115119

@@ -173,6 +177,22 @@ STATIC mp_uint_t ussl_socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int
173177
return size;
174178
}
175179

180+
STATIC mp_obj_t ussl_socket_recv(mp_obj_t self_in, mp_obj_t len_in) {
181+
size_t len = mp_obj_get_int(len_in);
182+
vstr_t vstr;
183+
vstr_init_len(&vstr, len);
184+
185+
int errcode;
186+
mp_uint_t ret = ussl_socket_read(self_in, vstr.buf, len, &errcode);
187+
if (ret == MP_STREAM_ERROR) {
188+
mp_raise_OSError(errcode);
189+
}
190+
191+
vstr.len = ret;
192+
return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr);
193+
}
194+
STATIC MP_DEFINE_CONST_FUN_OBJ_2(ussl_socket_recv_obj, ussl_socket_recv);
195+
176196
STATIC mp_uint_t ussl_socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, int *errcode) {
177197
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);
178198

@@ -181,14 +201,43 @@ STATIC mp_uint_t ussl_socket_write(mp_obj_t o_in, const void *buf, mp_uint_t siz
181201
return MP_STREAM_ERROR;
182202
}
183203

184-
mp_int_t r = ssl_write(o->ssl_sock, buf, size);
204+
mp_int_t r;
205+
eagain:
206+
r = ssl_write(o->ssl_sock, buf, size);
207+
if (r == SSL_OK) {
208+
// see comment in read method
209+
if (o->blocking) {
210+
goto eagain;
211+
} else {
212+
r = SSL_EAGAIN;
213+
}
214+
}
185215
if (r < 0) {
216+
if (r == SSL_CLOSE_NOTIFY || r == SSL_ERROR_CONN_LOST) {
217+
return 0; // EOF
218+
}
219+
if (r == SSL_EAGAIN) {
220+
r = MP_EAGAIN;
221+
}
186222
*errcode = r;
187223
return MP_STREAM_ERROR;
188224
}
189225
return r;
190226
}
191227

228+
STATIC mp_obj_t ussl_socket_send(mp_obj_t self_in, mp_obj_t buf_in) {
229+
mp_buffer_info_t bufinfo;
230+
mp_get_buffer_raise(buf_in, &bufinfo, MP_BUFFER_READ);
231+
232+
int errcode;
233+
mp_uint_t r = ussl_socket_write(self_in, bufinfo.buf, bufinfo.len, &errcode);
234+
if (r == MP_STREAM_ERROR) {
235+
mp_raise_OSError(errcode);
236+
}
237+
return mp_obj_new_int(r);
238+
}
239+
STATIC MP_DEFINE_CONST_FUN_OBJ_2(ussl_socket_send_obj, ussl_socket_send);
240+
192241
STATIC mp_uint_t ussl_socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, int *errcode) {
193242
mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(o_in);
194243
if (request == MP_STREAM_CLOSE && self->ssl_sock != NULL) {
@@ -216,7 +265,9 @@ STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = {
216265
{ MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mp_stream_read_obj) },
217266
{ MP_ROM_QSTR(MP_QSTR_readinto), MP_ROM_PTR(&mp_stream_readinto_obj) },
218267
{ MP_ROM_QSTR(MP_QSTR_readline), MP_ROM_PTR(&mp_stream_unbuffered_readline_obj) },
268+
{ MP_ROM_QSTR(MP_QSTR_recv), MP_ROM_PTR(&ussl_socket_recv_obj) },
219269
{ MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mp_stream_write_obj) },
270+
{ MP_ROM_QSTR(MP_QSTR_send), MP_ROM_PTR(&ussl_socket_send_obj) },
220271
{ MP_ROM_QSTR(MP_QSTR_setblocking), MP_ROM_PTR(&ussl_socket_setblocking_obj) },
221272
{ MP_ROM_QSTR(MP_QSTR_close), MP_ROM_PTR(&mp_stream_close_obj) },
222273
#if MICROPY_PY_USSL_FINALISER

extmod/modussl_mbedtls.c

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
9191
}
9292
}
9393

94+
// _mbedtls_ssl_recv is called by mbedtls to receive bytes from the underlying socket
9495
STATIC int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
9596
mp_obj_t sock = *(mp_obj_t *)ctx;
9697

@@ -129,7 +130,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
129130
mbedtls_pk_init(&o->pkey);
130131
mbedtls_ctr_drbg_init(&o-& 10000 gt;ctr_drbg);
131132
#ifdef MBEDTLS_DEBUG_C
132-
// Debug level (0-4)
133+
// Debug level (0-4) 1=warning, 2=info, 3=debug, 4=verbose
133134
mbedtls_debug_set_threshold(0);
134135
#endif
135136

@@ -197,7 +198,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
197198
if (args->do_handshake.u_bool) {
198199
while ((ret = mbedtls_ssl_handshake(&o->ssl)) != 0) {
199200
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
200-
printf("mbedtls_ssl_handshake error: -%x\n", -ret);
201+
printf("mbedtls_ssl_handshake error: %d/-0x%x\n", ret, -ret);
201202
goto cleanup;
202203
}
203204
}
@@ -221,7 +222,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
221222
} else if (ret == MBEDTLS_ERR_X509_BAD_INPUT_DATA) {
222223
mp_raise_ValueError(MP_ERROR_TEXT("invalid cert"));
223224
} else {
224-
mp_raise_OSError(MP_EIO);
225+
mbedtls_raise_error(ret);
225226
}
226227
}
227228

@@ -267,6 +268,22 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
267268
return MP_STREAM_ERROR;
268269
}
269270

271+
STATIC mp_obj_t socket_recv(mp_obj_t self_in, mp_obj_t len_in) {
272+
size_t len = mp_obj_get_int(len_in);
273+
vstr_t vstr;
274+
vstr_init_len(&vstr, len);
275+
276+
int errcode;
277+
int ret = socket_read(self_in, vstr.buf, len, &errcode);
278+
if (ret == MP_STREAM_ERROR) {
279+
mp_raise_OSError(errcode);
280+
}
281+
282+
vstr.len = ret;
283+
return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr);
284+
}
285+
STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_recv_obj, socket_recv);
286+
270287
STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, int *errcode) {
271288
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);
272289

@@ -286,6 +303,19 @@ STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in
286303
return MP_STREAM_ERROR;
287304
}
288305

306+
STATIC mp_obj_t socket_send(mp_obj_t self_in, mp_obj_t buf_in) {
307+
mp_buffer_info_t bufinfo;
308+
mp_get_buffer_raise(buf_in, &bufinfo, MP_BUFFER_READ);
309+
310+
int errcode;
311+
int r = socket_write(self_in, bufinfo.buf, bufinfo.len, &errcode);
312+
if (r == MP_STREAM_ERROR) {
313+
mp_raise_OSError(errcode);
314+
}
315+
return mp_obj_new_int(r);
316+
}
317+
STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_send_obj, socket_send);
318+
289319
STATIC mp_obj_t socket_setblocking(mp_obj_t self_in, mp_obj_t flag_in) {
290320
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(self_in);
291321
mp_obj_t sock = o->sock;
@@ -315,7 +345,9 @@ STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = {
315345
{ MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mp_stream_read_obj) },
316346
{ MP_ROM_QSTR(MP_QSTR_readinto), MP_ROM_PTR(&mp_stream_readinto_obj) },
317347
{ MP_ROM_QSTR(MP_QSTR_readline), MP_ROM_PTR(&mp_stream_unbuffered_readline_obj) },
348+
{ MP_ROM_QSTR(MP_QSTR_recv), MP_ROM_PTR(&socket_recv_obj) },
318349
{ MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mp_stream_write_obj) },
350+
{ MP_ROM_QSTR(MP_QSTR_send), MP_ROM_PTR(&socket_send_obj) },
319351
{ MP_ROM_QSTR(MP_QSTR_setblocking), MP_ROM_PTR(&socket_setblocking_obj) },
320352
{ MP_ROM_QSTR(MP_QSTR_close), MP_ROM_PTR(&mp_stream_close_obj) },
321353
#if MICROPY_PY_USSL_FINALISER

ports/esp32/modsocket.c

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ STATIC mp_obj_t socket_accept(const mp_obj_t arg0) {
325325
if (new_fd >= 0) {
326326
break;
327327
}
328-
if (errno != EAGAIN) {
328+
if (errno != MP_EAGAIN) {
329329
exception_from_errno(errno);
330330
}
331331
check_for_exceptions();
@@ -518,7 +518,7 @@ STATIC mp_uint_t _socket_read_data(mp_obj_t self_in, void *buf, size_t size,
518518
if (r >= 0) {
519519
return r;
520520
}
521-
if (errno != EWOULDBLOCK) {
521+
if (errno != MP_EWOULDBLOCK) {
522522
*errcode = errno;
523523
return MP_STREAM_ERROR;
524524
}
@@ -571,7 +571,8 @@ int _socket_send(socket_obj_t *sock, const char *data, size_t datalen) {
571571
MP_THREAD_GIL_EXIT();
572572
int r = lwip_write(sock->fd, data + sentlen, datalen - sentlen);
573573
MP_THREAD_GIL_ENTER();
574-
if (r < 0 && errno != EWOULDBLOCK) {
574+
// lwip returns MP_EINPROGRESS when trying to send right after a non-blocking connect
575+
if (r < 0 && errno != MP_EWOULDBLOCK && errno != MP_EINPROGRESS) {
575576
exception_from_errno(errno);
576577
}
577578
if (r > 0) {
@@ -580,7 +581,7 @@ int _socket_send(socket_obj_t *sock, const char *data, size_t datalen) {
580581
check_for_exceptions();
581582
}
582583
if (sentlen == 0) {
583-
mp_raise_OSError(MP_ETIMEDOUT);
584+
mp_raise_OSError(sock->retries == 0 ? MP_EWOULDBLOCK : MP_ETIMEDOUT);
584585
}
585586
return sentlen;
586587
}
@@ -629,7 +630,7 @@ STATIC mp_obj_t socket_sendto(mp_obj_t self_in, mp_obj_t data_in, mp_obj_t addr_
629630
if (ret > 0) {
630631
return mp_obj_new_int_from_uint(ret);
631632
}
632-
if (ret == -1 && errno != EWOULDBLOCK) {
633+
if (ret == -1 && errno != MP_EWOULDBLOCK) {
633634
exception_from_errno(errno);
634635
}
635636
check_for_exceptions();
@@ -661,9 +662,12 @@ STATIC mp_uint_t socket_stream_write(mp_obj_t self_in, const void *buf, mp_uint_
661662
int r = lwip_write(sock->fd, buf, size);
662663
MP_THREAD_GIL_ENTER();
663664
if (r > 0) {
665+
printf("socket_stream_write wrote %d\n", r);
664666
return r;
665667
}
666-
if (r < 0 && errno != EWOULDBLOCK) {
668+
// lwip returns MP_EINPROGRESS when trying to write right after a non-blocking connect
669+
if (r < 0 && errno != MP_EWOULDBLOCK && errno != MP_EINPROGRESS) {
670+
printf("socket_stream_write error %d\n", errno);
667671
*errcode = errno;
668672
return MP_STREAM_ERROR;
669673
}

tests/net_hosted/connect_nonblock.py

Lines changed: 101 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,116 @@
11
# test that socket.connect() on a non-blocking socket raises EINPROGRESS
2+
# and that an immediate write/send/read/recv does the right thing
23

34
try:
4-
import usocket as socket
5+
import usocket as socket, ussl as ssl, sys, time
56
except:
6-
import socket
7+
import socket, ssl, sys, time
78

89

9-
def test(peer_addr):
10+
def dp(e):
11+
# print(e) # uncomment this line for dev&test to print the actual exceptions
12+
pass
13+
14+
15+
# do_connect establishes the socket and wraps it if requested
16+
def do_connect(peer_addr, tls, handshake):
1017
s = socket.socket()
1118
s.setblocking(False)
1219
try:
1320
s.connect(peer_addr)
1421
except OSError as er:
15-
print(er.args[0] == 115) # 115 is EINPROGRESS
22+
print("connect:", er.args[0] == 115) # 115 is EINPROGRESS
23+
# wrap with ssl/tls if desired
24+
if tls:
25+
try:
26+
if sys.implementation.name = 10000 = "micropython":
27+
s = ssl.wrap_socket(s, do_handshake=handshake)
28+
else:
29+
s = ssl.wrap_socket(s, do_handshake_on_connect=handshake)
30+
print("wrap: True")
31+
except Exception as e:
32+
dp(e)
33+
print("wrap:", e)
34+
# if handshake is set, we wait after connect() so it has time to actually happen
35+
if handshake and not tls: # with tls the handshake does it
36+
time.sleep(0.2)
37+
return s
38+
39+
40+
def test(peer_addr, tls=False, handshake=False):
41+
# a fresh socket is opened for each combination because MP on linux is too fast
42+
43+
# hasRW is false in CPython for sockets: they don't have read or write methods
44+
hasRW = sys.implementation.name == "micropython" or tls
45+
46+
# connect + send
47+
s = do_connect(peer_addr, tls, handshake)
48+
# send -> 4 or EAGAIN
49+
try:
50+
ret = s.send(b"1234")
51+
print("send:", handshake and ret == 4)
52+
except OSError as er:
53+
dp(er)
54+
print("send:", er.args[0] == 11) # 11 is EAGAIN
1655
s.close()
1756

57+
# connect + write
58+
if hasRW:
59+
s = do_connect(peer_addr, tls, handshake)
60+
# write -> None
61+
try:
62+
ret = s.write(b"1234")
63+
print("write:", ret is (4 if handshake else None))
64+
except OSError as er:
65+
dp(er)
66+
print("write:", False) # should not raise
67+
except ValueError as er: # CPython
68+
dp(er)
69+
print("write:", er.args[0] == "Write on closed or unwrapped SSL socket.")
70+
s.close()
71+
else: # fake it...
72+
print("connect:", True)
73+
print("write:", True)
74+
75+
# connect + recv
76+
s = do_connect(peer_addr, tls, handshake)
77+
# recv -> EAGAIN
78+
try:
79+
print("recv:", s.recv(10))
80+
except OSError as er:
81+
dp(er)
82+
print("recv:", er.args[0] == 11) # 11 is EAGAIN
83+
s.close()
84+
85+
# connect + read
86+
if hasRW:
87+
s = do_connect(peer_addr, tls, handshake)
88+
# read -> None
89+
try:
90+
ret = s.read(10)
91+
print("read:", ret is None)
92+
except OSError as er:
93+
dp(er)
94+
print("read:", False) # should not raise
95+
except ValueError as er: # CPython
96+
dp(er)
97+
print("read:", er.args[0] == "Read on closed or unwrapped SSL socket.")
98+
s.close()
99+
else: # fake it...
100+
print("connect:", True)
101+
print("read:", True)
102+
18103

19104
if __name__ == "__main__":
20-
test(socket.getaddrinfo("micropython.org", 80)[0][-1])
105+
# these tests use an non-existant test IP address, this way the connect takes forever and
106+
# we can see EAGAIN/None (https://tools.ietf.org/html/rfc5737)
107+
print("--- Plain sockets to nowhere ---")
108+
test(socket.getaddrinfo("192.0.2.1", 80)[0][-1], False, False)
109+
print("--- SSL sockets to nowhere ---")
110+
# this test fails with AXTLS because do_handshake=False blocks on first read/write and
111+
# there it times out until the connect is aborted
112+
test(socket.getaddrinfo("192.0.2.1", 443)[0][-1], True, False)
113+
print("--- Plain sockets ---")
114+
test(socket.getaddrinfo("micropython.org", 80)[0][-1], False, True)
115+
print("--- SSL sockets ---")
116+
test(socket.getaddrinfo("micropython.org", 443)[0][-1], True, True)

0 commit comments

Comments
 (0)
0