|
1 | 1 | # Adapted with permission from the EdgeDB project;
|
2 | 2 | # license: PSFL.
|
3 | 3 |
|
| 4 | +import weakref |
| 5 | +import sys |
4 | 6 | import gc
|
5 | 7 | import asyncio
|
6 | 8 | import contextvars
|
@@ -28,7 +30,25 @@ def get_error_types(eg):
|
28 | 30 | return {type(exc) for exc in eg.exceptions}
|
29 | 31 |
|
30 | 32 |
|
31 |
| -class TestTaskGroup(unittest.IsolatedAsyncioTestCase): |
| 33 | +def set_gc_state(enabled): |
| 34 | + was_enabled = gc.isenabled() |
| 35 | + if enabled: |
| 36 | + gc.enable() |
| 37 | + else: |
| 38 | + gc.disable() |
| 39 | + return was_enabled |
| 40 | + |
| 41 | + |
| 42 | +@contextlib.contextmanager |
| 43 | +def disable_gc(): |
| 44 | + was_enabled = set_gc_state(enabled=False) |
| 45 | + try: |
| 46 | + yield |
| 47 | + finally: |
| 48 | + set_gc_state(enabled=was_enabled) |
| 49 | + |
| 50 | + |
| 51 | +class BaseTestTaskGroup: |
32 | 52 |
|
33 | 53 | async def test_taskgroup_01(self):
|
34 | 54 |
|
@@ -822,15 +842,15 @@ async def test_taskgroup_without_parent_task(self):
|
822 | 842 | with self.assertRaisesRegex(RuntimeError, "has not been entered"):
|
823 | 843 | tg.create_task(coro)
|
824 | 844 |
|
825 |
| - def test_coro_closed_when_tg_closed(self): |
| 845 | + async def test_coro_closed_when_tg_closed(self): |
826 | 846 | async def run_coro_after_tg_closes():
|
827 | 847 | async with taskgroups.TaskGroup() as tg:
|
828 | 848 | pass
|
829 | 849 | coro = asyncio.sleep(0)
|
830 | 850 | with self.assertRaisesRegex(RuntimeError, "is finished"):
|
831 | 851 | tg.create_task(coro)
|
832 |
| - loop = asyncio.get_event_loop() |
833 |
| - loop.run_until_complete(run_coro_after_tg_closes()) |
| 852 | + |
| 853 | + await run_coro_after_tg_closes() |
834 | 854 |
|
835 | 855 | async def test_cancelling_level_preserved(self):
|
836 | 856 | async def raise_after(t, e):
|
@@ -955,6 +975,30 @@ async def coro_fn():
|
955 | 975 | self.assertIsInstance(exc, _Done)
|
956 | 976 | self.assertListEqual(gc.get_referrers(exc), [])
|
957 | 977 |
|
| 978 | + |
| 979 | + async def test_exception_refcycles_parent_task_wr(self): |
| 980 | + """Test that TaskGroup deletes self._parent_task and create_task() deletes task""" |
| 981 | + tg = asyncio.TaskGroup() |
| 982 | + exc = None |
| 983 | + |
| 984 | + class _Done(Exception): |
| 985 | + pass |
| 986 | + |
| 987 | + async def coro_fn(): |
| 988 | + async with tg: |
| 989 | + raise _Done |
| 990 | + |
| 991 | + with disable_gc(): |
| 992 | + try: |
| 993 | + async with asyncio.TaskGroup() as tg2: |
| 994 | + task_wr = weakref.ref(tg2.create_task(coro_fn())) |
| 995 | + except* _Done as excs: |
| 996 | + exc = excs.exceptions[0].exceptions[0] |
| 997 | + |
| 998 | + self.assertIsNone(task_wr()) |
| 999 | + self.assertIsInstance(exc, _Done) |
| 1000 | + self.assertListEqual(gc.get_referrers(exc), []) |
| 1001 | + |
958 | 1002 | async def test_exception_refcycles_propagate_cancellation_error(self):
|
959 | 1003 | """Test that TaskGroup deletes propagate_cancellation_error"""
|
960 | 1004 | tg = asyncio.TaskGroup()
|
@@ -988,5 +1032,16 @@ class MyKeyboardInterrupt(KeyboardInterrupt):
|
988 | 1032 | self.assertListEqual(gc.get_referrers(exc), [])
|
989 | 1033 |
|
990 | 1034 |
|
| 1035 | +class TestTaskGroup(BaseTestTaskGroup, unittest.IsolatedAsyncioTestCase): |
| 1036 | + loop_factory = asyncio.EventLoop |
| 1037 | + |
| 1038 | +class TestEagerTaskTaskGroup(BaseTestTaskGroup, unittest.IsolatedAsyncioTestCase): |
| 1039 | + @staticmethod |
| 1040 | + def loop_factory(): |
| 1041 | + loop = asyncio.EventLoop() |
| 1042 | + loop.set_task_factory(asyncio.eager_task_factory) |
| 1043 | + return loop |
| 1044 | + |
| 1045 | + |
991 | 1046 | if __name__ == "__main__":
|
992 | 1047 | unittest.main()
|
0 commit comments