10000 gh-124872: Mark the thread's default context as entered by rhansen · Pull Request #125638 · python/cpython · GitHub
[go: up one dir, main page]

Skip to content

gh-124872: Mark the thread's default context as entered #125638

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
gh-124872: Mark the thread's default context as entered
Starting with commit 843d28f
(temporarily reverted in d3c82b9 and
restored in commit bee112a), it is
now technically possible to access a thread's default context created
by `context_get`.  Mark that context as entered so that users cannot
push that context onto the thread's stack a second time, which would
cause a cycle.

This change also causes a `CONTEXT_SWITCHED` event to be emitted when
the default context is created, which might be important in some use
cases.

Also exit the default context when the thread exits, for symmetry and
in case the user wants to re-enter it for some reason.

(Even if the `CONTEXT_SWITCHED` event is removed, entering the default
context is good defensive practice, and the consistent treatment of
all contexts on the stack makes it easier to understand the code.)
  • Loading branch information
rhansen committed Oct 17, 2024
commit 5337fc30ede3618e15cbbf990c1982b9764c8357
3 changes: 2 additions & 1 deletion Doc/c-api/contextvars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ Context object management functions:
- ``Py_CONTEXT_SWITCHED``: The :term:`current context` has switched to a
different context. The object passed to the watch callback is the
now-current :class:`contextvars.Context` object, or None if no context is
current.
current. The thread executing the callback is guaranteed to be the thread
that experienced the context switch.

.. versionadded:: 3.14

Expand Down
4 changes: 3 additions & 1 deletion Include/cpython/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ typedef enum {
/*
* The current context has switched to a different context. The object
* passed to the watch callback is the now-current contextvars.Context
* object, or None if no context is current.
* object, or None if no context is current. The thread executing the
* callback is guaranteed to be the thread that experienced the context
* switch.
*/
Py_CONTEXT_SWITCHED = 1,
} PyContextEvent;
Expand Down
12 changes: 11 additions & 1 deletion Include/internal/pycore_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ extern PyTypeObject _PyContextTokenMissing_Type;

PyStatus _PyContext_Init(PyInterpreterState *);

// Exits any thread-owned contexts (see context_get) at the top of the thread's
// context stack. Logs a warning via PyErr_FormatUnraisable if the thread's
// context stack is non-empty afterwards (those contexts can never be exited or
// re-entered).
void _PyContext_ExitThreadOwned(PyThreadState *);


/* other API */

Expand All @@ -27,7 +33,11 @@ struct _pycontextobject {
PyContext *ctx_prev;
PyHamtObject *ctx_vars;
PyObject *ctx_weakreflist;
int ctx_entered;
_Bool ctx_entered:1;
// True for the thread's default context created by context_get. Used to
// safely determine whether the base context can be exited when clearing a
// PyThreadState.
_Bool ctx_owned_by_thread:1;
};


Expand Down
73 changes: 72 additions & 1 deletion Lib/test/test_capi/test_watchers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import threading
import unittest
import contextvars

from contextlib import contextmanager, ExitStack
from test.support import (
catch_unraisable_exception, import_helper,
gc_collect, suppress_immortalization)
gc_collect, suppress_immortalization, threading_helper)


# Skip this test if the _testcapi module isn't available.
Expand Down Expand Up @@ -659,5 +660,75 @@ def test_exit_base_context(self):
ctx.run(lambda: None)
self.assertEqual(switches, [ctx, None])

def test_reenter_default_context(self):
_testcapi.clear_context_stack()
# contextvars.copy_context() creates the thread's default context (via
# the context_get C function).
ctx = contextvars.copy_context()
with self.context_watcher(0) as switches:
ctx.run(lambda: None)
self.assertEqual(len(switches), 2)
self.assertEqual(switches[0], ctx)
base_ctx = switches[1]
self.assertIsNotNone(base_ctx)
self.assertIsNot(base_ctx, ctx)
with self.assertRaisesRegex(RuntimeError, 'already entered'):
base_ctx.run(lambda: None)

def test_default_context_enter(self):
_testcapi.clear_context_stack()
with self.context_watcher(0) as switches:
ctx = contextvars.copy_context()
ctx.run(lambda: None)
self.assertEqual(len(switches), 3)
base_ctx = switches[0]
self.assertIsNotNone(base_ctx)
self.assertEqual(switches, [base_ctx, ctx, base_ctx])

@threading_helper.requires_working_threading()
def test_default_context_exit_during_thread_cleanup(self):
# Context watchers are per-interpreter, not per-thread.
with self.context_watcher(0) as switches:
def _thread_main():
_testcapi.clear_context_stack()
# contextvars.copy_context() creates the thread's default
# context (via the context_get C function).
contextvars.copy_context()
# This test only cares about the final switch that happens when
# exiting the thread's default context during thread cleanup.
switches.clear()

thread = threading.Thread(target=_thread_main)
thread.start()
threading_helper.join_thread(thread)
self.assertEqual(switches, [None])

@threading_helper.requires_working_threading()
def test_thread_cleanup_with_entered_context(self):
unraisables = []
try:
with catch_unraisable_exception() as cm:
with self.context_watcher(0) as switches:
def _thread_main():
_testcapi.clear_context_stack()
ctx = contextvars.copy_context()
_testcapi.context_enter(ctx)
switches.clear()

thread = threading.Thread(target=_thread_main)
thread.start()
threading_helper.join_thread(thread)
unraisables.append(cm.unraisable)
self.assertEqual(switches, [])
self.assertEqual(len(unraisables), 1)
self.assertIsNotNone(unraisables[0])
self.assertRegex(unraisables[0].err_msg,
r'^Exception ignored during reset of thread state')
self.assertRegex(str(unraisables[0].exc_value), r'still entered')
finally:
# Break reference cycle
unraisables = None


if __name__ == "__main__":
unittest.main()
10 changes: 10 additions & 0 deletions Modules/_testcapi/watchers.c
F438
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,15 @@ clear_context_stack(PyObject *Py_UNUSED(self), PyObject *Py_UNUSED(args))
Py_RETURN_NONE;
}

static PyObject *
context_enter(PyObject *self, PyObject *ctx)
{
if (PyContext_Enter(ctx)) {
return NULL;
}
Py_RETURN_NONE;
}

static PyObject *
get_context_switches(PyObject *Py_UNUSED(self), PyObject *watcher_id)
{
Expand Down Expand Up @@ -841,6 +850,7 @@ static PyMethodDef test_methods[] = {
{"add_context_watcher", add_context_watcher, METH_O, NULL},
{"clear_context_watcher", clear_context_watcher, METH_O, NULL},
{"clear_context_stack", clear_context_stack, METH_NOARGS, NULL},
{"context_enter", context_enter, METH_O, NULL},
{"get_context_switches", get_context_switches, METH_O, NULL},
{"allocate_too_many_context_watchers",
(PyCFunction) allocate_too_many_context_watchers, METH_NOARGS, NULL},
Expand Down
62 changes: 53 additions & 9 deletions Python/context.c
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ context_event_name(PyContextEvent event) {
static void
notify_context_watchers(PyThreadState *ts, PyContextEvent event, PyObject *ctx)
{
// The callbacks are registered on the interpreter, not on the thread, so
// the only way callbacks can know which thread changed is by calling the
// callbacks from the affected thread.
assert(ts == _PyThreadState_GET());
if (ctx == NULL) {
// This will happen after exiting the last context in the stack, which
// can occur if context_get was never called before entering a context
Expand Down Expand Up @@ -184,8 +188,9 @@ static inline void
context_switched(PyThreadState *ts)
{
ts->context_ver++;
// ts->context is used instead of context_get() because context_get() might
// throw if ts->context is NULL.
// ts->context is used instead of context_get() because if ts->context is
// NULL, context_get() will either call context_switched -- causing a
// double notification -- or throw.
notify_context_watchers(ts, Py_CONTEXT_SWITCHED, ts->context);
}

Expand Down Expand Up @@ -244,6 +249,7 @@ _PyContext_Exit(PyThreadState *ts, PyObject *octx)

ctx->ctx_prev = NULL;
ctx->ctx_entered = 0;
ctx->ctx_owned_by_thread = 0;
context_switched(ts);
return 0;
}
Expand All @@ -257,6 +263,37 @@ PyContext_Exit(PyObject *octx)
}


void
_PyContext_ExitThreadOwned(PyThreadState *ts)
{
assert(ts != NULL);
// notify_context_watchers requires the notification to come from the
// affected thread, so we can only exit the context(s) if ts belongs to the
// current thread.
_Bool on_thread = ts == _PyThreadState_GET();
while (ts->context != NULL
&& PyContext_CheckExact(ts->context)
&& ((PyContext *)ts->context)->ctx_owned_by_thread
&& on_thread) {
if (_PyContext_Exit(ts, ts->context)) {
Py_UNREACHABLE();
}
}
if (ts->context != NULL) {
// This intentionally does not use tstate variants of these functions
// (e.g., _PyErr_GetRaisedException(ts)) because ts might not belong to
// the current thread.
PyObject *exc = PyErr_GetRaisedException();
PyErr_SetString(PyExc_RuntimeError,
"contextvars.Context object(s) still entered during "
"thread state reset");
PyErr_FormatUnraisable(
"Exception ignored during reset of thread state %p", ts);
PyErr_SetRaisedException(exc);
}
}


PyObject *
PyContextVar_New(const char *name, PyObject *def)
{
Expand Down Expand Up @@ -433,6 +470,7 @@ _context_alloc(void)
ctx->ctx_vars = NULL;
ctx->ctx_prev = NULL;
ctx->ctx_entered = 0;
ctx->ctx_owned_by_thread = 0;
ctx->ctx_weakreflist = NULL;

return ctx;
Expand Down Expand Up @@ -478,15 +516,21 @@ context_get(void)
{
PyThreadState *ts = _PyThreadState_GET();
assert(ts != NULL);
PyContext *current_ctx = (PyContext *)ts->context;
if (current_ctx == NULL) {
current_ctx = context_new_empty();
if (current_ctx == NULL) {
return NULL;
if (ts->context == NULL) {
PyContext *ctx = context_new_empty();
if (ctx != NULL) {
if (_PyContext_Enter(ts, (PyObject *)ctx)) {
Py_UNREACHABLE();
}
ctx->ctx_owned_by_thread = 1;
}
ts->context = (PyObject *)current_ctx;
assert(ts->context == (PyObject *)ctx);
Py_CLEAR(ctx); // _PyContext_Enter created its own ref.
}
return current_ctx;
// The current context may be NULL if the above context_new_empty() call
// failed.
assert(ts->context == NULL || PyContext_CheckExact(ts->context));
return (PyContext *)ts->context;
}

static int
Expand Down
4 changes: 4 additions & 0 deletions Python/pystate.c
Original file line number Diff line number Diff line change
Expand Up @@ -1650,6 +1650,10 @@ PyThreadState_Clear(PyThreadState *tstate)
"PyThreadState_Clear: warning: thread still has a frame\n");
}

// This calls callbacks registered with PyContext_AddWatcher and can call
// sys.unraisablehook.
_PyContext_ExitThreadOwned(tstate);

/* At this point tstate shouldn't be used any more,
neither to run Python code nor for other uses.

Expand Down
Loading
0