From 2340e69f32666b8de2da11aa99b6dffc61a8d0ee Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Sat, 7 Oct 2023 12:02:28 +0300 Subject: [PATCH] Synchronize test_contextlib with test_contextlib_async --- Lib/test/test_contextlib.py | 46 +++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py index 5d94ec7cae4706..3dad2567015e24 100644 --- a/Lib/test/test_contextlib.py +++ b/Lib/test/test_contextlib.py @@ -173,6 +173,15 @@ def whoo(): # The "gen" attribute is an implementation detail. self.assertFalse(ctx.gen.gi_suspended) + def test_contextmanager_trap_no_yield(self): + @contextmanager + def whoo(): + if False: + yield + ctx = whoo() + with self.assertRaises(RuntimeError): + ctx.__enter__() + def test_contextmanager_trap_second_yield(self): @contextmanager def whoo(): @@ -186,6 +195,19 @@ def whoo(): # The "gen" attribute is an implementation detail. self.assertFalse(ctx.gen.gi_suspended) + def test_contextmanager_non_normalised(self): + @contextmanager + def whoo(): + try: + yield + except RuntimeError: + raise SyntaxError + + ctx = whoo() + ctx.__enter__() + with self.assertRaises(SyntaxError): + ctx.__exit__(RuntimeError, None, None) + def test_contextmanager_except(self): state = [] @contextmanager @@ -265,6 +287,25 @@ def test_issue29692(): self.assertEqual(ex.args[0], 'issue29692:Unchained') self.assertIsNone(ex.__cause__) + def test_contextmanager_wrap_runtimeerror(self): + @contextmanager + def woohoo(): + try: + yield + except Exception as exc: + raise RuntimeError(f'caught {exc}') from exc + + with self.assertRaises(RuntimeError): + with woohoo(): + 1 / 0 + + # If the context manager wrapped StopIteration in a RuntimeError, + # we also unwrap it, because we can't tell whether the wrapping was + # done by the generator machinery or by the generator itself. + with self.assertRaises(StopIteration): + with woohoo(): + raise StopIteration + def _create_contextmanager_attribs(self): def attribs(**kw): def decorate(func): @@ -276,6 +317,7 @@ def decorate(func): @attribs(foo='bar') def baz(spam): """Whee!""" + yield return baz def test_contextmanager_attribs(self): @@ -332,8 +374,11 @@ def woohoo(a, *, b): def test_recursive(self): depth = 0 + ncols = 0 @contextmanager def woohoo(): + nonlocal ncols + ncols += 1 nonlocal depth before = depth depth += 1 @@ -347,6 +392,7 @@ def recursive(): recursive() recursive() + self.assertEqual(ncols, 10) self.assertEqual(depth, 0)