|
25 | 25 | import functools
|
26 | 26 | import gc
|
27 | 27 | import io
|
| 28 | +import re |
28 | 29 | import sys
|
29 | 30 | import unittest
|
30 | 31 | import unittest.mock
|
31 | 32 | import sqlite3 as sqlite
|
32 | 33 |
|
33 |
| -from test.support import bigmemtest |
| 34 | +from test.support import bigmemtest, catch_unraisable_exception |
34 | 35 | from .test_dbapi import cx_limit
|
35 | 36 |
|
36 | 37 |
|
37 |
| -def with_tracebacks(strings, traceback=True): |
| 38 | +def with_tracebacks(exc, regex="", name=""): |
38 | 39 | """Convenience decorator for testing callback tracebacks."""
|
39 |
| - if traceback: |
40 |
| - strings.append('Traceback') |
41 |
| - |
42 | 40 | def decorator(func):
|
| 41 | + _regex = re.compile(regex) if regex else None |
43 | 42 | @functools.wraps(func)
|
44 | 43 | def wrapper(self, *args, **kwargs):
|
45 |
| - # First, run the test with traceback enabled. |
46 |
| - with check_tracebacks(self, strings): |
47 |
| - func(self, *args, **kwargs) |
| 44 | + with catch_unraisable_exception() as cm: |
| 45 | + # First, run the test with traceback enabled. |
| 46 | + with check_tracebacks(self, cm, exc, _regex, name): |
| 47 | + func(self, *args, **kwargs) |
48 | 48 |
|
49 | 49 | # Then run the test with traceback disabled.
|
50 | 50 | func(self, *args, **kwargs)
|
51 | 51 | return wrapper
|
52 | 52 | return decorator
|
53 | 53 |
|
| 54 | + |
54 | 55 | @contextlib.contextmanager
|
55 |
| -def check_tracebacks(self, strings): |
| 56 | +def check_tracebacks(self, cm, exc, regex, obj_name): |
56 | 57 | """Convenience context manager for testing callback tracebacks."""
|
57 | 58 | sqlite.enable_callback_tracebacks(True)
|
58 | 59 | try:
|
59 | 60 | buf = io.StringIO()
|
60 | 61 | with contextlib.redirect_stderr(buf):
|
61 | 62 | yield
|
62 |
| - tb = buf.getvalue() |
63 |
| - for s in strings: |
64 |
| - self.assertIn(s, tb) |
| 63 | + |
| 64 | + self.assertEqual(cm.unraisable.exc_type, exc) |
| 65 | + if regex: |
| 66 | + msg = str(cm.unraisable.exc_value) |
| 67 | + self.assertIsNotNone(regex.search(msg)) |
| 68 | + if obj_name: |
| 69 | + self.assertEqual(cm.unraisable.object.__name__, obj_name) |
65 | 70 | finally:
|
66 | 71 | sqlite.enable_callback_tracebacks(False)
|
67 | 72 |
|
| 73 | + |
68 | 74 | def func_returntext():
|
69 | 75 | return "foo"
|
70 | 76 | def func_returntextwithnull():
|
@@ -299,22 +305,22 @@ def test_func_return_long_long(self):
|
299 | 305 | val = cur.fetchone()[0]
|
300 | 306 | self.assertEqual(val, 1<<31)
|
301 | 307 |
|
302 |
| - @with_tracebacks(['func_raiseexception', '5/0', 'ZeroDivisionError']) |
| 308 | + @with_tracebacks(ZeroDivisionError, name="func_raiseexception") |
303 | 309 | def test_func_exception(self):
|
304 | 310 | cur = self.con.cursor()
|
305 | 311 | with self.assertRaises(sqlite.OperationalError) as cm:
|
306 | 312 | cur.execute("select raiseexception()")
|
307 | 313 | cur.fetchone()
|
308 | 314 | self.assertEqual(str(cm.exception), 'user-defined function raised exception')
|
309 | 315 |
|
310 |
| - @with_tracebacks(['func_memoryerror', 'MemoryError']) |
| 316 | + @with_tracebacks(MemoryError, name="func_memoryerror") |
311 | 317 | def test_func_memory_error(self):
|
312 | 318 | cur = self.con.cursor()
|
313 | 319 | with self.assertRaises(MemoryError):
|
314 | 320 | cur.execute("select memoryerror()")
|
315 | 321 | cur.fetchone()
|
316 | 322 |
|
317 |
| - @with_tracebacks(['func_overflowerror', 'OverflowError']) |
| 323 | + @with_tracebacks(OverflowError, name="func_overflowerror") |
318 | 324 | def test_func_overflow_error(self):
|
319 | 325 | cur = self.con.cursor()
|
320 | 326 | with self.assertRaises(sqlite.DataError):
|
@@ -426,22 +432,21 @@ def md5sum(t):
|
426 | 432 | del x,y
|
427 | 433 | gc.collect()
|
428 | 434 |
|
| 435 | + @with_tracebacks(OverflowError) |
429 | 436 | def test_func_return_too_large_int(self):
|
430 | 437 | cur = self.con.cursor()
|
431 | 438 | for value in 2**63, -2**63-1, 2**64:
|
432 | 439 | self.con.create_function("largeint", 0, lambda value=value: value)
|
433 |
| - with check_tracebacks(self, ['OverflowError']): |
434 |
| - with self.assertRaises(sqlite.DataError): |
435 |
| - cur.execute("select largeint()") |
| 440 | + with self.assertRaises(sqlite.DataError): |
| 441 | + cur.execute("select largeint()") |
436 | 442 |
|
| 443 | + @with_tracebacks(UnicodeEncodeError, "surrogates not allowed", "chr") |
437 | 444 | def test_func_return_text_with_surrogates(self):
|
438 | 445 | cur = self.con.cursor()
|
439 | 446 | self.con.create_function("pychr", 1, chr)
|
440 | 447 | for value in 0xd8ff, 0xdcff:
|
441 |
| - with check_tracebacks(self, |
442 |
| - ['UnicodeEncodeError', 'surrogates not allowed']): |
443 |
| - with self.assertRaises(sqlite.OperationalError): |
444 |
| - cur.execute("select pychr(?)", (value,)) |
| 448 | + with self.assertRaises(sqlite.OperationalError): |
| 449 | + cur.execute("select pychr(?)", (value,)) |
445 | 450 |
|
446 | 451 | @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
|
447 | 452 | @bigmemtest(size=2**31, memuse=3, dry_run=False)
|
@@ -510,23 +515,23 @@ def test_aggr_no_finalize(self):
|
510 | 515 | val = cur.fetchone()[0]
|
511 | 516 | self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
|
512 | 517 |
|
513 |
| - @with_tracebacks(['__init__', '5/0', 'ZeroDivisionError']) |
| 518 | + @with_tracebacks(ZeroDivisionError, name="AggrExceptionInInit") |
514 | 519 | def test_aggr_exception_in_init(self):
|
515 | 520 | cur = self.con.cursor()
|
516 | 521 | with self.assertRaises(sqlite.OperationalError) as cm:
|
517 | 522 | cur.execute("select excInit(t) from test")
|
518 | 523 | val = cur.fetchone()[0]
|
519 | 524 | self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error")
|
520 | 525 |
|
521 |
| - @with_tracebacks(['step', '5/0', 'ZeroDivisionError']) |
| 526 | + @with_tracebacks(ZeroDivisionError, name="AggrExceptionInStep") |
522 | 527 | def test_aggr_exception_in_step(self):
|
523 | 528 | cur = self.con.cursor()
|
524 | 529 | with self.assertRaises(sqlite.OperationalError) as cm:
|
525 | 530 | cur.execute("select excStep(t) from test")
|
526 | 531 | val = cur.fetchone()[0]
|
527 | 532 | self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error")
|
528 | 533 |
|
529 |
| - @with_tracebacks(['finalize', '5/0', 'ZeroDivisionError']) |
| 534 | + @with_tracebacks(ZeroDivisionError, name="AggrExceptionInFinalize") |
530 | 535 | def test_aggr_exception_in_finalize(self):
|
531 | 536 | cur = self.con.cursor()
|
532 | 537 | with self.assertRaises(sqlite.OperationalError) as cm:
|
@@ -643,11 +648,11 @@ def authorizer_cb(action, arg1, arg2, dbname, source):
|
643 | 648 | raise ValueError
|
644 | 649 | return sqlite.SQLITE_OK
|
645 | 650 |
|
646 |
| - @with_tracebacks(['authorizer_cb', 'ValueError']) |
| 651 | + @with_tracebacks(ValueError, name="authorizer_cb") |
647 | 652 | def test_table_access(self):
|
648 | 653 | super().test_table_access()
|
649 | 654 |
|
650 |
| - @with_tracebacks(['authorizer_cb', 'ValueError']) |
| 655 | + @with_tracebacks(ValueError, name="authorizer_cb") |
651 | 656 | def test_column_access(self):
|
652 | 657 | super().test_table_access()
|
653 | 658 |
|
|