8000 gh-124309: Modernize the `staggered_race` implementation to support eager task factories by ZeroIntensity · Pull Request #124390 · python/cpython · GitHub
[go: up one dir, main page]

Skip to content

gh-124309: Modernize the staggered_race implementation to support eager task factories #124390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Partial implementation.
  • Loading branch information
ZeroIntensity committed Sep 23, 2024
commit 4002695554ad6c6ef73ef1380b569c611d0ba268
106 changes: 38 additions & 68 deletions Lib/asyncio/staggered.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from . import events
from . import exceptions as exceptions_mod
from . import locks
from . import tasks
from . import taskgroups


async def staggered_race(coro_fns, delay, *, loop=None):
Expand Down Expand Up @@ -63,76 +63,46 @@ async def staggered_race(coro_fns, delay, *, loop=None):
"""
# TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
loop = loop or events.get_running_loop()
enum_coro_fns = enumerate(coro_fns)
winner_result = None
winner_index = None
exceptions = []
running_tasks = []

async def run_one_coro(previous_failed) -> None:
# Wait for the previous task to finish, or for delay seconds
if previous_failed is not None:
with contextlib.suppress(exceptions_mod.TimeoutError):
# Use asyncio.wait_for() instead of asyncio.wait() here, so
# that if we get cancelled at this point, Event.wait() is also
# cancelled, otherwise there will be a "Task destroyed but it is
# pending" later.
await tasks.wait_for(previous_failed.wait(), delay)
# Get the next coroutine to run
try:
this_index, coro_fn = next(enum_coro_fns)
except StopIteration:
return
# Start task that will run the next coroutine
this_failed = locks.Event()
next_task = loop.create_task(run_one_coro(this_failed))
running_tasks.append(next_task)
assert len(running_tasks) == this_index + 2
# Prepare place to put this coroutine's exceptions if not won
exceptions.append(None)
assert len(exceptions) == this_index + 1

def future_callback(index, future, task_group):
assert future.done()

try:
result = await coro_fn()
except (SystemExit, KeyboardInterrupt):
10000 raise
except BaseException as e:
exceptions[this_index] = e
this_failed.set() # Kickstart the next coroutine
error = future.exception()
except exceptions_mod.CancelledError as cancelled_error:
exceptions[index] = cancelled_error
else:
# Store winner's results
nonlocal winner_index, winner_result
assert winner_index is None
winner_index = this_index
winner_result = result
# Cancel all other tasks. We take care to not cancel the current
# task as well. If we do so, then since there is no `await` after
# here and CancelledError are usually thrown at one, we will
# encounter a curious corner case where the current task will end
# up as done() == True, cancelled() == False, exception() ==
# asyncio.CancelledError. This behavior is specified in
# https://bugs.python.org/issue30048
for i, t in enumerate(running_tasks):
if i != this_index:
t.cancel()

first_task = loop.create_task(run_one_coro(None))
running_tasks.append(first_task)
try:
# Wait for a growing list of tasks to all finish: poor man's version of
# curio's TaskGroup or trio's nursery
done_count = 0
while done_count != len(running_tasks):
done, _ = await tasks.wait(running_tasks)
done_count = len(done)
# If run_one_coro raises an unhandled exception, it's probably a
# programming error, and I want to see it.
if __debug__:
for d in done:
if d.done() and not d.cancelled() and d.exception():
raise d.exception()
return winner_result, winner_index, exceptions
finally:
# Make sure no tasks are left running if we leave this function
for t in running_tasks:
t.cancel()
exceptions[index] = error
task_group._errors.remove(error)

nonlocal winner_result, winner_index
if (winner_result is None) and (not task_group._aborting):
# If this is in an eager task factory, it's possible
# for multiple tasks to get here. In that case, we want
# only the first one to win and the rest to no-op before
# cancellation.
winner_result = future.result()
winner_index = index
task_group._abort()

async with taskgroups.TaskGroup() as task_group:
for index, coro in enumerate(coro_fns):
if task_group._aborting:
break

def wrapper(idx):
return lambda future: future_callback(idx, future, task_group)

exceptions.append(None)
task = task_group.create_task(coro())
task.add_done_callback(wrapper(index))

if delay is not None:
await tasks.sleep(delay)
else:
await task

return winner_result, winner_index, exceptions
20 changes: 20 additions & 0 deletions Lib/test/test_asyncio/test_eager_task_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,26 @@ async def run():

self.run_coro(run())

def test_staggered_race_with_eager_tasks(self):
# See GH-124309
async def coro(amount):
await asyncio.sleep(amount)
return amount

async def run():
winner, index, excs = await asyncio.staggered.staggered_race(
[
lambda: coro(1),
lambda: coro(0),
lambda: coro(2)
],
delay=None
)
self.assertEqual(winner, 0)
self.assertEqual(index, 1)

self.run_coro(run())


class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
Task = tasks._PyTask
Expand Down
4 changes: 4 additions & 0 deletions Lib/test/test_asyncio/test_staggered.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,7 @@ async def coro(index):
self.assertEqual(len(excs), 2)
self.assertIsInstance(excs[0], ValueError)
self.assertIsInstance(excs[1], ValueError)


if __name__ == "__main__":
unittest.main()
0