8000 ssl: work on anything implementing the socket protocol · adafruit/circuitpython@c793a02 · GitHub
[go: up one dir, main page]

Skip to content

Commit c793a02

Browse files
committed
ssl: work on anything implementing the socket protocol
In principle this allows core SSL code to be used with e.g., wiznet or airlift sockets. It might actually be useful with wiznet ethernet devices (it's probably not with airlift)
1 parent 5973c4a commit c793a02

File tree

7 files changed

+114
-61
lines changed

7 files changed

+114
-61
lines changed

shared-bindings/ssl/SSLContext.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,9 @@ STATIC mp_obj_t ssl_sslcontext_wrap_socket(size_t n_args, const mp_obj_t *pos_ar
200200
mp_raise_ValueError(MP_ERROR_TEXT("Server side context cannot have hostname"));
201201
}
202202

203-
socketpool_socket_obj_t *sock = args[ARG_sock].u_obj;
203+
mp_obj_t sock_obj = args[ARG_sock].u_obj;
204204

205-
return common_hal_ssl_sslcontext_wrap_socket(self, sock, server_side, server_hostname);
205+
return common_hal_ssl_sslcontext_wrap_socket(self, sock_obj, server_side, server_hostname);
206206
}
207207
STATIC MP_DEFINE_CONST_FUN_OBJ_KW(ssl_sslcontext_wrap_socket_obj, 1, ssl_sslcontext_wrap_socket);
208208

shared-bindings/ssl/SSLContext.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,14 @@
3333
#include "common-hal/ssl/SSLContext.h"
3434
#endif
3535

36-
#include "shared-bindings/socketpool/Socket.h"
3736
#include "shared-bindings/ssl/SSLSocket.h"
3837

3938
extern const mp_obj_type_t ssl_sslcontext_type;
4039

4140
void common_hal_ssl_sslcontext_construct(ssl_sslcontext_obj_t *self);
4241

4342
ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t *self,
44-
socketpool_socket_obj_t *sock, bool server_side, const char *server_hostname);
43+
mp_obj_t socket, bool server_side, const char *server_hostname);
4544

4645
void common_hal_ssl_sslcontext_load_verify_locations(ssl_sslcontext_obj_t *self,
4746
const char *cadata);

shared-bindings/ssl/SSLSocket.c

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(ssl_sslsocket___exit___obj, 4, 4, ssl
7373
//| Returns a tuple of (new_socket, remote_address)"""
7474
STATIC mp_obj_t ssl_sslsocket_accept(mp_obj_t self_in) {
7575
ssl_sslsocket_obj_t *self = MP_OBJ_TO_PTR(self_in);
76-
uint8_t ip[4];
77-
uint32_t port;
78-
79-
ssl_sslsocket_obj_t *sslsock = common_hal_ssl_sslsocket_accept(self, ip, &port);
80-
81-
mp_obj_t tuple_contents[2];
82-
tuple_contents[0] = MP_OBJ_FROM_PTR(sslsock);
83-
tuple_contents[1] = netutils_format_inet_addr(ip, port, NETUTILS_BIG);
84-
return mp_obj_new_tuple(2, tuple_contents);
76+
return common_hal_ssl_sslsocket_accept(self);
8577
}
8678
STATIC MP_DEFINE_CONST_FUN_OBJ_1(ssl_sslsocket_accept_obj, ssl_sslsocket_accept);
8779

@@ -96,14 +88,7 @@ STATIC mp_obj_t ssl_sslsocket_bind(mp_obj_t self_in, mp_obj_t addr_in) {
9688
mp_obj_t *addr_items;
9789
mp_obj_get_array_fixed_n(addr_in, 2, &addr_items);
9890

99-
size_t hostlen;
100-
const char *host = mp_obj_str_get_data(addr_items[0], &hostlen);
101-
mp_int_t port = mp_obj_get_int(addr_items[1]);
102-
if (port < 0) {
103-
mp_raise_ValueError(MP_ERROR_TEXT("port must be >= 0"));
104-
}
105-
106-
size_t error = common_hal_ssl_sslsocket_bind(self, host, hostlen, (uint32_t)port);
91+
size_t error = common_hal_ssl_sslsocket_bind(self, addr_in);
10792
if (error != 0) {
10893
mp_raise_OSError(error);
10994
}
@@ -128,18 +113,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_1(ssl_sslsocket_close_obj, ssl_sslsocket_close);
128113
//| ...
129114
STATIC mp_obj_t ssl_sslsocket_connect(mp_obj_t self_in, mp_obj_t addr_in) {
130115
ssl_sslsocket_obj_t *self = MP_OBJ_TO_PTR(self_in);
131-
132-
mp_obj_t *addr_items;
133-
mp_obj_get_array_fixed_n(addr_in, 2, &addr_items);
134-
135-
size_t hostlen;
136-
const char *host = mp_obj_str_get_data(addr_items[0], &hostlen);
137-
mp_int_t port = mp_obj_get_int(addr_items[1]);
138-
if (port < 0) {
139-
mp_raise_ValueError(MP_ERROR_TEXT("port must be >= 0"));
140-
}
141-
142-
common_hal_ssl_sslsocket_connect(self, host, hostlen, (uint32_t)port);
116+
common_hal_ssl_sslsocket_connect(self, addr_in);
143117

144118
return mp_const_none;
145119
}

shared-bindings/ssl/SSLSocket.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@
3434

3535
extern const mp_obj_type_t ssl_sslsocket_type;
3636

37-
ssl_sslsocket_obj_t *common_hal_ssl_sslsocket_accept(ssl_sslsocket_obj_t *self, uint8_t *ip, uint32_t *port);
38-
size_t common_hal_ssl_sslsocket_bind(ssl_sslsocket_obj_t *self, const char *host, size_t hostlen, uint32_t port);
37+
mp_obj_t common_hal_ssl_sslsocket_accept(ssl_sslsocket_obj_t *self);
38+
size_t common_hal_ssl_sslsocket_bind(ssl_sslsocket_obj_t *self, mp_obj_t addr);
3939
void common_hal_ssl_sslsocket_close(ssl_sslsocket_obj_t *self);
40-
void common_hal_ssl_sslsocket_connect(ssl_sslsocket_obj_t *self, const char *host, size_t hostlen, uint32_t port);
40+
void common_hal_ssl_sslsocket_connect(ssl_sslsocket_obj_t *self, mp_obj_t addr);
4141
bool common_hal_ssl_sslsocket_get_closed(ssl_sslsocket_obj_t *self);
4242
bool common_hal_ssl_sslsocket_get_connected(ssl_sslsocket_obj_t *self);
4343
bool common_hal_ssl_sslsocket_listen(ssl_sslsocket_obj_t *self, int backlog);

shared-module/ssl/SSLContext.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
#include "shared-bindings/ssl/SSLContext.h"
2828
#include "shared-bindings/ssl/SSLSocket.h"
29-
#include "shared-bindings/socketpool/SocketPool.h"
3029

3130
#include "py/runtime.h"
3231
#include "py/stream.h"

shared-module/ssl/SSLSocket.c

Lines changed: 96 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,10 @@
2727
*/
2828

2929
#include "shared-bindings/ssl/SSLSocket.h"
30-
#include "shared-bindings/socketpool/Socket.h"
3130
#include "shared-bindings/ssl/SSLContext.h"
32-
#include "shared-bindings/socketpool/SocketPool.h"
33-
#include "shared-bindings/socketpool/Socket.h"
3431

3532
#include "shared/runtime/interrupt_char.h"
33+
#include "shared/netutils/netutils.h"
3634
#include "py/mperrno.h"
3735
#include "py/mphal.h"
3836
#include "py/objstr.h"
@@ -108,11 +106,72 @@ STATIC NORETURN void mbedtls_raise_error(int err) {
108106
#endif
109107
}
110108

109+
STATIC int call_method_errno(size_t n_args, const mp_obj_t *args) {
110+
nlr_buf_t nlr;
111+
mp_int_t result = -MP_EINVAL;
112+
if (nlr_push(&nlr) == 0) {
113+
mp_obj_t obj_result = mp_call_method_n_kw(n_args, 0, args);
114+
result = (obj_result == mp_const_none) ? 0 : mp_obj_get_int(obj_result);
115+
nlr_pop();
116+
return result;
117+
} else {
118+
mp_obj_t exc = MP_OBJ_FROM_PTR(nlr.ret_val);
119+
if (nlr_push(&nlr) == 0) {
120+
result = -mp_obj_get_int(mp_load_attr(exc, MP_QSTR_errno));
121+
nlr_pop();
122+
}
123+
}
124+
return result;
125+
}
126+
127+
static int ssl_socket_send(ssl_sslsocket_obj_t *self, const byte *buf, size_t len) {
128+
mp_obj_array_t mv;
129+
mp_obj_memoryview_init(&mv, 'B', 0, len, (void *)buf);
130+
131+
self->send_args[2] = MP_OBJ_FROM_PTR(&mv);
132+
return call_method_errno(1, self->send_args);
133+
}
134+
135+
static int ssl_socket_recv_into(ssl_sslsocket_obj_t *self, byte *buf, size_t len) {
136+
mp_obj_array_t mv;
137+
mp_obj_memoryview_init(&mv, 'B' | MP_OBJ_ARRAY_TYPECODE_FLAG_RW, 0, len, buf);
138+
139+
self->recv_into_args[2] = MP_OBJ_FROM_PTR(&mv);
140+
return call_method_errno(1, self->recv_into_args);
141+
}
142+
143+
static int ssl_socket_connect(ssl_sslsocket_obj_t *self, mp_obj_t addr_in) {
144+
self->connect_args[2] = addr_in;
145+
return call_method_errno(1, self->connect_args);
146+
}
147+
148+
static int ssl_socket_bind(ssl_sslsocket_obj_t *self, mp_obj_t addr_in) {
149+
self->bind_args[2] = addr_in;
150+
return call_method_errno(1, self->bind_args);
151+
}
152+
153+
static int ssl_socket_close(ssl_sslsocket_obj_t *self) {
154+
return call_method_errno(0, self->close_args);
155+
}
156+
157+
static int ssl_socket_settimeout(ssl_sslsocket_obj_t *self, mp_int_t timeout_ms) {
158+
self->settimeout_args[2] = mp_obj_new_float(timeout_ms * MICROPY_FLOAT_CONST(1e-3));
159+
return call_method_errno(1, self->settimeout_args);
160+
}
161+
162+
static int ssl_socket_listen(ssl_sslsocket_obj_t *self, mp_int_t backlog) {
163+
self->listen_args[2] = MP_OBJ_NEW_SMALL_INT(backlog);
164+
return call_method_errno(1, self->listen_args);
165+
}
166+
167+
static mp_obj_t ssl_socket_accept(ssl_sslsocket_obj_t *self) {
168+
return mp_call_method_n_kw(0, 0, self->accept_args);
169+
}
170+
111171
STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
112-
mp_obj_t sock = *(mp_obj_t *)ctx;
172+
ssl_sslsocket_obj_t *self = (ssl_sslsocket_obj_t *)ctx;
113173

114-
// mp_uint_t out_sz = sock_stream->write(sock, buf, len, &err);
115-
mp_int_t out_sz = socketpool_socket_send(sock, buf, len);
174+
mp_int_t out_sz = ssl_socket_send(self, buf, len);
116175
DEBUG_PRINT("socket_send() -> %d", out_sz);
117176
if (out_sz < 0) {
118177
int err = -out_sz;
@@ -128,9 +187,9 @@ STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
128187

129188
// _mbedtls_ssl_recv is called by mbedtls to receive bytes from the underlying socket
130189
STATIC int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
131-
mp_obj_t sock = *(mp_obj_t *)ctx;
190+
ssl_sslsocket_obj_t *self = (ssl_sslsocket_obj_t *)ctx;
132191

133-
mp_int_t out_sz = socketpool_socket_recv_into(sock, buf, len);
192+
mp_int_t out_sz = ssl_socket_recv_into(self, buf, len);
134193
DEBUG_PRINT("socket_recv() -> %d", out_sz);
135194
if (out_sz < 0) {
136195
int err = -out_sz;
@@ -155,16 +214,26 @@ static int urandom_adapter(void *unused, unsigned char *buf, size_t n) {
155214
#endif
156215

157216
ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t *self,
158-
socketpool_socket_obj_t *socket, bool server_side, const char *server_hostname) {
217+
mp_obj_t socket, bool server_side, const char *server_hostname) {
159218

160-
if (socket->type != SOCKETPOOL_SOCK_STREAM) {
219+
mp_int_t socket_type = mp_obj_get_int(mp_load_attr(socket, MP_QSTR_type));
220+
if (socket_type != SOCKETPOOL_SOCK_STREAM) {
161221
mp_raise_RuntimeError(MP_ERROR_TEXT("Invalid socket for TLS"));
162222
}
163223

164224
ssl_sslsocket_obj_t *o = m_new_obj_with_finaliser(ssl_sslsocket_obj_t);
165225
o->base.type = &ssl_sslsocket_type;
166226
o->ssl_context = self;
167-
o->sock = socket;
227+
o->sock_obj = socket;
228+
229+
mp_load_method(socket, MP_QSTR_accept, o->accept_args);
230+
mp_load_method(socket, MP_QSTR_bind, o->bind_args);
231+
mp_load_method(socket, MP_QSTR_close, o->close_args);
232+
mp_load_method(socket, MP_QSTR_connect, o->connect_args);
233+
mp_load_method(socket, MP_QSTR_listen, o->listen_args);
234+
mp_load_method(socket, MP_QSTR_recv_into, o->recv_into_args);
235+
mp_load_method(socket, MP_QSTR_send, o->send_args);
236+
mp_load_method(socket, MP_QSTR_settimeout, o->settimeout_args);
168237

169238
mbedtls_ssl_init(&o->ssl);
170239
mbedtls_ssl_config_init(&o->conf);
@@ -223,7 +292,7 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
223292
}
224293
}
225294

226-
mbedtls_ssl_set_bio(&o->ssl, &o->sock, _mbedtls_ssl_send, _mbedtls_ssl_recv, NULL);
295+
mbedtls_ssl_set_bio(&o->ssl, o, _mbedtls_ssl_send, _mbedtls_ssl_recv, NULL);
227296

228297
if (self->cert_buf.buf != NULL) {
229298
#if MBEDTLS_VERSION_MAJOR >= 3
@@ -292,13 +361,13 @@ mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t
292361
mbedtls_raise_error(ret);
293362
}
294363

295-
size_t common_hal_ssl_sslsocket_bind(ssl_sslsocket_obj_t *self, const char *host, size_t hostlen, uint32_t port) {
296-
return common_hal_socketpool_socket_bind(self->sock, host, hostlen, port);
364+
size_t common_hal_ssl_sslsocket_bind(ssl_sslsocket_obj_t *self, mp_obj_t addr_in) {
365+
return ssl_socket_bind(self, addr_in);
297366
}
298367

299368
void common_hal_ssl_sslsocket_close(ssl_sslsocket_obj_t *self) {
300369
self->closed = true;
301-
common_hal_socketpool_socket_close(self->sock);
370+
ssl_socket_close(self);
302371
mbedtls_pk_free(&self->pkey);
303372
mbedtls_x509_crt_free(&self->cert);
304373
mbedtls_x509_crt_free(&self->cacert);
@@ -344,8 +413,8 @@ STATIC void do_handshake(ssl_sslsocket_obj_t *self) {
344413
}
345414
}
346415

347-
void common_hal_ssl_sslsocket_connect(ssl_sslsocket_obj_t *self, const char *host, size_t hostlen, uint32_t port) {
348-
common_hal_socketpool_socket_connect(self->sock, host, hostlen, port);
416+
void common_hal_ssl_sslsocket_connect(ssl_sslsocket_obj_t *self, mp_obj_t addr_in) {
417+
ssl_socket_connect(self, addr_in);
349418
do_handshake(self);
350419
}
351420

@@ -358,16 +427,21 @@ bool common_hal_ssl_sslsocket_get_connected(ssl_sslsocket_obj_t *self) {
358427
}
359428

360429
bool common_hal_ssl_sslsocket_listen(ssl_sslsocket_obj_t *self, int backlog) {
361-
return common_hal_socketpool_socket_listen(self->sock, backlog);
430+
return ssl_socket_listen(self, backlog);
362431
}
363432

364-
ssl_sslsocket_obj_t *common_hal_ssl_sslsocket_accept(ssl_sslsocket_obj_t *self, uint8_t *ip, uint32_t *port) {
365-
socketpool_socket_obj_t *sock = common_hal_socketpool_socket_accept(self->sock, ip, port);
433+
mp_obj_t common_hal_ssl_sslsocket_accept(ssl_sslsocket_obj_t *self) {
434+
mp_obj_t accepted = ssl_socket_accept(self);
435+
mp_obj_t sock = mp_obj_subscr(accepted, MP_OBJ_NEW_SMALL_INT(0), MP_OBJ_SENTINEL);
366436
ssl_sslsocket_obj_t *sslsock = common_hal_ssl_sslcontext_wrap_socket(self->ssl_context, sock, true, NULL);
367437
do_handshake(sslsock);
368-
return sslsock;
438+
mp_obj_t peer = mp_obj_subscr(accepted, MP_OBJ_NEW_SMALL_INT(0), MP_OBJ_SENTINEL);
439+
mp_obj_t tuple_contents[2];
440+
tuple_contents[0] = MP_OBJ_FROM_PTR(sslsock);
441+
tuple_contents[1] = peer;
442+
return mp_obj_new_tuple(2, tuple_contents);
369443
}
370444

371445
void common_hal_ssl_sslsocket_settimeout(ssl_sslsocket_obj_t *self, uint32_t timeout_ms) {
372-
common_hal_socketpool_socket_settimeout(self->sock, timeout_ms);
446+
ssl_socket_settimeout(self, timeout_ms);
373447
}

shared-module/ssl/SSLSocket.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
#include "py/obj.h"
3131

3232
#include "shared-module/ssl/SSLContext.h"
33-
#include "common-hal/socketpool/Socket.h"
3433

3534
#include "mbedtls/platform.h"
3635
#include "mbedtls/ssl.h"
@@ -41,7 +40,7 @@
4140

4241
typedef struct ssl_sslsocket_obj {
4342
mp_obj_base_t base;
44-
socketpool_socket_obj_t *sock;
43+
mp_obj_t sock_obj;
4544
ssl_sslcontext_obj_t *ssl_context;
4645
mbedtls_entropy_context entropy;
4746
mbedtls_ctr_drbg_context ctr_drbg;
@@ -51,4 +50,12 @@ typedef struct ssl_sslsocket_obj {
5150
mbedtls_x509_crt cert;
5251
mbedtls_pk_context pkey;
5352
bool closed;
53+
mp_obj_t accept_args[2];
54+
mp_obj_t bind_args[3];
55+
mp_obj_t close_args[2];
56+
mp_obj_t connect_args[3];
57+
mp_obj_t listen_args[3];
58+
mp_obj_t recv_into_args[3];
59+
mp_obj_t send_args[3];
60+
mp_obj_t settimeout_args[3];
5461
} ssl_sslsocket_obj_t;

0 commit comments

Comments
 (0)
0