diff --git a/.lintrunner.toml b/.lintrunner.toml index 7270725400f4d1..25eea8bf5e0c79 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -18,6 +18,8 @@ exclude_patterns = [ 'torch/_inductor/autoheuristic/artifacts/**', 'scripts/**', 'test/generated_type_hints_smoketest.py', + # CPython tests + 'test/dynamo/cpython/**', # Tests from the NumPy test suite 'test/torch_np/numpy_test/**/*.py', 'third_party/**', @@ -398,6 +400,7 @@ exclude_patterns=[ 'tools/clang_format_hash/**', 'test/cpp/jit/upgrader_models/*.ptl', 'test/cpp/jit/upgrader_models/*.ptl.ff', + 'test/dynamo/cpython/**', '**/*.png', '**/*.gz', '**/*.patch', @@ -936,6 +939,7 @@ include_patterns = [ exclude_patterns = [ 'test/run_test.py', '**/fb/**', + 'test/dynamo/cpython/3.13/**', 'test/quantization/**', # should be run through test/test_quantization.py 'test/jit/**', # should be run through test/test_jit.py 'test/ao/sparsity/**', # should be run through test/test_ao_sparsity.py @@ -1131,6 +1135,7 @@ exclude_patterns = [ 'caffe2/**/*.pyi', 'fb/**', '**/fb/**', + 'test/dynamo/cpython/**', 'third_party/**/*.py', 'third_party/**/*.pyi', 'torch/_vendor/**', @@ -1536,6 +1541,7 @@ exclude_patterns = [ 'functorch/notebooks/**', 'torch/_inductor/fx_passes/serialized_patterns/**', 'torch/_inductor/autoheuristic/artifacts/**', + 'test/dynamo/cpython/**', 'scripts/**', 'third_party/**', 'fb/**', diff --git a/test/dynamo/cpython/3_13/CHANGES.txt b/test/dynamo/cpython/3_13/CHANGES.txt new file mode 100644 index 00000000000000..ad541cea2b31a8 --- /dev/null +++ b/test/dynamo/cpython/3_13/CHANGES.txt @@ -0,0 +1,9 @@ +This subdirectory contains a selection of tests from the CPython repository (branch: v3.13.0):\ +https://github.com/python/cpython/releases/tag/v3.13.0 + +Modifications were made to ensure compatibility with the Dynamo infrastructure: ++ Monkey-patched `unittest.TestCase` to `torch._dynamo.test_case.CPythonTestCase`. ++ Replaced `unittest.main()` with `torch._dynamo.test_case.run_tests()`. ++ Assigned test "owners." ++ Annotated CPU-intensive tests with the `@slowTest` decorator. ++ Adjusted imports to use `import module` instead of `from test import module`. diff --git a/test/dynamo/cpython/3_13/LICENSE b/test/dynamo/cpython/3_13/LICENSE new file mode 100644 index 00000000000000..1c9c8bddbbf32c --- /dev/null +++ b/test/dynamo/cpython/3_13/LICENSE @@ -0,0 +1,46 @@ +PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +-------------------------------------------- + +1. This LICENSE AGREEMENT is between the Python Software Foundation +("PSF"), and the Individual or Organization ("Licensee") accessing and +otherwise using this software ("Python") in source or binary form and +its associated documentation. + +2. Subject to the terms and conditions of this License Agreement, PSF hereby +grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +analyze, test, perform and/or display publicly, prepare derivative works, +distribute, and otherwise use Python alone or in any derivative version, +provided, however, that PSF's License Agreement and PSF's notice of copyright, +i.e., "Copyright (c) 2001 Python Software Foundation; All Rights Reserved" +are retained in Python alone or in any derivative version prepared by Licensee. + +3. In the event Licensee prepares a derivative work that is based on +or incorporates Python or any part thereof, and wants to make +the derivative work available to others as provided herein, then +Licensee hereby agrees to include in any such work a brief summary of +the changes made to Python. + +4. PSF is making Python available to Licensee on an "AS IS" +basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +INFRINGE ANY THIRD PARTY RIGHTS. + +5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. + +6. This License Agreement will automatically terminate upon a material +breach of its terms and conditions. + +7. Nothing in this License Agreement shall be deemed to create any +relationship of agency, partnership, or joint venture between PSF and +Licensee. This License Agreement does not grant permission to use PSF +trademarks or trade name in a trademark sense to endorse or promote +products or services of Licensee, or any third party. + +8. By copying, installing or otherwise using Python, Licensee +agrees to be bound by the terms and conditions of this License +Agreement. diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index 74ff84dbb9e36b..d0216ed5903850 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -1,7 +1,6 @@ # Owner(s): ["module: dynamo"] import contextlib import sys -import traceback import unittest from contextlib import contextmanager @@ -9,18 +8,12 @@ import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.exc import InternalTorchDynamoError -from torch._dynamo.testing import ( - EagerAndRecordGraphs, - normalize_gm, - same, - skipIfNotPy311, -) +from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm, same from torch._dynamo.utils import counters from torch.nn import functional as F from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, - make_dynamo_test, parametrize, ) @@ -37,6 +30,16 @@ k_glb = 0 +@contextlib.contextmanager +def set_default_dtype(dtype): + old_dtype = torch.get_default_dtype() + try: + torch.set_default_dtype(dtype) + yield + finally: + torch.set_default_dtype(old_dtype) + + class CustomizedCtxManager: def __init__(self, mode): self.prev = torch.is_grad_enabled() @@ -2700,319 +2703,6 @@ def fn(t): self.assertEqual(y, t.sin()) -class CPythonContextManagerTestCase(torch._dynamo.test_case.CPythonTestCase): - # Tests taken from CPython source code in cpython/Lib/test/test_contextlib.py - # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_contextlib.py - - @make_dynamo_test - def test_contextmanager_plain(self): - state = [] - - @contextmanager - def woohoo(): - state.append(1) - yield 42 - state.append(999) - - with woohoo() as x: - self.assertEqual(state, [1]) - self.assertEqual(x, 42) - state.append(x) - self.assertEqual(state, [1, 42, 999]) - - @skipIfNotPy311 - @make_dynamo_test - def test_contextmanager_finally(self): - state = [] - - @contextmanager - def woohoo(): - state.append(1) - try: - yield 42 - finally: - state.append(999) - - with self.assertRaises(ZeroDivisionError): - with woohoo() as x: - self.assertEqual(state, [1]) - self.assertEqual(x, 42) - state.append(x) - raise ZeroDivisionError - self.assertEqual(state, [1, 42, 999]) - - @unittest.expectedFailure - @make_dynamo_test - def test_contextmanager_traceback(self): - @contextmanager - def f(): - yield - - try: - with f(): - 1 / 0 - except ZeroDivisionError as e: - frames = traceback.extract_tb(e.__traceback__) - - self.assertEqual(len(frames), 1) - self.assertEqual(frames[0].name, "test_contextmanager_traceback") - self.assertEqual(frames[0].line, "1/0") - - # Repeat with RuntimeError (which goes through a different code path) - try: - with f(): - raise NotImplementedError(42) - except NotImplementedError as e: - frames = traceback.extract_tb(e.__traceback__) - - self.assertEqual(len(frames), 1) - self.assertEqual(frames[0].name, "test_contextmanager_traceback") - self.assertEqual(frames[0].line, "raise NotImplementedError(42)") - - @make_dynamo_test - def test_contextmanager_no_reraise(self): - @contextmanager - def whee(): - yield - - ctx = whee() - ctx.__enter__() - # Calling __exit__ should not result in an exception - self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None)) - - @make_dynamo_test - def test_contextmanager_trap_yield_after_throw(self): - @contextmanager - def whoo(): - try: - yield - except Exception: # noqa: E722 - yield - - ctx = whoo() - ctx.__enter__() - with self.assertRaises(RuntimeError): - ctx.__exit__(TypeError, TypeError("foo"), None) - - @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") - def test_contextmanager_except(self): - state = [] - - @contextmanager - def woohoo(): - state.append(1) - try: - yield 42 - except ZeroDivisionError as e: - state.append(e.args[0]) - self.assertEqual(state, [1, 42, 999]) - - with woohoo() as x: - self.assertEqual(state, [1]) - self.assertEqual(x, 42) - state.append(x) - raise ZeroDivisionError(999) - self.assertEqual(state, [1, 42, 999]) - - @unittest.expectedFailure - @make_dynamo_test - def test_contextmanager_except_stopiter(self): - @contextmanager - def woohoo(): - yield - - class StopIterationSubclass(StopIteration): - pass - - for stop_exc in (StopIteration("spam"), StopIterationSubclass("spam")): - with self.subTest(type=type(stop_exc)): - try: - with woohoo(): - raise stop_exc - except Exception as ex: - self.assertIs(ex, stop_exc) - else: - self.fail(f"{stop_exc} was suppressed") - - @unittest.expectedFailure - @make_dynamo_test - def test_contextmanager_except_pep479(self): - code = """\ -from __future__ import generator_stop -from contextlib import contextmanager -@contextmanager -def woohoo(): - yield -""" - locals = {} - exec(code, locals, locals) - woohoo = locals["woohoo"] - - stop_exc = StopIteration("spam") - try: - with woohoo(): - raise stop_exc - except Exception as ex: - self.assertIs(ex, stop_exc) - else: - self.fail("StopIteration was suppressed") - - @unittest.expectedFailure - @make_dynamo_test - def test_contextmanager_do_not_unchain_non_stopiteration_exceptions(self): - @contextmanager - def test_issue29692(): - try: - yield - except Exception as exc: - raise RuntimeError("issue29692:Chained") from exc - - try: - with test_issue29692(): - raise ZeroDivisionError - except Exception as ex: - self.assertIs(type(ex), RuntimeError) - self.assertEqual(ex.args[0], "issue29692:Chained") - self.assertIsInstance(ex.__cause__, ZeroDivisionError) - - try: - with test_issue29692(): - raise StopIteration("issue29692:Unchained") - except Exception as ex: - self.assertIs(type(ex), StopIteration) - self.assertEqual(ex.args[0], "issue29692:Unchained") - self.assertIsNone(ex.__cause__) - - @unittest.expectedFailure - @make_dynamo_test - def _create_contextmanager_attribs(self): - def attribs(**kw): - def decorate(func): - for k, v in kw.items(): - setattr(func, k, v) - return func - - return decorate - - @contextmanager - @attribs(foo="bar") - def baz(spam): - """Whee!""" - - return baz - - @unittest.expectedFailure - @make_dynamo_test - def test_contextmanager_attribs(self): - baz = self._create_contextmanager_attribs() - self.assertEqual(baz.__name__, "baz") - self.assertEqual(baz.foo, "bar") - - @make_dynamo_test - def test_keywords(self): - # Ensure no keyword arguments are inhibited - @contextmanager - def woohoo(self, func, args, kwds): - yield (self, func, args, kwds) - - with woohoo(self=11, func=22, args=33, kwds=44) as target: - self.assertEqual(target, (11, 22, 33, 44)) - - @unittest.expectedFailure - @make_dynamo_test - def test_param_errors(self): - @contextmanager - def woohoo(a, *, b): - yield - - with self.assertRaises(TypeError): - woohoo() - with self.assertRaises(TypeError): - woohoo(3, 5) - with self.assertRaises(TypeError): - woohoo(b=3) - - @make_dynamo_test - def test_recursive(self): - depth = 0 - - @contextmanager - def woohoo(): - nonlocal depth - before = depth - depth += 1 - yield - depth -= 1 - self.assertEqual(depth, before) - - @woohoo() - def recursive(): - if depth < 10: - recursive() - - recursive() - self.assertEqual(depth, 0) - - @skipIfNotPy311 - @make_dynamo_test - def test_contextmanager_trap_no_yield(self): - @contextmanager - def whoo(): - if False: - yield - - ctx = whoo() - with self.assertRaises(RuntimeError): - ctx.__enter__() - - @make_dynamo_test - def test_contextmanager_trap_second_yield(self): - @contextmanager - def whoo(): - yield - yield - - ctx = whoo() - ctx.__enter__() - with self.assertRaises(RuntimeError): - ctx.__exit__(None, None, None) - - @unittest.expectedFailure - @make_dynamo_test - def test_contextmanager_wrap_runtimeerror(self): - @contextmanager - def woohoo(): - try: - yield - except Exception as exc: - raise RuntimeError(f"caught {exc}") from exc - - with self.assertRaises(RuntimeError): - with woohoo(): - 1 / 0 - - # If the context manager wrapped StopIteration in a RuntimeError, - # we also unwrap it, because we can't tell whether the wrapping was - # done by the generator machinery or by the generator itself. - with self.assertRaises(StopIteration): - with woohoo(): - raise StopIteration - - @make_dynamo_test - def test_contextmanager_non_normalised(self): - @contextmanager - def whoo(): - try: - yield - except RuntimeError: - raise SyntaxError # noqa: B904 - - ctx = whoo() - ctx.__enter__() - with self.assertRaises(SyntaxError): - ctx.__exit__(RuntimeError, None, None) - - instantiate_parametrized_tests(CtxManagerTests) instantiate_parametrized_tests(ContextlibContextManagerTests) diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index 926b227cd8094b..a5a4d8fa009259 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -2,7 +2,6 @@ import contextlib import sys -import unittest import torch import torch._dynamo.config @@ -905,238 +904,6 @@ def test_raise_set___context__(self): assert exc2.__context__ is None -class CPythonExceptionTests(torch._dynamo.test_case.TestCase): - # Tests taken from CPython source code in cpython/Lib/test/test_exceptions.py - # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_exceptions.py - def setUp(self): - self._u_prev = torch._dynamo.config.enable_trace_unittest - torch._dynamo.config.enable_trace_unittest = True - - def tearDown(self): - torch._dynamo.config.enable_trace_unittest = self._u_prev - - @make_dynamo_test - def testChainingAttrs(self): - e = Exception() - assert e.__context__ is None - assert e.__cause__ is None - - e = TypeError() - assert e.__context__ is None - assert e.__cause__ is None - - e = MyException() - assert e.__context__ is None - assert e.__cause__ is None - - @make_dynamo_test - def testChainingDescriptors(self): - try: - raise Exception # noqa: TRY002 - except Exception as exc: - e = exc - - assert e.__context__ is None - assert e.__cause__ is None - assert e.__suppress_context__ is False - - e.__context__ = NameError() - e.__cause__ = None - assert isinstance(e.__context__, NameError) - assert e.__cause__ is None - assert e.__suppress_context__ is True - e.__suppress_context__ = False - assert e.__suppress_context__ is False - - @make_dynamo_test - def test_context_of_exception_in_try_and_finally(self): - try: - try: - te = TypeError(1) - raise te - finally: - ve = ValueError(2) - raise ve - except Exception as e: - exc = e - - assert exc is ve - assert exc.__context__ is te - - @make_dynamo_test - def test_context_of_exception_in_except_and_finally(self): - try: - try: - te = TypeError(1) - raise te - except Exception: # noqa: E722 - ve = ValueError(2) - raise ve # noqa: B904 - finally: - oe = OSError(3) - raise oe - except Exception as e: - exc = e - - assert exc is oe - assert exc.__context__ is ve - assert exc.__context__.__context__ is te - - @make_dynamo_test - def test_context_of_exception_in_else_and_finally(self): - try: - try: - pass - except Exception: # noqa: E722 - pass - else: - ve = ValueError(1) - raise ve - finally: - oe = OSError(2) - raise oe - except Exception as e: - exc = e - - assert exc is oe - assert exc.__context__ is ve - - @make_dynamo_test - def test_raise_does_not_create_context_chain_cycle(self): - A = AssertionError - B = BytesWarning - C = ConnectionError - - # Create a context chain: - # C -> B -> A - # Then raise A in context of C. - try: - try: - raise A - except A as a_: - a = a_ - try: - raise B - except B as b_: - b = b_ - try: - raise C - except C as c_: - c = c_ - self.assertIsInstance(a, A) - self.assertIsInstance(b, B) - self.assertIsInstance(c, C) - self.assertIsNone(a.__context__) - self.assertIs(b.__context__, a) - self.assertIs(c.__context__, b) - raise a # noqa: B904 - except A as e: - exc = e - - # Expect A -> C -> B, without cycle - self.assertIs(exc, a) - self.assertIs(a.__context__, c) - self.assertIs(c.__context__, b) - self.assertIsNone(b.__context__) - - @make_dynamo_test - def test_no_hang_on_context_chain_cycle1(self): - # See issue 25782. Cycle in context chain. - - def cycle(): - try: - raise ValueError(1) - except ValueError as ex: - ex.__context__ = ex - raise TypeError(2) # noqa: B904 - - try: - cycle() - except Exception as e: - exc = e - - self.assertIsInstance(exc, TypeError) - self.assertIsInstance(exc.__context__, ValueError) - self.assertIs(exc.__context__.__context__, exc.__context__) - - @unittest.expectedFailure - @make_dynamo_test - def test_no_hang_on_context_chain_cycle2(self): - # See issue 25782. Cycle at head of context chain. - - A = AssertionError - B = BytesWarning - C = ConnectionError - - # Context cycle: - # +-----------+ - # V | - # C --> B --> A - with self.assertRaises(C) as cm: - try: - raise A() # noqa: RSE102 - except A as _a: - a = _a - try: - raise B() # noqa: RSE102 - except B as _b: - b = _b - try: - raise C() # noqa: RSE102 - except C as _c: - c = _c - a.__context__ = c - raise c # noqa: B904 - - self.assertIs(cm.exception, c) - # Verify the expected context chain cycle - self.assertIs(c.__context__, b) - self.assertIs(b.__context__, a) - self.assertIs(a.__context__, c) - - @make_dynamo_test - def test_no_hang_on_context_chain_cycle3(self): - # See issue 25782. Longer context chain with cycle. - A = AssertionError - B = BytesWarning - C = ConnectionError - D = DeprecationWarning - E = Exception - - # Context cycle: - # +-----------+ - # V | - # E --> D --> C --> B --> A - with self.assertRaises(E) as cm: - try: - raise A - except A as _a: - a = _a - try: - raise B - except B as _b: - b = _b - try: - raise C - except C as _c: - c = _c - a.__context__ = c - try: - raise D - except D as _d: - d = _d - e = E() - raise e # noqa: B904 - - self.assertIs(cm.exception, e) - # Verify the expected context chain cycle - self.assertIs(e.__context__, d) - self.assertIs(d.__context__, c) - self.assertIs(c.__context__, b) - self.assertIs(b.__context__, a) - self.assertIs(a.__context__, c) - - instantiate_parametrized_tests(ExceptionTests) diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index fb1697286f6e02..adf1e5aff0d398 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -1481,331 +1481,6 @@ def fn(t): self._compile_check(fn) -class GeneratorCloseCPythonTests(GeneratorTestsBase): - # Taken from commit - # https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10 - # changed the tests a little bit to run them inside dynamo - # + replaced all self.assert* calls to plain assert statements - - def test_close_no_return_value(self): - def f(): - yield - - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - gen = f() - gen.send(None) - assert gen.close() is None - return t.sin() - - t = torch.randn(2) - fn(t) - - def test_close_return_value(self): - def f(): - try: - yield - # close() raises GeneratorExit here, which is caught - except GeneratorExit: - return 0 # noqa: B901 - - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - gen = f() - gen.send(None) - assert gen.close() == 0 - return t.sin() - - t = torch.randn(2) - fn(t) - - def test_close_not_catching_exit(self): - def f(): - yield - # close() raises GeneratorExit here, which isn't caught and - # therefore propagates -- no return value - return 0 # noqa: B901 - - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - gen = f() - gen.send(None) - assert gen.close() is None - return t.sin() - - t = torch.randn(2) - fn(t) - - def test_close_not_started(self): - def f(): - try: - yield - except GeneratorExit: - return 0 # noqa: B901 - - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - gen = f() - assert gen.close() is None - return t.sin() - - t = torch.randn(2) - fn(t) - - def test_close_exhausted(self): - def f(): - try: - yield - except GeneratorExit: - return 0 # noqa: B901 - - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - gen = f() - next(gen) - z = 0 - try: - next(gen) # -> StopIteration - except StopIteration: - z = 1 - except Exception as e: - # anything other than StopIteration should fail - raise AssertionError from e - assert z == 1 - assert gen.close() is None - return t.sin() - - t = torch.randn(2) - fn(t) - - def test_close_closed(self): - def f(): - try: - yield - except GeneratorExit: - return 0 # noqa: B901 - - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - gen = f() - gen.send(None) - assert gen.close() == 0 - assert gen.close() is None - return t.sin() - - t = torch.randn(2) - fn(t) - - def test_close_raises(self): - def f(): - try: - yield - except GeneratorExit: - pass - raise RuntimeError - - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - gen = f() - gen.send(None) - z = 0 - try: - gen.close() # -> RuntimeError - except RuntimeError: - z = 1 - except Exception as e: - raise AssertionError from e - assert z == 1 - return t.sin() - - t = torch.randn(2) - fn(t) - - -class GeneratorThrowCpythonTests(GeneratorTestsBase): - # Taken from commit - # https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10 - # changed the tests a little bit to run them inside dynamo - # + replaced all self.assert* calls to plain assert statements - - def test_exception_context_with_yield(self): - def f(): - try: - raise KeyError("a") - except Exception: - yield - - def fn(t): - gen = f() - gen.send(None) - try: - gen.throw(ValueError) - except ValueError as e: - context = e.__context__ - assert (type(context), context.args) == (KeyError, ("a",)) - except Exception as e: - raise AssertionError from e - return t.sin() - - self._compile_check(fn) - - def test_exception_context_with_yield_inside_generator(self): - # Check that the context is also available from inside the generator - # with yield, as opposed to outside. - def f(): - z = 0 - try: - raise KeyError("a") - except Exception: - try: - yield - except Exception as exc: - z = 1 - assert type(exc) == ValueError - context = exc.__context__ - assert (type(context), context.args) == (KeyError, ("a",)) - yield "b" - finally: - assert z == 1 - - def fn(t): - gen = f() - gen.send(None) - actual = gen.throw(ValueError) - # This ensures that the assertions inside were executed. - assert actual == "b" - return t.sin() - - self._compile_check(fn) - - def test_exception_context_with_yield_from(self): - def f(): - yield - - def g(): - try: - raise KeyError("a") - except Exception: - yield from f() - - def fn(t): - gen = g() - gen.send(None) - try: - gen.throw(ValueError) - except ValueError as e: - context = e.__context__ - assert (type(context), context.args) == (KeyError, ("a",)) - except Exception as e: - raise AssertionError from e - return t.sin() - - self._compile_check(fn) - - def test_exception_context_with_yield_from_with_context_cycle(self): - # Check trying to create an exception context cycle: - # https://bugs.python.org/issue40696 - has_cycle = None - - def f(): - yield - - def g(exc): - nonlocal has_cycle - try: - raise exc - except Exception: - try: - yield from f() - except Exception as exc: - has_cycle = exc is exc.__context__ - yield - - def fn(t): - exc = KeyError("a") - gen = g(exc) - gen.send(None) - gen.throw(exc) - # This also distinguishes from the initial has_cycle=None. - assert has_cycle is False - return t.sin() - - self._compile_check(fn) - - def test_throw_after_none_exc_type(self): - def g(): - try: - raise KeyError - except KeyError: - pass - - try: - yield - except Exception: - raise RuntimeError # noqa: B904 - - def fn(t): - gen = g() - gen.send(None) - z = 0 - try: - gen.throw(ValueError) - except RuntimeError: - z += 1 - except Exception: - raise AssertionError # noqa: B904 - assert z == 1 - return t.sin() - - self._compile_check(fn) - - -class GeneratorCPythonTests(GeneratorTestsBase): - # Taken from commit - # https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10 - # changed the tests a little bit to run them inside dynamo - # + replaced all self.assert* calls to plain assert statements - - def test_send_non_none_to_new_gen(self): - def f(): - yield 1 - - def fn(t): - g = f() - z = 0 - try: - g.send(0) - except TypeError: - z += 1 - except Exception as e: - raise AssertionError from e - assert z == 1 - assert next(g) == 1 - return t.sin() - - self._compile_check(fn) - - def test_issue103488(self): - def gen_raises(): - yield 1 - raise ValueError - - def loop(): - try: - for _ in gen_raises(): - if True is False: # noqa: PLR0133 - return - except ValueError: - pass - - def fn(t): - # This should not raise - loop() - return t.sin() - - self._compile_check(fn) - - instantiate_parametrized_tests(GeneratorTests) instantiate_parametrized_tests(TestGeneratorSend) instantiate_parametrized_tests(TestGeneratorClose) diff --git a/test/dynamo/test_generator_stop.py b/test/dynamo/test_generator_stop.py deleted file mode 100644 index 7091d3d371378f..00000000000000 --- a/test/dynamo/test_generator_stop.py +++ /dev/null @@ -1,52 +0,0 @@ -# Owner(s): ["module: dynamo"] - -import sys -import unittest - -import torch -import torch._dynamo.test_case -from torch.testing._internal.common_utils import make_dynamo_test - - -class TestPEP479(torch._dynamo.test_case.CPythonTestCase): - # Tests taken from CPython source code in cpython/Lib/test/test_generator_stop.py - # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_generator_stop.py - @unittest.skipIf(sys.version_info < (3, 12), "Test does not work in Python < 3.12") - @make_dynamo_test - def test_stopiteration_wrapping(self): - def f(): - raise StopIteration - - def g(): - yield f() - - with self.assertRaises(RuntimeError) as cm: - next(g()) - self.assertEqual("generator raised StopIteration", str(cm.exception)) - - @unittest.skipIf(sys.version_info < (3, 12), "Test does not work in Python < 3.12") - @make_dynamo_test - def test_stopiteration_wrapping_context(self): - def f(): - raise StopIteration - - def g(): - yield f() - - try: - next(g()) - except RuntimeError as exc: - self.assertIs(type(exc.__cause__), StopIteration) - self.assertIs(type(exc.__context__), StopIteration) - self.assertTrue(exc.__suppress_context__) - else: - self.fail( - "__cause__, __context__, or __suppress_context__ " - "were not properly set" - ) - - -if __name__ == "__main__": - from torch._dynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_raise.py b/test/dynamo/test_raise.py deleted file mode 100644 index 9a95d23226c022..00000000000000 --- a/test/dynamo/test_raise.py +++ /dev/null @@ -1,563 +0,0 @@ -# Owner(s): ["module: dynamo"] - -# ruff: noqa -# flake8: noqa - -import sys -import types -import unittest - -import torch -import torch._dynamo.config -import torch._dynamo.test_case -import torch._functorch.config -import torch.nn -import torch.utils.checkpoint -from torch.testing._internal.common_utils import make_dynamo_test - - -def get_tb(): - try: - raise OSError() - except: - return sys.exc_info()[2] - - -class Context: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, exc_tb): - return True - - -class MyException(Exception): - def __init__(self): - raise RuntimeError() - - -class ContextManager: - def __enter__(self): - pass - - def __exit__(self, t, v, tb): - raise NameError - - -class TestRaise(torch._dynamo.test_case.CPythonTestCase): - # Tests taken from CPython source code in cpython/Lib/test/test_raise.py - # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py - @make_dynamo_test - def test_invalid_reraise(self): - try: - raise - except RuntimeError as e: - self.assertIn("No active exception", str(e)) - else: - self.fail("No exception raised") - - @make_dynamo_test - def test_reraise(self): - try: - try: - raise IndexError - except IndexError as e: - exc1 = e - raise - except IndexError as exc2: - self.assertIs(exc1, exc2) - else: - self.fail("No exception raised") - - @make_dynamo_test - def test_except_reraise(self): - def reraise(): - try: - raise TypeError("foo") - except: - try: - raise KeyError("caught") - except KeyError: - pass - raise - - self.assertRaises(TypeError, reraise) - - @make_dynamo_test - def test_finally_reraise(self): - def reraise(): - try: - raise TypeError("foo") - except: - try: - raise KeyError("caught") - finally: - raise - - self.assertRaises(KeyError, reraise) - - @make_dynamo_test - def test_nested_reraise(self): - def nested_reraise(): - raise - - def reraise(): - try: - raise TypeError("foo") - except: - nested_reraise() - - self.assertRaises(TypeError, reraise) - - @make_dynamo_test - def test_raise_from_None(self): - try: - try: - raise TypeError("foo") - except: - raise ValueError() from None - except ValueError as e: - self.assertIsInstance(e.__context__, TypeError) - self.assertIsNone(e.__cause__) - - @make_dynamo_test - def test_with_reraise1(self): - def reraise(): - try: - raise TypeError("foo") - except: - with Context(): - pass - raise - - self.assertRaises(TypeError, reraise) - - @make_dynamo_test - def test_with_reraise2(self): - def reraise(): - try: - raise TypeError("foo") - except: - with Context(): - raise KeyError("caught") - raise - - self.assertRaises(TypeError, reraise) - - @make_dynamo_test - def test_yield_reraise(self): - def reraise(): - try: - raise TypeError("foo") - except: - yield 1 - raise - - g = reraise() - next(g) - self.assertRaises(TypeError, lambda: next(g)) - self.assertRaises(StopIteration, lambda: next(g)) - - @make_dynamo_test - def test_erroneous_exception(self): - try: - raise MyException - except RuntimeError: - pass - else: - self.fail("No exception raised") - - @unittest.expectedFailure # object - @make_dynamo_test - def test_new_returns_invalid_instance(self): - # See issue #11627. - class MyException2(Exception): - def __new__(cls, *args): - return object() - - with self.assertRaises(TypeError): - raise MyException2 - - @unittest.expectedFailure # Assertion with non-string message - @make_dynamo_test - def test_assert_with_tuple_arg(self): - try: - assert False, (3,) - except AssertionError as e: - self.assertEqual(str(e), "(3,)") - - -class TestCause(torch._dynamo.test_case.TestCase): - # Tests taken from CPython source code in cpython/Lib/test/test_raise.py - # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py - def setUp(self): - self._prev = torch._dynamo.config.enable_trace_unittest - torch._dynamo.config.enable_trace_unittest = True - - def tearDown(self): - torch._dynamo.config.enable_trace_unittest = self._prev - - @make_dynamo_test - def testCauseSyntax(self): - try: - try: - try: - raise TypeError - except Exception: - raise ValueError from None - except ValueError as exc: - self.assertIsNone(exc.__cause__) - self.assertTrue(exc.__suppress_context__) - exc.__suppress_context__ = False - raise exc - except ValueError as exc: - e = exc - - self.assertIsNone(e.__cause__) - self.assertFalse(e.__suppress_context__) - self.assertIsInstance(e.__context__, TypeError) - - @make_dynamo_test - def test_invalid_cause(self): - try: - raise IndexError from 5 - except TypeError as e: - self.assertIn("exception cause", str(e)) - else: - self.fail("No exception raised") - - @make_dynamo_test - def test_class_cause(self): - try: - raise IndexError from KeyError - except IndexError as e: - self.assertIsInstance(e.__cause__, KeyError) - else: - self.fail("No exception raised") - - @make_dynamo_test - def test_instance_cause(self): - cause = KeyError() - try: - raise IndexError from cause - except IndexError as e: - self.assertIs(e.__cause__, cause) - else: - self.fail("No exception raised") - - @make_dynamo_test - def test_erroneous_cause(self): - try: - raise IndexError from MyException - except RuntimeError: - pass - else: - self.fail("No exception raised") - - -class TestTraceback(torch._dynamo.test_case.TestCase): - # Tests taken from CPython source code in cpython/Lib/test/test_raise.py - # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py - def setUp(self): - self._prev = torch._dynamo.config.enable_trace_unittest - torch._dynamo.config.enable_trace_unittest = True - - def tearDown(self): - torch._dynamo.config.enable_trace_unittest = self._prev - - @unittest.expectedFailure # Dynamo doesn't track traceback - @make_dynamo_test - def test_sets_traceback(self): - try: - raise IndexError() - except IndexError as e: - self.assertIsInstance(e.__traceback__, types.TracebackType) - else: - self.fail("No exception raised") - - @unittest.expectedFailure # Dynamo doesn't track traceback - @make_dynamo_test - def test_accepts_traceback(self): - tb = get_tb() - try: - raise IndexError().with_traceback(tb) - except IndexError as e: - self.assertNotEqual(e.__traceback__, tb) - self.assertEqual(e.__traceback__.tb_next, tb) - else: - self.fail("No exception raised") - - -class TestTracebackType(torch._dynamo.test_case.TestCase): - # Tests taken from CPython source code in cpython/Lib/test/test_raise.py - # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py - def setUp(self): - self._prev = torch._dynamo.config.enable_trace_unittest - torch._dynamo.config.enable_trace_unittest = True - - def tearDown(self): - torch._dynamo.config.enable_trace_unittest = self._prev - - def raiser(self): - raise ValueError - - @unittest.expectedFailure # Dynamo doesn't track traceback - @make_dynamo_test - def test_attrs(self): - try: - self.raiser() - except Exception as exc: - tb = exc.__traceback__ - - self.assertIsInstance(tb.tb_next, types.TracebackType) - self.assertIs(tb.tb_frame, sys._getframe()) - self.assertIsInstance(tb.tb_lasti, int) - self.assertIsInstance(tb.tb_lineno, int) - - self.assertIs(tb.tb_next.tb_next, None) - - # Invalid assignments - with self.assertRaises(TypeError): - del tb.tb_next - - with self.assertRaises(TypeError): - tb.tb_next = "asdf" - - # Loops - with self.assertRaises(ValueError): - tb.tb_next = tb - - with self.assertRaises(ValueError): - tb.tb_next.tb_next = tb - - # Valid assignments - tb.tb_next = None - self.assertIs(tb.tb_next, None) - - new_tb = get_tb() - tb.tb_next = new_tb - self.assertIs(tb.tb_next, new_tb) - - @unittest.expectedFailure # Dynamo doesn't track traceback - @make_dynamo_test - def test_constructor(self): - other_tb = get_tb() - frame = sys._getframe() - - tb = types.TracebackType(other_tb, frame, 1, 2) - self.assertEqual(tb.tb_next, other_tb) - self.assertEqual(tb.tb_frame, frame) - self.assertEqual(tb.tb_lasti, 1) - self.assertEqual(tb.tb_lineno, 2) - - tb = types.TracebackType(None, frame, 1, 2) - self.assertEqual(tb.tb_next, None) - - with self.assertRaises(TypeError): - types.TracebackType("no", frame, 1, 2) - - with self.assertRaises(TypeError): - types.TracebackType(other_tb, "no", 1, 2) - - with self.assertRaises(TypeError): - types.TracebackType(other_tb, frame, "no", 2) - - with self.assertRaises(TypeError): - types.TracebackType(other_tb, frame, 1, "nuh-uh") - - -class TestContext(torch._dynamo.test_case.TestCase): - # Tests taken from CPython source code in cpython/Lib/test/test_raise.py - # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py - def setUp(self): - self._prev = torch._dynamo.config.enable_trace_unittest - torch._dynamo.config.enable_trace_unittest = True - - def tearDown(self): - torch._dynamo.config.enable_trace_unittest = self._prev - - @unittest.expectedFailure # missing Exception.__eq__ - @make_dynamo_test - def test_instance_context_instance_raise(self): - context = IndexError() - try: - try: - raise context - except: - raise OSError() - except OSError as e: - self.assertEqual(e.__context__, context) - else: - self.fail("No exception raised") - - @unittest.expectedFailure # missing Exception.__eq__ and Exception.__repr__ - @make_dynamo_test - def test_class_context_instance_raise(self): - context = IndexError - try: - try: - raise context - except: - raise OSError() - except OSError as e: - self.assertNotEqual(e.__context__, context) - self.assertIsInstance(e.__context__, context) - else: - self.fail("No exception raised") - - @unittest.expectedFailure # missing Exception.__eq__ and Exception.__repr__ - @make_dynamo_test - def test_class_context_class_raise(self): - context = IndexError - try: - try: - raise context - except: - raise OSError - except OSError as e: - self.assertNotEqual(e.__context__, context) - self.assertIsInstance(e.__context__, context) - else: - self.fail("No exception raised") - - @make_dynamo_test - def test_c_exception_context(self): - try: - try: - raise ZeroDivisionError - except: - raise OSError - except OSError as e: - self.assertIsInstance(e.__context__, ZeroDivisionError) - else: - self.fail("No exception raised") - - @make_dynamo_test - def test_c_exception_raise(self): - try: - try: - raise ZeroDivisionError - except: - raise NameError - except NameError as e: - self.assertIsInstance(e.__context__, ZeroDivisionError) - else: - self.fail("No exception raised") - - @make_dynamo_test - def test_noraise_finally(self): - try: - try: - pass - finally: - raise OSError - except OSError as e: - self.assertIsNone(e.__context__) - else: - self.fail("No exception raised") - - @make_dynamo_test - def test_raise_finally(self): - try: - try: - raise ZeroDivisionError - finally: - raise OSError - except OSError as e: - self.assertIsInstance(e.__context__, ZeroDivisionError) - else: - self.fail("No exception raised") - - @make_dynamo_test - def test_context_manager(self): - try: - with ContextManager(): - raise ZeroDivisionError - except NameError as e: - self.assertIsInstance(e.__context__, ZeroDivisionError) - else: - self.fail("No exception raised") - - @make_dynamo_test - def test_cycle_broken(self): - # Self-cycles (when re-raising a caught exception) are broken - try: - try: - raise ZeroDivisionError - except ZeroDivisionError as e: - raise e - except ZeroDivisionError as e: - self.assertIsNone(e.__context__) - - @make_dynamo_test - def test_reraise_cycle_broken(self): - # Non-trivial context cycles (through re-raising a previous exception) - # are broken too. - try: - try: - raise NameError - except NameError as a: - try: - raise ZeroDivisionError - except ZeroDivisionError: - raise a - except NameError as e: - self.assertIsNone(e.__context__.__context__) - - @make_dynamo_test - def test_3118(self): - # deleting the generator caused the __context__ to be cleared - def gen(): - try: - yield 1 - finally: - pass - - def f(): - g = gen() - next(g) - try: - try: - raise ValueError - except: - del g - raise KeyError - except Exception as e: - self.assertIsInstance(e.__context__, ValueError) - - f() - - @unittest.expectedFailure # too CPython specific(?) - @make_dynamo_test - def test_3611(self): - # A re-raised exception in a __del__ caused the __context__ - # to be cleared - class C: - def __del__(self): - try: - raise ZeroDivisionError - except: - raise - - def f(): - x = C() - try: - try: - x.x - except AttributeError: - del x - raise TypeError - except Exception as e: - self.assertNotEqual(e.__context__, None) - self.assertIsInstance(e.__context__, AttributeError) - - with support.catch_unraisable_exception() as cm: - f() - - self.assertEqual(ZeroDivisionError, cm.unraisable.exc_type) - - -if __name__ == "__main__": - from torch._dynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_sys.py b/test/dynamo/test_sys.py deleted file mode 100644 index 3b72ecb36d998c..00000000000000 --- a/test/dynamo/test_sys.py +++ /dev/null @@ -1,107 +0,0 @@ -# Owner(s): ["module: dynamo"] -import sys -import unittest - -import torch -import torch._dynamo.test_case -from torch.testing._internal.common_utils import make_dynamo_test - - -class SysTests(torch._dynamo.test_case.TestCase): - def test_exc_info(self): - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - try: - raise ValueError - except Exception: - typ, _, _ = sys.exc_info() - if typ is ValueError: - return t.sin() - else: - return t.cos() - - t = torch.randn(2) - y = fn(t) - self.assertEqual(y, t.sin()) - - -class CPythonActiveExceptionTests(torch._dynamo.test_case.CPythonTestCase): - # Tests taken from CPython source code in cpython/Lib/test/test_sys.py - # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_sys.py - - @make_dynamo_test - def test_exc_info_no_exception(self): - self.assertEqual(sys.exc_info(), (None, None, None)) - - @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") - @make_dynamo_test - def test_sys_exception_no_exception(self): - self.assertEqual(sys.exception(), None) - - @make_dynamo_test - def test_exc_info_with_exception_instance(self): - def f(): - raise ValueError(42) - - try: - f() - except Exception as e_: - e = e_ - exc_info = sys.exc_info() - - self.assertIsInstance(e, ValueError) - self.assertIs(exc_info[0], ValueError) - self.assertIs(exc_info[1], e) - self.assertIs(exc_info[2], e.__traceback__) - - @make_dynamo_test - def test_exc_info_with_exception_type(self): - def f(): - raise ValueError - - try: - f() - except Exception as e_: - e = e_ - exc_info = sys.exc_info() - - self.assertIsInstance(e, ValueError) - self.assertIs(exc_info[0], ValueError) - self.assertIs(exc_info[1], e) - self.assertIs(exc_info[2], e.__traceback__) - - @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") - @make_dynamo_test - def test_sys_exception_with_exception_instance(self): - def f(): - raise ValueError(42) - - try: - f() - except Exception as e_: - e = e_ - exc = sys.exception() - - self.assertIsInstance(e, ValueError) - self.assertIs(exc, e) - - @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") - @make_dynamo_test - def test_sys_exception_with_exception_type(self): - def f(): - raise ValueError - - try: - f() - except Exception as e_: - e = e_ - exc = sys.exception() - - self.assertIsInstance(e, ValueError) - self.assertIs(exc, e) - - -if __name__ == "__main__": - from torch._dynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_unittest.py b/test/dynamo/test_unittest.py index 244785e01bcd99..de1364d424d7f4 100644 --- a/test/dynamo/test_unittest.py +++ b/test/dynamo/test_unittest.py @@ -1,8 +1,5 @@ # Owner(s): ["module: dynamo"] -import sys import unittest -import warnings -from itertools import product import torch import torch._dynamo.test_case @@ -28,591 +25,6 @@ def test_SkipTest(self): self.assertEqual(z, 1) -class CPythonTest_Assertions(torch._dynamo.test_case.CPythonTestCase): - # Tests taken from CPython source code in cpython/Lib/test/test_unittest/test_assertions.py - # https://github.com/python/cpython/blob/3.13/Lib/test/test_unittest/test_assertions.py - - @make_dynamo_test - def test_AlmostEqual(self): - self.assertAlmostEqual(1.00000001, 1.0) - self.assertNotAlmostEqual(1.0000001, 1.0) - self.assertRaises(self.failureException, self.assertAlmostEqual, 1.0000001, 1.0) - self.assertRaises( - self.failureException, self.assertNotAlmostEqual, 1.00000001, 1.0 - ) - - self.assertAlmostEqual(1.1, 1.0, places=0) - self.assertRaises( - self.failureException, self.assertAlmostEqual, 1.1, 1.0, places=1 - ) - - self.assertAlmostEqual(0, 0.1 + 0.1j, places=0) - self.assertNotAlmostEqual(0, 0.1 + 0.1j, places=1) - self.assertRaises( - self.failureException, self.assertAlmostEqual, 0, 0.1 + 0.1j, places=1 - ) - self.assertRaises( - self.failureException, self.assertNotAlmostEqual, 0, 0.1 + 0.1j, places=0 - ) - - self.assertAlmostEqual(float("inf"), float("inf")) - self.assertRaises( - self.failureException, self.assertNotAlmostEqual, float("inf"), float("inf") - ) - - @make_dynamo_test - def test_AmostEqualWithDelta(self): - self.assertAlmostEqual(1.1, 1.0, delta=0.5) - self.assertAlmostEqual(1.0, 1.1, delta=0.5) - self.assertNotAlmostEqual(1.1, 1.0, delta=0.05) - self.assertNotAlmostEqual(1.0, 1.1, delta=0.05) - - self.assertAlmostEqual(1.0, 1.0, delta=0.5) - self.assertRaises( - self.failureException, self.assertNotAlmostEqual, 1.0, 1.0, delta=0.5 - ) - - self.assertRaises( - self.failureException, self.assertAlmostEqual, 1.1, 1.0, delta=0.05 - ) - self.assertRaises( - self.failureException, self.assertNotAlmostEqual, 1.1, 1.0, delta=0.5 - ) - - self.assertRaises( - TypeError, self.assertAlmostEqual, 1.1, 1.0, places=2, delta=2 - ) - self.assertRaises( - TypeError, self.assertNotAlmostEqual, 1.1, 1.0, places=2, delta=2 - ) - - @make_dynamo_test - def test_assertRaises(self): - def _raise(e): - raise e - - self.assertRaises(KeyError, _raise, KeyError) - self.assertRaises(KeyError, _raise, KeyError("key")) - try: - self.assertRaises(KeyError, lambda: None) - except self.failureException as e: - self.assertIn("KeyError not raised", str(e)) - else: - self.fail("assertRaises() didn't fail") - try: - self.assertRaises(KeyError, _raise, ValueError) - except ValueError: - pass - else: - self.fail("assertRaises() didn't let exception pass through") - with self.assertRaises(KeyError) as cm: - try: - raise KeyError - except Exception as e: - exc = e - raise - self.assertIs(cm.exception, exc) - - with self.assertRaises(KeyError): - raise KeyError("key") - try: - with self.assertRaises(KeyError): - pass - except self.failureException as e: - self.assertIn("KeyError not raised", str(e)) - else: - self.fail("assertRaises() didn't fail") - try: - with self.assertRaises(KeyError): - raise ValueError - except ValueError: - pass - else: - self.fail("assertRaises() didn't let exception pass through") - - @make_dynamo_test - def testAssertNotRegex(self): - self.assertNotRegex("Ala ma kota", r"r+") - try: - self.assertNotRegex("Ala ma kota", r"k.t", "Message") - except self.failureException as e: - self.assertIn("Message", e.args[0]) - else: - self.fail("assertNotRegex should have failed.") - - -class CPythonTestLongMessage(torch._dynamo.test_case.CPythonTestCase): - """Test that the individual asserts honour longMessage. - This actually tests all the message behaviour for - asserts that use longMessage.""" - - def setUp(self): - super().setUp() - - class TestableTestFalse(unittest.TestCase): - longMessage = False - failureException = self.failureException - - def testTest(self): - pass - - class TestableTestTrue(unittest.TestCase): - longMessage = True - failureException = self.failureException - - def testTest(self): - pass - - self.testableTrue = TestableTestTrue("testTest") - self.testableFalse = TestableTestFalse("testTest") - - def testDefault(self): - self.assertTrue(unittest.TestCase.longMessage) - - def test_formatMsg(self): - self.assertEqual(self.testableFalse._formatMessage(None, "foo"), "foo") - self.assertEqual(self.testableFalse._formatMessage("foo", "bar"), "foo") - - self.assertEqual(self.testableTrue._formatMessage(None, "foo"), "foo") - self.assertEqual(self.testableTrue._formatMessage("foo", "bar"), "bar : foo") - - # This blows up if _formatMessage uses string concatenation - self.testableTrue._formatMessage(object(), "foo") - - def test_formatMessage_unicode_error(self): - one = "".join(chr(i) for i in range(255)) - # this used to cause a UnicodeDecodeError constructing msg - self.testableTrue._formatMessage(one, "\uFFFD") - - def assertMessages(self, methodName, args, errors): - """ - Check that methodName(*args) raises the correct error messages. - errors should be a list of 4 regex that match the error when: - 1) longMessage = False and no msg passed; - 2) longMessage = False and msg passed; - 3) longMessage = True and no msg passed; - 4) longMessage = True and msg passed; - """ - - def getMethod(i): - useTestableFalse = i < 2 - if useTestableFalse: - test = self.testableFalse - else: - test = self.testableTrue - return getattr(test, methodName) - - for i, expected_regex in enumerate(errors): - testMethod = getMethod(i) - kwargs = {} - withMsg = i % 2 - if withMsg: - kwargs = {"msg": "oops"} - - # with self.assertRaisesRegex( - # self.failureException, expected_regex=expected_regex - # ): - # testMethod(*args, **kwargs) - with self.assertRaises(self.failureException) as cm: - testMethod(*args, **kwargs) - self.assertRegex(str(cm.exception), expected_regex) - - @make_dynamo_test - def testAssertTrue(self): - self.assertMessages( - "assertTrue", - (False,), - [ - "False is not true", - "oops", - "False is not true", - "False is not true : oops", - ], - ) - - @make_dynamo_test - def testAssertFalse(self): - self.assertMessages( - "assertFalse", - (True,), - [ - "True is not false", - "oops", - "True is not false", - "True is not false : oops", - ], - ) - - @make_dynamo_test - def testNotEqual(self): - self.assertMessages( - "assertNotEqual", (1, 1), ["1 == 1", "oops", "1 == 1", "1 == 1 : oops"] - ) - - @make_dynamo_test - def testAlmostEqual(self): - self.assertMessages( - "assertAlmostEqual", - (1, 2), - [ - r"^1 != 2 within 7 places \(1 difference\)$", - "^oops$", - r"^1 != 2 within 7 places \(1 difference\)$", - r"^1 != 2 within 7 places \(1 difference\) : oops$", - ], - ) - - @make_dynamo_test - def testNotAlmostEqual(self): - self.assertMessages( - "assertNotAlmostEqual", - (1, 1), - [ - "^1 == 1 within 7 places$", - "^oops$", - "^1 == 1 within 7 places$", - "^1 == 1 within 7 places : oops$", - ], - ) - - @make_dynamo_test - def test_baseAssertEqual(self): - self.assertMessages( - "_baseAssertEqual", - (1, 2), - ["^1 != 2$", "^oops$", "^1 != 2$", "^1 != 2 : oops$"], - ) - - @unittest.expectedFailure - @make_dynamo_test - def testAssertSequenceEqual(self): - # Error messages are multiline so not testing on full message - # assertTupleEqual and assertListEqual delegate to this method - self.assertMessages( - "assertSequenceEqual", - ([], [None]), - [r"\+ \[None\]$", "^oops$", r"\+ \[None\]$", r"\+ \[None\] : oops$"], - ) - - @make_dynamo_test - def testAssertSetEqual(self): - self.assertMessages( - "assertSetEqual", - (set(), set([None])), # noqa: C405 - ["None$", "^oops$", "None$", "None : oops$"], - ) - - @make_dynamo_test - def testAssertIn(self): - self.assertMessages( - "assertIn", - (None, []), - [ - r"^None not found in \[\]$", - "^oops$", - r"^None not found in \[\]$", - r"^None not found in \[\] : oops$", - ], - ) - - @make_dynamo_test - def testAssertNotIn(self): - self.assertMessages( - "assertNotIn", - (None, [None]), - [ - r"^None unexpectedly found in \[None\]$", - "^oops$", - r"^None unexpectedly found in \[None\]$", - r"^None unexpectedly found in \[None\] : oops$", - ], - ) - - @unittest.expectedFailure - @make_dynamo_test - def testAssertDictEqual(self): - self.assertMessages( - "assertDictEqual", - ({}, {"key": "value"}), - [ - r"\+ \{'key': 'value'\}$", - "^oops$", - r"\+ \{'key': 'value'\}$", - r"\+ \{'key': 'value'\} : oops$", - ], - ) - - @unittest.expectedFailure - @make_dynamo_test - def testAssertMultiLineEqual(self): - self.assertMessages( - "assertMultiLineEqual", - ("", "foo"), - [r"\+ foo\n$", "^oops$", r"\+ foo\n$", r"\+ foo\n : oops$"], - ) - - @make_dynamo_test - def testAssertLess(self): - self.assertMessages( - "assertLess", - (2, 1), - [ - "^2 not less than 1$", - "^oops$", - "^2 not less than 1$", - "^2 not less than 1 : oops$", - ], - ) - - @make_dynamo_test - def testAssertLessEqual(self): - self.assertMessages( - "assertLessEqual", - (2, 1), - [ - "^2 not less than or equal to 1$", - "^oops$", - "^2 not less than or equal to 1$", - "^2 not less than or equal to 1 : oops$", - ], - ) - - @make_dynamo_test - def testAssertGreater(self): - self.assertMessages( - "assertGreater", - (1, 2), - [ - "^1 not greater than 2$", - "^oops$", - "^1 not greater than 2$", - "^1 not greater than 2 : oops$", - ], - ) - - @make_dynamo_test - def testAssertGreaterEqual(self): - self.assertMessages( - "assertGreaterEqual", - (1, 2), - [ - "^1 not greater than or equal to 2$", - "^oops$", - "^1 not greater than or equal to 2$", - "^1 not greater than or equal to 2 : oops$", - ], - ) - - @make_dynamo_test - def testAssertIsNone(self): - self.assertMessages( - "assertIsNone", - ("not None",), - [ - "^'not None' is not None$", - "^oops$", - "^'not None' is not None$", - "^'not None' is not None : oops$", - ], - ) - - @make_dynamo_test - def testAssertIsNotNone(self): - self.assertMessages( - "assertIsNotNone", - (None,), - [ - "^unexpectedly None$", - "^oops$", - "^unexpectedly None$", - "^unexpectedly None : oops$", - ], - ) - - @make_dynamo_test - def testAssertIs(self): - self.assertMessages( - "assertIs", - (None, "foo"), - [ - "^None is not 'foo'$", - "^oops$", - "^None is not 'foo'$", - "^None is not 'foo' : oops$", - ], - ) - - @make_dynamo_test - def testAssertIsNot(self): - self.assertMessages( - "assertIsNot", - (None, None), - [ - "^unexpectedly identical: None$", - "^oops$", - "^unexpectedly identical: None$", - "^unexpectedly identical: None : oops$", - ], - ) - - @make_dynamo_test - def testAssertRegex(self): - self.assertMessages( - "assertRegex", - ("foo", "bar"), - [ - "^Regex didn't match:", - "^oops$", - "^Regex didn't match:", - "^Regex didn't match: (.*) : oops$", - ], - ) - - @make_dynamo_test - def testAssertNotRegex(self): - self.assertMessages( - "assertNotRegex", - ("foo", "foo"), - [ - "^Regex matched:", - "^oops$", - "^Regex matched:", - "^Regex matched: (.*) : oops$", - ], - ) - - def assertMessagesCM(self, methodName, args, func, errors): - """ - Check that the correct error messages are raised while executing: - with method(*args): - func() - *errors* should be a list of 4 regex that match the error when: - 1) longMessage = False and no msg passed; - 2) longMessage = False and msg passed; - 3) longMessage = True and no msg passed; - 4) longMessage = True and msg passed; - """ - p = product((self.testableFalse, self.testableTrue), ({}, {"msg": "oops"})) - for (cls, kwargs), err in zip(p, errors): - method = getattr(cls, methodName) - # with self.assertRaisesRegex(cls.failureException, err): - with self.assertRaises(cls.failureException) as c: - with method(*args, **kwargs) as cm: # noqa: F841 - func() - self.assertRegex(str(c.exception), err) - - @make_dynamo_test - def testAssertRaises(self): - self.assertMessagesCM( - "assertRaises", - (TypeError,), - lambda: None, - [ - "^TypeError not raised$", - "^oops$", - "^TypeError not raised$", - "^TypeError not raised : oops$", - ], - ) - - @unittest.expectedFailure - @make_dynamo_test - def testAssertRaisesRegex(self): - self.assertMessagesCM( - "assertRaisesRegex", - (TypeError, "unused regex"), - lambda: None, - [ - "^TypeError not raised$", - "^oops$", - "^TypeError not raised$", - "^TypeError not raised : oops$", - ], - ) - - # test error raised but with wrong message - def raise_wrong_message(): - raise TypeError("foo") - - self.assertMessagesCM( - "assertRaisesRegex", - (TypeError, "regex"), - raise_wrong_message, - [ - '^"regex" does not match "foo"$', - "^oops$", - '^"regex" does not match "foo"$', - '^"regex" does not match "foo" : oops$', - ], - ) - - @unittest.expectedFailure - @make_dynamo_test - def testAssertWarns(self): - self.assertMessagesCM( - "assertWarns", - (UserWarning,), - lambda: None, - [ - "^UserWarning not triggered$", - "^oops$", - "^UserWarning not triggered$", - "^UserWarning not triggered : oops$", - ], - ) - - @unittest.expectedFailure - @unittest.skipIf(sys.version_info < (3, 13), "feature landed in 3.13") - @make_dynamo_test - def test_assertNotWarns(self): - def warn_future(): - warnings.warn("xyz", FutureWarning, stacklevel=2) - - self.assertMessagesCM( - "_assertNotWarns", - (FutureWarning,), - warn_future, - [ - "^FutureWarning triggered$", - "^oops$", - "^FutureWarning triggered$", - "^FutureWarning triggered : oops$", - ], - ) - - @unittest.expectedFailure - @make_dynamo_test - def testAssertWarnsRegex(self): - # test error not raised - self.assertMessagesCM( - "assertWarnsRegex", - (UserWarning, "unused regex"), - lambda: None, - [ - "^UserWarning not triggered$", - "^oops$", - "^UserWarning not triggered$", - "^UserWarning not triggered : oops$", - ], - ) - - # test warning raised but with wrong message - def raise_wrong_message(): - warnings.warn("foo") - - self.assertMessagesCM( - "assertWarnsRegex", - (UserWarning, "regex"), - raise_wrong_message, - [ - '^"regex" does not match "foo"$', - "^oops$", - '^"regex" does not match "foo"$', - '^"regex" does not match "foo" : oops$', - ], - ) - - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/run_test.py b/test/run_test.py index 9d6ea165c18a32..44ec58fd6b2ac7 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1593,6 +1593,13 @@ def get_selected_tests(options) -> list[str]: ] ) + if sys.version_info[:2] < (3, 13): + # Skip tests for older Python versions as they may use syntax or features + # not supported in those versions + options.exclude.extend( + [test for test in selected_tests if test.startswith("dynamo/cpython/3_13/")] + ) + selected_tests = exclude_tests(options.exclude, selected_tests) if sys.platform == "win32" and not options.ignore_win_blocklist: diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py index ac505c0de02a72..1da20cd32a07e8 100644 --- a/torch/_dynamo/test_case.py +++ b/torch/_dynamo/test_case.py @@ -1,3 +1,5 @@ +# mypy: allow-untyped-defs + """Testing utilities for Dynamo, providing a specialized TestCase class and test running functionality. This module extends PyTorch's testing framework with Dynamo-specific testing capabilities. @@ -10,8 +12,13 @@ import contextlib import importlib +import inspect import logging import os +import pathlib +import re +import sys +import unittest from typing import Union import torch @@ -98,7 +105,70 @@ def tearDown(self) -> None: class CPythonTestCase(TestCase): + """ + Test class for CPython tests located in "test/dynamo/CPython/Py_version/*". + + This class enables specific features that are disabled by default, such as + tracing through unittest methods. + """ + _stack: contextlib.ExitStack + dynamo_strict_nopython = True + + # Restore original unittest methods to simplify tracing CPython test cases. + assertEqual = unittest.TestCase.assertEqual # type: ignore[assignment] + assertNotEqual = unittest.TestCase.assertNotEqual # type: ignore[assignment] + assertTrue = unittest.TestCase.assertTrue + assertFalse = unittest.TestCase.assertFalse + assertIs = unittest.TestCase.assertIs + assertIsNot = unittest.TestCase.assertIsNot + assertIsNone = unittest.TestCase.assertIsNone + assertIsNotNone = unittest.TestCase.assertIsNotNone + assertIn = unittest.TestCase.assertIn + assertNotIn = unittest.TestCase.assertNotIn + assertIsInstance = unittest.TestCase.assertIsInstance + assertNotIsInstance = unittest.TestCase.assertNotIsInstance + assertAlmostEqual = unittest.TestCase.assertAlmostEqual + assertNotAlmostEqual = unittest.TestCase.assertNotAlmostEqual + assertGreater = unittest.TestCase.assertGreater + assertGreaterEqual = unittest.TestCase.assertGreaterEqual + assertLess = unittest.TestCase.assertLess + assertLessEqual = unittest.TestCase.assertLessEqual + assertRegex = unittest.TestCase.assertRegex + assertNotRegex = unittest.TestCase.assertNotRegex + assertCountEqual = unittest.TestCase.assertCountEqual + assertMultiLineEqual = unittest.TestCase.assertMultiLineEqual + assertSequenceEqual = unittest.TestCase.assertSequenceEqual + assertListEqual = unittest.TestCase.assertListEqual + assertTupleEqual = unittest.TestCase.assertTupleEqual + assertSetEqual = unittest.TestCase.assertSetEqual + assertDictEqual = unittest.TestCase.assertDictEqual + assertRaises = unittest.TestCase.assertRaises + assertRaisesRegex = unittest.TestCase.assertRaisesRegex + assertWarns = unittest.TestCase.assertWarns + assertWarnsRegex = unittest.TestCase.assertWarnsRegex + assertLogs = unittest.TestCase.assertLogs + fail = unittest.TestCase.fail + failureException = unittest.TestCase.failureException + + def compile_fn(self, fn, backend, nopython): + # We want to compile only the test function, excluding any setup code + # from unittest + method = getattr(self, self._testMethodName) + method = torch._dynamo.optimize(backend, nopython=nopython)(method) + setattr(self, self._testMethodName, method) + return fn + + def _dynamo_test_key(self): + suffix = super()._dynamo_test_key() + test_cls = self.__class__ + test_file = inspect.getfile(test_cls).split(os.sep)[-1].split(".")[0] + py_ver = re.search(r"/([\d_]+)/", inspect.getfile(test_cls)) + if py_ver: + py_ver = py_ver.group().strip(os.sep).replace("_", "") # type: ignore[assignment] + else: + return suffix + return f"CPython{py_ver}-{test_file}-{suffix}" @classmethod def tearDownClass(cls) -> None: @@ -107,6 +177,24 @@ def tearDownClass(cls) -> None: @classmethod def setUpClass(cls) -> None: + # Skip test if python versions doesn't match + normalized_path = pathlib.PurePath("dynamo/cpython").as_posix() + regex = re.escape(normalized_path) + r"\b\d+_\d{2}\b" + m = re.search(regex, inspect.getfile(cls)) + if m: + test_py_ver = tuple(map(int, m.group().split("_"))) + py_ver = sys.version_info[:2] + if py_ver != test_py_ver: + expected = ".".join(map(str, test_py_ver)) + got = ".".join(map(str, py_ver)) + raise unittest.SkipTest( + f"Test requires Python {expected} but got Python {got}" + ) + else: + raise unittest.SkipTest( + f"Test requires a specific Python version but not found in path {inspect.getfile(cls)}" + ) + super().setUpClass() cls._stack = contextlib.ExitStack() # type: ignore[attr-defined] cls._stack.enter_context( # type: ignore[attr-defined] diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 5bb4cede2165ce..cb76d3c61a93b8 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1989,11 +1989,11 @@ def call_getattr( ) ): unimplemented_v2( - gb_type="Failed to trace builtin operator", + gb_type="Failed to trace unittest method", context=f"function: unittest.TestCase.{name}", - explanation=f"Dynamo does not know how to trace builtin operator `{name}` ", + explanation=f"Dynamo does not know how to trace unittest method `{name}` ", hints=[ - f"Avoid calling builtin `{name}`. " + f"Avoid calling `TestCase.{name}`. " "Please report an issue to PyTorch.", ], ) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 7dbf1dbd084092..e9f5154f09962b 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -3157,6 +3157,13 @@ def wrapper(self, *args, **kwargs): def wrap_with_cuda_memory_check(self, method): return self.wrap_method_with_policy(method, self.assertLeaksNoCudaTensors) + def _dynamo_test_key(self): + return f"{self.__class__.__name__}.{self._testMethodName}" + + def compile_fn(self, fn, backend, nopython): + # Allows subclasses to control compilation + return torch._dynamo.optimize(backend, nopython=nopython)(fn) + def _run_custom(self, result=None): using_unittest = isinstance(result, unittest.TestResult) @@ -3232,16 +3239,16 @@ def _run_custom(self, result=None): with unittest.mock.patch("torch._dynamo.config.suppress_errors", suppress_errors), maybe_disable_size_asserts: if TEST_WITH_AOT_EAGER: - super_run = torch._dynamo.optimize("aot_eager_decomp_partition")(super_run) + super_run = self.compile_fn(super_run, "aot_eager_decomp_partition", nopython) elif TEST_WITH_TORCHDYNAMO or TEST_WITH_TORCHINDUCTOR: if TEST_WITH_TORCHINDUCTOR: - super_run = torch._dynamo.optimize("inductor")(super_run) + super_run = self.compile_fn(super_run, "inductor", nopython) else: # Assume eager-generated GraphModules will not error out. # If we do, this is probably a Dynamo bug! - super_run = torch._dynamo.optimize("eager_noexcept", nopython=nopython)(super_run) + super_run = self.compile_fn(super_run, "eager_noexcept", nopython) - key = f"{self.__class__.__name__}.{self._testMethodName}" + key = self._dynamo_test_key() def expect_failure(f, file_name): @wraps(f)