diff --git a/extmod/modwebsocket.c b/extmod/modwebsocket.c index a651164b2fe3d..7d4f6cfa8f23c 100644 --- a/extmod/modwebsocket.c +++ b/extmod/modwebsocket.c @@ -28,6 +28,7 @@ #include #include +#include "py/objmodule.h" #include "py/runtime.h" #include "py/stream.h" #include "extmod/modwebsocket.h" @@ -38,14 +39,18 @@ enum { FRAME_HEADER, FRAME_OPT, PAYLOAD, CONTROL }; enum { BLOCKING_WRITE = 0x80 }; +enum { NO_WRITE_MASKING, NORMAL_WRITE_MASKING, DEBUG_WRITE_MASKING }; + typedef struct _mp_obj_websocket_t { mp_obj_base_t base; mp_obj_t sock; uint32_t msg_sz; - byte mask[4]; + byte read_mask[4]; + byte do_write_masking; + byte debug_write_mask[4]; byte state; byte to_recv; - byte mask_pos; + byte read_mask_pos; byte buf_pos; byte buf[6]; byte opts; @@ -58,16 +63,39 @@ typedef struct _mp_obj_websocket_t { STATIC mp_uint_t websocket_write(mp_obj_t self_in, const void *buf, mp_uint_t size, int *errcode); STATIC mp_obj_t websocket_make_new(const mp_obj_type_t *type, size_t n_args, size_t n_kw, const mp_obj_t *args) { - mp_arg_check_num(n_args, n_kw, 1, 2, false); + static const mp_arg_t allowed_args[] = { + { MP_QSTR_sock, MP_ARG_REQUIRED | MP_ARG_OBJ, {.u_obj = MP_OBJ_NULL} }, + { MP_QSTR_use_blocking_writes, MP_ARG_BOOL, {.u_bool = false} }, + { MP_QSTR_is_client, MP_ARG_BOOL, {.u_bool = false} }, + { MP_QSTR_debug_mask, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_obj = MP_OBJ_NULL} }, + }; + + // parse args + struct { + mp_arg_val_t sock, use_blocking_writes, is_client, debug_mask; + } arg_vals; + mp_arg_parse_all_kw_array(n_args, n_kw, args, + MP_ARRAY_SIZE(allowed_args), allowed_args, (mp_arg_val_t*)&arg_vals); + mp_obj_websocket_t *o = m_new_obj(mp_obj_websocket_t); o->base.type = type; - o->sock = args[0]; + o->sock = arg_vals.sock.u_obj; + o->do_write_masking = !arg_vals.is_client.u_bool ? NO_WRITE_MASKING : NORMAL_WRITE_MASKING; + if (arg_vals.debug_mask.u_obj != MP_OBJ_NULL) { + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(arg_vals.debug_mask.u_obj, &bufinfo, MP_BUFFER_READ); + if (bufinfo.len != 4) { + mp_raise_ValueError("debug mask must have length of 4"); + } + o->do_write_masking = DEBUG_WRITE_MASKING; + memcpy(o->debug_write_mask, bufinfo.buf, 4); + } o->state = FRAME_HEADER; o->to_recv = 2; - o->mask_pos = 0; + o->read_mask_pos = 0; o->buf_pos = 0; o->opts = FRAME_TXT; - if (n_args > 1 && args[1] == mp_const_true) { + if (arg_vals.use_blocking_writes.u_bool) { o->opts |= BLOCKING_WRITE; } return MP_OBJ_FROM_PTR(o); @@ -111,7 +139,7 @@ STATIC mp_uint_t websocket_read(mp_obj_t self_in, void *buf, mp_uint_t size, int // Reset mask in case someone will use "simplified" protocol // without masks. - memset(self->mask, 0, sizeof(self->mask)); + memset(self->read_mask, 0, sizeof(self->read_mask)); int to_recv = 0; size_t sz = self->buf[1] & 0x7f; @@ -149,7 +177,7 @@ STATIC mp_uint_t websocket_read(mp_obj_t self_in, void *buf, mp_uint_t size, int } if (self->buf_pos >= 4) { // Last 4 bytes is mask - memcpy(self->mask, self->buf + self->buf_pos - 4, 4); + memcpy(self->read_mask, self->buf + self->buf_pos - 4, 4); } self->buf_pos = 0; if ((self->last_flags & FRAME_OPCODE_MASK) >= FRAME_CLOSE) { @@ -176,7 +204,7 @@ STATIC mp_uint_t websocket_read(mp_obj_t self_in, void *buf, mp_uint_t size, int sz = out_sz; for (byte *p = buf; sz--; p++) { - *p ^= self->mask[self->mask_pos++ & 3]; + *p ^= self->read_mask[self->read_mask_pos++ & 3]; } self->msg_sz -= out_sz; @@ -186,7 +214,7 @@ STATIC mp_uint_t websocket_read(mp_obj_t self_in, void *buf, mp_uint_t size, int last_state = self->state; self->state = FRAME_HEADER; self->to_recv = 2; - self->mask_pos = 0; + self->read_mask_pos = 0; self->buf_pos = 0; // Handle control frame @@ -218,7 +246,7 @@ STATIC mp_uint_t websocket_read(mp_obj_t self_in, void *buf, mp_uint_t size, int STATIC mp_uint_t websocket_write(mp_obj_t self_in, const void *buf, mp_uint_t size, int *errcode) { mp_obj_websocket_t *self = MP_OBJ_TO_PTR(self_in); assert(size < 0x10000); - byte header[4] = {0x80 | (self->opts & FRAME_OPCODE_MASK)}; + byte header[8] = {0x80 | (self->opts & FRAME_OPCODE_MASK)}; int hdr_sz; if (size < 126) { header[1] = size; @@ -229,6 +257,34 @@ STATIC mp_uint_t websocket_write(mp_obj_t self_in, const void *buf, mp_uint_t si header[3] = size & 0xff; hdr_sz = 4; } + if (self->do_write_masking != NO_WRITE_MASKING) { + hdr_sz += 4; + header[1] |= 0x80; + if (self->do_write_masking == NORMAL_WRITE_MASKING) { + + // RFC6455 Section 5.3 states that the masking key must be derived + // from a strong source of entropy. The "urandom" module doesn't + // qualify in this regard, but there isn't any cross-platform + // alternative. Fortunately, the purpose of masking is not + // cryptographically motivated. The "urandom" module should be + // seeded though, otherwise upon restart, the same sequence of + // masks will always be used. A seed could be derived from a + // network resource, a network interface's characteristics or + // statistics, or a platform specific resource. Examples of using + // a platform specific resource include reading an ESP8266's + // 32-bit Random Number Generator register, or reading consecutive + // values from a floating analog pin. + mp_obj_t dest[3]; + mp_load_method(mp_module_get(MP_QSTR_urandom), MP_QSTR_getrandbits, dest); + dest[2] = mp_obj_new_int(32); + unsigned int randbits = MP_OBJ_SMALL_INT_VALUE(mp_call_method_n_kw(1, 0, dest)); + for (int i = 0; i < 4; ++i) { + header[hdr_sz - 4 + i] = (randbits >> ((i ^ 3) << 3)) & 0xff; + } + } else if (self->do_write_masking == DEBUG_WRITE_MASKING) { + memcpy(&header[hdr_sz - 4], self->debug_write_mask, 4); + } + } mp_obj_t dest[3]; if (self->opts & BLOCKING_WRITE) { @@ -239,7 +295,15 @@ STATIC mp_uint_t websocket_write(mp_obj_t self_in, const void *buf, mp_uint_t si mp_uint_t out_sz = mp_stream_write_exactly(self->sock, header, hdr_sz, errcode); if (*errcode == 0) { - out_sz = mp_stream_write_exactly(self->sock, buf, size, errcode); + if (self->do_write_masking == NO_WRITE_MASKING) { + out_sz = mp_stream_write_exactly(self->sock, buf, size, errcode); + } else { + byte masked_buf[size]; + for (mp_uint_t i = 0; i < size; ++i) { + masked_buf[i] = ((byte*)buf)[i] ^ header[hdr_sz - 4 + (i & 3)]; + } + out_sz = mp_stream_write_exactly(self->sock, masked_buf, size, errcode); + } } if (self->opts & BLOCKING_WRITE) { diff --git a/tests/extmod/websocket_basic.py b/tests/extmod/websocket_basic.py index 9a80503a0373f..26dceb9351c26 100644 --- a/tests/extmod/websocket_basic.py +++ b/tests/extmod/websocket_basic.py @@ -1,4 +1,5 @@ try: + import urandom import uio import uerrno import websocket @@ -6,15 +7,20 @@ print("SKIP") raise SystemExit +# When writing a websocket frame as a client, the masking key used is obtained +# from the "urandom" module. Seeding the random number generator with a +# constant guarantees that the masking keys used are deterministic. +urandom.seed(0x875513b0) + # put raw data in the stream and do a websocket read def ws_read(msg, sz): ws = websocket.websocket(uio.BytesIO(msg)) return ws.read(sz) # do a websocket write and then return the raw data from the stream -def ws_write(msg, sz): +def ws_write(msg, sz, **kwargs): s = uio.BytesIO() - ws = websocket.websocket(s) + ws = websocket.websocket(s, **kwargs) ws.write(msg) s.seek(0) return s.read(sz) @@ -31,9 +37,13 @@ def ws_write(msg, sz): print(ws_read(b'\x81~\x00\x80' + b'ping' * 32, 128)) print(ws_write(b"pong" * 32, 132)) -# mask (returned data will be 'mask' ^ 'mask') +# read mask (returned data will be 'mask' ^ 'mask') print(ws_read(b"\x81\x84maskmask", 4)) +# write mask +print(ws_write(b"pong", 10, debug_mask=b"\x01\x00\x01\x00")) +print(ws_write(b"pong", 10, is_client=True)) + # close control frame s = uio.BytesIO(b'\x88\x00') # FRAME_CLOSE ws = websocket.websocket(s) diff --git a/tests/extmod/websocket_basic.py.exp b/tests/extmod/websocket_basic.py.exp index 2d7657b535407..828d21b350c6e 100644 --- a/tests/extmod/websocket_basic.py.exp +++ b/tests/extmod/websocket_basic.py.exp @@ -4,6 +4,8 @@ b'\x81\x04pong' b'pingpingpingpingpingpingpingpingpingpingpingpingpingpingpingpingpingpingpingpingpingpingpingpingpingpingpingpingpingpingpingping' b'\x81~\x00\x80pongpongpongpongpongpongpongpongpongpongpongpongpongpongpongpongpongpongpongpongpongpongpongpongpongpongpongpongpongpongpongpong' b'\x00\x00\x00\x00' +b'\x81\x84\x01\x00\x01\x00qoog' +b'\x81\x84\x90\x03\xbf\xee\xe0l\xd1\x89' b'' b'\x81\x02\x88\x00' b'ping'