28
28
#include <stdint.h>
29
29
#include <string.h>
30
30
31
+ #include "py/objmodule.h"
31
32
#include "py/runtime.h"
32
33
#include "py/stream.h"
33
34
#include "extmod/modwebsocket.h"
@@ -38,14 +39,18 @@ enum { FRAME_HEADER, FRAME_OPT, PAYLOAD, CONTROL };
38
39
39
40
enum { BLOCKING_WRITE = 0x80 };
40
41
42
+ enum { NO_WRITE_MASKING , NORMAL_WRITE_MASKING , DEBUG_WRITE_MASKING };
43
+
41
44
typedef struct _mp_obj_websocket_t {
42
45
mp_obj_base_t base ;
43
46
mp_obj_t sock ;
44
47
uint32_t msg_sz ;
45
- byte mask [4 ];
48
+ byte read_mask [4 ];
49
+ byte do_write_masking ;
50
+ byte debug_write_mask [4 ];
46
51
byte state ;
47
52
byte to_recv ;
48
- byte mask_pos ;
53
+ byte read_mask_pos ;
49
54
byte buf_pos ;
50
55
byte buf [6 ];
51
56
byte opts ;
@@ -58,16 +63,39 @@ typedef struct _mp_obj_websocket_t {
58
63
STATIC mp_uint_t websocket_write (mp_obj_t self_in , const void * buf , mp_uint_t size , int * errcode );
59
64
60
65
STATIC mp_obj_t websocket_make_new (const mp_obj_type_t * type , size_t n_args , size_t n_kw , const mp_obj_t * args ) {
61
- mp_arg_check_num (n_args , n_kw , 1 , 2 , false);
66
+ static const mp_arg_t allowed_args [] = {
67
+ { MP_QSTR_sock , MP_ARG_REQUIRED | MP_ARG_OBJ , {.u_obj = MP_OBJ_NULL } },
68
+ { MP_QSTR_use_blocking_writes , MP_ARG_BOOL , {.u_bool = false} },
69
+ { MP_QSTR_is_client , MP_ARG_BOOL , {.u_bool = false} },
70
+ { MP_QSTR_debug_mask , MP_ARG_KW_ONLY | MP_ARG_OBJ , {.u_obj = MP_OBJ_NULL } },
71
+ };
72
+
73
+ // parse args
74
+ struct {
75
+ mp_arg_val_t sock , use_blocking_writes , is_client , debug_mask ;
76
+ } arg_vals ;
77
+ mp_arg_parse_all_kw_array (n_args , n_kw , args ,
78
+ MP_ARRAY_SIZE (allowed_args ), allowed_args , (mp_arg_val_t * )& arg_vals );
79
+
62
80
mp_obj_websocket_t * o = m_new_obj (mp_obj_websocket_t );
63
81
o -> base .type = type ;
64
- o -> sock = args [0 ];
82
+ o -> sock = arg_vals .sock .u_obj ;
83
+ o -> do_write_masking = !arg_vals .is_client .u_bool ? NO_WRITE_MASKING : NORMAL_WRITE_MASKING ;
84
+ if (arg_vals .debug_mask .u_obj != MP_OBJ_NULL ) {
85
+ mp_buffer_info_t bufinfo ;
86
+ mp_get_buffer_raise (arg_vals .debug_mask .u_obj , & bufinfo , MP_BUFFER_READ );
87
+ if (bufinfo .len != 4 ) {
88
+ mp_raise_ValueError ("debug mask must have length of 4" );
89
+ }
90
+ o -> do_write_masking = DEBUG_WRITE_MASKING ;
91
+ memcpy (o -> debug_write_mask , bufinfo .buf , 4 );
92
+ }
65
93
o -> state = FRAME_HEADER ;
66
94
o -> to_recv = 2 ;
67
- o -> mask_pos = 0 ;
95
+ o -> read_mask_pos = 0 ;
68
96
o -> buf_pos = 0 ;
69
97
o -> opts = FRAME_TXT ;
70
- if (n_args > 1 && args [ 1 ] == mp_const_true ) {
98
+ if (arg_vals . use_blocking_writes . u_bool ) {
71
99
o -> opts |= BLOCKING_WRITE ;
72
100
}
73
101
return MP_OBJ_FROM_PTR (o );
@@ -111,7 +139,7 @@ STATIC mp_uint_t websocket_read(mp_obj_t self_in, void *buf, mp_uint_t size, int
111
139
112
140
// Reset mask in case someone will use "simplified" protocol
113
141
// without masks.
114
- memset (self -> mask , 0 , sizeof (self -> mask ));
142
+ memset (self -> read_mask , 0 , sizeof (self -> read_mask ));
115
143
116
144
int to_recv = 0 ;
117
145
size_t sz = self -> buf [1 ] & 0x7f ;
@@ -149,7 +177,7 @@ STATIC mp_uint_t websocket_read(mp_obj_t self_in, void *buf, mp_uint_t size, int
149
177
}
150
178
if (self -> buf_pos >= 4 ) {
151
179
// Last 4 bytes is mask
152
- memcpy (self -> mask , self -> buf + self -> buf_pos - 4 , 4 );
180
+ memcpy (self -> read_mask , self -> buf + self -> buf_pos - 4 , 4 );
153
181
}
154
182
self -> buf_pos = 0 ;
155
183
if ((self -> last_flags & FRAME_OPCODE_MASK ) >= FRAME_CLOSE ) {
@@ -176,7 +204,7 @@ STATIC mp_uint_t websocket_read(mp_obj_t self_in, void *buf, mp_uint_t size, int
176
204
177
205
sz = out_sz ;
178
206
for (byte * p = buf ; sz -- ; p ++ ) {
179
- * p ^= self -> mask [self -> mask_pos ++ & 3 ];
207
+ * p ^= self -> read_mask [self -> read_mask_pos ++ & 3 ];
180
208
}
181
209
182
210
self -> msg_sz -= out_sz ;
@@ -186,7 +214,7 @@ STATIC mp_uint_t websocket_read(mp_obj_t self_in, void *buf, mp_uint_t size, int
186
214
last_state = self -> state ;
187
215
self -> state = FRAME_HEADER ;
188
216
self -> to_recv = 2 ;
189
- self -> mask_pos = 0 ;
217
+ self -> read_mask_pos = 0 ;
190
218
self -> buf_pos = 0 ;
191
219
192
220
// Handle control frame
@@ -218,7 +246,7 @@ STATIC mp_uint_t websocket_read(mp_obj_t self_in, void *buf, mp_uint_t size, int
218
246
STATIC mp_uint_t websocket_write (mp_obj_t self_in , const void * buf , mp_uint_t size , int * errcode ) {
219
247
mp_obj_websocket_t * self = MP_OBJ_TO_PTR (self_in );
220
248
assert (size < 0x10000 );
221
- byte header [4 ] = {0x80 | (self -> opts & FRAME_OPCODE_MASK )};
249
+ byte header [8 ] = {0x80 | (self -> opts & FRAME_OPCODE_MASK )};
222
250
int hdr_sz ;
223
251
if (size < 126 ) {
224
252
header [1 ] = size ;
@@ -229,6 +257,34 @@ STATIC mp_uint_t websocket_write(mp_obj_t self_in, const void *buf, mp_uint_t si
229
257
header [3 ] = size & 0xff ;
230
258
hdr_sz = 4 ;
231
259
}
260
+ if (self -> do_write_masking != NO_WRITE_MASKING ) {
261
+ hdr_sz += 4 ;
262
+ header [1 ] |= 0x80 ;
263
+ if (self -> do_write_masking == NORMAL_WRITE_MASKING ) {
264
+
265
+ // RFC6455 Section 5.3 states that the masking key must be derived
266
+ // from a strong source of entropy. The "urandom" module doesn't
267
+ // qualify in this regard, but there isn't any cross-platform
268
+ // alternative. Fortunately, the purpose of masking is not
269
+ // cryptographically motivated. The "urandom" module should be
270
+ // seeded though, otherwise upon restart, the same sequence of
271
+ // masks will always be used. A seed could be derived from a
272
+ // network resource, a network interface's characteristics or
273
+ // statistics, or a platform specific resource. Examples of using
274
+ // a platform specific resource include reading an ESP8266's
275
+ // 32-bit Random Number Generator register, or reading consecutive
276
+ // values from a floating analog pin.
277
+ mp_obj_t dest [3 ];
278
+ mp_load_method (mp_module_get (MP_QSTR_urandom ), MP_QSTR_getrandbits , dest );
279
+ dest [2 ] = mp_obj_new_int (32 );
280
+ unsigned int randbits = MP_OBJ_SMALL_INT_VALUE (mp_call_method_n_kw (1 , 0 , dest ));
281
+ for (int i = 0 ; i < 4 ; ++ i ) {
282
+ header [hdr_sz - 4 + i ] = (randbits >> ((i ^ 3 ) << 3 )) & 0xff ;
283
+ }
284
+ } else if (self -> do_write_masking == DEBUG_WRITE_MASKING ) {
285
+ memcpy (& header [hdr_sz - 4 ], self -> debug_write_mask , 4 );
286
+ }
287
+ }
232
288
233
289
mp_obj_t dest [3 ];
234
290
if (self -> opts & BLOCKING_WRITE ) {
@@ -239,7 +295,15 @@ STATIC mp_uint_t websocket_write(mp_obj_t self_in, const void *buf, mp_uint_t si
239
295
240
296
mp_uint_t out_sz = mp_stream_write_exactly (self -> sock , header , hdr_sz , errcode );
241
297
if (* errcode == 0 ) {
242
- out_sz = mp_stream_write_exactly (self -> sock , buf , size , errcode );
298
+ if (self -> do_write_masking == NO_WRITE_MASKING ) {
299
+ out_sz = mp_stream_write_exactly (self -> sock , buf , size , errcode );
300
+ } else {
301
+ byte masked_buf [size ];
302
+ for (mp_uint_t i = 0 ; i < size ; ++ i ) {
303
+ masked_buf [i ] = ((byte * )buf )[i ] ^ header [hdr_sz - 4 + (i & 3 )];
304
+ }
305
+ out_sz = mp_stream_write_exactly (self -> sock , masked_buf , size , errcode );
306
+ }
243
307
}
244
308
245
309
if (self -> opts & BLOCKING_WRITE ) {
0 commit comments