diff --git a/Doc/c-api/contextvars.rst b/Doc/c-api/contextvars.rst index b7c6550ff34aac..504393fe1c687b 100644 --- a/Doc/c-api/contextvars.rst +++ b/Doc/c-api/contextvars.rst @@ -126,7 +126,8 @@ Context object management functions: - ``Py_CONTEXT_SWITCHED``: The :term:`current context` has switched to a different context. The object passed to the watch callback is the now-current :class:`contextvars.Context` object, or None if no context is - current. + current. The thread executing the callback is guaranteed to be the thread + that experienced the context switch. .. versionadded:: 3.14 diff --git a/Include/cpython/context.h b/Include/cpython/context.h index 3a7a4b459c09ad..7430db962bffbd 100644 --- a/Include/cpython/context.h +++ b/Include/cpython/context.h @@ -31,7 +31,9 @@ typedef enum { /* * The current context has switched to a different context. The object * passed to the watch callback is the now-current contextvars.Context - * object, or None if no context is current. + * object, or None if no context is current. The thread executing the + * callback is guaranteed to be the thread that experienced the context + * switch. */ Py_CONTEXT_SWITCHED = 1, } PyContextEvent; diff --git a/Include/internal/pycore_context.h b/Include/internal/pycore_context.h index c2b98d15da68fa..6db141eb2b0b48 100644 --- a/Include/internal/pycore_context.h +++ b/Include/internal/pycore_context.h @@ -15,6 +15,15 @@ extern PyTypeObject _PyContextTokenMissing_Type; PyStatus _PyContext_Init(PyInterpreterState *); +// Exits any thread-owned contexts (see context_get) at the top of the given +// thread's context stack. The given thread state is not required to belong to +// the calling thread; if not, the thread is assumed to have exited (or not yet +// started) and no Py_CONTEXT_SWITCHED event is emitted for any context +// changes. Logs a warning via PyErr_FormatUnraisable if the thread's context +// stack is non-empty afterwards (because those contexts can never be exited or +// re-entered). +void _PyContext_ExitThreadOwned(PyThreadState *); + /* other API */ @@ -27,7 +36,11 @@ struct _pycontextobject { PyContext *ctx_prev; PyHamtObject *ctx_vars; PyObject *ctx_weakreflist; - int ctx_entered; + _Bool ctx_entered:1; + // True for the thread's default context created by context_get. Used to + // safely determine whether the base context can be exited when clearing a + // PyThreadState. + _Bool ctx_owned_by_thread:1; }; diff --git a/Lib/test/test_capi/test_watchers.py b/Lib/test/test_capi/test_watchers.py index 8644479d83d5ed..bee67fe595c15f 100644 --- a/Lib/test/test_capi/test_watchers.py +++ b/Lib/test/test_capi/test_watchers.py @@ -1,10 +1,11 @@ +import threading import unittest import contextvars from contextlib import contextmanager, ExitStack from test.support import ( catch_unraisable_exception, import_helper, - gc_collect) + gc_collect, threading_helper) # Skip this test if the _testcapi module isn't available. @@ -674,5 +675,75 @@ def test_exit_base_context(self): ctx.run(lambda: None) self.assertEqual(switches, [ctx, None]) + def test_reenter_default_context(self): + _testcapi.clear_context_stack() + # contextvars.copy_context() creates the thread's default context (via + # the context_get C function). + ctx = contextvars.copy_context() + with self.context_watcher(0) as switches: + ctx.run(lambda: None) + self.assertEqual(len(switches), 2) + self.assertEqual(switches[0], ctx) + base_ctx = switches[1] + self.assertIsNotNone(base_ctx) + self.assertIsNot(base_ctx, ctx) + with self.assertRaisesRegex(RuntimeError, 'already entered'): + base_ctx.run(lambda: None) + + def test_default_context_enter(self): + _testcapi.clear_context_stack() + with self.context_watcher(0) as switches: + ctx = contextvars.copy_context() + ctx.run(lambda: None) + self.assertEqual(len(switches), 3) + base_ctx = switches[0] + self.assertIsNotNone(base_ctx) + self.assertEqual(switches, [base_ctx, ctx, base_ctx]) + + @threading_helper.requires_working_threading() + def test_default_context_exit_during_thread_cleanup(self): + # Context watchers are per-interpreter, not per-thread. + with self.context_watcher(0) as switches: + def _thread_main(): + _testcapi.clear_context_stack() + # contextvars.copy_context() creates the thread's default + # context (via the context_get C function). + contextvars.copy_context() + # This test only cares about the final switch that happens when + # exiting the thread's default context during thread cleanup. + switches.clear() + + thread = threading.Thread(target=_thread_main) + thread.start() + threading_helper.join_thread(thread) + self.assertEqual(switches, [None]) + + @threading_helper.requires_working_threading() + def test_thread_cleanup_with_entered_context(self): + unraisables = [] + try: + with catch_unraisable_exception() as cm: + with self.context_watcher(0) as switches: + def _thread_main(): + _testcapi.clear_context_stack() + ctx = contextvars.copy_context() + _testcapi.context_enter(ctx) + switches.clear() + + thread = threading.Thread(target=_thread_main) + thread.start() + threading_helper.join_thread(thread) + unraisables.append(cm.unraisable) + self.assertEqual(switches, []) + self.assertEqual(len(unraisables), 1) + self.assertIsNotNone(unraisables[0]) + self.assertRegex(unraisables[0].err_msg, + r'^Exception ignored during reset of thread state') + self.assertRegex(str(unraisables[0].exc_value), r'still entered') + finally: + # Break reference cycle + unraisables = None + + if __name__ == "__main__": unittest.main() diff --git a/Modules/_testcapi/watchers.c b/Modules/_testcapi/watchers.c index 321d3aeffb6ad1..b105cd1302d467 100644 --- a/Modules/_testcapi/watchers.c +++ b/Modules/_testcapi/watchers.c @@ -724,6 +724,15 @@ clear_context_stack(PyObject *Py_UNUSED(self), PyObject *Py_UNUSED(args)) Py_RETURN_NONE; } +static PyObject * +context_enter(PyObject *self, PyObject *ctx) +{ + if (PyContext_Enter(ctx) < 0) { + return NULL; + } + Py_RETURN_NONE; +} + static PyObject * get_context_switches(PyObject *Py_UNUSED(self), PyObject *watcher_id) { @@ -841,6 +850,7 @@ static PyMethodDef test_methods[] = { {"add_context_watcher", add_context_watcher, METH_O, NULL}, {"clear_context_watcher", clear_context_watcher, METH_O, NULL}, {"clear_context_stack", clear_context_stack, METH_NOARGS, NULL}, + {"context_enter", context_enter, METH_O, NULL}, {"get_context_switches", get_context_switches, METH_O, NULL}, {"allocate_too_many_context_watchers", (PyCFunction) allocate_too_many_context_watchers, METH_NOARGS, NULL}, diff --git a/Python/context.c b/Python/context.c index 95aa82206270f9..7a1fc1d0ebac0e 100644 --- a/Python/context.c +++ b/Python/context.c @@ -113,6 +113,11 @@ context_event_name(PyContextEvent event) { static void notify_context_watchers(PyThreadState *ts, PyContextEvent event, PyObject *ctx) { + // The callbacks are registered on the interpreter, not on the thread, so + // the only way callbacks can know which thread changed is by calling the + // callbacks from the affected thread. + assert(ts != NULL); + assert(ts == _PyThreadState_GET()); if (ctx == NULL) { // This will happen after exiting the last context in the stack, which // can occur if context_get was never called before entering a context @@ -184,12 +189,14 @@ static inline void context_switched(PyThreadState *ts) { ts->context_ver++; - // ts->context is used instead of context_get() because context_get() might - // throw if ts->context is NULL. + // ts->context is used instead of context_get() because if ts->context is + // NULL, context_get() will either call context_switched -- causing a + // double notification -- or throw. notify_context_watchers(ts, Py_CONTEXT_SWITCHED, ts->context); } +// ts is not required to belong to the calling thread. static int _PyContext_Enter(PyThreadState *ts, PyObject *octx) { @@ -197,8 +204,8 @@ _PyContext_Enter(PyThreadState *ts, PyObject *octx) PyContext *ctx = (PyContext *)octx; if (ctx->ctx_entered) { - _PyErr_Format(ts, PyExc_RuntimeError, - "cannot enter context: %R is already entered", ctx); + PyErr_Format(PyExc_RuntimeError, + "cannot enter context: %R is already entered", ctx); return -1; } @@ -206,7 +213,6 @@ _PyContext_Enter(PyThreadState *ts, PyObject *octx) ctx->ctx_entered = 1; ts->context = Py_NewRef(ctx); - context_switched(ts); return 0; } @@ -216,10 +222,15 @@ PyContext_Enter(PyObject *octx) { PyThreadState *ts = _PyThreadState_GET(); assert(ts != NULL); - return _PyContext_Enter(ts, octx); + if (_PyContext_Enter(ts, octx) < 0) { + return -1; + } + context_switched(ts); + return 0; } +// ts is not required to belong to the calling thread. static int _PyContext_Exit(PyThreadState *ts, PyObject *octx) { @@ -244,7 +255,7 @@ _PyContext_Exit(PyThreadState *ts, PyObject *octx) ctx->ctx_prev = NULL; ctx->ctx_entered = 0; - context_switched(ts); + ctx->ctx_owned_by_thread = 0; return 0; } @@ -253,7 +264,49 @@ PyContext_Exit(PyObject *octx) { PyThreadState *ts = _PyThreadState_GET(); assert(ts != NULL); - return _PyContext_Exit(ts, octx); + if (_PyContext_Exit(ts, octx) < 0) { + return -1; + } + context_switched(ts); + return 0; +} + + +void +_PyContext_ExitThreadOwned(PyThreadState *ts) +{ + assert(ts != NULL); + while (ts->context != NULL + && PyContext_CheckExact(ts->context) + && ((PyContext *)ts->context)->ctx_owned_by_thread) { + if (_PyContext_Exit(ts, ts->context) < 0) { + // Exiting a context that is already known to be at the top of the + // stack cannot fail. + Py_UNREACHABLE(); + } + // notify_context_watchers() requires the notification to come from the + // affected thread, so context_switched() must not be called if ts + // doesn't belong to the current thread. However, it's OK to skip + // calling it in this case: this function is only called when resetting + // a PyThreadState, so if the calling thread doesn't own ts, then the + // owning thread must not be running anymore (it must have just + // finished because a thread-owned context exists here). + if (ts == _PyThreadState_GET()) { + context_switched(ts); + } + } + if (ts->context != NULL) { + // This intentionally does not use tstate variants of these functions + // (e.g., _PyErr_GetRaisedException(ts)) because ts might not belong to + // the current thread. + PyObject *exc = PyErr_GetRaisedException(); + PyErr_SetString(PyExc_RuntimeError, + "contextvars.Context object(s) still entered during " + "thread state reset"); + PyErr_FormatUnraisable( + "Exception ignored during reset of thread state"); + PyErr_SetRaisedException(exc); + } } @@ -433,6 +486,7 @@ _context_alloc(void) ctx->ctx_vars = NULL; ctx->ctx_prev = NULL; ctx->ctx_entered = 0; + ctx->ctx_owned_by_thread = 0; ctx->ctx_weakreflist = NULL; return ctx; @@ -478,15 +532,18 @@ context_get(void) { PyThreadState *ts = _PyThreadState_GET(); assert(ts != NULL); - PyContext *current_ctx = (PyContext *)ts->context; - if (current_ctx == NULL) { - current_ctx = context_new_empty(); - if (current_ctx == NULL) { + if (ts->context == NULL) { + PyContext *ctx = context_new_empty(); + if (ctx == NULL || _PyContext_Enter(ts, (PyObject *)ctx) < 0) { return NULL; } - ts->context = (PyObject *)current_ctx; + ctx->ctx_owned_by_thread = 1; + assert(ts->context == (PyObject *)ctx); + Py_CLEAR(ctx); // _PyContext_Enter created its own ref. + context_switched(ts); } - return current_ctx; + assert(PyContext_CheckExact(ts->context)); + return (PyContext *)ts->context; } static int @@ -715,6 +772,7 @@ context_run(PyContext *self, PyObject *const *args, if (_PyContext_Enter(ts, (PyObject *)self)) { return NULL; } + context_switched(ts); PyObject *call_result = _PyObject_VectorcallTstate( ts, args[0], args + 1, nargs - 1, kwnames); @@ -723,6 +781,7 @@ context_run(PyContext *self, PyObject *const *args, Py_XDECREF(call_result); return NULL; } + context_switched(ts); return call_result; } diff --git a/Python/pystate.c b/Python/pystate.c index ded5fde9c4bb51..6b1ef5d4ee98cd 100644 --- a/Python/pystate.c +++ b/Python/pystate.c @@ -1654,6 +1654,10 @@ PyThreadState_Clear(PyThreadState *tstate) "PyThreadState_Clear: warning: thread still has a frame\n"); } + // This calls callbacks registered with PyContext_AddWatcher and can call + // sys.unraisablehook. + _PyContext_ExitThreadOwned(tstate); + /* At this point tstate shouldn't be used any more, neither to run Python code nor for other uses.