|
6 | 6 | import logging
|
7 | 7 | import math
|
8 | 8 | import weakref
|
| 9 | +from collections.abc import Awaitable, Generator |
9 | 10 | from contextlib import suppress
|
10 | 11 | from inspect import isawaitable
|
11 |
| -from typing import TYPE_CHECKING, Any, ClassVar |
| 12 | +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar |
12 | 13 |
|
13 | 14 | from tornado import gen
|
14 | 15 | from tornado.ioloop import IOLoop
|
@@ -108,6 +109,16 @@ async def __aexit__(self, exc_type, exc_value, traceback):
|
108 | 109 | await self.close()
|
109 | 110 |
|
110 | 111 |
|
| 112 | +_T = TypeVar("_T") |
| 113 | + |
| 114 | + |
| 115 | +async def _wrap_awaitable(aw: Awaitable[_T]) -> _T: |
| 116 | + return await aw |
| 117 | + |
| 118 | + |
| 119 | +_T_spec_cluster = TypeVar("_T_spec_cluster", bound="SpecCluster") |
| 120 | + |
| 121 | + |
111 | 122 | class SpecCluster(Cluster):
|
112 | 123 | """Cluster that requires a full specification of workers
|
113 | 124 |
|
@@ -327,7 +338,7 @@ def _correct_state(self):
|
327 | 338 | self._correct_state_waiting = task
|
328 | 339 | return task
|
329 | 340 |
|
330 |
| - async def _correct_state_internal(self): |
| 341 | + async def _correct_state_internal(self) -> None: |
331 | 342 | async with self._lock:
|
332 | 343 | self._correct_state_waiting = None
|
333 | 344 |
|
@@ -363,7 +374,9 @@ async def _correct_state_internal(self):
|
363 | 374 | self._created.add(worker)
|
364 | 375 | workers.append(worker)
|
365 | 376 | if workers:
|
366 |
| - await asyncio.wait(workers) |
| 377 | + await asyncio.wait( |
| 378 | + [asyncio.create_task(_wrap_awaitable(w)) for w in workers] |
| 379 | + ) |
367 | 380 | for w in workers:
|
368 | 381 | w._cluster = weakref.ref(self)
|
369 | 382 | await w # for tornado gen.coroutine support
|
@@ -392,14 +405,19 @@ def f():
|
392 | 405 | asyncio.get_running_loop().call_later(delay, f)
|
393 | 406 | super()._update_worker_status(op, msg)
|
394 | 407 |
|
395 |
| - def __await__(self): |
396 |
| - async def _(): |
| 408 | + def __await__(self: _T_spec_cluster) -> Generator[Any, Any, _T_spec_cluster]: |
| 409 | + async def _() -> _T_spec_cluster: |
397 | 410 | if self.status == Status.created:
|
398 | 411 | await self._start()
|
399 | 412 | await self.scheduler
|
400 | 413 | await self._correct_state()
|
401 | 414 | if self.workers:
|
402 |
| - await asyncio.wait(list(self.workers.values())) # maybe there are more |
| 415 | + await asyncio.wait( |
| 416 | + [ |
| 417 | + asyncio.create_task(_wrap_awaitable(w)) |
| 418 | + for w in self.workers.values() |
| 419 | + ] |
| 420 | + ) # maybe there are more |
403 | 421 | return self
|
404 | 422 |
|
405 | 423 | return _().__await__()
|
|
0 commit comments