diff --git a/Doc/library/asyncio-eventloop.rst b/Doc/library/asyncio-eventloop.rst index c51990eff8deec..6fe95687c151de 100644 --- a/Doc/library/asyncio-eventloop.rst +++ b/Doc/library/asyncio-eventloop.rst @@ -1271,6 +1271,15 @@ Allows customizing how exceptions are handled in the event loop. (see :meth:`call_exception_handler` documentation for details about context). + If the handler is called on behalf of a :class:`~asyncio.Task` or + :class:`~asyncio.Handle`, it is run in the + :class:`contextvars.Context` of that task or callback handle. + + .. versionchanged:: 3.12 + + The handler may be called in the :class:`~contextvars.Context` + of the task or handle where the exception originated. + .. method:: loop.get_exception_handler() Return the current exception handler, or ``None`` if no custom @@ -1474,6 +1483,13 @@ Callback Handles A callback wrapper object returned by :meth:`loop.call_soon`, :meth:`loop.call_soon_threadsafe`. + .. method:: get_context() + + Return the :class:`contextvars.Context` object + associated with the handle. + + .. versionadded:: 3.12 + .. method:: cancel() Cancel the callback. If the callback has already been canceled diff --git a/Doc/library/asyncio-task.rst b/Doc/library/asyncio-task.rst index ade969220ea701..d922f614954fcd 100644 --- a/Doc/library/asyncio-task.rst +++ b/Doc/library/asyncio-task.rst @@ -1097,6 +1097,13 @@ Task Object .. versionadded:: 3.8 + .. method:: get_context() + + Return the :class:`contextvars.Context` object + associated with the task. + + .. versionadded:: 3.12 + .. method:: get_name() Return the name of the Task. diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 66202f09794db8..c8a2f9f25634ef 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -1808,7 +1808,22 @@ def call_exception_handler(self, context): exc_info=True) else: try: - self._exception_handler(self, context) + ctx = None + thing = context.get("task") + if thing is None: + # Even though Futures don't have a context, + # Task is a subclass of Future, + # and sometimes the 'future' key holds a Task. + thing = context.get("future") + if thing is None: + # Handles also have a context. + thing = context.get("handle") + if thing is not None and hasattr(thing, "get_context"): + ctx = thing.get_context() + if ctx is not None and hasattr(ctx, "run"): + ctx.run(self._exception_handler, self, context) + else: + self._exception_handler(self, context) except (SystemExit, KeyboardInterrupt): raise except BaseException as exc: diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py index 0d26ea545baa5d..a327ba54a323a8 100644 --- a/Lib/asyncio/events.py +++ b/Lib/asyncio/events.py @@ -61,6 +61,9 @@ def __repr__(self): info = self._repr_info() return '<{}>'.format(' '.join(info)) + def get_context(self): + return self._context + def cancel(self): if not self._cancelled: self._cancelled = True diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py index e48da0f2008829..8d6dfcd81b7377 100644 --- a/Lib/asyncio/tasks.py +++ b/Lib/asyncio/tasks.py @@ -139,6 +139,9 @@ def __repr__(self): def get_coro(self): return self._coro + def get_context(self): + return self._context + def get_name(self): return self._name diff --git a/Lib/test/test_asyncio/test_futures2.py b/Lib/test/test_asyncio/test_futures2.py index 71279b69c7929e..9e7a5775a70383 100644 --- a/Lib/test/test_asyncio/test_futures2.py +++ b/Lib/test/test_asyncio/test_futures2.py @@ -1,5 +1,6 @@ # IsolatedAsyncioTestCase based tests import asyncio +import contextvars import traceback import unittest from asyncio import tasks @@ -27,6 +28,46 @@ async def raise_exc(): else: self.fail('TypeError was not raised') + async def test_task_exc_handler_correct_context(self): + # see https://github.com/python/cpython/issues/96704 + name = contextvars.ContextVar('name', default='foo') + exc_handler_called = False + + def exc_handler(*args): + self.assertEqual(name.get(), 'bar') + nonlocal exc_handler_called + exc_handler_called = True + + async def task(): + name.set('bar') + 1/0 + + loop = asyncio.get_running_loop() + loop.set_exception_handler(exc_handler) + self.cls(task()) + await asyncio.sleep(0) + self.assertTrue(exc_handler_called) + + async def test_handle_exc_handler_correct_context(self): + # see https://github.com/python/cpython/issues/96704 + name = contextvars.ContextVar('name', default='foo') + exc_handler_called = False + + def exc_handler(*args): + self.assertEqual(name.get(), 'bar') + nonlocal exc_handler_called + exc_handler_called = True + + def callback(): + name.set('bar') + 1/0 + + loop = asyncio.get_running_loop() + loop.set_exception_handler(exc_handler) + loop.call_soon(callback) + await asyncio.sleep(0) + self.assertTrue(exc_handler_called) + @unittest.skipUnless(hasattr(tasks, '_CTask'), 'requires the C _asyncio module') class CFutureTests(FutureTests, unittest.IsolatedAsyncioTestCase): diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py index 04bdf648313148..2491285206bcd7 100644 --- a/Lib/test/test_asyncio/test_tasks.py +++ b/Lib/test/test_asyncio/test_tasks.py @@ -2482,6 +2482,17 @@ def test_get_coro(self): finally: loop.close() + def test_get_context(self): + loop = asyncio.new_event_loop() + coro = coroutine_function() + context = contextvars.copy_context() + try: + task = self.new_task(loop, coro, context=context) + loop.run_until_complete(task) + self.assertIs(task.get_context(), context) + finally: + loop.close() + def add_subclass_tests(cls): BaseTask = cls.Task diff --git a/Misc/NEWS.d/next/Library/2022-09-18-04-51-30.gh-issue-96704.DmamRX.rst b/Misc/NEWS.d/next/Library/2022-09-18-04-51-30.gh-issue-96704.DmamRX.rst new file mode 100644 index 00000000000000..6ac99197e685c1 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2022-09-18-04-51-30.gh-issue-96704.DmamRX.rst @@ -0,0 +1 @@ +Pass the correct ``contextvars.Context`` when a ``asyncio`` exception handler is called on behalf of a task or callback handle. This adds a new ``Task`` method, ``get_context``, and also a new ``Handle`` method with the same name. If this method is not found on a task object (perhaps because it is a third-party library that does not yet provide this method), the context prevailing at the time the exception handler is called is used. diff --git a/Modules/_asynciomodule.c b/Modules/_asynciomodule.c index 909171150bdd36..efa0d2d6906eab 100644 --- a/Modules/_asynciomodule.c +++ b/Modules/_asynciomodule.c @@ -2409,6 +2409,18 @@ _asyncio_Task_get_coro_impl(TaskObj *self) return self->task_coro; } +/*[clinic input] +_asyncio.Task.get_context +[clinic start generated code]*/ + +static PyObject * +_asyncio_Task_get_context_impl(TaskObj *self) +/*[clinic end generated code: output=6996f53d3dc01aef input=87c0b209b8fceeeb]*/ +{ + Py_INCREF(self->task_context); + return self->task_context; +} + /*[clinic input] _asyncio.Task.get_name [clinic start generated code]*/ @@ -2536,6 +2548,7 @@ static PyMethodDef TaskType_methods[] = { _ASYNCIO_TASK_GET_NAME_METHODDEF _ASYNCIO_TASK_SET_NAME_METHODDEF _ASYNCIO_TASK_GET_CORO_METHODDEF + _ASYNCIO_TASK_GET_CONTEXT_METHODDEF {"__class_getitem__", Py_GenericAlias, METH_O|METH_CLASS, PyDoc_STR("See PEP 585")}, {NULL, NULL} /* Sentinel */ }; diff --git a/Modules/clinic/_asynciomodule.c.h b/Modules/clinic/_asynciomodule.c.h index daf524c3456ca8..ddec54c8d7c2bc 100644 --- a/Modules/clinic/_asynciomodule.c.h +++ b/Modules/clinic/_asynciomodule.c.h @@ -772,6 +772,23 @@ _asyncio_Task_get_coro(TaskObj *self, PyObject *Py_UNUSED(ignored)) return _asyncio_Task_get_coro_impl(self); } +PyDoc_STRVAR(_asyncio_Task_get_context__doc__, +"get_context($self, /)\n" +"--\n" +"\n"); + +#define _ASYNCIO_TASK_GET_CONTEXT_METHODDEF \ + {"get_context", (PyCFunction)_asyncio_Task_get_context, METH_NOARGS, _asyncio_Task_get_context__doc__}, + +static PyObject * +_asyncio_Task_get_context_impl(TaskObj *self); + +static PyObject * +_asyncio_Task_get_context(TaskObj *self, PyObject *Py_UNUSED(ignored)) +{ + return _asyncio_Task_get_context_impl(self); +} + PyDoc_STRVAR(_asyncio_Task_get_name__doc__, "get_name($self, /)\n" "--\n" @@ -1172,4 +1189,4 @@ _asyncio__leave_task(PyObject *module, PyObject *const *args, Py_ssize_t nargs, exit: return return_value; } -/*[clinic end generated code: output=459a7c7f21bbc290 input=a9049054013a1b77]*/ +/*[clinic end generated code: output=f117b2246eaf7a55 input=a9049054013a1b77]*/