4
4
5
5
import contextlib
6
6
7
+ from . import events
8
+ from . import exceptions as exceptions_mod
7
9
from . import locks
8
10
from . import tasks
9
- from . import taskgroups
10
11
11
- class _Done (Exception ):
12
- pass
13
12
14
- async def staggered_race (coro_fns , delay ):
13
+ async def staggered_race (coro_fns , delay , * , loop = None ):
15
14
"""Run coroutines with staggered start times and take the first to finish.
16
15
17
16
This method takes an iterable of coroutine functions. The first one is
@@ -43,6 +42,8 @@ async def staggered_race(coro_fns, delay):
43
42
delay: amount of time, in seconds, between starting coroutines. If
44
43
``None``, the coroutines will run sequentially.
45
44
45
+ loop: the event loop to use.
46
+
46
47
Returns:
47
48
tuple *(winner_result, winner_index, exceptions)* where
48
49
@@ -61,11 +62,36 @@ async def staggered_race(coro_fns, delay):
61
62
62
63
"""
63
64
# TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
65
+ loop = loop or events .get_running_loop ()
66
+ enum_coro_fns = enumerate (coro_fns )
64
67
winner_result = None
65
68
winner_index = None
66
69
exceptions = []
70
+ running_tasks = []
71
+
72
+ async def run_one_coro (previous_failed ) -> None :
73
+ # Wait for the previous task to finish, or for delay seconds
74
+ if previous_failed is not None :
75
+ with contextlib .suppress (exceptions_mod .TimeoutError ):
76
+ # Use asyncio.wait_for() instead of asyncio.wait() here, so
77
+ # that if we get cancelled at this point, Event.wait() is also
78
+ # cancelled, otherwise there will be a "Task destroyed but it is
79
+ # pending" later.
80
+ await tasks .wait_for (previous_failed .wait (), delay )
81
+ # Get the next coroutine to run
82
+ try :
83
+ this_index , coro_fn = next (enum_coro_fns )
84
+ except StopIteration :
85
+ return
86
+ # Start task that will run the next coroutine
87
+ this_failed = locks .Event ()
88
+ next_task = loop .create_task (run_one_coro (this_failed ))
89
+ running_tasks .append (next_task )
90
+ assert len (running_tasks ) == this_index + 2
91
+ # Prepare place to put this coroutine's exceptions if not won
92
+ exceptions .append (None )
93
+ assert len (exceptions ) == this_index + 1
67
94
68
- async def run_one_coro (this_index , coro_fn , this_failed ):
69
95
try :
70
96
result = await coro_fn ()
71
97
except (SystemExit , KeyboardInterrupt ):
@@ -79,17 +105,34 @@ async def run_one_coro(this_index, coro_fn, this_failed):
79
105
assert winner_index is None
80
106
winner_index = this_index
81
107
winner_result = result
82
- raise _Done
83
-
108
+ # Cancel all other tasks. We take care to not cancel the current
109
+ # task as well. If we do so, then since there is no `await` after
110
+ # here and CancelledError are usually thrown at one, we will
111
+ # encounter a curious corner case where the current task will end
112
+ # up as done() == True, cancelled() == False, exception() ==
113
+ # asyncio.CancelledError. This behavior is specified in
114
+ # https://bugs.python.org/issue30048
115
+ for i , t in enumerate (running_tasks ):
116
+ if i != this_index :
117
+ t .cancel ()
118
+
119
+ first_task = loop .create_task (run_one_coro (None ))
120
+ running_tasks .append (first_task )
84
121
try :
85
- async with taskgroups .TaskGroup () as tg :
86
- for this_index , coro_fn in enumerate (coro_fns ):
87
- this_failed = locks .Event ()
88
- exceptions .append (None )
89
- tg .create_task (run_one_coro (this_index , coro_fn , this_failed ))
90
- with contextlib .suppress (TimeoutError ):
91
- await tasks .wait_for (this_failed .wait (), delay )
92
- except* _Done :
93
- pass
94
-
95
- return winner_result , winner_index , exceptions
122
+ # Wait for a growing list of tasks to all finish: poor man's version of
123
+ # curio's TaskGroup or trio's nursery
124
+ done_count = 0
125
+ while done_count != len (running_tasks ):
126
+ done , _ = await tasks .wait (running_tasks )
127
+ done_count = len (done )
128
+ # If run_one_coro raises an unhandled exception, it's probably a
129
+ # programming error, and I want to see it.
130
+ if __debug__ :
131
+ for d in done :
132
+ if d .done () and not d .cancelled () and d .exception ():
133
+ raise d .exception ()
134
+ return winner_result , winner_index , exceptions
135
+ finally :
136
+ # Make sure no tasks are left running if we leave this function
137
+ for t in running_tasks :
138
+ t .cancel ()
0 commit comments