diff --git a/extmod/modtls_mbedtls.c b/extmod/modtls_mbedtls.c index 0a1b8828af5ca..3fd416d72f5ef 100644 --- a/extmod/modtls_mbedtls.c +++ b/extmod/modtls_mbedtls.c @@ -37,6 +37,7 @@ #include "py/stream.h" #include "py/objstr.h" #include "py/reader.h" +#include "py/gc.h" #include "extmod/vfs.h" // mbedtls_time_t @@ -58,6 +59,10 @@ #include "mbedtls/asn1.h" #endif +#ifndef MICROPY_MBEDTLS_CONFIG_BARE_METAL +#define MICROPY_MBEDTLS_CONFIG_BARE_METAL (0) +#endif + #define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR) // This corresponds to an SSLContext object. @@ -545,6 +550,16 @@ static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t mbedtls_ssl_init(&o->ssl); ret = mbedtls_ssl_setup(&o->ssl, &ssl_context->conf); + #if !MICROPY_MBEDTLS_CONFIG_BARE_METAL + if (ret == MBEDTLS_ERR_SSL_ALLOC_FAILED) { + // If mbedTLS relies on platform libc heap for buffers (i.e. esp32 + // port), then run a GC pass and then try again. This is useful because + // it may free a Python object (like an old SSL socket) whose finaliser + // frees some platform-level heap. + gc_collect(); + ret = mbedtls_ssl_setup(&o->ssl, &ssl_context->conf); + } + #endif if (ret != 0) { goto cleanup; } diff --git a/ports/rp2/mpthreadport.c b/ports/rp2/mpthreadport.c index 0d3b343cab067..8fd5e5d790277 100644 --- a/ports/rp2/mpthreadport.c +++ b/ports/rp2/mpthreadport.c @@ -84,10 +84,10 @@ void mp_thread_deinit(void) { assert(get_core_num() == 0); // Must ensure that core1 is not currently holding the GC lock, otherwise // it will be terminated while holding the lock. - mp_thread_mutex_lock(&MP_STATE_MEM(gc_mutex), 1); + mp_thread_recursive_mutex_lock(&MP_STATE_MEM(gc_mutex), 1); multicore_reset_core1(); core1_entry = NULL; - mp_thread_mutex_unlock(&MP_STATE_MEM(gc_mutex)); + mp_thread_recursive_mutex_unlock(&MP_STATE_MEM(gc_mutex)); } void mp_thread_gc_others(void) { diff --git a/ports/rp2/mpthreadport.h b/ports/rp2/mpthreadport.h index 67a0da0e9373a..f2f2e17bb079d 100644 --- a/ports/rp2/mpthreadport.h +++ b/ports/rp2/mpthreadport.h @@ -26,9 +26,10 @@ #ifndef MICROPY_INCLUDED_RP2_MPTHREADPORT_H #define MICROPY_INCLUDED_RP2_MPTHREADPORT_H -#include "pico/mutex.h" +#include "mutex_extra.h" typedef struct mutex mp_thread_mutex_t; +typedef recursive_mutex_nowait_t mp_thread_recursive_mutex_t; extern void *core_state[2]; @@ -65,4 +66,21 @@ static inline void mp_thread_mutex_unlock(mp_thread_mutex_t *m) { mutex_exit(m); } +static inline void mp_thread_recursive_mutex_init(mp_thread_recursive_mutex_t *m) { + recursive_mutex_nowait_init(m); +} + +static inline int mp_thread_recursive_mutex_lock(mp_thread_recursive_mutex_t *m, int wait) { + if (wait) { + recursive_mutex_nowait_enter_blocking(m); + return 1; + } else { + return recursive_mutex_nowait_try_enter(m, NULL); + } +} + +static inline void mp_thread_recursive_mutex_unlock(mp_thread_recursive_mutex_t *m) { + recursive_mutex_nowait_exit(m); +} + #endif // MICROPY_INCLUDED_RP2_MPTHREADPORT_H diff --git a/ports/rp2/pendsv.c b/ports/rp2/pendsv.c index 905a5aa162ec5..4ba1e81604b0f 100644 --- a/ports/rp2/pendsv.c +++ b/ports/rp2/pendsv.c @@ -26,7 +26,7 @@ #include #include "py/mpconfig.h" -#include "mutex_extra.h" +#include "py/mpthread.h" #include "pendsv.h" #if PICO_RP2040 @@ -47,21 +47,21 @@ void PendSV_Handler(void); // Using the nowait variant here as softtimer updates PendSV from the loop of mp_wfe_or_timeout(), // where we don't want the CPU event bit to be set. -static recursive_mutex_nowait_t pendsv_mutex; +static mp_thread_recursive_mutex_t pendsv_mutex; void pendsv_init(void) { - recursive_mutex_nowait_init(&pendsv_mutex); + mp_thread_recursive_mutex_init(&pendsv_mutex); } void pendsv_suspend(void) { // Recursive Mutex here as either core may call pendsv_suspend() and expect // both mutual exclusion (other core can't enter pendsv_suspend() at the // same time), and that no PendSV handler will run. - recursive_mutex_nowait_enter_blocking(&pendsv_mutex); + mp_thread_recursive_mutex_lock(&pendsv_mutex, 1); } void pendsv_resume(void) { - recursive_mutex_nowait_exit(&pendsv_mutex); + mp_thread_recursive_mutex_unlock(&pendsv_mutex); // Run pendsv if needed. Find an entry with a dispatch and call pendsv dispatch // with it. If pendsv runs it will service all slots. @@ -97,7 +97,7 @@ void pendsv_schedule_dispatch(size_t slot, pendsv_dispatch_t f) { // PendSV interrupt handler to perform background processing. void PendSV_Handler(void) { - if (!recursive_mutex_nowait_try_enter(&pendsv_mutex, NULL)) { + if (!mp_thread_recursive_mutex_lock(&pendsv_mutex, 0)) { // Failure here means core 1 holds pendsv_mutex. ISR will // run again after core 1 calls pendsv_resume(). return; @@ -117,5 +117,5 @@ void PendSV_Handler(void) { } } - recursive_mutex_nowait_exit(&pendsv_mutex); + mp_thread_recursive_mutex_unlock(&pendsv_mutex); } diff --git a/ports/unix/mbedtls/mbedtls_config_port.h b/ports/unix/mbedtls/mbedtls_config_port.h index c619de9b8b1b9..aec65e6581e73 100644 --- a/ports/unix/mbedtls/mbedtls_config_port.h +++ b/ports/unix/mbedtls/mbedtls_config_port.h @@ -32,7 +32,18 @@ // Enable mbedtls modules #define MBEDTLS_TIMING_C +#if defined(MICROPY_UNIX_COVERAGE) +// Test the "bare metal" memory management in the coverage build +#define MICROPY_MBEDTLS_CONFIG_BARE_METAL (1) +#endif + // Include common mbedtls configuration. #include "extmod/mbedtls/mbedtls_config_common.h" +#if defined(MICROPY_UNIX_COVERAGE) +// See comment above, but fall back to the default platform entropy functions +#undef MBEDTLS_ENTROPY_HARDWARE_ALT +#undef MBEDTLS_NO_PLATFORM_ENTROPY +#endif + #endif /* MICROPY_INCLUDED_MBEDTLS_CONFIG_H */ diff --git a/ports/unix/mpthreadport.c b/ports/unix/mpthreadport.c index 16ac4da8bf56f..5172645bc147a 100644 --- a/ports/unix/mpthreadport.c +++ b/ports/unix/mpthreadport.c @@ -65,7 +65,7 @@ static pthread_key_t tls_key; // The mutex is used for any code in this port that needs to be thread safe. // Specifically for thread management, access to the linked list is one example. // But also, e.g. scheduler state. -static pthread_mutex_t thread_mutex; +static mp_thread_recursive_mutex_t thread_mutex; static mp_thread_t *thread; // this is used to synchronise the signal handler of the thread @@ -78,11 +78,11 @@ static sem_t thread_signal_done; #endif void mp_thread_unix_begin_atomic_section(void) { - pthread_mutex_lock(&thread_mutex); + mp_thread_recursive_mutex_lock(&thread_mutex, true); } void mp_thread_unix_end_atomic_section(void) { - pthread_mutex_unlock(&thread_mutex); + mp_thread_recursive_mutex_unlock(&thread_mutex); } // this signal handler is used to scan the regs and stack of a thread @@ -113,10 +113,7 @@ void mp_thread_init(void) { // Needs to be a recursive mutex to emulate the behavior of // BEGIN_ATOMIC_SECTION on bare metal. - pthread_mutexattr_t thread_mutex_attr; - pthread_mutexattr_init(&thread_mutex_attr); - pthread_mutexattr_settype(&thread_mutex_attr, PTHREAD_MUTEX_RECURSIVE); - pthread_mutex_init(&thread_mutex, &thread_mutex_attr); + mp_thread_recursive_mutex_init(&thread_mutex); // create first entry in linked list of all threads thread = malloc(sizeof(mp_thread_t)); @@ -321,6 +318,26 @@ void mp_thread_mutex_unlock(mp_thread_mutex_t *mutex) { // TODO check return value } +#if MICROPY_PY_THREAD_RECURSIVE_MUTEX + +void mp_thread_recursive_mutex_init(mp_thread_recursive_mutex_t *mutex) { + pthread_mutexattr_t attr; + pthread_mutexattr_init(&attr); + pthread_mutexattr_settype(&attr, PTHREAD_MUTEX_RECURSIVE); + pthread_mutex_init(mutex, &attr); + pthread_mutexattr_destroy(&attr); +} + +int mp_thread_recursive_mutex_lock(mp_thread_recursive_mutex_t *mutex, int wait) { + return mp_thread_mutex_lock(mutex, wait); +} + +void mp_thread_recursive_mutex_unlock(mp_thread_recursive_mutex_t *mutex) { + mp_thread_mutex_unlock(mutex); +} + +#endif // MICROPY_PY_THREAD_RECURSIVE_MUTEX + #endif // MICROPY_PY_THREAD // this is used even when MICROPY_PY_THREAD is disabled diff --git a/ports/unix/mpthreadport.h b/ports/unix/mpthreadport.h index b365f200edf97..a38223e720b45 100644 --- a/ports/unix/mpthreadport.h +++ b/ports/unix/mpthreadport.h @@ -28,6 +28,7 @@ #include typedef pthread_mutex_t mp_thread_mutex_t; +typedef pthread_mutex_t mp_thread_recursive_mutex_t; void mp_thread_init(void); void mp_thread_deinit(void); diff --git a/py/gc.c b/py/gc.c index bee44925076f0..eda63187b20a3 100644 --- a/py/gc.c +++ b/py/gc.c @@ -113,13 +113,28 @@ #endif #if MICROPY_PY_THREAD && !MICROPY_PY_THREAD_GIL -#define GC_ENTER() mp_thread_mutex_lock(&MP_STATE_MEM(gc_mutex), 1) -#define GC_EXIT() mp_thread_mutex_unlock(&MP_STATE_MEM(gc_mutex)) +#define GC_MUTEX_INIT() mp_thread_recursive_mutex_init(&MP_STATE_MEM(gc_mutex)) +#define GC_ENTER() mp_thread_recursive_mutex_lock(&MP_STATE_MEM(gc_mutex), 1) +#define GC_EXIT() mp_thread_recursive_mutex_unlock(&MP_STATE_MEM(gc_mutex)) #else +// Either no threading, or assume callers to gc_collect() hold the GIL +#define GC_MUTEX_INIT() #define GC_ENTER() #define GC_EXIT() #endif +// Static functions for individual steps of the GC mark/sweep sequence +static void gc_collect_start_common(void); +static void *gc_get_ptr(void **ptrs, int i); +#if MICROPY_GC_SPLIT_HEAP +static void gc_mark_subtree(mp_state_mem_area_t *area, size_t block); +#else +static void gc_mark_subtree(size_t block); +#endif +static void gc_deal_with_stack_overflow(void); +static void gc_sweep_run_finalisers(void); +static void gc_sweep_free_blocks(void); + // TODO waste less memory; currently requires that all entries in alloc_table have a corresponding block in pool static void gc_setup_area(mp_state_mem_area_t *area, void *start, void *end) { // calculate parameters for GC (T=total, A=alloc table, F=finaliser table, P=pool; all in bytes): @@ -210,9 +225,7 @@ void gc_init(void *start, void *end) { MP_STATE_MEM(gc_alloc_amount) = 0; #endif - #if MICROPY_PY_THREAD && !MICROPY_PY_THREAD_GIL - mp_thread_mutex_init(&MP_STATE_MEM(gc_mutex)); - #endif + GC_MUTEX_INIT(); } #if MICROPY_GC_SPLIT_HEAP @@ -334,12 +347,12 @@ void gc_lock(void) { // - each thread has its own gc_lock_depth so there are no races between threads; // - a hard interrupt will only change gc_lock_depth during its execution, and // upon return will restore the value of gc_lock_depth. - MP_STATE_THREAD(gc_lock_depth)++; + MP_STATE_THREAD(gc_lock_depth) += (1 << GC_LOCK_DEPTH_SHIFT); } void gc_unlock(void) { // This does not need to be atomic, See comment above in gc_lock. - MP_STATE_THREAD(gc_lock_depth)--; + MP_STATE_THREAD(gc_lock_depth) -= (1 << GC_LOCK_DEPTH_SHIFT); } bool gc_is_locked(void) { @@ -378,6 +391,64 @@ static inline mp_state_mem_area_t *gc_get_ptr_area(const void *ptr) { #endif #endif +void gc_collect_start(void) { + gc_collect_start_common(); + #if MICROPY_GC_ALLOC_THRESHOLD + MP_STATE_MEM(gc_alloc_amount) = 0; + #endif + + // Trace root pointers. This relies on the root pointers being organised + // correctly in the mp_state_ctx structure. We scan nlr_top, dict_locals, + // dict_globals, then the root pointer section of mp_state_vm. + void **ptrs = (void **)(void *)&mp_state_ctx; + size_t root_start = offsetof(mp_state_ctx_t, thread.dict_locals); + size_t root_end = offsetof(mp_state_ctx_t, vm.qstr_last_chunk); + gc_collect_root(ptrs + root_start / sizeof(void *), (root_end - root_start) / sizeof(void *)); + + #if MICROPY_ENABLE_PYSTACK + // Trace root pointers from the Python stack. + ptrs = (void **)(void *)MP_STATE_THREAD(pystack_start); + gc_collect_root(ptrs, (MP_STATE_THREAD(pystack_cur) - MP_STATE_THREAD(pystack_start)) / sizeof(void *)); + #endif +} + +static void gc_collect_start_common(void) { + GC_ENTER(); + assert((MP_STATE_THREAD(gc_lock_depth) & GC_COLLECT_FLAG) == 0); + MP_STATE_THREAD(gc_lock_depth) |= GC_COLLECT_FLAG; + MP_STATE_MEM(gc_stack_overflow) = 0; +} + +void gc_collect_root(void **ptrs, size_t len) { + #if !MICROPY_GC_SPLIT_HEAP + mp_state_mem_area_t *area = &MP_STATE_MEM(area); + #endif + for (size_t i = 0; i < len; i++) { + MICROPY_GC_HOOK_LOOP(i); + void *ptr = gc_get_ptr(ptrs, i); + #if MICROPY_GC_SPLIT_HEAP + mp_state_mem_area_t *area = gc_get_ptr_area(ptr); + if (!area) { + continue; + } + #else + if (!VERIFY_PTR(ptr)) { + continue; + } + #endif + size_t block = BLOCK_FROM_PTR(area, ptr); + if (ATB_GET_KIND(area, block) == AT_HEAD) { + // An unmarked head: mark it, and mark all its children + ATB_HEAD_TO_MARK(area, block); + #if MICROPY_GC_SPLIT_HEAP + gc_mark_subtree(area, block); + #else + gc_mark_subtree(block); + #endif + } + } +} + // Take the given block as the topmost block on the stack. Check all it's // children: mark the unmarked child blocks and put those newly marked // blocks on the stack. When all children have been checked, pop off the @@ -456,6 +527,25 @@ static void gc_mark_subtree(size_t block) } } +void gc_sweep_all(void) { + gc_collect_start_common(); + gc_collect_end(); +} + +void gc_collect_end(void) { + gc_deal_with_stack_overflow(); + gc_sweep_run_finalisers(); + gc_sweep_free_blocks(); + #if MICROPY_GC_SPLIT_HEAP + MP_STATE_MEM(gc_last_free_area) = &MP_STATE_MEM(area); + #endif + for (mp_state_mem_area_t *area = &MP_STATE_MEM(area); area != NULL; area = NEXT_AREA(area)) { + area->gc_last_free_atb_index = 0; + } + MP_STATE_THREAD(gc_lock_depth) &= ~GC_COLLECT_FLAG; + GC_EXIT(); +} + static void gc_deal_with_stack_overflow(void) { while (MP_STATE_MEM(gc_stack_overflow)) { MP_STATE_MEM(gc_stack_overflow) = 0; @@ -477,29 +567,20 @@ static void gc_deal_with_stack_overflow(void) { } } -static void gc_sweep(void) { - #if MICROPY_PY_GC_COLLECT_RETVAL - MP_STATE_MEM(gc_collected) = 0; - #endif - // free unmarked heads and their tails - int free_tail = 0; - #if MICROPY_GC_SPLIT_HEAP_AUTO - mp_state_mem_area_t *prev_area = NULL; - #endif - for (mp_state_mem_area_t *area = &MP_STATE_MEM(area); area != NULL; area = NEXT_AREA(area)) { - size_t end_block = area->gc_alloc_table_byte_len * BLOCKS_PER_ATB; - if (area->gc_last_used_block < end_block) { - end_block = area->gc_last_used_block + 1; - } - - size_t last_used_block = 0; - - for (size_t block = 0; block < end_block; block++) { - MICROPY_GC_HOOK_LOOP(block); - switch (ATB_GET_KIND(area, block)) { - case AT_HEAD: - #if MICROPY_ENABLE_FINALISER - if (FTB_GET(area, block)) { +// Run finalisers for all to-be-freed blocks +static void gc_sweep_run_finalisers(void) { + #if MICROPY_ENABLE_FINALISER + for (const mp_state_mem_area_t *area = &MP_STATE_MEM(area); area != NULL; area = NEXT_AREA(area)) { + assert(area->gc_last_used_block <= area->gc_alloc_table_byte_len * BLOCKS_PER_ATB); + // Small speed optimisation: skip over empty FTB blocks + size_t ftb_end = area->gc_last_used_block / BLOCKS_PER_FTB; // index is inclusive + for (size_t ftb_idx = 0; ftb_idx <= ftb_end; ftb_idx++) { + byte ftb = area->gc_finaliser_table_start[ftb_idx]; + size_t block = ftb_idx * BLOCKS_PER_FTB; + while (ftb) { + MICROPY_GC_HOOK_LOOP(block); + if (ftb & 1) { // FTB_GET(area, block) shortcut + if (ATB_GET_KIND(area, block) == AT_HEAD) { mp_obj_base_t *obj = (mp_obj_base_t *)PTR_FROM_BLOCK(area, block); if (obj->type != NULL) { // if the object has a type then see if it has a __del__ method @@ -519,9 +600,35 @@ static void gc_sweep(void) { // clear finaliser flag FTB_CLEAR(area, block); } - #endif + } + ftb >>= 1; + block++; + } + } + } + #endif // MICROPY_ENABLE_FINALISER +} + +// Free unmarked heads and their tails +static void gc_sweep_free_blocks(void) { + #if MICROPY_PY_GC_COLLECT_RETVAL + MP_STATE_MEM(gc_collected) = 0; + #endif + int free_tail = 0; + #if MICROPY_GC_SPLIT_HEAP_AUTO + mp_state_mem_area_t *prev_area = NULL; + #endif + + for (mp_state_mem_area_t *area = &MP_STATE_MEM(area); area != NULL; area = NEXT_AREA(area)) { + size_t last_used_block = 0; + assert(area->gc_last_used_block <= area->gc_alloc_table_byte_len * BLOCKS_PER_ATB); + + for (size_t block = 0; block <= area->gc_last_used_block; block++) { + MICROPY_GC_HOOK_LOOP(block); + switch (ATB_GET_KIND(area, block)) { + case AT_HEAD: free_tail = 1; - DEBUG_printf("gc_sweep(%p)\n", (void *)PTR_FROM_BLOCK(area, block)); + DEBUG_printf("gc_sweep_free_blocks(%p)\n", (void *)PTR_FROM_BLOCK(area, block)); #if MICROPY_PY_GC_COLLECT_RETVAL MP_STATE_MEM(gc_collected)++; #endif @@ -552,7 +659,7 @@ static void gc_sweep(void) { #if MICROPY_GC_SPLIT_HEAP_AUTO // Free any empty area, aside from the first one if (last_used_block == 0 && prev_area != NULL) { - DEBUG_printf("gc_sweep free empty area %p\n", area); + DEBUG_printf("gc_sweep_free_blocks free empty area %p\n", area); NEXT_AREA(prev_area) = NEXT_AREA(area); MP_PLAT_FREE_HEAP(area); area = prev_area; @@ -562,29 +669,6 @@ static void gc_sweep(void) { } } -void gc_collect_start(void) { - GC_ENTER(); - MP_STATE_THREAD(gc_lock_depth)++; - #if MICROPY_GC_ALLOC_THRESHOLD - MP_STATE_MEM(gc_alloc_amount) = 0; - #endif - MP_STATE_MEM(gc_stack_overflow) = 0; - - // Trace root pointers. This relies on the root pointers being organised - // correctly in the mp_state_ctx structure. We scan nlr_top, dict_locals, - // dict_globals, then the root pointer section of mp_state_vm. - void **ptrs = (void **)(void *)&mp_state_ctx; - size_t root_start = offsetof(mp_state_ctx_t, thread.dict_locals); - size_t root_end = offsetof(mp_state_ctx_t, vm.qstr_last_chunk); - gc_collect_root(ptrs + root_start / sizeof(void *), (root_end - root_start) / sizeof(void *)); - - #if MICROPY_ENABLE_PYSTACK - // Trace root pointers from the Python stack. - ptrs = (void **)(void *)MP_STATE_THREAD(pystack_start); - gc_collect_root(ptrs, (MP_STATE_THREAD(pystack_cur) - MP_STATE_THREAD(pystack_start)) / sizeof(void *)); - #endif -} - // Address sanitizer needs to know that the access to ptrs[i] must always be // considered OK, even if it's a load from an address that would normally be // prohibited (due to being undefined, in a red zone, etc). @@ -600,56 +684,6 @@ static void *gc_get_ptr(void **ptrs, int i) { return ptrs[i]; } -void gc_collect_root(void **ptrs, size_t len) { - #if !MICROPY_GC_SPLIT_HEAP - mp_state_mem_area_t *area = &MP_STATE_MEM(area); - #endif - for (size_t i = 0; i < len; i++) { - MICROPY_GC_HOOK_LOOP(i); - void *ptr = gc_get_ptr(ptrs, i); - #if MICROPY_GC_SPLIT_HEAP - mp_state_mem_area_t *area = gc_get_ptr_area(ptr); - if (!area) { - continue; - } - #else - if (!VERIFY_PTR(ptr)) { - continue; - } - #endif - size_t block = BLOCK_FROM_PTR(area, ptr); - if (ATB_GET_KIND(area, block) == AT_HEAD) { - // An unmarked head: mark it, and mark all its children - ATB_HEAD_TO_MARK(area, block); - #if MICROPY_GC_SPLIT_HEAP - gc_mark_subtree(area, block); - #else - gc_mark_subtree(block); - #endif - } - } -} - -void gc_collect_end(void) { - gc_deal_with_stack_overflow(); - gc_sweep(); - #if MICROPY_GC_SPLIT_HEAP - MP_STATE_MEM(gc_last_free_area) = &MP_STATE_MEM(area); - #endif - for (mp_state_mem_area_t *area = &MP_STATE_MEM(area); area != NULL; area = NEXT_AREA(area)) { - area->gc_last_free_atb_index = 0; - } - MP_STATE_THREAD(gc_lock_depth)--; - GC_EXIT(); -} - -void gc_sweep_all(void) { - GC_ENTER(); - MP_STATE_THREAD(gc_lock_depth)++; - MP_STATE_MEM(gc_stack_overflow) = 0; - gc_collect_end(); -} - void gc_info(gc_info_t *info) { GC_ENTER(); info->total = 0; @@ -883,10 +917,13 @@ void *gc_alloc(size_t n_bytes, unsigned int alloc_flags) { // force the freeing of a piece of memory // TODO: freeing here does not call finaliser void gc_free(void *ptr) { - if (MP_STATE_THREAD(gc_lock_depth) > 0) { - // Cannot free while the GC is locked. However free is an optimisation - // to reclaim the memory immediately, this means it will now be left - // until the next collection. + // Cannot free while the GC is locked, unless we're only doing a gc sweep. + // However free is an optimisation to reclaim the memory immediately, this + // means it will now be left until the next collection. + // + // (We have the optimisation to free immediately from inside a gc sweep so + // that finalisers can free more memory when trying to avoid MemoryError.) + if (MP_STATE_THREAD(gc_lock_depth) & ~GC_COLLECT_FLAG) { return; } @@ -911,7 +948,8 @@ void gc_free(void *ptr) { #endif size_t block = BLOCK_FROM_PTR(area, ptr); - assert(ATB_GET_KIND(area, block) == AT_HEAD); + assert(ATB_GET_KIND(area, block) == AT_HEAD + || (ATB_GET_KIND(area, block) == AT_MARK && (MP_STATE_THREAD(gc_lock_depth) & GC_COLLECT_FLAG))); #if MICROPY_ENABLE_FINALISER FTB_CLEAR(area, block); diff --git a/py/modmicropython.c b/py/modmicropython.c index 1bf0a000c2026..d1a687f10e14e 100644 --- a/py/modmicropython.c +++ b/py/modmicropython.c @@ -132,13 +132,13 @@ static MP_DEFINE_CONST_FUN_OBJ_0(mp_micropython_heap_lock_obj, mp_micropython_he static mp_obj_t mp_micropython_heap_unlock(void) { gc_unlock(); - return MP_OBJ_NEW_SMALL_INT(MP_STATE_THREAD(gc_lock_depth)); + return MP_OBJ_NEW_SMALL_INT(MP_STATE_THREAD(gc_lock_depth) >> GC_LOCK_DEPTH_SHIFT); } static MP_DEFINE_CONST_FUN_OBJ_0(mp_micropython_heap_unlock_obj, mp_micropython_heap_unlock); #if MICROPY_PY_MICROPYTHON_HEAP_LOCKED static mp_obj_t mp_micropython_heap_locked(void) { - return MP_OBJ_NEW_SMALL_INT(MP_STATE_THREAD(gc_lock_depth)); + return MP_OBJ_NEW_SMALL_INT(MP_STATE_THREAD(gc_lock_depth) >> GC_LOCK_DEPTH_SHIFT); } static MP_DEFINE_CONST_FUN_OBJ_0(mp_micropython_heap_locked_obj, mp_micropython_heap_locked); #endif diff --git a/py/mpconfig.h b/py/mpconfig.h index e84d258a1220a..64138a9ea7acc 100644 --- a/py/mpconfig.h +++ b/py/mpconfig.h @@ -1629,6 +1629,11 @@ typedef double mp_float_t; #define MICROPY_PY_THREAD_GIL_VM_DIVISOR (32) #endif +// Is a recursive mutex type in use? +#ifndef MICROPY_PY_THREAD_RECURSIVE_MUTEX +#define MICROPY_PY_THREAD_RECURSIVE_MUTEX (MICROPY_PY_THREAD && !MICROPY_PY_THREAD_GIL) +#endif + // Extended modules #ifndef MICROPY_PY_ASYNCIO diff --git a/py/mpstate.h b/py/mpstate.h index 54eca596daa28..138c5617300e8 100644 --- a/py/mpstate.h +++ b/py/mpstate.h @@ -77,6 +77,18 @@ typedef struct _mp_sched_item_t { mp_obj_t arg; } mp_sched_item_t; +// gc_lock_depth field is a combination of the GC_COLLECT_FLAG +// bit and a lock depth shifted GC_LOCK_DEPTH_SHIFT bits left. +#if MICROPY_ENABLE_FINALISER +#define GC_COLLECT_FLAG 1 +#define GC_LOCK_DEPTH_SHIFT 1 +#else +// If finalisers are disabled then this check doesn't matter, as gc_lock() +// is called anywhere else that heap can't be changed. So save some code size. +#define GC_COLLECT_FLAG 0 +#define GC_LOCK_DEPTH_SHIFT 0 +#endif + // This structure holds information about a single contiguous area of // memory reserved for the memory manager. typedef struct _mp_state_mem_area_t { @@ -133,7 +145,7 @@ typedef struct _mp_state_mem_t { #if MICROPY_PY_THREAD && !MICROPY_PY_THREAD_GIL // This is a global mutex used to make the GC thread-safe. - mp_thread_mutex_t gc_mutex; + mp_thread_recursive_mutex_t gc_mutex; #endif } mp_state_mem_t; @@ -268,6 +280,7 @@ typedef struct _mp_state_thread_t { #endif // Locking of the GC is done per thread. + // See GC_LOCK_DEPTH_SHIFT for an explanation of this field. uint16_t gc_lock_depth; //////////////////////////////////////////////////////////// diff --git a/py/mpthread.h b/py/mpthread.h index f335cc02911fc..795f230bb4a0c 100644 --- a/py/mpthread.h +++ b/py/mpthread.h @@ -48,6 +48,12 @@ void mp_thread_mutex_init(mp_thread_mutex_t *mutex); int mp_thread_mutex_lock(mp_thread_mutex_t *mutex, int wait); void mp_thread_mutex_unlock(mp_thread_mutex_t *mutex); +#if MICROPY_PY_THREAD_RECURSIVE_MUTEX +void mp_thread_recursive_mutex_init(mp_thread_recursive_mutex_t *mutex); +int mp_thread_recursive_mutex_lock(mp_thread_recursive_mutex_t *mutex, int wait); +void mp_thread_recursive_mutex_unlock(mp_thread_recursive_mutex_t *mutex); +#endif + #endif // MICROPY_PY_THREAD #if MICROPY_PY_THREAD && MICROPY_PY_THREAD_GIL diff --git a/tests/extmod/ssl_noleak.py b/tests/extmod/ssl_noleak.py new file mode 100644 index 0000000000000..870032d58e63e --- /dev/null +++ b/tests/extmod/ssl_noleak.py @@ -0,0 +1,50 @@ +# Ensure that SSLSockets can be allocated sequentially +# without running out of available memory. +try: + import io + import tls +except ImportError: + print("SKIP") + raise SystemExit + +import unittest + + +class TestSocket(io.IOBase): + def write(self, buf): + return len(buf) + + def readinto(self, buf): + return 0 + + def ioctl(self, cmd, arg): + return 0 + + def setblocking(self, value): + pass + + +ITERS = 128 + + +class TLSNoLeaks(unittest.TestCase): + def test_unique_context(self): + for n in range(ITERS): + print(n) + s = TestSocket() + ctx = tls.SSLContext(tls.PROTOCOL_TLS_CLIENT) + ctx.verify_mode = tls.CERT_NONE + s = ctx.wrap_socket(s, do_handshake_on_connect=False) + + def test_shared_context(self): + # Single SSLContext, multiple sockets + ctx = tls.SSLContext(tls.PROTOCOL_TLS_CLIENT) + ctx.verify_mode = tls.CERT_NONE + for n in range(ITERS): + print(n) + s = TestSocket() + s = ctx.wrap_socket(s, do_handshake_on_connect=False) + + +if __name__ == "__main__": + unittest.main()