27
27
*/
28
28
29
29
#include "shared-bindings/ssl/SSLSocket.h"
30
- #include "shared-bindings/socketpool/Socket.h"
31
30
#include "shared-bindings/ssl/SSLContext.h"
32
- #include "shared-bindings/socketpool/SocketPool.h"
33
- #include "shared-bindings/socketpool/Socket.h"
34
31
35
32
#include "shared/runtime/interrupt_char.h"
33
+ #include "shared/netutils/netutils.h"
36
34
#include "py/mperrno.h"
37
35
#include "py/mphal.h"
38
36
#include "py/objstr.h"
@@ -108,11 +106,72 @@ STATIC NORETURN void mbedtls_raise_error(int err) {
108
106
#endif
109
107
}
110
108
109
+ STATIC int call_method_errno (size_t n_args , const mp_obj_t * args ) {
110
+ nlr_buf_t nlr ;
111
+ mp_int_t result = - MP_EINVAL ;
112
+ if (nlr_push (& nlr ) == 0 ) {
113
+ mp_obj_t obj_result = mp_call_method_n_kw (n_args , 0 , args );
114
+ result = (obj_result == mp_const_none ) ? 0 : mp_obj_get_int (obj_result );
115
+ nlr_pop ();
116
+ return result ;
117
+ } else {
118
+ mp_obj_t exc = MP_OBJ_FROM_PTR (nlr .ret_val );
119
+ if (nlr_push (& nlr ) == 0 ) {
120
+ result = - mp_obj_get_int (mp_load_attr (exc , MP_QSTR_errno ));
121
+ nlr_pop ();
122
+ }
123
+ }
124
+ return result ;
125
+ }
126
+
127
+ static int ssl_socket_send (ssl_sslsocket_obj_t * self , const byte * buf , size_t len ) {
128
+ mp_obj_array_t mv ;
129
+ mp_obj_memoryview_init (& mv , 'B' , 0 , len , (void * )buf );
130
+
131
+ self -> send_args [2 ] = MP_OBJ_FROM_PTR (& mv );
132
+ return call_method_errno (1 , self -> send_args );
133
+ }
134
+
135
+ static int ssl_socket_recv_into (ssl_sslsocket_obj_t * self , byte * buf , size_t len ) {
136
+ mp_obj_array_t mv ;
137
+ mp_obj_memoryview_init (& mv , 'B' | MP_OBJ_ARRAY_TYPECODE_FLAG_RW , 0 , len , buf );
138
+
139
+ self -> recv_into_args [2 ] = MP_OBJ_FROM_PTR (& mv );
140
+ return call_method_errno (1 , self -> recv_into_args );
141
+ }
142
+
143
+ static int ssl_socket_connect (ssl_sslsocket_obj_t * self , mp_obj_t addr_in ) {
144
+ self -> connect_args [2 ] = addr_in ;
145
+ return call_method_errno (1 , self -> connect_args );
146
+ }
147
+
148
+ static int ssl_socket_bind (ssl_sslsocket_obj_t * self , mp_obj_t addr_in ) {
149
+ self -> bind_args [2 ] = addr_in ;
150
+ return call_method_errno (1 , self -> bind_args );
151
+ }
152
+
153
+ static int ssl_socket_close (ssl_sslsocket_obj_t * self ) {
154
+ return call_method_errno (0 , self -> close_args );
155
+ }
156
+
157
+ static int ssl_socket_settimeout (ssl_sslsocket_obj_t * self , mp_int_t timeout_ms ) {
158
+ self -> settimeout_args [2 ] = mp_obj_new_float (timeout_ms * MICROPY_FLOAT_CONST (1e-3 ));
159
+ return call_method_errno (1 , self -> settimeout_args );
160
+ }
161
+
162
+ static int ssl_socket_listen (ssl_sslsocket_obj_t * self , mp_int_t backlog ) {
163
+ self -> listen_args [2 ] = MP_OBJ_NEW_SMALL_INT (backlog );
164
+ return call_method_errno (1 , self -> listen_args );
165
+ }
166
+
167
+ static mp_obj_t ssl_socket_accept (ssl_sslsocket_obj_t * self ) {
168
+ return mp_call_method_n_kw (0 , 0 , self -> accept_args );
169
+ }
170
+
111
171
STATIC int _mbedtls_ssl_send (void * ctx , const byte * buf , size_t len ) {
112
- mp_obj_t sock = * ( mp_obj_t * )ctx ;
172
+ ssl_sslsocket_obj_t * self = ( ssl_sslsocket_obj_t * )ctx ;
113
173
114
- // mp_uint_t out_sz = sock_stream->write(sock, buf, len, &err);
115
- mp_int_t out_sz = socketpool_socket_send (sock , buf , len );
174
+ mp_int_t out_sz = ssl_socket_send (self , buf , len );
116
175
DEBUG_PRINT ("socket_send() -> %d" , out_sz );
117
176
if (out_sz < 0 ) {
118
177
int err = - out_sz ;
@@ -128,9 +187,9 @@ STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
128
187
129
188
// _mbedtls_ssl_recv is called by mbedtls to receive bytes from the underlying socket
130
189
STATIC int _mbedtls_ssl_recv (void * ctx , byte * buf , size_t len ) {
131
- mp_obj_t sock = * ( mp_obj_t * )ctx ;
190
+ ssl_sslsocket_obj_t * self = ( ssl_sslsocket_obj_t * )ctx ;
132
191
133
- mp_int_t out_sz = socketpool_socket_recv_into ( sock , buf , len );
192
+ mp_int_t out_sz = ssl_socket_recv_into ( self , buf , len );
134
193
DEBUG_PRINT ("socket_recv() -> %d" , out_sz );
135
194
if (out_sz < 0 ) {
136
195
int err = - out_sz ;
@@ -155,16 +214,26 @@ static int urandom_adapter(void *unused, unsigned char *buf, size_t n) {
155
214
#endif
156
215
157
216
ssl_sslsocket_obj_t * common_hal_ssl_sslcontext_wrap_socket (ssl_sslcontext_obj_t * self ,
158
- socketpool_socket_obj_t * socket , bool server_side , const char * server_hostname ) {
217
+ mp_obj_t socket , bool server_side , const char * server_hostname ) {
159
218
160
- if (socket -> type != SOCKETPOOL_SOCK_STREAM ) {
219
+ mp_int_t socket_type = mp_obj_get_int (mp_load_attr (socket , MP_QSTR_type ));
220
+ if (socket_type != SOCKETPOOL_SOCK_STREAM ) {
161
221
mp_raise_RuntimeError (MP_ERROR_TEXT ("Invalid socket for TLS" ));
162
222
}
163
223
164
224
ssl_sslsocket_obj_t * o = m_new_obj_with_finaliser (ssl_sslsocket_obj_t );
165
225
o -> base .type = & ssl_sslsocket_type ;
166
226
o -> ssl_context = self ;
167
- o -> sock = socket ;
227
+ o -> sock_obj = socket ;
228
+
229
+ mp_load_method (socket , MP_QSTR_accept , o -> accept_args );
230
+ mp_load_method (socket , MP_QSTR_bind , o -> bind_args );
231
+ mp_load_method (socket , MP_QSTR_close , o -> close_args );
232
+ mp_load_method (socket , MP_QSTR_connect , o -> connect_args );
233
+ mp_load_method (socket , MP_QSTR_listen , o -> listen_args );
234
+ mp_load_method (socket , MP_QSTR_recv_into , o -> recv_into_args );
235
+ mp_load_method (socket , MP_QSTR_send , o -> send_args );
236
+ mp_load_method (socket , MP_QSTR_settimeout , o -> settimeout_args );
168
237
169
238
mbedtls_ssl_init (& o -> ssl );
170
239
mbedtls_ssl_config_init (& o -> conf );
@@ -223,7 +292,7 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
223
292
}
224
293
}
225
294
226
- mbedtls_ssl_set_bio (& o -> ssl , & o -> sock , _mbedtls_ssl_send , _mbedtls_ssl_recv , NULL );
295
+ mbedtls_ssl_set_bio (& o -> ssl , o , _mbedtls_ssl_send , _mbedtls_ssl_recv , NULL );
227
296
228
297
if (self -> cert_buf .buf != NULL ) {
229
298
#if MBEDTLS_VERSION_MAJOR >= 3
@@ -292,13 +361,13 @@ mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t
292
361
mbedtls_raise_error (ret );
293
362
}
294
363
295
- size_t common_hal_ssl_sslsocket_bind (ssl_sslsocket_obj_t * self , const char * host , size_t hostlen , uint32_t port ) {
296
- return common_hal_socketpool_socket_bind (self -> sock , host , hostlen , port );
364
+ size_t common_hal_ssl_sslsocket_bind (ssl_sslsocket_obj_t * self , mp_obj_t addr_in ) {
365
+ return ssl_socket_bind (self , addr_in );
297
366
}
298
367
299
368
void common_hal_ssl_sslsocket_close (ssl_sslsocket_obj_t * self ) {
300
369
self -> closed = true;
301
- common_hal_socketpool_socket_close (self -> sock );
370
+ ssl_socket_close (self );
302
371
mbedtls_pk_free (& self -> pkey );
303
372
mbedtls_x509_crt_free (& self -> cert );
304
373
mbedtls_x509_crt_free (& self -> cacert );
@@ -344,8 +413,8 @@ STATIC void do_handshake(ssl_sslsocket_obj_t *self) {
344
413
}
345
414
}
346
415
347
- void common_hal_ssl_sslsocket_connect (ssl_sslsocket_obj_t * self , const char * host , size_t hostlen , uint32_t port ) {
348
- common_hal_socketpool_socket_connect (self -> sock , host , hostlen , port );
416
+ void common_hal_ssl_sslsocket_connect (ssl_sslsocket_obj_t * self , mp_obj_t addr_in ) {
417
+ ssl_socket_connect (self , addr_in );
349
418
do_handshake (self );
350
419
}
351
420
@@ -358,16 +427,21 @@ bool common_hal_ssl_sslsocket_get_connected(ssl_sslsocket_obj_t *self) {
358
427
}
359
428
360
429
bool common_hal_ssl_sslsocket_listen (ssl_sslsocket_obj_t * self , int backlog ) {
361
- return common_hal_socketpool_socket_listen (self -> sock , backlog );
430
+ return ssl_socket_listen (self , backlog );
362
431
}
363
432
364
- ssl_sslsocket_obj_t * common_hal_ssl_sslsocket_accept (ssl_sslsocket_obj_t * self , uint8_t * ip , uint32_t * port ) {
365
- socketpool_socket_obj_t * sock = common_hal_socketpool_socket_accept (self -> sock , ip , port );
433
+ mp_obj_t common_hal_ssl_sslsocket_accept (ssl_sslsocket_obj_t * self ) {
434
+ mp_obj_t accepted = ssl_socket_accept (self );
435
+ mp_obj_t sock = mp_obj_subscr (accepted , MP_OBJ_NEW_SMALL_INT (0 ), MP_OBJ_SENTINEL );
366
436
ssl_sslsocket_obj_t * sslsock = common_hal_ssl_sslcontext_wrap_socket (self -> ssl_context , sock , true, NULL );
367
437
do_handshake (sslsock );
368
- return sslsock ;
438
+ mp_obj_t peer = mp_obj_subscr (accepted , MP_OBJ_NEW_SMALL_INT (0 ), MP_OBJ_SENTINEL );
439
+ mp_obj_t tuple_contents [2 ];
440
+ tuple_contents [0 ] = MP_OBJ_FROM_PTR (sslsock );
441
+ tuple_contents [1 ] = peer ;
442
+ return mp_obj_new_tuple (2 , tuple_contents );
369
443
}
370
444
371
445
void common_hal_ssl_sslsocket_settimeout (ssl_sslsocket_obj_t * self , uint32_t timeout_ms ) {
372
- common_hal_socketpool_socket_settimeout (self -> sock , timeout_ms );
446
+ ssl_socket_settimeout (self , timeout_ms );
373
447
}
0 commit comments