diff --git a/Lib/asyncio/queues.py b/Lib/asyncio/queues.py index a9656a6df561ba..d591d0ebab481b 100644 --- a/Lib/asyncio/queues.py +++ b/Lib/asyncio/queues.py @@ -1,6 +1,14 @@ -__all__ = ('Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty') +__all__ = ( + 'Queue', + 'PriorityQueue', + 'LifoQueue', + 'QueueFull', + 'QueueEmpty', + 'QueueShutDown', +) import collections +import enum import heapq from types import GenericAlias @@ -18,6 +26,17 @@ class QueueFull(Exception): pass +class QueueShutDown(Exception): + """Raised when putting on to or getting from a shut-down Queue.""" + pass + + +class _QueueState(enum.Enum): + alive = "alive" + shutdown = "shutdown" + shutdown_immediate = "shutdown-immediate" + + class Queue(mixins._LoopBoundMixin): """A queue, useful for coordinating producer and consumer coroutines. @@ -41,6 +60,7 @@ def __init__(self, maxsize=0): self._finished = locks.Event() self._finished.set() self._init(maxsize) + self._shutdown_state = _QueueState.alive # These three are overridable in subclasses. @@ -113,6 +133,8 @@ async def put(self, item): Put an item into the queue. If the queue is full, wait until a free slot is available before adding item. """ + if self._shutdown_state != _QueueState.alive: + raise QueueShutDown while self.full(): putter = self._get_loop().create_future() self._putters.append(putter) @@ -132,6 +154,8 @@ async def put(self, item): # the call. Wake up the next in line. self._wakeup_next(self._putters) raise + if self._shutdown_state != _QueueState.alive: + raise QueueShutDown return self.put_nowait(item) def put_nowait(self, item): @@ -139,6 +163,8 @@ def put_nowait(self, item): If no free slot is immediately available, raise QueueFull. """ + if self._shutdown_state != _QueueState.alive: + raise QueueShutDown if self.full(): raise QueueFull self._put(item) @@ -151,7 +177,11 @@ async def get(self): If queue is empty, wait until an item is available. """ + if self._shutdown_state == _QueueState.shutdown_immediate: + raise QueueShutDown while self.empty(): + if self._shutdown_state != _QueueState.alive: + raise QueueShutDown getter = self._get_loop().create_future() self._getters.append(getter) try: @@ -170,6 +200,8 @@ async def get(self): # the call. Wake up the next in line. self._wakeup_next(self._getters) raise + if self._shutdown_state == _QueueState.shutdown_immediate: + raise QueueShutDown return self.get_nowait() def get_nowait(self): @@ -178,7 +210,11 @@ def get_nowait(self): Return an item if one is immediately available, else raise QueueEmpty. """ if self.empty(): + if self._shutdown_state != _QueueState.alive: + raise QueueShutDown raise QueueEmpty + elif self._shutdown_state == _QueueState.shutdown_immediate: + raise QueueShutDown item = self._get() self._wakeup_next(self._putters) return item @@ -214,6 +250,29 @@ async def join(self): if self._unfinished_tasks > 0: await self._finished.wait() + def shutdown(self, immediate=False): + """Shut-down the queue, making queue gets and puts raise. + + By default, gets will only raise once the queue is empty. Set + 'immediate' to True to make gets raise immediately instead. + + All blocked callers of put() will be unblocked, and also get() + and join() if 'immediate'. The QueueShutDown exception is raised. + """ + if immediate: + self._shutdown_state = _QueueState.shutdown_immediate + while self._getters: + getter = self._getters.popleft() + if not getter.done(): + getter.set_result(None) + else: + self._shutdown_state = _QueueState.shutdown + while self._putters: + putter = self._putters.popleft() + if not putter.done(): + putter.set_result(None) + # Release 'joined' tasks/coros + self._finished.set() class PriorityQueue(Queue): """A subclass of Queue; retrieves entries in priority order (lowest first). diff --git a/Lib/multiprocessing/queues.py b/Lib/multiprocessing/queues.py index daf9ee94a19431..5220504369937d 100644 --- a/Lib/multiprocessing/queues.py +++ b/Lib/multiprocessing/queues.py @@ -17,8 +17,9 @@ import types import weakref import errno +import ctypes -from queue import Empty, Full +from queue import Empty, Full, ShutDown import _multiprocessing @@ -28,6 +29,10 @@ from .util import debug, info, Finalize, register_after_fork, is_exiting +_queue_alive = 0 +_queue_shutdown = 1 +_queue_shutdown_immediate = 2 + # # Queue type using a pipe, buffer and thread # @@ -50,6 +55,9 @@ def __init__(self, maxsize=0, *, ctx): # For use by concurrent.futures self._ignore_epipe = False self._reset() + self._shutdown_state = context._default_context.Value( + ctypes.c_uint8, lock=self._rlock + ) if sys.platform != 'win32': register_after_fork(self, Queue._after_fork) @@ -86,20 +94,28 @@ def _reset(self, after_fork=False): def put(self, obj, block=True, timeout=None): if self._closed: raise ValueError(f"Queue {self!r} is closed") + if self._shutdown_state.value != _queue_alive: + raise ShutDown if not self._sem.acquire(block, timeout): raise Full with self._notempty: + if self._shutdown_state.value != _queue_alive: + raise ShutDown if self._thread is None: self._start_thread() self._buffer.append(obj) self._notempty.notify() def get(self, block=True, timeout=None): + if self._shutdown_state.value == _queue_shutdown_immediate: + raise ShutDown if self._closed: raise ValueError(f"Queue {self!r} is closed") if block and timeout is None: with self._rlock: + if self._shutdown_state.value != _queue_alive: + raise ShutDown res = self._recv_bytes() self._sem.release() else: @@ -111,13 +127,19 @@ def get(self, block=True, timeout=None): if block: timeout = deadline - time.monotonic() if not self._poll(timeout): + if self._shutdown_state.value != _queue_alive: + raise ShutDown raise Empty + if self._shutdown_state.value != _queue_alive : + raise ShutDown elif not self._poll(): raise Empty res = self._recv_bytes() self._sem.release() finally: self._rlock.release() + if self._shutdown_state.value == _queue_shutdown: + raise ShutDown # unserialize the data after having released the lock return _ForkingPickler.loads(res) @@ -158,6 +180,14 @@ def cancel_join_thread(self): except AttributeError: pass + def shutdown(self, immediate=False): + with self._rlock: + if immediate: + self._shutdown_state = _queue_shutdown_immediate + self._notempty.notify_all() + else: + self._shutdown_state = _queue_shutdown + def _start_thread(self): debug('Queue._start_thread()') @@ -329,6 +359,8 @@ def task_done(self): def join(self): with self._cond: + if self._shutdown_state.value == _queue_shutdown_immediate: + return if not self._unfinished_tasks._semlock._is_zero(): self._cond.wait() diff --git a/Lib/queue.py b/Lib/queue.py index 55f50088460f9e..f08dbd47f188ee 100644 --- a/Lib/queue.py +++ b/Lib/queue.py @@ -25,6 +25,15 @@ class Full(Exception): pass +class ShutDown(Exception): + '''Raised when put/get with shut-down queue.''' + + +_queue_alive = "alive" +_queue_shutdown = "shutdown" +_queue_shutdown_immediate = "shutdown-immediate" + + class Queue: '''Create a queue object with a given maximum size. @@ -54,6 +63,9 @@ def __init__(self, maxsize=0): self.all_tasks_done = threading.Condition(self.mutex) self.unfinished_tasks = 0 + # Queue shut-down state + self.shutdown_state = _queue_alive + def task_done(self): '''Indicate that a formerly enqueued task is complete. @@ -87,6 +99,8 @@ def join(self): ''' with self.all_tasks_done: while self.unfinished_tasks: + if self.shutdown_state == _queue_shutdown_immediate: + return self.all_tasks_done.wait() def qsize(self): @@ -130,6 +144,8 @@ def put(self, item, block=True, timeout=None): is immediately available, else raise the Full exception ('timeout' is ignored in that case). ''' + if self.shutdown_state != _queue_alive: + raise ShutDown with self.not_full: if self.maxsize > 0: if not block: @@ -138,6 +154,8 @@ def put(self, item, block=True, timeout=None): elif timeout is None: while self._qsize() >= self.maxsize: self.not_full.wait() + if self.shutdown_state != _queue_alive: + raise ShutDown elif timeout < 0: raise ValueError("'timeout' must be a non-negative number") else: @@ -147,6 +165,8 @@ def put(self, item, block=True, timeout=None): if remaining <= 0.0: raise Full self.not_full.wait(remaining) + if self.shutdown_state != _queue_alive: + raise ShutDown self._put(item) self.unfinished_tasks += 1 self.not_empty.notify() @@ -162,22 +182,36 @@ def get(self, block=True, timeout=None): available, else raise the Empty exception ('timeout' is ignored in that case). ''' + if self.shutdown_state == _queue_shutdown_immediate: + raise ShutDown with self.not_empty: if not block: if not self._qsize(): + if self.shutdown_state != _queue_alive: + raise ShutDown raise Empty elif timeout is None: while not self._qsize(): + if self.shutdown_state != _queue_alive: + raise ShutDown self.not_empty.wait() + if self.shutdown_state != _queue_alive: + raise ShutDown elif timeout < 0: raise ValueError("'timeout' must be a non-negative number") else: endtime = time() + timeout while not self._qsize(): + if self.shutdown_state != _queue_alive: + raise ShutDown remaining = endtime - time() if remaining <= 0.0: raise Empty self.not_empty.wait(remaining) + if self.shutdown_state != _queue_alive: + raise ShutDown + if self.shutdown_state == _queue_shutdown_immediate: + raise ShutDown item = self._get() self.not_full.notify() return item @@ -198,6 +232,28 @@ def get_nowait(self): ''' return self.get(block=False) + def shutdown(self, immediate=False): + '''Shut-down the queue, making queue gets and puts raise. + + By default, gets will only raise once the queue is empty. Set + 'immediate' to True to make gets raise immediately instead. + + All blocked callers of put() will be unblocked, and also get() + and join() if 'immediate'. The ShutDown exception is raised. + ''' + with self.mutex: + if immediate: + self.shutdown_state = _queue_shutdown_immediate + self.not_empty.notify_all() + # set self.unfinished_tasks to 0 + # to break the loop in 'self.join()' + # when quits from `wait()` + self.unfinished_tasks = 0 + self.all_tasks_done.notify_all() + else: + self.shutdown_state = _queue_shutdown + self.not_full.notify_all() + # Override these methods to implement other queue organizations # (e.g. stack or priority queue). # These will only be called with appropriate locks held diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index 9a2db24b4bd597..d9264ed62f5526 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -1277,6 +1277,37 @@ def test_closed_queue_put_get_exceptions(self): q.put('foo') with self.assertRaisesRegex(ValueError, 'is closed'): q.get() + + def test_shutdown_empty(self): + q = multiprocessing.Queue() + q.shutdown() + with self.assertRaises( + pyqueue.ShutDown, msg="Didn't appear to shut-down queue" + ): + q.put("data") + with self.assertRaises( + pyqueue.ShutDown, msg="Didn't appear to shut-down queue" + ): + q.get() + + def test_shutdown_nonempty(self): + q = multiprocessing.Queue() + q.put("data") + q.shutdown() + q.get() + with self.assertRaises( + pyqueue.ShutDown, msg="Didn't appear to shut-down queue" + ): + q.get() + + def test_shutdown_immediate(self): + q = multiprocessing.Queue() + q.put("data") + q.shutdown(immediate=True) + with self.assertRaises( + pyqueue.ShutDown, msg="Didn't appear to shut-down queue" + ): + q.get() # # # diff --git a/Lib/test/test_asyncio/test_queues.py b/Lib/test/test_asyncio/test_queues.py index 2d058ccf6a8c72..75b016f399a13b 100644 --- a/Lib/test/test_asyncio/test_queues.py +++ b/Lib/test/test_asyncio/test_queues.py @@ -522,5 +522,149 @@ class PriorityQueueJoinTests(_QueueJoinTestMixin, unittest.IsolatedAsyncioTestCa q_class = asyncio.PriorityQueue +class _QueueShutdownTestMixin: + q_class = None + + async def test_empty(self): + q = self.q_class() + q.shutdown() + with self.assertRaises( + asyncio.QueueShutDown, msg="Didn't appear to shut-down queue" + ): + await q.put("data") + with self.assertRaises( + asyncio.QueueShutDown, msg="Didn't appear to shut-down queue" + ): + await q.get() + + async def test_nonempty(self): + q = self.q_class() + q.put_nowait("data") + q.shutdown() + await q.get() + with self.assertRaises( + asyncio.QueueShutDown, msg="Didn't appear to shut-down queue" + ): + await q.get() + + async def test_immediate(self): + q = self.q_class() + q.put_nowait("data") + q.shutdown(immediate=True) + with self.assertRaises( + asyncio.QueueShutDown, msg="Didn't appear to shut-down queue" + ): + await q.get() + async def test_repr_shutdown(self): + q = self.q_class() + q.shutdown() + self.assertIn("shutdown", repr(q)) + + q = self.q_class() + q.shutdown(immediate=True) + self.assertIn("shutdown-immediate", repr(q)) + + async def test_get_shutdown_immediate(self): + results = [] + maxsize = 2 + delay = 1e-3 + + async def get_q(q): + try: + msg = await q.get() + results.append(False) + except asyncio.QueueShutDown: + results.append(True) + return True + + async def shutdown(q, delay, immediate): + await asyncio.sleep(delay) + q.shutdown(immediate) + return True + + q = self.q_class(maxsize) + t = [asyncio.create_task(get_q(q)) for _ in range(maxsize)] + t += [asyncio.create_task(shutdown(q, delay, True))] + res = await asyncio.gather(*t) + + self.assertEqual(results, [True]*maxsize) + + async def test_put_shutdown(self): + maxsize = 2 + results = [] + delay = 1e-3 + + async def put_twice(q, delay, msg): + await q.put(msg) + await asyncio.sleep(delay) + try: + await q.put(msg+maxsize) + results.append(False) + except asyncio.QueueShutDown: + results.append(True) + return msg + + async def shutdown(q, delay, immediate): + await asyncio.sleep(delay) + q.shutdown(immediate) + + q = self.q_class(maxsize) + t = [asyncio.create_task(put_twice(q, delay, i+1)) for i in range(maxsize)] + t += [asyncio.create_task(shutdown(q, delay*2, False))] + res = await asyncio.gather(*t) + + self.assertEqual(results, [True]*maxsize) + + async def test_put_and_join_shutdown(self): + maxsize = 2 + results = [] + delay = 1e-3 + + async def put_twice(q, delay, msg): + await q.put(msg) + await asyncio.sleep(delay) + try: + await q.put(msg+maxsize) + results.append(False) + except asyncio.QueueShutDown: + results.append(True) + return msg + + async def shutdown(q, delay, immediate): + await asyncio.sleep(delay) + q.shutdown(immediate) + + async def join(q, delay): + await asyncio.sleep(delay) + await q.join() + results.append(True) + return True + + q = self.q_class(maxsize) + t = [asyncio.create_task(put_twice(q, delay, i+1)) for i in range(maxsize)] + t += [asyncio.create_task(shutdown(q, delay*2, True)), + asyncio.create_task(join(q, delay))] + res = await asyncio.gather(*t) + + self.assertEqual(results, [True]*(maxsize+1)) + +class QueueShutdownTests( + _QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase +): + q_class = asyncio.Queue + + +class LifoQueueShutdownTests( + _QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase +): + q_class = asyncio.LifoQueue + + +class PriorityQueueShutdownTests( + _QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase +): + q_class = asyncio.PriorityQueue + + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_queue.py b/Lib/test/test_queue.py index 33113a72e6b6a9..354299b9a5b16a 100644 --- a/Lib/test/test_queue.py +++ b/Lib/test/test_queue.py @@ -241,6 +241,41 @@ def test_shrinking_queue(self): with self.assertRaises(self.queue.Full): q.put_nowait(4) + def test_shutdown_empty(self): + q = self.type2test() + q.shutdown() + try: + q.put("data") + self.fail("Didn't appear to shut-down queue") + except self.queue.ShutDown: + pass + try: + q.get() + self.fail("Didn't appear to shut-down queue") + except self.queue.ShutDown: + pass + + def test_shutdown_nonempty(self): + q = self.type2test() + q.put("data") + q.shutdown() + q.get() + try: + q.get() + self.fail("Didn't appear to shut-down queue") + except self.queue.ShutDown: + pass + + def test_shutdown_immediate(self): + q = self.type2test() + q.put("data") + q.shutdown(immediate=True) + try: + q.get() + self.fail("Didn't appear to shut-down queue") + except self.queue.ShutDown: + pass + class QueueTest(BaseQueueTestMixin): def setUp(self):