8000 bpo-46994: Accept explicit contextvars.Context in asyncio create_task… · python/cpython@9523c0d · GitHub
[go: up one dir, main page]

Skip to content

Commit 9523c0d

Browse files
authored
bpo-46994: Accept explicit contextvars.Context in asyncio create_task() API (GH-31837)
1 parent 2153daf commit 9523c0d

File tree

13 files changed

+209
-65
lines changed

13 files changed

+209
-65
lines changed

Doc/library/asyncio-eventloop.rst

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ Creating Futures and Tasks
330330

331331
.. versionadded:: 3.5.2
332332

333-
.. method:: loop.create_task(coro, *, name=None)
333+
.. method:: loop.create_task(coro, *, name=None, context=None)
334334

335335
Schedule the execution of a :ref:`coroutine`.
336336
Return a :class:`Task` object.
@@ -342,17 +342,24 @@ Creating Futures and Tasks
342342
If the *name* argument is provided and not ``None``, it is set as
343343
the name of the task using :meth:`Task.set_name`.
344344

345+
An optional keyword-only *context* argument allows specifying a
346+
custom :class:`contextvars.Context` for the *coro* to run in.
347+
The current context copy is created when no *context* is provided.
348+
345349
.. versionchanged:: 3.8
346350
Added the *name* parameter.
347351

352+
.. versionchanged:: 3.11
353+
Added the *context* parameter.
354+
348355
.. method:: loop.set_task_factory(factory)
349356

350357
Set a task factory that will be used by
351358
:meth:`loop.create_task`.
352359

353360
If *factory* is ``None`` the default task factory will be set.
354361
Otherwise, *factory* must be a *callable* with the signature matching
355-
``(loop, coro)``, where *loop* is a reference to the active
362+
``(loop, coro, context=None)``, where *loop* is a reference to the active
356363
event loop, and *coro* is a coroutine object. The callable
357364
must return a :class:`asyncio.Future`-compatible object.
358365

Doc/library/asyncio-task.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,14 +244,18 @@ Running an asyncio Program
244244
Creating Tasks
245245
==============
246246

247-
.. function:: create_task(coro, *, name=None)
247+
.. function:: create_task(coro, *, name=None, context=None)
248248

249249
Wrap the *coro* :ref:`coroutine <coroutine>` into a :class:`Task`
250250
and schedule its execution. Return the Task object.
251251

252252
If *name* is not ``None``, it is set as the name of the task using
253253
:meth:`Task.set_name`.
254254

255+
An optional keyword-only *context* argument allows specifying a
256+
custom :class:`contextvars.Context` for the *coro* to run in.
257+
The current context copy is created when no *context* is provided.
258+
255259
The task is executed in the loop returned by :func:`get_running_loop`,
256260
:exc:`RuntimeError` is raised if there is no running loop in
257261
current thread.
@@ -281,6 +285,9 @@ Creating Tasks
281285
.. versionchanged:: 3.8
282286
Added the *name* parameter.
283287

288+
.. versionchanged:: 3.11
289+
Added the *context* parameter.
290+
284291

285292
Sleeping
286293
========

Lib/asyncio/base_events.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,18 +426,23 @@ def create_future(self):
426426
"""Create a Future object attached to the loop."""
427427
return futures.Future(loop=self)
428428

429-
def create_task(self, coro, *, name=None):
429+
def create_task(self, coro, *, name=None, context=None):
430430
"""Schedule a coroutine object.
431431
432432
Return a task object.
433433
"""
434434
self._check_closed()
435435
if self._task_factory is None:
436-
task = tasks.Task(coro, loop=self, name=name)
436+
task = tasks.Task(coro, loop=self, name=name, context=context)
437437
if task._source_traceback:
438438
del task._source_traceback[-1]
439439
else:
440-
task = self._task_factory(self, coro)
440+
if context is None:
441+
# Use legacy API if context is not needed
442+
task = self._task_factory(self, coro)
443+
else:
444+
task = self._task_factory(self, coro, context=context)
445+
441446
tasks._set_task_name(task, name)
442447

443448
return task

Lib/asyncio/events.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def create_future(self):
274274

275275
# Method scheduling a coroutine object: create a task.
276276

277-
def create_task(self, coro, *, name=None):
277+
def create_task(self, coro, *, name=None, context=None):
278278
raise NotImplementedError
279279

280280
# Methods for interacting with threads.

Lib/asyncio/taskgroups.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,15 @@ async def __aexit__(self, et, exc, tb):
138138
me = BaseExceptionGroup('unhandled errors in a TaskGroup', errors)
139139
raise me from None
140140

141-
def create_task(self, coro, *, name=None):
141+
def create_task(self, coro, *, name=None, context=None):
142142
if not self._entered:
143143
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
144144
if self._exiting and self._unfinished_tasks == 0:
145145
raise RuntimeError(f"TaskGroup {self!r} is finished")
146-
task = self._loop.create_task(coro)
146+
if context is None:
147+
task = self._loop.create_task(coro)
148+
else:
149+
task = self._loop.create_task(coro, context=context)
147150
tasks._set_task_name(task, name)
148151
task.add_done_callback(self._on_task_done)
149152
self._unfinished_tasks += 1

Lib/asyncio/tasks.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
9393
# status is still pending
9494
_log_destroy_pending = True
9595

96-
def __init__(self, coro, *, loop=None, name=None):
96+
def __init__(self, coro, *, loop=None, name=None, context=None):
9797
super().__init__(loop=loop)
9898
if self._source_traceback:
9999
del self._source_traceback[-1]
@@ -112,7 +112,10 @@ def __init__(self, coro, *, loop=None, name=None):
112112
self._must_cancel = False
113113
self._fut_waiter = None
114114
self._coro = coro
115-
self._context = contextvars.copy_context()
115+
if context is None:
116+
self._context = contextvars.copy_context()
117+
else:
118+
self._context = context
116119

117120
self._loop.call_soon(self.__step, context=self._context)
118121
_register_task(self)
@@ -360,13 +363,18 @@ def __wakeup(self, future):
360363
Task = _CTask = _asyncio.Task
361364

362365

363-
def create_task(coro, *, name=None):
366+
def create_task(coro, *, name=None, context=None):
364367
"""Schedule the execution of a coroutine object in a spawn task.
365368
366369
Return a Task object.
367370
"""
368371
loop = events.get_running_loop()
369-
task = loop.create_task(coro)
372+
if context is None:
373+
# Use legacy API if context is not needed
374+
task = loop.create_task(coro)
375+
else:
376+
task = loop.create_task(coro, context=context)
377+
370378
_set_task_name(task, name)
371379
return task
372380

Lib/test/test_asyncio/test_taskgroups.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
import asyncio
5+
import contextvars
56
10000
67
from asyncio import taskgroups
78
import unittest
@@ -708,6 +709,23 @@ async def coro():
708709
t = g.create_task(coro(), name="yolo")
709710
self.assertEqual(t.get_name(), "yolo")
710711

712+
async def test_taskgroup_task_context(self):
713+
cvar = contextvars.ContextVar('cvar')
714+
715+
async def coro(val):
716+
await asyncio.sleep(0)
717+
cvar.set(val)
718+
719+
async with taskgroups.TaskGroup() as g:
720+
ctx = contextvars.copy_context()
721+
self.assertIsNone(ctx.get(cvar))
722+
t1 = g.create_task(coro(1), context=ctx)
723+
await t1
724+
self.assertEqual(1, ctx.get(cvar))
725+
t2 = g.create_task(coro(2), context=ctx)
726+
await t2
727+
self.assertEqual(2, ctx.get(cvar))
728+
711729

712730
if __name__ == "__main__":
713731
unittest.main()

Lib/test/test_asyncio/test_tasks.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ class BaseTaskTests:
9595
Task = None
9696
Future = None
9797

98-
def new_task(self, loop, coro, name='TestTask'):
99-
return self.__class__.Task(coro, loop=loop, name=name)
98+
def new_task(self, loop, coro, name='TestTask', context=None):
99+
return self.__class__.Task(coro, loop=loop, name=name, context=context)
100100

101101
def new_future(self, loop):
102102
return self.__class__.Future(loop=loop)
@@ -2527,6 +2527,90 @@ async def main():
25272527

25282528
self.assertEqual(cvar.get(), -1)
25292529

2530+
def test_context_4(self):
2531+
cvar = contextvars.ContextVar('cvar')
2532+
2533+
async def coro(val):
2534+
await asyncio.sleep(0)
2535+
cvar.set(val)
2536+
2537+
async def main():
2538+
ret = []
2539+
ctx = contextvars.copy_context()
2540+
ret.append(ctx.get(cvar))
2541+
t1 = self.new_task(loop, coro(1), context=ctx)
2542+
await t1
2543+
ret.append(ctx.get(cvar))
2544+
t2 = self.new_task(loop, coro(2), context=ctx)
2545+
await t2
2546+
ret.append(ctx.get(cvar))
2547+
return ret
2548+
2549+
loop = asyncio.new_event_loop()
2550+
try:
2551+
task = self.new_task(loop, main())
2552+
ret = loop.run_until_complete(task)
2553+
finally:
2554+
loop.close()
2555+
2556+
self.assertEqual([None, 1, 2], ret)
2557+
2558+
def test_context_5(self):
2559+
cvar = contextvars.ContextVar('cvar')
2560+
2561+
async def coro(val):
2562+
await asyncio.sleep(0)
2563+
cvar.set(val)
2564+
2565+
async def main():
2566+
ret = []
2567+
ctx = contextvars.copy_context()
2568+
ret.append(ctx.get(cvar))
2569+
t1 = asyncio.create_task(coro(1), context=ctx)
2570+
await t1
2571+
ret.append(ctx.get(cvar))
2572+
t2 = asyncio.create_task(coro(2), context=ctx)
2573+
await t2
2574+
ret.append(ctx.get(cvar))
2575+
return ret
2576+
2577+
loop = asyncio.new_event_loop()
2578+
try:
2579+
task = self.new_task(loop, main())
2580+
ret = loop.run_until_complete(task)
2581+
finally:
2582+
loop.close()
2583+
2584+
self.assertEqual([None, 1, 2], ret)
2585+
2586+
def test_context_6(self):
2587+
cvar = contextvars.ContextVar('cvar')
2588+
2589+
async def coro(val):
2590+
await asyncio.sleep(0)
2591+
cvar.set(val)
2592+
2593+
async def main():
2594+
ret = []
2595+
ctx = contextvars.copy_context()
2596+
ret.append(ctx.get(cvar))
2597+
t1 = loop.create_task(coro(1), context=ctx)
2598+
await t1
2599+
ret.append(ctx.get(cvar))
2600+
t2 = loop.create_task(coro(2), context=ctx)
2601+
await t2
2602+
ret.append(ctx.get(cvar))
2603+
return ret
2604+
2605+
loop = asyncio.new_event_loop()
2606+
try:
2607+
task = loop.create_task(main())
2608+
ret = loop.run_until_complete(task)
2609+
finally:
2610+
loop.close()
2611+
2612+
self.assertEqual([None, 1, 2], ret)
2613+
25302614
def test_get_coro(self):
25312615
loop = asyncio.new_event_loop()
25322616
coro = coroutine_function()

Lib/unittest/async_case.py

Lines changed: 17 additions & 38 deletions
38
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextvars
23
import inspect
34
import warnings
45

@@ -34,7 +35,7 @@ class IsolatedAsyncioTestCase(TestCase):
3435
def __init__(self, methodName='runTest'):
3536
super().__init__(methodName)
3637
self._asyncioTestLoop = None
37-
self._asyncioCallsQueue = None
+
self._asyncioTestContext = contextvars.copy_context()
3839

3940
async def asyncSetUp(self):
4041
pass
@@ -58,7 +59,7 @@ def addAsyncCleanup(self, func, /, *args, **kwargs):
5859
self.addCleanup(*(func, *args), **kwargs)
5960

6061
def _callSetUp(self):
61-
self.setUp()
62+
self._asyncioTestContext.run(self.setUp)
6263
self._callAsync(self.asyncSetUp)
6364

6465
def _callTestMethod(self, method):
@@ -68,64 +69,42 @@ def _callTestMethod(self, method):
6869

6970
def _callTearDown(self):
7071
self._callAsync(self.asyncTearDown)
71-
self.tearDown()
72+
self._asyncioTestContext.run(self.tearDown)
7273

7374
def _callCleanup(self, function, *args, **kwargs):
7475
self._callMaybeAsync(function, *args, **kwargs)
7576

7677
def _callAsync(self, func, /, *args, **kwargs):
7778
assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
78-
ret = func(*args, **kwargs)
79-
assert inspect.isawaitable(ret), f'{func!r} returned non-awaitable'
80-
fut = self._asyncioTestLoop.create_future()
81-
self._asyncioCallsQueue.put_nowait((fut, ret))
82-
return self._asyncioTestLoop.run_until_complete(fut)
79+
assert inspect.iscoroutinefunction(func), f'{func!r} is not an async function'
80+
task = self._asyncioTestLoop.create_task(
81+
func(*args, **kwargs),
82+
context=self._asyncioTestContext,
83+
)
84+
return self._asyncioTestLoop.run_until_complete(task)
8385

8486
def _callMaybeAsync(self, func, /, *args, **kwargs):
8587
assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
86-
ret = func(*args, **kwargs)
87-
if inspect.isawaitable(ret):
88-
fut = self._asyncioTestLoop.create_future()
89-
self._asyncioCallsQueue.put_nowait((fut, ret))
90-
return self._asyncioTestLoop.run_until_complete(fut)
88+
if inspect.iscoroutinefunction(func):
89+
task = self._asyncioTestLoop.create_task(
90+
func(*args, **kwargs),
91+
context=self._asyncioTestContext,
92+
)
93+
return self._asyncioTestLoop.run_until_complete(task)
9194
else:
92-
return ret
93-
94-
async def _asyncioLoopRunner(self, fut):
95-
self._asyncioCallsQueue = queue = asyncio.Queue()
96-
fut.set_result(None)
97-
while True:
98-
query = await queue.get()
99-
queue.task_done()
100-
if query is None:
101-
return
102-
fut, awaitable = query
103-
try:
104-
ret = await awaitable
105-
if not fut.cancelled():
106-
fut.set_result(ret)
107-
except (SystemExit, KeyboardInterrupt):
108-
raise
109-
except (BaseException, asyncio.CancelledError) as ex:
110-
if not fut.cancelled():
111-
fut.set_exception(ex)
95+
return self._asyncioTestContext.run(func, *args, **kwargs)
11296

11397
def _setupAsyncioLoop(self):
11498
assert self._asyncioTestLoop is None, 'asyncio test loop already initialized'
11599
loop = asyncio.new_event_loop()
116100
asyncio.set_event_loop(loop)
117101
loop.set_debug(True)
118102
self._asyncioTestLoop = loop
119-
fut = loop.create_future()
120-
self._asyncioCallsTask = loop.create_task(self._asyncioLoopRunner(fut))
121-
loop.run_until_complete(fut)
122103

123104
def _tearDownAsyncioLoop(self):
124105
assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
125106
loop = self._asyncioTestLoop
126107
self._asyncioTestLoop = None
127-
self._asyncioCallsQueue.put_nowait(None)
128-
loop.run_until_complete(self._asyncioCallsQueue.join())
129108

130109
try:
131110
# cancel all tasks

0 commit comments

Comments
 (0)
0