8000 gh-124309: Modernize the `staggered_race` implementation to support eager task factories by ZeroIntensity · Pull Request #124390 · python/cpython · GitHub
[go: up one dir, main page]

Skip to content

gh-124309: Modernize the staggered_race implementation to support eager task factories #124390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor to use only public APIs
Co-authored-by: Thomas Grainger <tagrain@gmail.com>
  • Loading branch information
ZeroIntensity and graingert committed Sep 24, 2024
commit e213f518b20ff398f32f44097a34a3f63f1f99b1
87 changes: 29 additions & 58 deletions Lib/asyncio/staggered.py
< 8000 td class="blob-code blob-code-deletion js-file-line">
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

import contextlib

from . import events
from . import exceptions as exceptions_mod
from . import locks
from . import tasks
from . import taskgroups


class _Done(Exception):
pass

async def staggered_race(coro_fns, delay, *, loop=None):
"""Run coroutines with staggered start times and take the first to finish.

Expand Down Expand Up @@ -61,67 +63,36 @@ async def staggered_race(coro_fns, delay, *, loop=None):
coroutine's entry is ``None``.

"""
# TODO: allow async iterables in coro_fns
loop = loop or events.get_running_loop()
# TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
winner_result = None
winner_index = None
exceptions = []

def future_callback(index, future, task_group):
assert future.done()
async def run_one_coro(this_index, coro_fn, this_failed):
try:
error = future.exception()
exceptions[index] = error
except exceptions_mod.CancelledError as cancelled_error:
# If another task finishes first and cancels this task, it
# is propagated here.
exceptions[index] = cancelled_error
return
result = await coro_fn()
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as e:
exceptions[this_index] = e
this_failed.set() # Kickstart the next coroutine
else:
if error is not None:
return

nonlocal winner_result, winner_index
# If this is in an eager task factory, it's possible
# for multiple tasks to get here. In that case, we want
# only the first one to win and the rest to no-op before
# cancellation.
if winner_result is None and not task_group._aborting:
winner_result = future.result()
winner_index = index

# Cancel all other tasks, we win!
task_group._abort()

async with taskgroups.TaskGroup() as task_group:
for index, coro in enumerate(coro_fns):
if task_group._aborting:
break

exceptions.append(None)
task = loop.create_task(coro())

# We don't want the task group to propagate the error. Instead,
# we want to put it in our special exceptions list, so we manually
# create the task.
task.add_done_callback(task_group._on_task_done_without_propagation)
task_group._tasks.add(task)

# We need this extra wrapper here to stop the closure from having
# an incorrect index.
def wrapper(idx):
return lambda future: future_callback(idx, future, task_group)

task.add_done_callback(wrapper(index))

if delay is not None:
await tasks.sleep(delay or 0)
else:
# We don't care about exceptions here, the callback will
# deal with it.
with contextlib.suppress(BaseException):
# If there's no delay, we just wait until completion.
await task
# Store winner's results
nonlocal winner_index, winner_result
# There could be more than one winner
winner_index = this_index
winner_result = result
raise _Done

try:
async with taskgroups.TaskGroup() as tg:
for this_index, coro_fn in enumerate(coro_fns):
this_failed = locks.Event()
exceptions.append(None)
tg.create_task(run_one_coro(this_index, coro_fn, this_failed))
with contextlib.suppress(TimeoutError):
await tasks.wait_for(this_failed.wait(), delay)
except* _Done:
pass

return winner_result, winner_index, exceptions
Loading
0