8000 gh-124309: Modernize the `staggered_race` implementation to support eager task factories by ZeroIntensity · Pull Request #124390 · python/cpython · GitHub
[go: up one dir, main page]

Skip to content

gh-124309: Modernize the staggered_race implementation to support eager task factories #124390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix tests with eager task factory.
  • Loading branch information
ZeroIntensity committed Sep 23, 2024
commit e2cf78a194d8ca8046ce52e223ddd6584c062689
41 changes: 30 additions & 11 deletions Lib/asyncio/staggered.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,37 +72,56 @@ def future_callback(index, future, task_group):

try:
error = future.exception()
exceptions[index] = error
except exceptions_mod.CancelledError as cancelled_error:
# If another task finishes first and cancels this task, it
# is propagated here.
exceptions[index] = cancelled_error
return
else:
exceptions[index] = error
task_group._errors.remove(error)
if error is not None:
return

nonlocal winner_result, winner_index
if (winner_result is None) and (not task_group._aborting):
# If this is in an eager task factory, it's possible
# for multiple tasks to get here. In that case, we want
# only the first one to win and the rest to no-op before
# cancellation.
# If this is in an eager task factory, it's possible
# for multiple tasks to get here. In that case, we want
# only the first one to win and the rest to no-op before
# cancellation.
if winner_result is None and not task_group._aborting:
winner_result = future.result()
winner_index = index

# Cancel all other tasks, we win!
task_group._abort()

async with taskgroups.TaskGroup() as task_group:
for index, coro in enumerate(coro_fns):
if task_group._aborting:
break

exceptions.append(None)
task = loop.create_task(coro())

# We don't want the task group to propagate the error. Instead,
# we want to put it in our special exceptions list, so we manually
# create the task.
task.add_done_callback(task_group._on_task_done_without_propagation)
task_group._tasks.add(task)

# We need this extra wrapper here to stop the closure from having
# an incorrect index.
def wrapper(idx):
return lambda future: future_callback(idx, future, task_group)

exceptions.append(None)
task = task_group.create_task(coro())
task.add_done_callback(wrapper(index))

if delay is not None:
await tasks.sleep(delay)
await tasks.sleep(delay or 0)
else:
await task
# We don't care about exceptions here, the callback will
# deal with it.
with contextlib.suppress(BaseException):
# If there's no delay, we just wait until completion.
await task

return winner_result, winner_index, exceptions
8 changes: 6 additions & 2 deletions Lib/asyncio/taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ def _abort(self):
if not t.done():
t.cancel()

def _on_task_done(self, task):
def _on_task_done_without_propagation(self, task):
# For staggered_race()
self._tasks.discard(task)

if self._on_completed_fut is not None and not self._tasks:
Expand All @@ -209,7 +210,10 @@ def _on_task_done(self, task):
if task.cancelled():
return

exc = task.exception()
return task.exception()

def _on_task_done(self, task):
exc = self._on_task_done_without_propagation(task)
if exc is None:
return

Expand Down
5 changes: 3 additions & 2 deletions Lib/test/test_asyncio/test_eager_task_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,16 +243,17 @@ async def fail():
async def run():
winner, index, excs = await asyncio.staggered.staggered_race(
[
lambda: fail(),
lambda: asyncio.sleep(1),
lambda: asyncio.sleep(0),
lambda: fail()
],
delay=None
)
self.assertIsNone(winner)
self.assertEqual(index, 1)
self.assertIsNone(excs[index])
self.assertIsInstance(excs[2], ValueError)
self.assertIsInstance(excs[0], ValueError)
self.assertEqual(len(excs), 2)

self.run_coro(run())

Expand Down
0