8000 gh-124872: Mark the thread's initial context as entered · rhansen/cpython@81d2a0d · GitHub
[go: up one dir, main page]

Skip to content

Commit 81d2a0d

Browse files
committed
pythongh-124872: Mark the thread's initial context as entered
Starting with commit 843d28f (temporarily reverted in d3c82b9 and restored in commit bee112a), it is now technically possible to access a thread's initial context created by `context_get`. Mark that context as entered so that developers cannot push that context onto the thread's stack a second time, which would cause a cycle. (Even if the `CONTEXT_SWITCHED` event is removed, this is good defensive practice, and the consistent treatment of all contexts on the stack makes it easier to understand the code.)
1 parent 8e7b2a1 commit 81d2a0d

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

Lib/test/test_capi/test_watchers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,5 +659,21 @@ def test_exit_base_context(self):
659659
ctx.run(lambda: None)
660660
self.assertEqual(switches, [ctx, None])
661661

662+
def test_reenter_base_context(self):
663+
_testcapi.clear_context_stack()
664+
# contextvars.copy_context() creates the base context (via the
665+
# context_get C function).
666+
ctx = contextvars.copy_context()
667+
with self.context_watcher(0) as switches:
668+
ctx.run(lambda: None)
669+
self.assertEqual(len(switches), 2)
670+
self.assertEqual(switches[0], ctx)
671+
base_ctx = switches[1]
672+
self.assertIsNotNone(base_ctx)
673+
self.assertIsNot(base_ctx, ctx)
674+
with self.assertRaisesRegex(RuntimeError, 'already entered'):
675+
base_ctx.run(lambda: None)
676+
677+
662678
if __name__ == "__main__":
663679
unittest.main()

Python/context.c

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,9 @@ static inline void
184184
context_switched(PyThreadState *ts)
185185
{
186186
ts->context_ver++;
187-
// ts->context is used instead of context_get() because context_get() might
188-
// throw if ts->context is NULL.
187+
// ts->context is used instead of context_get() because if ts->context is
188+
// NULL, context_get() will either call context_switched -- causing a
189+
// double notification -- or throw.
189190
notify_context_watchers(ts, Py_CONTEXT_SWITCHED, ts->context);
190191
}
191192

@@ -478,15 +479,18 @@ context_get(void)
478479
{
479480
PyThreadState *ts = _PyThreadState_GET();
480481
assert(ts != NULL);
481-
PyContext *current_ctx = (PyContext *)ts->context;
482-
if (current_ctx == NULL) {
483-
current_ctx = context_new_empty();
484-
if (current_ctx == NULL) {
485-
return NULL;
482+
if (ts->context == NULL) {
483+
PyContext *ctx = context_new_empty();
484+
if (ctx != NULL && _PyContext_Enter(ts, (PyObject *)ctx)) {
485+
Py_UNREACHABLE();
486486
}
487-
ts->context = (PyObject *)current_ctx;
487+
assert(ts->context == (PyObject *)ctx);
488+
Py_CLEAR(ctx); // _PyContext_Enter created its own ref.
488489
}
489-
return current_ctx;
490+
// The current context may be NULL if the above context_new_empty() call
491+
// failed.
492+
assert(ts->context == NULL || PyContext_CheckExact(ts->context));
493+
return (PyContext *)ts->context;
490494
}
491495

492496
static int

0 commit comments

Comments
 (0)
0