diff --git a/src/test_typing.py b/src/test_typing.py index 87d707c1..99eb27a6 100644 --- a/src/test_typing.py +++ b/src/test_typing.py @@ -13,7 +13,7 @@ from typing import Tuple, List, MutableMapping from typing import Callable from typing import Generic, ClassVar, GenericMeta -from typing import cast +from typing import cast, castto from typing import get_type_hints from typing import no_type_check, no_type_check_decorator from typing import Type @@ -1321,6 +1321,61 @@ def test_errors(self): cast('hello', 42) +class CasttoTests(BaseTestCase): + + def test_basics(self): + + @castto(int) + def send_to_float(x: int): + return float(x) + + @castto(float) + def send_to_int(x: float): + return int(x) + + @castto(Any) + def send_to_any(x: T) -> T: + return x + + @castto(list) + def send_to_list(x: int) -> int: + return x + + @castto(Union[str, float]) + def send_to_Union_str_float(x: int) -> int: + return x + + @castto(AnyStr) + def send_to_AnyStr(x: int) -> int: + return x + + @castto(None) + def send_to_None(x: int) -> int: + return x + + self.assertEqual(send_to_float(42), 42) + self.assertEqual(send_to_int(42.0), 42) + self.assertIs(type(send_to_int(42.0)), int) + self.assertEqual(send_to_any(42), 42) + self.assertEqual(send_to_list(42), 42) + self.assertEqual(send_to_Union_str_float(42), 42) + self.assertEqual(send_to_AnyStr(42), 42) + self.assertEqual(send_to_None(42), 42) + + def test_errors(self): + + @castto(42) + def bogus_one(x: int) -> int: + return x + + @castto('hello') + def bogus_two(x: int) -> int: + return x + + bogus_one(42) + bogus_two(42) + + class ForwardRefTests(BaseTestCase): def test_basics(self): diff --git a/src/typing.py b/src/typing.py index c00a3a10..f8ac1471 100644 --- a/src/typing.py +++ b/src/typing.py @@ -86,6 +86,7 @@ # One-off things. 'AnyStr', 'cast', + 'castto', 'get_type_hints', 'NewType', 'no_type_check', @@ -1436,6 +1437,17 @@ def cast(typ, val): return val +def castto(typ): + """ Type cast decorator. + + Same as cast, except written as a decorator. For use in decorators + with functools.wraps. + """ + def _cast(val): + return cast(typ, val) + return _cast + + def _get_defaults(func): """Internal helper to extract the default arguments, by name.""" try: