8000 GH-6785: asyncio.wait no longer calls ensure_future · graingert/distributed@896f770 · GitHub
[go: up one dir, main page]

Skip to content

Commit 896f770

Browse files
committed
daskGH-6785: asyncio.wait no longer calls ensure_future
python/cpython#95601
1 parent 81f0b67 commit 896f770

File tree

3 files changed

+26
-8
lines changed

3 files changed

+26
-8
lines changed

distributed/deploy/spec.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import logging
77
import math
88
import weakref
9+
from collections.abc import Awaitable, Generator
910
from contextlib import suppress
1011
from inspect import isawaitable
11-
from typing import TYPE_CHECKING, Any, ClassVar
12+
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar
1213

1314
from tornado import gen
1415
from tornado.ioloop import IOLoop
@@ -108,6 +109,16 @@ async def __aexit__(self, exc_type, exc_value, traceback):
108109
await self.close()
109110

110111

112+
_T = TypeVar("_T")
113+
114+
115+
async def _wrap_awaitable(aw: Awaitable[_T]) -> _T:
116+
return await aw
117+
118+
119+
_T_spec_cluster = TypeVar("_T_spec_cluster", bound="SpecCluster")
120+
121+
111122
class SpecCluster(Cluster):
112123
"""Cluster that requires a full specification of workers
113124
@@ -327,7 +338,7 @@ def _correct_state(self):
327338
self._correct_state_waiting = task
328339
return task
329340

330-
async def _correct_state_internal(self):
341+
async def _correct_state_internal(self) -> None:
331342
async with self._lock:
332343
self._correct_state_waiting = None
333344

@@ -363,7 +374,9 @@ async def _correct_state_internal(self):
363374
self._created.add(worker)
364375
workers.append(worker)
365376
if workers:
366-
await asyncio.wait(workers)
377+
await asyncio.wait(
378+
[asyncio.create_task(_wrap_awaitable(w)) for w in workers]
379+
)
367380
for w in workers:
368381
w._cluster = weakref.ref(self)
369382
await w # for tornado gen.coroutine support
@@ -392,14 +405,19 @@ def f():
392405
asyncio.get_running_loop().call_later(delay, f)
393406
super()._update_worker_status(op, msg)
394407

395-
def __await__(self):
396-
async def _():
408+
def __await__(self: _T_spec_cluster) -> Generator[Any, Any, _T_spec_cluster]:
409+
async def _() -> _T_spec_cluster:
397410
if self.status == Status.created:
398411
await self._start()
399412
await self.scheduler
400413
await self._correct_state()
401414
if self.workers:
402-
await asyncio.wait(list(self.workers.values())) # maybe there are more
415+
await asyncio.wait(
416+
[
417+
asyncio.create_task(_wrap_awaitable(w))
418+
for w in self.workers.values()
419+
]
420+
) # maybe there are more
403421
return self
404422

405423
return _().__await__()

distributed/tests/test_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2625,7 +2625,7 @@ async def test_task_unique_groups(c, s, a, b):
26252625
x = c.submit(sum, [1, 2])
26262626
y = c.submit(len, [1, 2])
26272627
z = c.submit(sum, [3, 4])
2628-
await asyncio.wait([x, y, z])
2628+
await asyncio.gather(x, y, z)
26292629

26302630
assert s.task_prefixes["len"].states["memory"] == 1
26312631
assert s.task_prefixes["sum"].states["memory"] == 2

distributed/tests/test_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2772,7 +2772,7 @@ async def test_forget_dependents_after_release(c, s, a):
27722772
fut = c.submit(inc, 1, key="f-1")
27732773
fut2 = c.submit(inc, fut, key="f-2")
27742774

2775-
await asyncio.wait([fut, fut2])
2775+
await asyncio.gather(fut, fut2)
27762776

27772777
assert fut.key in a.state.tasks
27782778
assert fut2.key in a.state.tasks

0 commit comments

Comments
 (0)
0