diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py index 62ae5622a8..8094962ccf 100644 --- a/Lib/test/test_dataclasses.py +++ b/Lib/test/test_dataclasses.py @@ -1906,6 +1906,8 @@ def new_method(self): c = Alias(10, 1.0) self.assertEqual(c.new_method(), 1.0) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_generic_dynamic(self): T = TypeVar('T') @@ -3250,6 +3252,8 @@ def test_classvar_module_level_import(self): # won't exist on the instance. self.assertNotIn('not_iv4', c.__dict__) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_text_annotations(self): from test import dataclass_textanno diff --git a/Lib/test/test_fractions.py b/Lib/test/test_fractions.py index 41f93390b5..628a5073a9 100644 --- a/Lib/test/test_fractions.py +++ b/Lib/test/test_fractions.py @@ -617,6 +617,8 @@ def testConversions(self): self.assertTypedEquals(0.1+0j, complex(F(1,10))) + # TODO: RUSTPYTHON + @unittest.expectedFailure def testSupportsInt(self): # See bpo-44547. f = F(3, 2) @@ -624,6 +626,8 @@ def testSupportsInt(self): self.assertEqual(int(f), 1) self.assertEqual(type(int(f)), int) + # TODO: RUSTPYTHON + @unittest.expectedFailure def testIntGuaranteesIntReturn(self): # Check that int(some_fraction) gives a result of exact type `int` # even if the fraction is using some other Integral type for its diff --git a/Lib/test/test_genericalias.py b/Lib/test/test_genericalias.py index 560b96f7e3..4d630ed166 100644 --- a/Lib/test/test_genericalias.py +++ b/Lib/test/test_genericalias.py @@ -173,6 +173,8 @@ def test_exposed_type(self): self.assertEqual(a.__args__, (int,)) self.assertEqual(a.__parameters__, ()) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_parameters(self): from typing import List, Dict, Callable D0 = dict[str, int] @@ -212,6 +214,8 @@ def test_parameters(self): self.assertEqual(L5.__args__, (Callable[[K, V], K],)) self.assertEqual(L5.__parameters__, (K, V)) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_parameter_chaining(self): from typing import List, Dict, Union, Callable self.assertEqual(list[T][int], list[int]) @@ -271,6 +275,8 @@ class MyType(type): with self.assertRaises(TypeError): MyType[int] + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_pickle(self): alias = GenericAlias(list, T) for proto in range(pickle.HIGHEST_PROTOCOL + 1): @@ -280,6 +286,8 @@ def test_pickle(self): self.assertEqual(loaded.__args__, alias.__args__) self.assertEqual(loaded.__parameters__, alias.__parameters__) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_copy(self): class X(list): def __copy__(self): @@ -303,6 +311,8 @@ def test_union(self): self.assertEqual(a.__args__, (list[int], list[str])) self.assertEqual(a.__parameters__, ()) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_union_generic(self): a = typing.Union[list[T], tuple[T, ...]] self.assertEqual(a.__args__, (list[T], tuple[T, ...])) diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py index 59dc9814fb..21b787db23 100644 --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -392,7 +392,6 @@ def test(i, format_spec, result): test(123456, "1=20", '11111111111111123456') test(123456, "*=20", '**************123456') - @unittest.expectedFailure @run_with_locale('LC_NUMERIC', 'en_US.UTF8') def test_float__format__locale(self): # test locale support for __format__ code 'n' @@ -742,6 +741,8 @@ def test_instancecheck_and_subclasscheck(self): self.assertTrue(issubclass(dict, x)) self.assertFalse(issubclass(list, x)) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_instancecheck_and_subclasscheck_order(self): T = typing.TypeVar('T') @@ -788,6 +789,8 @@ def __subclasscheck__(cls, sub): self.assertTrue(issubclass(int, x)) self.assertRaises(ZeroDivisionError, issubclass, list, x) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_or_type_operator_with_TypeVar(self): TV = typing.TypeVar('T') assert TV | str == typing.Union[TV, str] @@ -795,6 +798,8 @@ def test_or_type_operator_with_TypeVar(self): self.assertIs((int | TV)[int], int) self.assertIs((TV | int)[int], int) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_union_args(self): def check(arg, expected): clear_typing_caches() @@ -825,6 +830,8 @@ def check(arg, expected): check(x | None, (x, type(None))) check(None | x, (type(None), x)) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_union_parameter_chaining(self): T = typing.TypeVar("T") S = typing.TypeVar("S") @@ -869,6 +876,8 @@ def eq(actual, expected, typed=True): eq(x[NT], int | NT | bytes) eq(x[S], int | S | bytes) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_union_pickle(self): orig = list[T] | int for proto in range(pickle.HIGHEST_PROTOCOL + 1): @@ -878,6 +887,8 @@ def test_union_pickle(self): self.assertEqual(loaded.__args__, orig.__args__) self.assertEqual(loaded.__parameters__, orig.__parameters__) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_union_copy(self): orig = list[T] | int for copied in (copy.copy(orig), copy.deepcopy(orig)): @@ -885,12 +896,16 @@ def test_union_copy(self): self.assertEqual(copied.__args__, orig.__args__) self.assertEqual(copied.__parameters__, orig.__parameters__) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_union_parameter_substitution_errors(self): T = typing.TypeVar("T") x = int | T with self.assertRaises(TypeError): x[int, str] + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_or_type_operator_with_forward(self): T = typing.TypeVar('T') ForwardAfter = T | 'Forward' diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index b6a167f998..9f3963f724 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -1,41 +1,59 @@ +# TODO: RUSTPYTHON +import unittest import contextlib import collections +import collections.abc +from collections import defaultdict +from functools import lru_cache, wraps, reduce +import gc +import inspect +import itertools +import operator import pickle import re import sys -from unittest import TestCase, main, skipUnless, skip -# TODO: RUSTPYTHON -import unittest +from unittest import TestCase, main, skip +from unittest.mock import patch from copy import copy, deepcopy -from typing import Any, NoReturn -from typing import TypeVar, AnyStr +from typing import Any, NoReturn, Never, assert_never +from typing import overload, get_overloads, clear_overloads +from typing import TypeVar, TypeVarTuple, Unpack, AnyStr from typing import T, KT, VT # Not in __all__. from typing import Union, Optional, Literal from typing import Tuple, List, Dict, MutableMapping from typing import Callable from typing import Generic, ClassVar, Final, final, Protocol -from typing import cast, runtime_checkable +from typing import assert_type, cast, runtime_checkable from typing import get_type_hints -from typing import get_origin, get_args -from typing import is_typeddict +from typing import get_origin, get_args, get_protocol_members +from typing import override +from typing import is_typeddict, is_protocol +from typing import reveal_type +from typing import dataclass_transform from typing import no_type_check, no_type_check_decorator from typing import Type -from typing import NewType -from typing import NamedTuple, TypedDict +from typing import NamedTuple, NotRequired, Required, ReadOnly, TypedDict from typing import IO, TextIO, BinaryIO from typing import Pattern, Match from typing import Annotated, ForwardRef +from typing import Self, LiteralString from typing import TypeAlias from typing import ParamSpec, Concatenate, ParamSpecArgs, ParamSpecKwargs -from typing import TypeGuard +from typing import TypeGuard, TypeIs, NoDefault import abc +import textwrap import typing import weakref import types -from test import mod_generics_cache -from test import _typed_dict_helper +from test.support import captured_stderr, cpython_only, infinite_recursion, requires_docstrings, import_helper +from test.typinganndata import ann_module695, mod_generics_cache, _typed_dict_helper + + +CANNOT_SUBCLASS_TYPE = 'Cannot subclass special typing classes' +NOT_A_BASE_TYPE = "type 'typing.%s' is not an acceptable base type" +CANNOT_SUBCLASS_INSTANCE = 'Cannot subclass an instance of %s' class BaseTestCase(TestCase): @@ -59,6 +77,18 @@ def clear_caches(self): f() +def all_pickle_protocols(test_func): + """Runs `test_func` with various values for `proto` argument.""" + + @wraps(test_func) + def wrapper(self): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(pickle_proto=proto): + test_func(self, proto=proto) + + return wrapper + + class Employee: pass @@ -81,28 +111,56 @@ def test_any_instance_type_error(self): with self.assertRaises(TypeError): isinstance(42, Any) - def test_any_subclass_type_error(self): - with self.assertRaises(TypeError): - issubclass(Employee, Any) - with self.assertRaises(TypeError): - issubclass(Any, Employee) - def test_repr(self): self.assertEqual(repr(Any), 'typing.Any') + class Sub(Any): pass + self.assertEqual( + repr(Sub), + f".Sub'>", + ) + def test_errors(self): with self.assertRaises(TypeError): issubclass(42, Any) with self.assertRaises(TypeError): Any[int] # Any is not a generic type. - def test_cannot_subclass(self): - with self.assertRaises(TypeError): - class A(Any): - pass - with self.assertRaises(TypeError): - class A(type(Any)): - pass + def test_can_subclass(self): + class Mock(Any): pass + self.assertTrue(issubclass(Mock, Any)) + self.assertIsInstance(Mock(), Mock) + + class Something: pass + self.assertFalse(issubclass(Something, Any)) + self.assertNotIsInstance(Something(), Mock) + + class MockSomething(Something, Mock): pass + self.assertTrue(issubclass(MockSomething, Any)) + ms = MockSomething() + self.assertIsInstance(ms, MockSomething) + self.assertIsInstance(ms, Something) + self.assertIsInstance(ms, Mock) + + def test_subclassing_with_custom_constructor(self): + class Sub(Any): + def __init__(self, *args, **kwargs): pass + # The instantiation must not fail. + Sub(0, s="") + + def test_multiple_inheritance_with_custom_constructors(self): + class Foo: + def __init__(self, x): + self.x = x + + class Bar(Any, Foo): + def __init__(self, x, y): + self.y = y + super().__init__(x) + + b = Bar(1, 2) + self.assertEqual(b.x, 1) + self.assertEqual(b.y, 2) def test_cannot_instantiate(self): with self.assertRaises(TypeError): @@ -117,48 +175,270 @@ def test_any_works_with_alias(self): typing.IO[Any] -class NoReturnTests(BaseTestCase): +class BottomTypeTestsMixin: + bottom_type: ClassVar[Any] + + def test_equality(self): + self.assertEqual(self.bottom_type, self.bottom_type) + self.assertIs(self.bottom_type, self.bottom_type) + self.assertNotEqual(self.bottom_type, None) + + def test_get_origin(self): + self.assertIs(get_origin(self.bottom_type), None) - def test_noreturn_instance_type_error(self): + def test_instance_type_error(self): + with self.assertRaises(TypeError): + isinstance(42, self.bottom_type) + + def test_subclass_type_error(self): + with self.assertRaises(TypeError): + issubclass(Employee, self.bottom_type) + with self.assertRaises(TypeError): + issubclass(NoReturn, self.bottom_type) + + def test_not_generic(self): with self.assertRaises(TypeError): - isinstance(42, NoReturn) + self.bottom_type[int] - def test_noreturn_subclass_type_error(self): + def test_cannot_subclass(self): + with self.assertRaisesRegex(TypeError, + 'Cannot subclass ' + re.escape(str(self.bottom_type))): + class A(self.bottom_type): + pass + with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): + class B(type(self.bottom_type)): + pass + + def test_cannot_instantiate(self): with self.assertRaises(TypeError): - issubclass(Employee, NoReturn) + self.bottom_type() with self.assertRaises(TypeError): - issubclass(NoReturn, Employee) + type(self.bottom_type)() + + +class NoReturnTests(BottomTypeTestsMixin, BaseTestCase): + bottom_type = NoReturn def test_repr(self): self.assertEqual(repr(NoReturn), 'typing.NoReturn') - def test_not_generic(self): + def test_get_type_hints(self): + def some(arg: NoReturn) -> NoReturn: ... + def some_str(arg: 'NoReturn') -> 'typing.NoReturn': ... + + expected = {'arg': NoReturn, 'return': NoReturn} + for target in [some, some_str]: + with self.subTest(target=target): + self.assertEqual(gth(target), expected) + + def test_not_equality(self): + self.assertNotEqual(NoReturn, Never) + self.assertNotEqual(Never, NoReturn) + + +class NeverTests(BottomTypeTestsMixin, BaseTestCase): + bottom_type = Never + + def test_repr(self): + self.assertEqual(repr(Never), 'typing.Never') + + def test_get_type_hints(self): + def some(arg: Never) -> Never: ... + def some_str(arg: 'Never') -> 'typing.Never': ... + + expected = {'arg': Never, 'return': Never} + for target in [some, some_str]: + with self.subTest(target=target): + self.assertEqual(gth(target), expected) + + +class AssertNeverTests(BaseTestCase): + def test_exception(self): + with self.assertRaises(AssertionError): + assert_never(None) + + value = "some value" + with self.assertRaisesRegex(AssertionError, value): + assert_never(value) + + # Make sure a huge value doesn't get printed in its entirety + huge_value = "a" * 10000 + with self.assertRaises(AssertionError) as cm: + assert_never(huge_value) + self.assertLess( + len(cm.exception.args[0]), + typing._ASSERT_NEVER_REPR_MAX_LENGTH * 2, + ) + + +class SelfTests(BaseTestCase): + def test_equality(self): + self.assertEqual(Self, Self) + self.assertIs(Self, Self) + self.assertNotEqual(Self, None) + + def test_basics(self): + class Foo: + def bar(self) -> Self: ... + class FooStr: + def bar(self) -> 'Self': ... + class FooStrTyping: + def bar(self) -> 'typing.Self': ... + + for target in [Foo, FooStr, FooStrTyping]: + with self.subTest(target=target): + self.assertEqual(gth(target.bar), {'return': Self}) + self.assertIs(get_origin(Self), None) + + def test_repr(self): + self.assertEqual(repr(Self), 'typing.Self') + + def test_cannot_subscript(self): with self.assertRaises(TypeError): - NoReturn[int] + Self[int] def test_cannot_subclass(self): - with self.assertRaises(TypeError): - class A(NoReturn): + with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): + class C(type(Self)): pass + with self.assertRaisesRegex(TypeError, + r'Cannot subclass typing\.Self'): + class D(Self): + pass + + def test_cannot_init(self): + with self.assertRaises(TypeError): + Self() + with self.assertRaises(TypeError): + type(Self)() + + def test_no_isinstance(self): + with self.assertRaises(TypeError): + isinstance(1, Self) + with self.assertRaises(TypeError): + issubclass(int, Self) + + def test_alias(self): + # TypeAliases are not actually part of the spec + alias_1 = Tuple[Self, Self] + alias_2 = List[Self] + alias_3 = ClassVar[Self] + self.assertEqual(get_args(alias_1), (Self, Self)) + self.assertEqual(get_args(alias_2), (Self,)) + self.assertEqual(get_args(alias_3), (Self,)) + + +class LiteralStringTests(BaseTestCase): + def test_equality(self): + self.assertEqual(LiteralString, LiteralString) + self.assertIs(LiteralString, LiteralString) + self.assertNotEqual(LiteralString, None) + + def test_basics(self): + class Foo: + def bar(self) -> LiteralString: ... + class FooStr: + def bar(self) -> 'LiteralString': ... + class FooStrTyping: + def bar(self) -> 'typing.LiteralString': ... + + for target in [Foo, FooStr, FooStrTyping]: + with self.subTest(target=target): + self.assertEqual(gth(target.bar), {'return': LiteralString}) + self.assertIs(get_origin(LiteralString), None) + + def test_repr(self): + self.assertEqual(repr(LiteralString), 'typing.LiteralString') + + def test_cannot_subscript(self): with self.assertRaises(TypeError): - class A(type(NoReturn)): + LiteralString[int] + + def test_cannot_subclass(self): + with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): + class C(type(LiteralString)): + pass + with self.assertRaisesRegex(TypeError, + r'Cannot subclass typing\.LiteralString'): + class D(LiteralString): pass - def test_cannot_instantiate(self): + def test_cannot_init(self): with self.assertRaises(TypeError): - NoReturn() + LiteralString() with self.assertRaises(TypeError): - type(NoReturn)() + type(LiteralString)() + def test_no_isinstance(self): + with self.assertRaises(TypeError): + isinstance(1, LiteralString) + with self.assertRaises(TypeError): + issubclass(int, LiteralString) -class TypeVarTests(BaseTestCase): + def test_alias(self): + alias_1 = Tuple[LiteralString, LiteralString] + alias_2 = List[LiteralString] + alias_3 = ClassVar[LiteralString] + self.assertEqual(get_args(alias_1), (LiteralString, LiteralString)) + self.assertEqual(get_args(alias_2), (LiteralString,)) + self.assertEqual(get_args(alias_3), (LiteralString,)) +class TypeVarTests(BaseTestCase): def test_basic_plain(self): T = TypeVar('T') # T equals itself. self.assertEqual(T, T) # T is an instance of TypeVar self.assertIsInstance(T, TypeVar) + self.assertEqual(T.__name__, 'T') + self.assertEqual(T.__constraints__, ()) + self.assertIs(T.__bound__, None) + self.assertIs(T.__covariant__, False) + self.assertIs(T.__contravariant__, False) + self.assertIs(T.__infer_variance__, False) + self.assertEqual(T.__module__, __name__) + + def test_basic_with_exec(self): + ns = {} + exec('from typing import TypeVar; T = TypeVar("T", bound=float)', ns, ns) + T = ns['T'] + self.assertIsInstance(T, TypeVar) + self.assertEqual(T.__name__, 'T') + self.assertEqual(T.__constraints__, ()) + self.assertIs(T.__bound__, float) + self.assertIs(T.__covariant__, False) + self.assertIs(T.__contravariant__, False) + self.assertIs(T.__infer_variance__, False) + self.assertIs(T.__module__, None) + + def test_attributes(self): + T_bound = TypeVar('T_bound', bound=int) + self.assertEqual(T_bound.__name__, 'T_bound') + self.assertEqual(T_bound.__constraints__, ()) + self.assertIs(T_bound.__bound__, int) + + T_constraints = TypeVar('T_constraints', int, str) + self.assertEqual(T_constraints.__name__, 'T_constraints') + self.assertEqual(T_constraints.__constraints__, (int, str)) + self.assertIs(T_constraints.__bound__, None) + + T_co = TypeVar('T_co', covariant=True) + self.assertEqual(T_co.__name__, 'T_co') + self.assertIs(T_co.__covariant__, True) + self.assertIs(T_co.__contravariant__, False) + self.assertIs(T_co.__infer_variance__, False) + + T_contra = TypeVar('T_contra', contravariant=True) + self.assertEqual(T_contra.__name__, 'T_contra') + self.assertIs(T_contra.__covariant__, False) + self.assertIs(T_contra.__contravariant__, True) + self.assertIs(T_contra.__infer_variance__, False) + + T_infer = TypeVar('T_infer', infer_variance=True) + self.assertEqual(T_infer.__name__, 'T_infer') + self.assertIs(T_infer.__covariant__, False) + self.assertIs(T_infer.__contravariant__, False) + self.assertIs(T_infer.__infer_variance__, True) def test_typevar_instance_type_error(self): T = TypeVar('T') @@ -218,15 +498,13 @@ def test_no_redefinition(self): self.assertNotEqual(TypeVar('T'), TypeVar('T')) self.assertNotEqual(TypeVar('T', int, str), TypeVar('T', int, str)) - def test_cannot_subclass_vars(self): - with self.assertRaises(TypeError): - class V(TypeVar('T')): - pass - - def test_cannot_subclass_var_itself(self): - with self.assertRaises(TypeError): - class V(TypeVar): - pass + def test_cannot_subclass(self): + with self.assertRaisesRegex(TypeError, NOT_A_BASE_TYPE % 'TypeVar'): + class V(TypeVar): pass + T = TypeVar("T") + with self.assertRaisesRegex(TypeError, + CANNOT_SUBCLASS_INSTANCE % 'TypeVar'): + class W(T): pass def test_cannot_instantiate_vars(self): with self.assertRaises(TypeError): @@ -234,9 +512,12 @@ def test_cannot_instantiate_vars(self): def test_bound_errors(self): with self.assertRaises(TypeError): - TypeVar('X', bound=42) + TypeVar('X', bound=Union) with self.assertRaises(TypeError): TypeVar('X', str, float, bound=Employee) + with self.assertRaisesRegex(TypeError, + r"Bound must be a type\. Got \(1, 2\)\."): + TypeVar('X', bound=(1, 2)) def test_missing__name__(self): # See bpo-39942 @@ -249,15 +530,1479 @@ def test_no_bivariant(self): with self.assertRaises(ValueError): TypeVar('T', covariant=True, contravariant=True) + def test_cannot_combine_explicit_and_infer(self): + with self.assertRaises(ValueError): + TypeVar('T', covariant=True, infer_variance=True) + with self.assertRaises(ValueError): + TypeVar('T', contravariant=True, infer_variance=True) + + def test_var_substitution(self): + T = TypeVar('T') + subst = T.__typing_subst__ + self.assertIs(subst(int), int) + self.assertEqual(subst(list[int]), list[int]) + self.assertEqual(subst(List[int]), List[int]) + self.assertEqual(subst(List), List) + self.assertIs(subst(Any), Any) + self.assertIs(subst(None), type(None)) + self.assertIs(subst(T), T) + self.assertEqual(subst(int|str), int|str) + self.assertEqual(subst(Union[int, str]), Union[int, str]) + def test_bad_var_substitution(self): T = TypeVar('T') - for arg in (), (int, str): + bad_args = ( + (), (int, str), Union, + Generic, Generic[T], Protocol, Protocol[T], + Final, Final[int], ClassVar, ClassVar[int], + ) + for arg in bad_args: with self.subTest(arg=arg): + with self.assertRaises(TypeError): + T.__typing_subst__(arg) with self.assertRaises(TypeError): List[T][arg] with self.assertRaises(TypeError): list[T][arg] + def test_many_weakrefs(self): + # gh-108295: this used to segfault + for cls in (ParamSpec, TypeVarTuple, TypeVar): + with self.subTest(cls=cls): + vals = weakref.WeakValueDictionary() + + for x in range(10): + vals[x] = cls(str(x)) + del vals + + def test_constructor(self): + T = TypeVar(name="T") + self.assertEqual(T.__name__, "T") + self.assertEqual(T.__constraints__, ()) + self.assertIs(T.__bound__, None) + self.assertIs(T.__default__, typing.NoDefault) + self.assertIs(T.__covariant__, False) + self.assertIs(T.__contravariant__, False) + self.assertIs(T.__infer_variance__, False) + + T = TypeVar(name="T", bound=type) + self.assertEqual(T.__name__, "T") + self.assertEqual(T.__constraints__, ()) + self.assertIs(T.__bound__, type) + self.assertIs(T.__default__, typing.NoDefault) + self.assertIs(T.__covariant__, False) + self.assertIs(T.__contravariant__, False) + self.assertIs(T.__infer_variance__, False) + + T = TypeVar(name="T", default=()) + self.assertEqual(T.__name__, "T") + self.assertEqual(T.__constraints__, ()) + self.assertIs(T.__bound__, None) + self.assertIs(T.__default__, ()) + self.assertIs(T.__covariant__, False) + self.assertIs(T.__contravariant__, False) + self.assertIs(T.__infer_variance__, False) + + T = TypeVar(name="T", covariant=True) + self.assertEqual(T.__name__, "T") + self.assertEqual(T.__constraints__, ()) + self.assertIs(T.__bound__, None) + self.assertIs(T.__default__, typing.NoDefault) + self.assertIs(T.__covariant__, True) + self.assertIs(T.__contravariant__, False) + self.assertIs(T.__infer_variance__, False) + + T = TypeVar(name="T", contravariant=True) + self.assertEqual(T.__name__, "T") + self.assertEqual(T.__constraints__, ()) + self.assertIs(T.__bound__, None) + self.assertIs(T.__default__, typing.NoDefault) + self.assertIs(T.__covariant__, False) + self.assertIs(T.__contravariant__, True) + self.assertIs(T.__infer_variance__, False) + + T = TypeVar(name="T", infer_variance=True) + self.assertEqual(T.__name__, "T") + self.assertEqual(T.__constraints__, ()) + self.assertIs(T.__bound__, None) + self.assertIs(T.__default__, typing.NoDefault) + self.assertIs(T.__covariant__, False) + self.assertIs(T.__contravariant__, False) + self.assertIs(T.__infer_variance__, True) + + +class TypeParameterDefaultsTests(BaseTestCase): + def test_typevar(self): + T = TypeVar('T', default=int) + self.assertEqual(T.__default__, int) + self.assertTrue(T.has_default()) + self.assertIsInstance(T, TypeVar) + + class A(Generic[T]): ... + Alias = Optional[T] + + def test_typevar_none(self): + U = TypeVar('U') + U_None = TypeVar('U_None', default=None) + self.assertIs(U.__default__, NoDefault) + self.assertFalse(U.has_default()) + self.assertIs(U_None.__default__, None) + self.assertTrue(U_None.has_default()) + + class X[T]: ... + T, = X.__type_params__ + self.assertIs(T.__default__, NoDefault) + self.assertFalse(T.has_default()) + + def test_paramspec(self): + P = ParamSpec('P', default=(str, int)) + self.assertEqual(P.__default__, (str, int)) + self.assertTrue(P.has_default()) + self.assertIsInstance(P, ParamSpec) + + class A(Generic[P]): ... + Alias = typing.Callable[P, None] + + P_default = ParamSpec('P_default', default=...) + self.assertIs(P_default.__default__, ...) + + def test_paramspec_none(self): + U = ParamSpec('U') + U_None = ParamSpec('U_None', default=None) + self.assertIs(U.__default__, NoDefault) + self.assertFalse(U.has_default()) + self.assertIs(U_None.__default__, None) + self.assertTrue(U_None.has_default()) + + class X[**P]: ... + P, = X.__type_params__ + self.assertIs(P.__default__, NoDefault) + self.assertFalse(P.has_default()) + + def test_typevartuple(self): + Ts = TypeVarTuple('Ts', default=Unpack[Tuple[str, int]]) + self.assertEqual(Ts.__default__, Unpack[Tuple[str, int]]) + self.assertTrue(Ts.has_default()) + self.assertIsInstance(Ts, TypeVarTuple) + + class A(Generic[Unpack[Ts]]): ... + Alias = Optional[Unpack[Ts]] + + def test_typevartuple_specialization(self): + T = TypeVar("T") + Ts = TypeVarTuple('Ts', default=Unpack[Tuple[str, int]]) + self.assertEqual(Ts.__default__, Unpack[Tuple[str, int]]) + class A(Generic[T, Unpack[Ts]]): ... + self.assertEqual(A[float].__args__, (float, str, int)) + self.assertEqual(A[float, range].__args__, (float, range)) + self.assertEqual(A[float, *tuple[int, ...]].__args__, (float, *tuple[int, ...])) + + def test_typevar_and_typevartuple_specialization(self): + T = TypeVar("T") + U = TypeVar("U", default=float) + Ts = TypeVarTuple('Ts', default=Unpack[Tuple[str, int]]) + self.assertEqual(Ts.__default__, Unpack[Tuple[str, int]]) + class A(Generic[T, U, Unpack[Ts]]): ... + self.assertEqual(A[int].__args__, (int, float, str, int)) + self.assertEqual(A[int, str].__args__, (int, str, str, int)) + self.assertEqual(A[int, str, range].__args__, (int, str, range)) + self.assertEqual(A[int, str, *tuple[int, ...]].__args__, (int, str, *tuple[int, ...])) + + def test_no_default_after_typevar_tuple(self): + T = TypeVar("T", default=int) + Ts = TypeVarTuple("Ts") + Ts_default = TypeVarTuple("Ts_default", default=Unpack[Tuple[str, int]]) + + with self.assertRaises(TypeError): + class X(Generic[*Ts, T]): ... + + with self.assertRaises(TypeError): + class Y(Generic[*Ts_default, T]): ... + + def test_allow_default_after_non_default_in_alias(self): + T_default = TypeVar('T_default', default=int) + T = TypeVar('T') + Ts = TypeVarTuple('Ts') + + a1 = Callable[[T_default], T] + self.assertEqual(a1.__args__, (T_default, T)) + + a2 = dict[T_default, T] + self.assertEqual(a2.__args__, (T_default, T)) + + a3 = typing.Dict[T_default, T] + self.assertEqual(a3.__args__, (T_default, T)) + + a4 = Callable[*Ts, T] + self.assertEqual(a4.__args__, (*Ts, T)) + + def test_paramspec_specialization(self): + T = TypeVar("T") + P = ParamSpec('P', default=[str, int]) + self.assertEqual(P.__default__, [str, int]) + class A(Generic[T, P]): ... + self.assertEqual(A[float].__args__, (float, (str, int))) + self.assertEqual(A[float, [range]].__args__, (float, (range,))) + + def test_typevar_and_paramspec_specialization(self): + T = TypeVar("T") + U = TypeVar("U", default=float) + P = ParamSpec('P', default=[str, int]) + self.assertEqual(P.__default__, [str, int]) + class A(Generic[T, U, P]): ... + self.assertEqual(A[float].__args__, (float, float, (str, int))) + self.assertEqual(A[float, int].__args__, (float, int, (str, int))) + self.assertEqual(A[float, int, [range]].__args__, (float, int, (range,))) + + def test_paramspec_and_typevar_specialization(self): + T = TypeVar("T") + P = ParamSpec('P', default=[str, int]) + U = TypeVar("U", default=float) + self.assertEqual(P.__default__, [str, int]) + class A(Generic[T, P, U]): ... + self.assertEqual(A[float].__args__, (float, (str, int), float)) + self.assertEqual(A[float, [range]].__args__, (float, (range,), float)) + self.assertEqual(A[float, [range], int].__args__, (float, (range,), int)) + + def test_typevartuple_none(self): + U = TypeVarTuple('U') + U_None = TypeVarTuple('U_None', default=None) + self.assertIs(U.__default__, NoDefault) + self.assertFalse(U.has_default()) + self.assertIs(U_None.__default__, None) + self.assertTrue(U_None.has_default()) + + class X[**Ts]: ... + Ts, = X.__type_params__ + self.assertIs(Ts.__default__, NoDefault) + self.assertFalse(Ts.has_default()) + + def test_no_default_after_non_default(self): + DefaultStrT = TypeVar('DefaultStrT', default=str) + T = TypeVar('T') + + with self.assertRaisesRegex( + TypeError, r"Type parameter ~T without a default follows type parameter with a default" + ): + Test = Generic[DefaultStrT, T] + + def test_need_more_params(self): + DefaultStrT = TypeVar('DefaultStrT', default=str) + T = TypeVar('T') + U = TypeVar('U') + + class A(Generic[T, U, DefaultStrT]): ... + A[int, bool] + A[int, bool, str] + + with self.assertRaisesRegex( + TypeError, r"Too few arguments for .+; actual 1, expected at least 2" + ): + Test = A[int] + + def test_pickle(self): + global U, U_co, U_contra, U_default # pickle wants to reference the class by name + U = TypeVar('U') + U_co = TypeVar('U_co', covariant=True) + U_contra = TypeVar('U_contra', contravariant=True) + U_default = TypeVar('U_default', default=int) + for proto in range(pickle.HIGHEST_PROTOCOL): + for typevar in (U, U_co, U_contra, U_default): + z = pickle.loads(pickle.dumps(typevar, proto)) + self.assertEqual(z.__name__, typevar.__name__) + self.assertEqual(z.__covariant__, typevar.__covariant__) + self.assertEqual(z.__contravariant__, typevar.__contravariant__) + self.assertEqual(z.__bound__, typevar.__bound__) + self.assertEqual(z.__default__, typevar.__default__) + + +def template_replace(templates: list[str], replacements: dict[str, list[str]]) -> list[tuple[str]]: + """Renders templates with possible combinations of replacements. + + Example 1: Suppose that: + templates = ["dog_breed are awesome", "dog_breed are cool"] + replacements = ["dog_breed": ["Huskies", "Beagles"]] + Then we would return: + [ + ("Huskies are awesome", "Huskies are cool"), + ("Beagles are awesome", "Beagles are cool") + ] + + Example 2: Suppose that: + templates = ["Huskies are word1 but also word2"] + replacements = {"word1": ["playful", "cute"], + "word2": ["feisty", "tiring"]} + Then we would return: + [ + ("Huskies are playful but also feisty"), + ("Huskies are playful but also tiring"), + ("Huskies are cute but also feisty"), + ("Huskies are cute but also tiring") + ] + + Note that if any of the replacements do not occur in any template: + templates = ["Huskies are word1", "Beagles!"] + replacements = {"word1": ["playful", "cute"], + "word2": ["feisty", "tiring"]} + Then we do not generate duplicates, returning: + [ + ("Huskies are playful", "Beagles!"), + ("Huskies are cute", "Beagles!") + ] + """ + # First, build a structure like: + # [ + # [("word1", "playful"), ("word1", "cute")], + # [("word2", "feisty"), ("word2", "tiring")] + # ] + replacement_combos = [] + for original, possible_replacements in replacements.items(): + original_replacement_tuples = [] + for replacement in possible_replacements: + original_replacement_tuples.append((original, replacement)) + replacement_combos.append(original_replacement_tuples) + + # Second, generate rendered templates, including possible duplicates. + rendered_templates = [] + for replacement_combo in itertools.product(*replacement_combos): + # replacement_combo would be e.g. + # [("word1", "playful"), ("word2", "feisty")] + templates_with_replacements = [] + for template in templates: + for original, replacement in replacement_combo: + template = template.replace(original, replacement) + templates_with_replacements.append(template) + rendered_templates.append(tuple(templates_with_replacements)) + + # Finally, remove the duplicates (but keep the order). + rendered_templates_no_duplicates = [] + for x in rendered_templates: + # Inefficient, but should be fine for our purposes. + if x not in rendered_templates_no_duplicates: + rendered_templates_no_duplicates.append(x) + + return rendered_templates_no_duplicates + + +class TemplateReplacementTests(BaseTestCase): + + def test_two_templates_two_replacements_yields_correct_renders(self): + actual = template_replace( + templates=["Cats are word1", "Dogs are word2"], + replacements={ + "word1": ["small", "cute"], + "word2": ["big", "fluffy"], + }, + ) + expected = [ + ("Cats are small", "Dogs are big"), + ("Cats are small", "Dogs are fluffy"), + ("Cats are cute", "Dogs are big"), + ("Cats are cute", "Dogs are fluffy"), + ] + self.assertEqual(actual, expected) + + def test_no_duplicates_if_replacement_not_in_templates(self): + actual = template_replace( + templates=["Cats are word1", "Dogs!"], + replacements={ + "word1": ["small", "cute"], + "word2": ["big", "fluffy"], + }, + ) + expected = [ + ("Cats are small", "Dogs!"), + ("Cats are cute", "Dogs!"), + ] + self.assertEqual(actual, expected) + + +class GenericAliasSubstitutionTests(BaseTestCase): + """Tests for type variable substitution in generic aliases. + + For variadic cases, these tests should be regarded as the source of truth, + since we hadn't realised the full complexity of variadic substitution + at the time of finalizing PEP 646. For full discussion, see + https://github.com/python/cpython/issues/91162. + """ + + def test_one_parameter(self): + T = TypeVar('T') + Ts = TypeVarTuple('Ts') + Ts2 = TypeVarTuple('Ts2') + + class C(Generic[T]): pass + + generics = ['C', 'list', 'List'] + tuple_types = ['tuple', 'Tuple'] + + tests = [ + # Alias # Args # Expected result + ('generic[T]', '[()]', 'TypeError'), + ('generic[T]', '[int]', 'generic[int]'), + ('generic[T]', '[int, str]', 'TypeError'), + ('generic[T]', '[tuple_type[int, ...]]', 'generic[tuple_type[int, ...]]'), + ('generic[T]', '[*tuple_type[int]]', 'generic[int]'), + ('generic[T]', '[*tuple_type[()]]', 'TypeError'), + ('generic[T]', '[*tuple_type[int, str]]', 'TypeError'), + ('generic[T]', '[*tuple_type[int, ...]]', 'TypeError'), + ('generic[T]', '[*Ts]', 'TypeError'), + ('generic[T]', '[T, *Ts]', 'TypeError'), + ('generic[T]', '[*Ts, T]', 'TypeError'), + # Raises TypeError because C is not variadic. + # (If C _were_ variadic, it'd be fine.) + ('C[T, *tuple_type[int, ...]]', '[int]', 'TypeError'), + # Should definitely raise TypeError: list only takes one argument. + ('list[T, *tuple_type[int, ...]]', '[int]', 'list[int, *tuple_type[int, ...]]'), + ('List[T, *tuple_type[int, ...]]', '[int]', 'TypeError'), + # Should raise, because more than one `TypeVarTuple` is not supported. + ('generic[*Ts, *Ts2]', '[int]', 'TypeError'), + ] + + for alias_template, args_template, expected_template in tests: + rendered_templates = template_replace( + templates=[alias_template, args_template, expected_template], + replacements={'generic': generics, 'tuple_type': tuple_types} + ) + for alias_str, args_str, expected_str in rendered_templates: + with self.subTest(alias=alias_str, args=args_str, expected=expected_str): + if expected_str == 'TypeError': + with self.assertRaises(TypeError): + eval(alias_str + args_str) + else: + self.assertEqual( + eval(alias_str + args_str), + eval(expected_str) + ) + + + def test_two_parameters(self): + T1 = TypeVar('T1') + T2 = TypeVar('T2') + Ts = TypeVarTuple('Ts') + + class C(Generic[T1, T2]): pass + + generics = ['C', 'dict', 'Dict'] + tuple_types = ['tuple', 'Tuple'] + + tests = [ + # Alias # Args # Expected result + ('generic[T1, T2]', '[()]', 'TypeError'), + ('generic[T1, T2]', '[int]', 'TypeError'), + ('generic[T1, T2]', '[int, str]', 'generic[int, str]'), + ('generic[T1, T2]', '[int, str, bool]', 'TypeError'), + ('generic[T1, T2]', '[*tuple_type[int]]', 'TypeError'), + ('generic[T1, T2]', '[*tuple_type[int, str]]', 'generic[int, str]'), + ('generic[T1, T2]', '[*tuple_type[int, str, bool]]', 'TypeError'), + + ('generic[T1, T2]', '[int, *tuple_type[str]]', 'generic[int, str]'), + ('generic[T1, T2]', '[*tuple_type[int], str]', 'generic[int, str]'), + ('generic[T1, T2]', '[*tuple_type[int], *tuple_type[str]]', 'generic[int, str]'), + ('generic[T1, T2]', '[*tuple_type[int, str], *tuple_type[()]]', 'generic[int, str]'), + ('generic[T1, T2]', '[*tuple_type[()], *tuple_type[int, str]]', 'generic[int, str]'), + ('generic[T1, T2]', '[*tuple_type[int], *tuple_type[()]]', 'TypeError'), + ('generic[T1, T2]', '[*tuple_type[()], *tuple_type[int]]', 'TypeError'), + ('generic[T1, T2]', '[*tuple_type[int, str], *tuple_type[float]]', 'TypeError'), + ('generic[T1, T2]', '[*tuple_type[int], *tuple_type[str, float]]', 'TypeError'), + ('generic[T1, T2]', '[*tuple_type[int, str], *tuple_type[float, bool]]', 'TypeError'), + + ('generic[T1, T2]', '[tuple_type[int, ...]]', 'TypeError'), + ('generic[T1, T2]', '[tuple_type[int, ...], tuple_type[str, ...]]', 'generic[tuple_type[int, ...], tuple_type[str, ...]]'), + ('generic[T1, T2]', '[*tuple_type[int, ...]]', 'TypeError'), + ('generic[T1, T2]', '[int, *tuple_type[str, ...]]', 'TypeError'), + ('generic[T1, T2]', '[*tuple_type[int, ...], str]', 'TypeError'), + ('generic[T1, T2]', '[*tuple_type[int, ...], *tuple_type[str, ...]]', 'TypeError'), + ('generic[T1, T2]', '[*Ts]', 'TypeError'), + ('generic[T1, T2]', '[T, *Ts]', 'TypeError'), + ('generic[T1, T2]', '[*Ts, T]', 'TypeError'), + # This one isn't technically valid - none of the things that + # `generic` can be (defined in `generics` above) are variadic, so we + # shouldn't really be able to do `generic[T1, *tuple_type[int, ...]]`. + # So even if type checkers shouldn't allow it, we allow it at + # runtime, in accordance with a general philosophy of "Keep the + # runtime lenient so people can experiment with typing constructs". + ('generic[T1, *tuple_type[int, ...]]', '[str]', 'generic[str, *tuple_type[int, ...]]'), + ] + + for alias_template, args_template, expected_template in tests: + rendered_templates = template_replace( + templates=[alias_template, args_template, expected_template], + replacements={'generic': generics, 'tuple_type': tuple_types} + ) + for alias_str, args_str, expected_str in rendered_templates: + with self.subTest(alias=alias_str, args=args_str, expected=expected_str): + if expected_str == 'TypeError': + with self.assertRaises(TypeError): + eval(alias_str + args_str) + else: + self.assertEqual( + eval(alias_str + args_str), + eval(expected_str) + ) + + def test_three_parameters(self): + T1 = TypeVar('T1') + T2 = TypeVar('T2') + T3 = TypeVar('T3') + + class C(Generic[T1, T2, T3]): pass + + generics = ['C'] + tuple_types = ['tuple', 'Tuple'] + + tests = [ + # Alias # Args # Expected result + ('generic[T1, bool, T2]', '[int, str]', 'generic[int, bool, str]'), + ('generic[T1, bool, T2]', '[*tuple_type[int, str]]', 'generic[int, bool, str]'), + ] + + for alias_template, args_template, expected_template in tests: + rendered_templates = template_replace( + templates=[alias_template, args_template, expected_template], + replacements={'generic': generics, 'tuple_type': tuple_types} + ) + for alias_str, args_str, expected_str in rendered_templates: + with self.subTest(alias=alias_str, args=args_str, expected=expected_str): + if expected_str == 'TypeError': + with self.assertRaises(TypeError): + eval(alias_str + args_str) + else: + self.assertEqual( + eval(alias_str + args_str), + eval(expected_str) + ) + + def test_variadic_parameters(self): + T1 = TypeVar('T1') + T2 = TypeVar('T2') + Ts = TypeVarTuple('Ts') + + class C(Generic[*Ts]): pass + + generics = ['C', 'tuple', 'Tuple'] + tuple_types = ['tuple', 'Tuple'] + + tests = [ + # Alias # Args # Expected result + ('generic[*Ts]', '[()]', 'generic[()]'), + ('generic[*Ts]', '[int]', 'generic[int]'), + ('generic[*Ts]', '[int, str]', 'generic[int, str]'), + ('generic[*Ts]', '[*tuple_type[int]]', 'generic[int]'), + ('generic[*Ts]', '[*tuple_type[*Ts]]', 'generic[*Ts]'), + ('generic[*Ts]', '[*tuple_type[int, str]]', 'generic[int, str]'), + ('generic[*Ts]', '[str, *tuple_type[int, ...], bool]', 'generic[str, *tuple_type[int, ...], bool]'), + ('generic[*Ts]', '[tuple_type[int, ...]]', 'generic[tuple_type[int, ...]]'), + ('generic[*Ts]', '[tuple_type[int, ...], tuple_type[str, ...]]', 'generic[tuple_type[int, ...], tuple_type[str, ...]]'), + ('generic[*Ts]', '[*tuple_type[int, ...]]', 'generic[*tuple_type[int, ...]]'), + ('generic[*Ts]', '[*tuple_type[int, ...], *tuple_type[str, ...]]', 'TypeError'), + + ('generic[*Ts]', '[*Ts]', 'generic[*Ts]'), + ('generic[*Ts]', '[T, *Ts]', 'generic[T, *Ts]'), + ('generic[*Ts]', '[*Ts, T]', 'generic[*Ts, T]'), + ('generic[T, *Ts]', '[()]', 'TypeError'), + ('generic[T, *Ts]', '[int]', 'generic[int]'), + ('generic[T, *Ts]', '[int, str]', 'generic[int, str]'), + ('generic[T, *Ts]', '[int, str, bool]', 'generic[int, str, bool]'), + ('generic[list[T], *Ts]', '[()]', 'TypeError'), + ('generic[list[T], *Ts]', '[int]', 'generic[list[int]]'), + ('generic[list[T], *Ts]', '[int, str]', 'generic[list[int], str]'), + ('generic[list[T], *Ts]', '[int, str, bool]', 'generic[list[int], str, bool]'), + + ('generic[*Ts, T]', '[()]', 'TypeError'), + ('generic[*Ts, T]', '[int]', 'generic[int]'), + ('generic[*Ts, T]', '[int, str]', 'generic[int, str]'), + ('generic[*Ts, T]', '[int, str, bool]', 'generic[int, str, bool]'), + ('generic[*Ts, list[T]]', '[()]', 'TypeError'), + ('generic[*Ts, list[T]]', '[int]', 'generic[list[int]]'), + ('generic[*Ts, list[T]]', '[int, str]', 'generic[int, list[str]]'), + ('generic[*Ts, list[T]]', '[int, str, bool]', 'generic[int, str, list[bool]]'), + + ('generic[T1, T2, *Ts]', '[()]', 'TypeError'), + ('generic[T1, T2, *Ts]', '[int]', 'TypeError'), + ('generic[T1, T2, *Ts]', '[int, str]', 'generic[int, str]'), + ('generic[T1, T2, *Ts]', '[int, str, bool]', 'generic[int, str, bool]'), + ('generic[T1, T2, *Ts]', '[int, str, bool, bytes]', 'generic[int, str, bool, bytes]'), + + ('generic[*Ts, T1, T2]', '[()]', 'TypeError'), + ('generic[*Ts, T1, T2]', '[int]', 'TypeError'), + ('generic[*Ts, T1, T2]', '[int, str]', 'generic[int, str]'), + ('generic[*Ts, T1, T2]', '[int, str, bool]', 'generic[int, str, bool]'), + ('generic[*Ts, T1, T2]', '[int, str, bool, bytes]', 'generic[int, str, bool, bytes]'), + + ('generic[T1, *Ts, T2]', '[()]', 'TypeError'), + ('generic[T1, *Ts, T2]', '[int]', 'TypeError'), + ('generic[T1, *Ts, T2]', '[int, str]', 'generic[int, str]'), + ('generic[T1, *Ts, T2]', '[int, str, bool]', 'generic[int, str, bool]'), + ('generic[T1, *Ts, T2]', '[int, str, bool, bytes]', 'generic[int, str, bool, bytes]'), + + ('generic[T, *Ts]', '[*tuple_type[int, ...]]', 'generic[int, *tuple_type[int, ...]]'), + ('generic[T, *Ts]', '[str, *tuple_type[int, ...]]', 'generic[str, *tuple_type[int, ...]]'), + ('generic[T, *Ts]', '[*tuple_type[int, ...], str]', 'generic[int, *tuple_type[int, ...], str]'), + ('generic[*Ts, T]', '[*tuple_type[int, ...]]', 'generic[*tuple_type[int, ...], int]'), + ('generic[*Ts, T]', '[str, *tuple_type[int, ...]]', 'generic[str, *tuple_type[int, ...], int]'), + ('generic[*Ts, T]', '[*tuple_type[int, ...], str]', 'generic[*tuple_type[int, ...], str]'), + ('generic[T1, *Ts, T2]', '[*tuple_type[int, ...]]', 'generic[int, *tuple_type[int, ...], int]'), + ('generic[T, str, *Ts]', '[*tuple_type[int, ...]]', 'generic[int, str, *tuple_type[int, ...]]'), + ('generic[*Ts, str, T]', '[*tuple_type[int, ...]]', 'generic[*tuple_type[int, ...], str, int]'), + ('generic[list[T], *Ts]', '[*tuple_type[int, ...]]', 'generic[list[int], *tuple_type[int, ...]]'), + ('generic[*Ts, list[T]]', '[*tuple_type[int, ...]]', 'generic[*tuple_type[int, ...], list[int]]'), + + ('generic[T, *tuple_type[int, ...]]', '[str]', 'generic[str, *tuple_type[int, ...]]'), + ('generic[T1, T2, *tuple_type[int, ...]]', '[str, bool]', 'generic[str, bool, *tuple_type[int, ...]]'), + ('generic[T1, *tuple_type[int, ...], T2]', '[str, bool]', 'generic[str, *tuple_type[int, ...], bool]'), + ('generic[T1, *tuple_type[int, ...], T2]', '[str, bool, float]', 'TypeError'), + + ('generic[T1, *tuple_type[T2, ...]]', '[int, str]', 'generic[int, *tuple_type[str, ...]]'), + ('generic[*tuple_type[T1, ...], T2]', '[int, str]', 'generic[*tuple_type[int, ...], str]'), + ('generic[T1, *tuple_type[generic[*Ts], ...]]', '[int, str, bool]', 'generic[int, *tuple_type[generic[str, bool], ...]]'), + ('generic[*tuple_type[generic[*Ts], ...], T1]', '[int, str, bool]', 'generic[*tuple_type[generic[int, str], ...], bool]'), + ] + + for alias_template, args_template, expected_template in tests: + rendered_templates = template_replace( + templates=[alias_template, args_template, expected_template], + replacements={'generic': generics, 'tuple_type': tuple_types} + ) + for alias_str, args_str, expected_str in rendered_templates: + with self.subTest(alias=alias_str, args=args_str, expected=expected_str): + if expected_str == 'TypeError': + with self.assertRaises(TypeError): + eval(alias_str + args_str) + else: + self.assertEqual( + eval(alias_str + args_str), + eval(expected_str) + ) + + + +class UnpackTests(BaseTestCase): + + def test_accepts_single_type(self): + (*tuple[int],) + Unpack[Tuple[int]] + + def test_dir(self): + dir_items = set(dir(Unpack[Tuple[int]])) + for required_item in [ + '__args__', '__parameters__', '__origin__', + ]: + with self.subTest(required_item=required_item): + self.assertIn(required_item, dir_items) + + def test_rejects_multiple_types(self): + with self.assertRaises(TypeError): + Unpack[Tuple[int], Tuple[str]] + # We can't do the equivalent for `*` here - + # *(Tuple[int], Tuple[str]) is just plain tuple unpacking, + # which is valid. + + def test_rejects_multiple_parameterization(self): + with self.assertRaises(TypeError): + (*tuple[int],)[0][tuple[int]] + with self.assertRaises(TypeError): + Unpack[Tuple[int]][Tuple[int]] + + def test_cannot_be_called(self): + with self.assertRaises(TypeError): + Unpack() + + def test_usage_with_kwargs(self): + Movie = TypedDict('Movie', {'name': str, 'year': int}) + def foo(**kwargs: Unpack[Movie]): ... + self.assertEqual(repr(foo.__annotations__['kwargs']), + f"typing.Unpack[{__name__}.Movie]") + + def test_builtin_tuple(self): + Ts = TypeVarTuple("Ts") + + class Old(Generic[*Ts]): ... + class New[*Ts]: ... + + PartOld = Old[int, *Ts] + self.assertEqual(PartOld[str].__args__, (int, str)) + self.assertEqual(PartOld[*tuple[str]].__args__, (int, str)) + self.assertEqual(PartOld[*Tuple[str]].__args__, (int, str)) + self.assertEqual(PartOld[Unpack[tuple[str]]].__args__, (int, str)) + self.assertEqual(PartOld[Unpack[Tuple[str]]].__args__, (int, str)) + + PartNew = New[int, *Ts] + self.assertEqual(PartNew[str].__args__, (int, str)) + self.assertEqual(PartNew[*tuple[str]].__args__, (int, str)) + self.assertEqual(PartNew[*Tuple[str]].__args__, (int, str)) + self.assertEqual(PartNew[Unpack[tuple[str]]].__args__, (int, str)) + self.assertEqual(PartNew[Unpack[Tuple[str]]].__args__, (int, str)) + + def test_unpack_wrong_type(self): + Ts = TypeVarTuple("Ts") + class Gen[*Ts]: ... + PartGen = Gen[int, *Ts] + + bad_unpack_param = re.escape("Unpack[...] must be used with a tuple type") + with self.assertRaisesRegex(TypeError, bad_unpack_param): + PartGen[Unpack[list[int]]] + with self.assertRaisesRegex(TypeError, bad_unpack_param): + PartGen[Unpack[List[int]]] + + +class TypeVarTupleTests(BaseTestCase): + + def assertEndsWith(self, string, tail): + if not string.endswith(tail): + self.fail(f"String {string!r} does not end with {tail!r}") + + def test_name(self): + Ts = TypeVarTuple('Ts') + self.assertEqual(Ts.__name__, 'Ts') + Ts2 = TypeVarTuple('Ts2') + self.assertEqual(Ts2.__name__, 'Ts2') + + def test_module(self): + Ts = TypeVarTuple('Ts') + self.assertEqual(Ts.__module__, __name__) + + def test_exec(self): + ns = {} + exec('from typing import TypeVarTuple; Ts = TypeVarTuple("Ts")', ns) + Ts = ns['Ts'] + self.assertEqual(Ts.__name__, 'Ts') + self.assertIs(Ts.__module__, None) + + def test_instance_is_equal_to_itself(self): + Ts = TypeVarTuple('Ts') + self.assertEqual(Ts, Ts) + + def test_different_instances_are_different(self): + self.assertNotEqual(TypeVarTuple('Ts'), TypeVarTuple('Ts')) + + def test_instance_isinstance_of_typevartuple(self): + Ts = TypeVarTuple('Ts') + self.assertIsInstance(Ts, TypeVarTuple) + + def test_cannot_call_instance(self): + Ts = TypeVarTuple('Ts') + with self.assertRaises(TypeError): + Ts() + + def test_unpacked_typevartuple_is_equal_to_itself(self): + Ts = TypeVarTuple('Ts') + self.assertEqual((*Ts,)[0], (*Ts,)[0]) + self.assertEqual(Unpack[Ts], Unpack[Ts]) + + def test_parameterised_tuple_is_equal_to_itself(self): + Ts = TypeVarTuple('Ts') + self.assertEqual(tuple[*Ts], tuple[*Ts]) + self.assertEqual(Tuple[Unpack[Ts]], Tuple[Unpack[Ts]]) + + def tests_tuple_arg_ordering_matters(self): + Ts1 = TypeVarTuple('Ts1') + Ts2 = TypeVarTuple('Ts2') + self.assertNotEqual( + tuple[*Ts1, *Ts2], + tuple[*Ts2, *Ts1], + ) + self.assertNotEqual( + Tuple[Unpack[Ts1], Unpack[Ts2]], + Tuple[Unpack[Ts2], Unpack[Ts1]], + ) + + def test_tuple_args_and_parameters_are_correct(self): + Ts = TypeVarTuple('Ts') + t1 = tuple[*Ts] + self.assertEqual(t1.__args__, (*Ts,)) + self.assertEqual(t1.__parameters__, (Ts,)) + t2 = Tuple[Unpack[Ts]] + self.assertEqual(t2.__args__, (Unpack[Ts],)) + self.assertEqual(t2.__parameters__, (Ts,)) + + def test_var_substitution(self): + Ts = TypeVarTuple('Ts') + T = TypeVar('T') + T2 = TypeVar('T2') + class G1(Generic[*Ts]): pass + class G2(Generic[Unpack[Ts]]): pass + + for A in G1, G2, Tuple, tuple: + B = A[*Ts] + self.assertEqual(B[()], A[()]) + self.assertEqual(B[float], A[float]) + self.assertEqual(B[float, str], A[float, str]) + + C = A[Unpack[Ts]] + self.assertEqual(C[()], A[()]) + self.assertEqual(C[float], A[float]) + self.assertEqual(C[float, str], A[float, str]) + + D = list[A[*Ts]] + self.assertEqual(D[()], list[A[()]]) + self.assertEqual(D[float], list[A[float]]) + self.assertEqual(D[float, str], list[A[float, str]]) + + E = List[A[Unpack[Ts]]] + self.assertEqual(E[()], List[A[()]]) + self.assertEqual(E[float], List[A[float]]) + self.assertEqual(E[float, str], List[A[float, str]]) + + F = A[T, *Ts, T2] + with self.assertRaises(TypeError): + F[()] + with self.assertRaises(TypeError): + F[float] + self.assertEqual(F[float, str], A[float, str]) + self.assertEqual(F[float, str, int], A[float, str, int]) + self.assertEqual(F[float, str, int, bytes], A[float, str, int, bytes]) + + G = A[T, Unpack[Ts], T2] + with self.assertRaises(TypeError): + G[()] + with self.assertRaises(TypeError): + G[float] + self.assertEqual(G[float, str], A[float, str]) + self.assertEqual(G[float, str, int], A[float, str, int]) + self.assertEqual(G[float, str, int, bytes], A[float, str, int, bytes]) + + H = tuple[list[T], A[*Ts], list[T2]] + with self.assertRaises(TypeError): + H[()] + with self.assertRaises(TypeError): + H[float] + if A != Tuple: + self.assertEqual(H[float, str], + tuple[list[float], A[()], list[str]]) + self.assertEqual(H[float, str, int], + tuple[list[float], A[str], list[int]]) + self.assertEqual(H[float, str, int, bytes], + tuple[list[float], A[str, int], list[bytes]]) + + I = Tuple[List[T], A[Unpack[Ts]], List[T2]] + with self.assertRaises(TypeError): + I[()] + with self.assertRaises(TypeError): + I[float] + if A != Tuple: + self.assertEqual(I[float, str], + Tuple[List[float], A[()], List[str]]) + self.assertEqual(I[float, str, int], + Tuple[List[float], A[str], List[int]]) + self.assertEqual(I[float, str, int, bytes], + Tuple[List[float], A[str, int], List[bytes]]) + + def test_bad_var_substitution(self): + Ts = TypeVarTuple('Ts') + T = TypeVar('T') + T2 = TypeVar('T2') + class G1(Generic[*Ts]): pass + class G2(Generic[Unpack[Ts]]): pass + + for A in G1, G2, Tuple, tuple: + B = A[Ts] + with self.assertRaises(TypeError): + B[int, str] + + C = A[T, T2] + with self.assertRaises(TypeError): + C[*Ts] + with self.assertRaises(TypeError): + C[Unpack[Ts]] + + B = A[T, *Ts, str, T2] + with self.assertRaises(TypeError): + B[int, *Ts] + with self.assertRaises(TypeError): + B[int, *Ts, *Ts] + + C = A[T, Unpack[Ts], str, T2] + with self.assertRaises(TypeError): + C[int, Unpack[Ts]] + with self.assertRaises(TypeError): + C[int, Unpack[Ts], Unpack[Ts]] + + def test_repr_is_correct(self): + Ts = TypeVarTuple('Ts') + + class G1(Generic[*Ts]): pass + class G2(Generic[Unpack[Ts]]): pass + + self.assertEqual(repr(Ts), 'Ts') + + self.assertEqual(repr((*Ts,)[0]), 'typing.Unpack[Ts]') + self.assertEqual(repr(Unpack[Ts]), 'typing.Unpack[Ts]') + + self.assertEqual(repr(tuple[*Ts]), 'tuple[typing.Unpack[Ts]]') + self.assertEqual(repr(Tuple[Unpack[Ts]]), 'typing.Tuple[typing.Unpack[Ts]]') + + self.assertEqual(repr(*tuple[*Ts]), '*tuple[typing.Unpack[Ts]]') + self.assertEqual(repr(Unpack[Tuple[Unpack[Ts]]]), 'typing.Unpack[typing.Tuple[typing.Unpack[Ts]]]') + + def test_variadic_class_repr_is_correct(self): + Ts = TypeVarTuple('Ts') + class A(Generic[*Ts]): pass + class B(Generic[Unpack[Ts]]): pass + + self.assertEndsWith(repr(A[()]), 'A[()]') + self.assertEndsWith(repr(B[()]), 'B[()]') + self.assertEndsWith(repr(A[float]), 'A[float]') + self.assertEndsWith(repr(B[float]), 'B[float]') + self.assertEndsWith(repr(A[float, str]), 'A[float, str]') + self.assertEndsWith(repr(B[float, str]), 'B[float, str]') + + self.assertEndsWith(repr(A[*tuple[int, ...]]), + 'A[*tuple[int, ...]]') + self.assertEndsWith(repr(B[Unpack[Tuple[int, ...]]]), + 'B[typing.Unpack[typing.Tuple[int, ...]]]') + + self.assertEndsWith(repr(A[float, *tuple[int, ...]]), + 'A[float, *tuple[int, ...]]') + self.assertEndsWith(repr(A[float, Unpack[Tuple[int, ...]]]), + 'A[float, typing.Unpack[typing.Tuple[int, ...]]]') + + self.assertEndsWith(repr(A[*tuple[int, ...], str]), + 'A[*tuple[int, ...], str]') + self.assertEndsWith(repr(B[Unpack[Tuple[int, ...]], str]), + 'B[typing.Unpack[typing.Tuple[int, ...]], str]') + + self.assertEndsWith(repr(A[float, *tuple[int, ...], str]), + 'A[float, *tuple[int, ...], str]') + self.assertEndsWith(repr(B[float, Unpack[Tuple[int, ...]], str]), + 'B[float, typing.Unpack[typing.Tuple[int, ...]], str]') + + def test_variadic_class_alias_repr_is_correct(self): + Ts = TypeVarTuple('Ts') + class A(Generic[Unpack[Ts]]): pass + + B = A[*Ts] + self.assertEndsWith(repr(B), 'A[typing.Unpack[Ts]]') + self.assertEndsWith(repr(B[()]), 'A[()]') + self.assertEndsWith(repr(B[float]), 'A[float]') + self.assertEndsWith(repr(B[float, str]), 'A[float, str]') + + C = A[Unpack[Ts]] + self.assertEndsWith(repr(C), 'A[typing.Unpack[Ts]]') + self.assertEndsWith(repr(C[()]), 'A[()]') + self.assertEndsWith(repr(C[float]), 'A[float]') + self.assertEndsWith(repr(C[float, str]), 'A[float, str]') + + D = A[*Ts, int] + self.assertEndsWith(repr(D), 'A[typing.Unpack[Ts], int]') + self.assertEndsWith(repr(D[()]), 'A[int]') + self.assertEndsWith(repr(D[float]), 'A[float, int]') + self.assertEndsWith(repr(D[float, str]), 'A[float, str, int]') + + E = A[Unpack[Ts], int] + self.assertEndsWith(repr(E), 'A[typing.Unpack[Ts], int]') + self.assertEndsWith(repr(E[()]), 'A[int]') + self.assertEndsWith(repr(E[float]), 'A[float, int]') + self.assertEndsWith(repr(E[float, str]), 'A[float, str, int]') + + F = A[int, *Ts] + self.assertEndsWith(repr(F), 'A[int, typing.Unpack[Ts]]') + self.assertEndsWith(repr(F[()]), 'A[int]') + self.assertEndsWith(repr(F[float]), 'A[int, float]') + self.assertEndsWith(repr(F[float, str]), 'A[int, float, str]') + + G = A[int, Unpack[Ts]] + self.assertEndsWith(repr(G), 'A[int, typing.Unpack[Ts]]') + self.assertEndsWith(repr(G[()]), 'A[int]') + self.assertEndsWith(repr(G[float]), 'A[int, float]') + self.assertEndsWith(repr(G[float, str]), 'A[int, float, str]') + + H = A[int, *Ts, str] + self.assertEndsWith(repr(H), 'A[int, typing.Unpack[Ts], str]') + self.assertEndsWith(repr(H[()]), 'A[int, str]') + self.assertEndsWith(repr(H[float]), 'A[int, float, str]') + self.assertEndsWith(repr(H[float, str]), 'A[int, float, str, str]') + + I = A[int, Unpack[Ts], str] + self.assertEndsWith(repr(I), 'A[int, typing.Unpack[Ts], str]') + self.assertEndsWith(repr(I[()]), 'A[int, str]') + self.assertEndsWith(repr(I[float]), 'A[int, float, str]') + self.assertEndsWith(repr(I[float, str]), 'A[int, float, str, str]') + + J = A[*Ts, *tuple[str, ...]] + self.assertEndsWith(repr(J), 'A[typing.Unpack[Ts], *tuple[str, ...]]') + self.assertEndsWith(repr(J[()]), 'A[*tuple[str, ...]]') + self.assertEndsWith(repr(J[float]), 'A[float, *tuple[str, ...]]') + self.assertEndsWith(repr(J[float, str]), 'A[float, str, *tuple[str, ...]]') + + K = A[Unpack[Ts], Unpack[Tuple[str, ...]]] + self.assertEndsWith(repr(K), 'A[typing.Unpack[Ts], typing.Unpack[typing.Tuple[str, ...]]]') + self.assertEndsWith(repr(K[()]), 'A[typing.Unpack[typing.Tuple[str, ...]]]') + self.assertEndsWith(repr(K[float]), 'A[float, typing.Unpack[typing.Tuple[str, ...]]]') + self.assertEndsWith(repr(K[float, str]), 'A[float, str, typing.Unpack[typing.Tuple[str, ...]]]') + + def test_cannot_subclass(self): + with self.assertRaisesRegex(TypeError, NOT_A_BASE_TYPE % 'TypeVarTuple'): + class C(TypeVarTuple): pass + Ts = TypeVarTuple('Ts') + with self.assertRaisesRegex(TypeError, + CANNOT_SUBCLASS_INSTANCE % 'TypeVarTuple'): + class D(Ts): pass + with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): + class E(type(Unpack)): pass + with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): + class F(type(*Ts)): pass + with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): + class G(type(Unpack[Ts])): pass + with self.assertRaisesRegex(TypeError, + r'Cannot subclass typing\.Unpack'): + class H(Unpack): pass + with self.assertRaisesRegex(TypeError, r'Cannot subclass typing.Unpack\[Ts\]'): + class I(*Ts): pass + with self.assertRaisesRegex(TypeError, r'Cannot subclass typing.Unpack\[Ts\]'): + class J(Unpack[Ts]): pass + + def test_variadic_class_args_are_correct(self): + T = TypeVar('T') + Ts = TypeVarTuple('Ts') + class A(Generic[*Ts]): pass + class B(Generic[Unpack[Ts]]): pass + + C = A[()] + D = B[()] + self.assertEqual(C.__args__, ()) + self.assertEqual(D.__args__, ()) + + E = A[int] + F = B[int] + self.assertEqual(E.__args__, (int,)) + self.assertEqual(F.__args__, (int,)) + + G = A[int, str] + H = B[int, str] + self.assertEqual(G.__args__, (int, str)) + self.assertEqual(H.__args__, (int, str)) + + I = A[T] + J = B[T] + self.assertEqual(I.__args__, (T,)) + self.assertEqual(J.__args__, (T,)) + + K = A[*Ts] + L = B[Unpack[Ts]] + self.assertEqual(K.__args__, (*Ts,)) + self.assertEqual(L.__args__, (Unpack[Ts],)) + + M = A[T, *Ts] + N = B[T, Unpack[Ts]] + self.assertEqual(M.__args__, (T, *Ts)) + self.assertEqual(N.__args__, (T, Unpack[Ts])) + + O = A[*Ts, T] + P = B[Unpack[Ts], T] + self.assertEqual(O.__args__, (*Ts, T)) + self.assertEqual(P.__args__, (Unpack[Ts], T)) + + def test_variadic_class_origin_is_correct(self): + Ts = TypeVarTuple('Ts') + + class C(Generic[*Ts]): pass + self.assertIs(C[int].__origin__, C) + self.assertIs(C[T].__origin__, C) + self.assertIs(C[Unpack[Ts]].__origin__, C) + + class D(Generic[Unpack[Ts]]): pass + self.assertIs(D[int].__origin__, D) + self.assertIs(D[T].__origin__, D) + self.assertIs(D[Unpack[Ts]].__origin__, D) + + def test_get_type_hints_on_unpack_args(self): + Ts = TypeVarTuple('Ts') + + def func1(*args: *Ts): pass + self.assertEqual(gth(func1), {'args': Unpack[Ts]}) + + def func2(*args: *tuple[int, str]): pass + self.assertEqual(gth(func2), {'args': Unpack[tuple[int, str]]}) + + class CustomVariadic(Generic[*Ts]): pass + + def func3(*args: *CustomVariadic[int, str]): pass + self.assertEqual(gth(func3), {'args': Unpack[CustomVariadic[int, str]]}) + + def test_get_type_hints_on_unpack_args_string(self): + Ts = TypeVarTuple('Ts') + + def func1(*args: '*Ts'): pass + self.assertEqual(gth(func1, localns={'Ts': Ts}), + {'args': Unpack[Ts]}) + + def func2(*args: '*tuple[int, str]'): pass + self.assertEqual(gth(func2), {'args': Unpack[tuple[int, str]]}) + + class CustomVariadic(Generic[*Ts]): pass + + def func3(*args: '*CustomVariadic[int, str]'): pass + self.assertEqual(gth(func3, localns={'CustomVariadic': CustomVariadic}), + {'args': Unpack[CustomVariadic[int, str]]}) + + def test_tuple_args_are_correct(self): + Ts = TypeVarTuple('Ts') + + self.assertEqual(tuple[*Ts].__args__, (*Ts,)) + self.assertEqual(Tuple[Unpack[Ts]].__args__, (Unpack[Ts],)) + + self.assertEqual(tuple[*Ts, int].__args__, (*Ts, int)) + self.assertEqual(Tuple[Unpack[Ts], int].__args__, (Unpack[Ts], int)) + + self.assertEqual(tuple[int, *Ts].__args__, (int, *Ts)) + self.assertEqual(Tuple[int, Unpack[Ts]].__args__, (int, Unpack[Ts])) + + self.assertEqual(tuple[int, *Ts, str].__args__, + (int, *Ts, str)) + self.assertEqual(Tuple[int, Unpack[Ts], str].__args__, + (int, Unpack[Ts], str)) + + self.assertEqual(tuple[*Ts, int].__args__, (*Ts, int)) + self.assertEqual(Tuple[Unpack[Ts]].__args__, (Unpack[Ts],)) + + def test_callable_args_are_correct(self): + Ts = TypeVarTuple('Ts') + Ts1 = TypeVarTuple('Ts1') + Ts2 = TypeVarTuple('Ts2') + + # TypeVarTuple in the arguments + + a = Callable[[*Ts], None] + b = Callable[[Unpack[Ts]], None] + self.assertEqual(a.__args__, (*Ts, type(None))) + self.assertEqual(b.__args__, (Unpack[Ts], type(None))) + + c = Callable[[int, *Ts], None] + d = Callable[[int, Unpack[Ts]], None] + self.assertEqual(c.__args__, (int, *Ts, type(None))) + self.assertEqual(d.__args__, (int, Unpack[Ts], type(None))) + + e = Callable[[*Ts, int], None] + f = Callable[[Unpack[Ts], int], None] + self.assertEqual(e.__args__, (*Ts, int, type(None))) + self.assertEqual(f.__args__, (Unpack[Ts], int, type(None))) + + g = Callable[[str, *Ts, int], None] + h = Callable[[str, Unpack[Ts], int], None] + self.assertEqual(g.__args__, (str, *Ts, int, type(None))) + self.assertEqual(h.__args__, (str, Unpack[Ts], int, type(None))) + + # TypeVarTuple as the return + + i = Callable[[None], *Ts] + j = Callable[[None], Unpack[Ts]] + self.assertEqual(i.__args__, (type(None), *Ts)) + self.assertEqual(j.__args__, (type(None), Unpack[Ts])) + + k = Callable[[None], tuple[int, *Ts]] + l = Callable[[None], Tuple[int, Unpack[Ts]]] + self.assertEqual(k.__args__, (type(None), tuple[int, *Ts])) + self.assertEqual(l.__args__, (type(None), Tuple[int, Unpack[Ts]])) + + m = Callable[[None], tuple[*Ts, int]] + n = Callable[[None], Tuple[Unpack[Ts], int]] + self.assertEqual(m.__args__, (type(None), tuple[*Ts, int])) + self.assertEqual(n.__args__, (type(None), Tuple[Unpack[Ts], int])) + + o = Callable[[None], tuple[str, *Ts, int]] + p = Callable[[None], Tuple[str, Unpack[Ts], int]] + self.assertEqual(o.__args__, (type(None), tuple[str, *Ts, int])) + self.assertEqual(p.__args__, (type(None), Tuple[str, Unpack[Ts], int])) + + # TypeVarTuple in both + + q = Callable[[*Ts], *Ts] + r = Callable[[Unpack[Ts]], Unpack[Ts]] + self.assertEqual(q.__args__, (*Ts, *Ts)) + self.assertEqual(r.__args__, (Unpack[Ts], Unpack[Ts])) + + s = Callable[[*Ts1], *Ts2] + u = Callable[[Unpack[Ts1]], Unpack[Ts2]] + self.assertEqual(s.__args__, (*Ts1, *Ts2)) + self.assertEqual(u.__args__, (Unpack[Ts1], Unpack[Ts2])) + + def test_variadic_class_with_duplicate_typevartuples_fails(self): + Ts1 = TypeVarTuple('Ts1') + Ts2 = TypeVarTuple('Ts2') + + with self.assertRaises(TypeError): + class C(Generic[*Ts1, *Ts1]): pass + with self.assertRaises(TypeError): + class D(Generic[Unpack[Ts1], Unpack[Ts1]]): pass + + with self.assertRaises(TypeError): + class E(Generic[*Ts1, *Ts2, *Ts1]): pass + with self.assertRaises(TypeError): + class F(Generic[Unpack[Ts1], Unpack[Ts2], Unpack[Ts1]]): pass + + def test_type_concatenation_in_variadic_class_argument_list_succeeds(self): + Ts = TypeVarTuple('Ts') + class C(Generic[Unpack[Ts]]): pass + + C[int, *Ts] + C[int, Unpack[Ts]] + + C[*Ts, int] + C[Unpack[Ts], int] + + C[int, *Ts, str] + C[int, Unpack[Ts], str] + + C[int, bool, *Ts, float, str] + C[int, bool, Unpack[Ts], float, str] + + def test_type_concatenation_in_tuple_argument_list_succeeds(self): + Ts = TypeVarTuple('Ts') + + tuple[int, *Ts] + tuple[*Ts, int] + tuple[int, *Ts, str] + tuple[int, bool, *Ts, float, str] + + Tuple[int, Unpack[Ts]] + Tuple[Unpack[Ts], int] + Tuple[int, Unpack[Ts], str] + Tuple[int, bool, Unpack[Ts], float, str] + + def test_variadic_class_definition_using_packed_typevartuple_fails(self): + Ts = TypeVarTuple('Ts') + with self.assertRaises(TypeError): + class C(Generic[Ts]): pass + + def test_variadic_class_definition_using_concrete_types_fails(self): + Ts = TypeVarTuple('Ts') + with self.assertRaises(TypeError): + class F(Generic[*Ts, int]): pass + with self.assertRaises(TypeError): + class E(Generic[Unpack[Ts], int]): pass + + def test_variadic_class_with_2_typevars_accepts_2_or_more_args(self): + Ts = TypeVarTuple('Ts') + T1 = TypeVar('T1') + T2 = TypeVar('T2') + + class A(Generic[T1, T2, *Ts]): pass + A[int, str] + A[int, str, float] + A[int, str, float, bool] + + class B(Generic[T1, T2, Unpack[Ts]]): pass + B[int, str] + B[int, str, float] + B[int, str, float, bool] + + class C(Generic[T1, *Ts, T2]): pass + C[int, str] + C[int, str, float] + C[int, str, float, bool] + + class D(Generic[T1, Unpack[Ts], T2]): pass + D[int, str] + D[int, str, float] + D[int, str, float, bool] + + class E(Generic[*Ts, T1, T2]): pass + E[int, str] + E[int, str, float] + E[int, str, float, bool] + + class F(Generic[Unpack[Ts], T1, T2]): pass + F[int, str] + F[int, str, float] + F[int, str, float, bool] + + def test_variadic_args_annotations_are_correct(self): + Ts = TypeVarTuple('Ts') + + def f(*args: Unpack[Ts]): pass + def g(*args: *Ts): pass + self.assertEqual(f.__annotations__, {'args': Unpack[Ts]}) + self.assertEqual(g.__annotations__, {'args': (*Ts,)[0]}) + + def test_variadic_args_with_ellipsis_annotations_are_correct(self): + def a(*args: *tuple[int, ...]): pass + self.assertEqual(a.__annotations__, + {'args': (*tuple[int, ...],)[0]}) + + def b(*args: Unpack[Tuple[int, ...]]): pass + self.assertEqual(b.__annotations__, + {'args': Unpack[Tuple[int, ...]]}) + + def test_concatenation_in_variadic_args_annotations_are_correct(self): + Ts = TypeVarTuple('Ts') + + # Unpacking using `*`, native `tuple` type + + def a(*args: *tuple[int, *Ts]): pass + self.assertEqual( + a.__annotations__, + {'args': (*tuple[int, *Ts],)[0]}, + ) + + def b(*args: *tuple[*Ts, int]): pass + self.assertEqual( + b.__annotations__, + {'args': (*tuple[*Ts, int],)[0]}, + ) + + def c(*args: *tuple[str, *Ts, int]): pass + self.assertEqual( + c.__annotations__, + {'args': (*tuple[str, *Ts, int],)[0]}, + ) + + def d(*args: *tuple[int, bool, *Ts, float, str]): pass + self.assertEqual( + d.__annotations__, + {'args': (*tuple[int, bool, *Ts, float, str],)[0]}, + ) + + # Unpacking using `Unpack`, `Tuple` type from typing.py + + def e(*args: Unpack[Tuple[int, Unpack[Ts]]]): pass + self.assertEqual( + e.__annotations__, + {'args': Unpack[Tuple[int, Unpack[Ts]]]}, + ) + + def f(*args: Unpack[Tuple[Unpack[Ts], int]]): pass + self.assertEqual( + f.__annotations__, + {'args': Unpack[Tuple[Unpack[Ts], int]]}, + ) + + def g(*args: Unpack[Tuple[str, Unpack[Ts], int]]): pass + self.assertEqual( + g.__annotations__, + {'args': Unpack[Tuple[str, Unpack[Ts], int]]}, + ) + + def h(*args: Unpack[Tuple[int, bool, Unpack[Ts], float, str]]): pass + self.assertEqual( + h.__annotations__, + {'args': Unpack[Tuple[int, bool, Unpack[Ts], float, str]]}, + ) + + def test_variadic_class_same_args_results_in_equalty(self): + Ts = TypeVarTuple('Ts') + class C(Generic[*Ts]): pass + class D(Generic[Unpack[Ts]]): pass + + self.assertEqual(C[int], C[int]) + self.assertEqual(D[int], D[int]) + + Ts1 = TypeVarTuple('Ts1') + Ts2 = TypeVarTuple('Ts2') + + self.assertEqual( + C[*Ts1], + C[*Ts1], + ) + self.assertEqual( + D[Unpack[Ts1]], + D[Unpack[Ts1]], + ) + + self.assertEqual( + C[*Ts1, *Ts2], + C[*Ts1, *Ts2], + ) + self.assertEqual( + D[Unpack[Ts1], Unpack[Ts2]], + D[Unpack[Ts1], Unpack[Ts2]], + ) + + self.assertEqual( + C[int, *Ts1, *Ts2], + C[int, *Ts1, *Ts2], + ) + self.assertEqual( + D[int, Unpack[Ts1], Unpack[Ts2]], + D[int, Unpack[Ts1], Unpack[Ts2]], + ) + + def test_variadic_class_arg_ordering_matters(self): + Ts = TypeVarTuple('Ts') + class C(Generic[*Ts]): pass + class D(Generic[Unpack[Ts]]): pass + + self.assertNotEqual( + C[int, str], + C[str, int], + ) + self.assertNotEqual( + D[int, str], + D[str, int], + ) + + Ts1 = TypeVarTuple('Ts1') + Ts2 = TypeVarTuple('Ts2') + + self.assertNotEqual( + C[*Ts1, *Ts2], + C[*Ts2, *Ts1], + ) + self.assertNotEqual( + D[Unpack[Ts1], Unpack[Ts2]], + D[Unpack[Ts2], Unpack[Ts1]], + ) + + def test_variadic_class_arg_typevartuple_identity_matters(self): + Ts = TypeVarTuple('Ts') + Ts1 = TypeVarTuple('Ts1') + Ts2 = TypeVarTuple('Ts2') + + class C(Generic[*Ts]): pass + class D(Generic[Unpack[Ts]]): pass + + self.assertNotEqual(C[*Ts1], C[*Ts2]) + self.assertNotEqual(D[Unpack[Ts1]], D[Unpack[Ts2]]) + + +class TypeVarTuplePicklingTests(BaseTestCase): + # These are slightly awkward tests to run, because TypeVarTuples are only + # picklable if defined in the global scope. We therefore need to push + # various things defined in these tests into the global scope with `global` + # statements at the start of each test. + + @all_pickle_protocols + def test_pickling_then_unpickling_results_in_same_identity(self, proto): + global global_Ts1 # See explanation at start of class. + global_Ts1 = TypeVarTuple('global_Ts1') + global_Ts2 = pickle.loads(pickle.dumps(global_Ts1, proto)) + self.assertIs(global_Ts1, global_Ts2) + + @all_pickle_protocols + def test_pickling_then_unpickling_unpacked_results_in_same_identity(self, proto): + global global_Ts # See explanation at start of class. + global_Ts = TypeVarTuple('global_Ts') + + unpacked1 = (*global_Ts,)[0] + unpacked2 = pickle.loads(pickle.dumps(unpacked1, proto)) + self.assertIs(unpacked1, unpacked2) + + unpacked3 = Unpack[global_Ts] + unpacked4 = pickle.loads(pickle.dumps(unpacked3, proto)) + self.assertIs(unpacked3, unpacked4) + + @all_pickle_protocols + def test_pickling_then_unpickling_tuple_with_typevartuple_equality( + self, proto + ): + global global_T, global_Ts # See explanation at start of class. + global_T = TypeVar('global_T') + global_Ts = TypeVarTuple('global_Ts') + + tuples = [ + tuple[*global_Ts], + Tuple[Unpack[global_Ts]], + + tuple[T, *global_Ts], + Tuple[T, Unpack[global_Ts]], + + tuple[int, *global_Ts], + Tuple[int, Unpack[global_Ts]], + ] + for t in tuples: + t2 = pickle.loads(pickle.dumps(t, proto)) + self.assertEqual(t, t2) + class UnionTests(BaseTestCase): @@ -313,6 +2058,26 @@ def test_union_union(self): v = Union[u, Employee] self.assertEqual(v, Union[int, float, Employee]) + def test_union_of_unhashable(self): + class UnhashableMeta(type): + __hash__ = None + + class A(metaclass=UnhashableMeta): ... + class B(metaclass=UnhashableMeta): ... + + self.assertEqual(Union[A, B].__args__, (A, B)) + union1 = Union[A, B] + with self.assertRaises(TypeError): + hash(union1) + + union2 = Union[int, B] + with self.assertRaises(TypeError): + hash(union2) + + union3 = Union[A, int] + with self.assertRaises(TypeError): + hash(union3) + def test_repr(self): self.assertEqual(repr(Union), 'typing.Union') u = Union[Employee, int] @@ -338,15 +2103,25 @@ def test_repr(self): u = Optional[str] self.assertEqual(repr(u), 'typing.Optional[str]') + def test_dir(self): + dir_items = set(dir(Union[str, int])) + for required_item in [ + '__args__', '__parameters__', '__origin__', + ]: + with self.subTest(required_item=required_item): + self.assertIn(required_item, dir_items) + def test_cannot_subclass(self): - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, + r'Cannot subclass typing\.Union'): class C(Union): pass - with self.assertRaises(TypeError): - class C(type(Union)): + with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): + class D(type(Union)): pass - with self.assertRaises(TypeError): - class C(Union[int, str]): + with self.assertRaisesRegex(TypeError, + r'Cannot subclass typing\.Union\[int, str\]'): + class E(Union[int, str]): pass def test_cannot_instantiate(self): @@ -410,6 +2185,35 @@ def Elem(*args): Union[Elem, str] # Nor should this + def test_union_of_literals(self): + self.assertEqual(Union[Literal[1], Literal[2]].__args__, + (Literal[1], Literal[2])) + self.assertEqual(Union[Literal[1], Literal[1]], + Literal[1]) + + self.assertEqual(Union[Literal[False], Literal[0]].__args__, + (Literal[False], Literal[0])) + self.assertEqual(Union[Literal[True], Literal[1]].__args__, + (Literal[True], Literal[1])) + + import enum + class Ints(enum.IntEnum): + A = 0 + B = 1 + + self.assertEqual(Union[Literal[Ints.A], Literal[Ints.A]], + Literal[Ints.A]) + self.assertEqual(Union[Literal[Ints.B], Literal[Ints.B]], + Literal[Ints.B]) + + self.assertEqual(Union[Literal[Ints.A], Literal[Ints.B]].__args__, + (Literal[Ints.A], Literal[Ints.B])) + + self.assertEqual(Union[Literal[0], Literal[Ints.A], Literal[False]].__args__, + (Literal[0], Literal[Ints.A], Literal[False])) + self.assertEqual(Union[Literal[1], Literal[Ints.B], Literal[True]].__args__, + (Literal[1], Literal[Ints.B], Literal[True])) + class TupleTests(BaseTestCase): @@ -476,6 +2280,15 @@ def test_eq_hash(self): self.assertNotEqual(C, Callable[..., int]) self.assertNotEqual(C, Callable) + def test_dir(self): + Callable = self.Callable + dir_items = set(dir(Callable[..., int])) + for required_item in [ + '__args__', '__parameters__', '__origin__', + ]: + with self.subTest(required_item=required_item): + self.assertIn(required_item, dir_items) + def test_cannot_instantiate(self): Callable = self.Callable with self.assertRaises(TypeError): @@ -505,13 +2318,13 @@ def test_callable_instance_type_error(self): def f(): pass with self.assertRaises(TypeError): - self.assertIsInstance(f, Callable[[], None]) + isinstance(f, Callable[[], None]) with self.assertRaises(TypeError): - self.assertIsInstance(f, Callable[[], Any]) + isinstance(f, Callable[[], Any]) with self.assertRaises(TypeError): - self.assertNotIsInstance(None, Callable[[], None]) + isinstance(None, Callable[[], None]) with self.assertRaises(TypeError): - self.assertNotIsInstance(None, Callable[[], Any]) + isinstance(None, Callable[[], Any]) def test_repr(self): Callable = self.Callable @@ -558,14 +2371,29 @@ def test_weakref(self): self.assertEqual(weakref.ref(alias)(), alias) def test_pickle(self): + global T_pickle, P_pickle, TS_pickle # needed for pickling Callable = self.Callable - alias = Callable[[int, str], float] - for proto in range(pickle.HIGHEST_PROTOCOL + 1): - s = pickle.dumps(alias, proto) - loaded = pickle.loads(s) - self.assertEqual(alias.__origin__, loaded.__origin__) - self.assertEqual(alias.__args__, loaded.__args__) - self.assertEqual(alias.__parameters__, loaded.__parameters__) + T_pickle = TypeVar('T_pickle') + P_pickle = ParamSpec('P_pickle') + TS_pickle = TypeVarTuple('TS_pickle') + + samples = [ + Callable[[int, str], float], + Callable[P_pickle, int], + Callable[P_pickle, T_pickle], + Callable[Concatenate[int, P_pickle], int], + Callable[Concatenate[*TS_pickle, P_pickle], int], + ] + for alias in samples: + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(alias=alias, proto=proto): + s = pickle.dumps(alias, proto) + loaded = pickle.loads(s) + self.assertEqual(alias.__origin__, loaded.__origin__) + self.assertEqual(alias.__args__, loaded.__args__) + self.assertEqual(alias.__parameters__, loaded.__parameters__) + + del T_pickle, P_pickle, TS_pickle # cleaning up global state def test_var_substitution(self): Callable = self.Callable @@ -574,8 +2402,7 @@ def test_var_substitution(self): C2 = Callable[[KT, T], VT] C3 = Callable[..., T] self.assertEqual(C1[str], Callable[[int, str], str]) - if Callable is typing.Callable: - self.assertEqual(C1[None], Callable[[int, type(None)], type(None)]) + self.assertEqual(C1[None], Callable[[int, type(None)], type(None)]) self.assertEqual(C2[int, float, str], Callable[[int, float], str]) self.assertEqual(C3[int], Callable[..., int]) self.assertEqual(C3[NoReturn], Callable[..., NoReturn]) @@ -592,6 +2419,16 @@ def test_var_substitution(self): self.assertEqual(C5[int, str, float], Callable[[typing.List[int], tuple[str, int], float], int]) + def test_type_subst_error(self): + Callable = self.Callable + P = ParamSpec('P') + T = TypeVar('T') + + pat = "Expected a list of types, an ellipsis, ParamSpec, or Concatenate." + + with self.assertRaisesRegex(TypeError, pat): + Callable[P, T][0, int] + def test_type_erasure(self): Callable = self.Callable class C1(Callable): @@ -649,8 +2486,7 @@ def test_concatenate(self): self.assertEqual(C[[], int], Callable[[int], int]) self.assertEqual(C[Concatenate[str, P2], int], Callable[Concatenate[int, str, P2], int]) - with self.assertRaises(TypeError): - C[..., int] + self.assertEqual(C[..., int], Callable[Concatenate[int, ...], int]) C = Callable[Concatenate[int, P], int] self.assertEqual(repr(C), @@ -661,8 +2497,49 @@ def test_concatenate(self): self.assertEqual(C[[]], Callable[[int], int]) self.assertEqual(C[Concatenate[str, P2]], Callable[Concatenate[int, str, P2], int]) - with self.assertRaises(TypeError): - C[...] + self.assertEqual(C[...], Callable[Concatenate[int, ...], int]) + + def test_nested_paramspec(self): + # Since Callable has some special treatment, we want to be sure + # that substituion works correctly, see gh-103054 + Callable = self.Callable + P = ParamSpec('P') + P2 = ParamSpec('P2') + T = TypeVar('T') + T2 = TypeVar('T2') + Ts = TypeVarTuple('Ts') + class My(Generic[P, T]): + pass + + self.assertEqual(My.__parameters__, (P, T)) + + C1 = My[[int, T2], Callable[P2, T2]] + self.assertEqual(C1.__args__, ((int, T2), Callable[P2, T2])) + self.assertEqual(C1.__parameters__, (T2, P2)) + self.assertEqual(C1[str, [list[int], bytes]], + My[[int, str], Callable[[list[int], bytes], str]]) + + C2 = My[[Callable[[T2], int], list[T2]], str] + self.assertEqual(C2.__args__, ((Callable[[T2], int], list[T2]), str)) + self.assertEqual(C2.__parameters__, (T2,)) + self.assertEqual(C2[list[str]], + My[[Callable[[list[str]], int], list[list[str]]], str]) + + C3 = My[[Callable[P2, T2], T2], T2] + self.assertEqual(C3.__args__, ((Callable[P2, T2], T2), T2)) + self.assertEqual(C3.__parameters__, (P2, T2)) + self.assertEqual(C3[[], int], + My[[Callable[[], int], int], int]) + self.assertEqual(C3[[str, bool], int], + My[[Callable[[str, bool], int], int], int]) + self.assertEqual(C3[[str, bool], T][int], + My[[Callable[[str, bool], int], int], int]) + + C4 = My[[Callable[[int, *Ts, str], T2], T2], T2] + self.assertEqual(C4.__args__, ((Callable[[int, *Ts, str], T2], T2), T2)) + self.assertEqual(C4.__parameters__, (Ts, T2)) + self.assertEqual(C4[bool, bytes, float], + My[[Callable[[int, bool, bytes, str], float], float], float]) def test_errors(self): Callable = self.Callable @@ -723,9 +2600,16 @@ def test_basics(self): Literal[Literal[1, 2], Literal[4, 5]] Literal[b"foo", u"bar"] + def test_enum(self): + import enum + class My(enum.Enum): + A = 'A' + + self.assertEqual(Literal[My.A].__args__, (My.A,)) + def test_illegal_parameters_do_not_raise_runtime_errors(self): # Type checkers should reject these types, but we do not - # raise errors at runtime to maintain maximium flexibility. + # raise errors at runtime to maintain maximum flexibility. Literal[int] Literal[3j + 2, ..., ()] Literal[{"foo": 3, "bar": 4}] @@ -743,6 +2627,14 @@ def test_repr(self): self.assertEqual(repr(Literal[None]), "typing.Literal[None]") self.assertEqual(repr(Literal[1, 2, 3, 3]), "typing.Literal[1, 2, 3]") + def test_dir(self): + dir_items = set(dir(Literal[1, 2, 3])) + for required_item in [ + '__args__', '__parameters__', '__origin__', + ]: + with self.subTest(required_item=required_item): + self.assertIn(required_item, dir_items) + def test_cannot_init(self): with self.assertRaises(TypeError): Literal() @@ -804,6 +2696,20 @@ def test_flatten(self): self.assertEqual(l, Literal[1, 2, 3]) self.assertEqual(l.__args__, (1, 2, 3)) + def test_does_not_flatten_enum(self): + import enum + class Ints(enum.IntEnum): + A = 1 + B = 2 + + l = Literal[ + Literal[Ints.A], + Literal[Ints.B], + Literal[1], + Literal[2], + ] + self.assertEqual(l.__args__, (Ints.A, Ints.B, 1, 2)) + XK = TypeVar('XK', str, bytes) XV = TypeVar('XV') @@ -914,6 +2820,48 @@ def f(): # TODO: RUSTPYTHON @unittest.expectedFailure + def test_runtime_checkable_generic_non_protocol(self): + # Make sure this doesn't raise AttributeError + with self.assertRaisesRegex( + TypeError, + "@runtime_checkable can be only applied to protocol classes", + ): + @runtime_checkable + class Foo[T]: ... + + def test_runtime_checkable_generic(self): + @runtime_checkable + class Foo[T](Protocol): + def meth(self) -> T: ... + + class Impl: + def meth(self) -> int: ... + + self.assertIsSubclass(Impl, Foo) + + class NotImpl: + def method(self) -> int: ... + + self.assertNotIsSubclass(NotImpl, Foo) + + def test_pep695_generics_can_be_runtime_checkable(self): + @runtime_checkable + class HasX(Protocol): + x: int + + class Bar[T]: + x: T + def __init__(self, x): + self.x = x + + class Capybara[T]: + y: str + def __init__(self, y): + self.y = y + + self.assertIsInstance(Bar(1), HasX) + self.assertNotIsInstance(Capybara('a'), HasX) + def test_everything_implements_empty_protocol(self): @runtime_checkable class Empty(Protocol): @@ -945,10 +2893,10 @@ class BP(Protocol): pass class P(C, Protocol): pass with self.assertRaises(TypeError): - class P(Protocol, C): + class Q(Protocol, C): pass with self.assertRaises(TypeError): - class P(BP, C, Protocol): + class R(BP, C, Protocol): pass class D(BP, C): pass @@ -989,6 +2937,32 @@ class CG(PG[T]): pass with self.assertRaises(TypeError): CG[int](42) + def test_protocol_defining_init_does_not_get_overridden(self): + # check that P.__init__ doesn't get clobbered + # see https://bugs.python.org/issue44807 + + class P(Protocol): + x: int + def __init__(self, x: int) -> None: + self.x = x + class C: pass + + c = C() + P.__init__(c, 1) + self.assertEqual(c.x, 1) + + def test_concrete_class_inheriting_init_from_protocol(self): + class P(Protocol): + x: int + def __init__(self, x: int) -> None: + self.x = x + + class C(P): pass + + c = C(1) + self.assertIsInstance(c, C) + self.assertEqual(c.x, 1) + def test_cannot_instantiate_abstract(self): @runtime_checkable class P(Protocol): @@ -1007,8 +2981,6 @@ def ameth(self) -> int: B() self.assertIsInstance(C(), P) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_subprotocols_extending(self): class P1(Protocol): def meth1(self): @@ -1104,19 +3076,187 @@ def x(self): ... self.assertIsSubclass(C, PG) self.assertIsSubclass(BadP, PG) - with self.assertRaises(TypeError): + no_subscripted_generics = ( + "Subscripted generics cannot be used with class and instance checks" + ) + + with self.assertRaisesRegex(TypeError, no_subscripted_generics): issubclass(C, PG[T]) - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, no_subscripted_generics): issubclass(C, PG[C]) - with self.assertRaises(TypeError): + + only_runtime_checkable_protocols = ( + "Instance and class checks can only be used with " + "@runtime_checkable protocols" + ) + + with self.assertRaisesRegex(TypeError, only_runtime_checkable_protocols): issubclass(C, BadP) - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, only_runtime_checkable_protocols): issubclass(C, BadPG) - with self.assertRaises(TypeError): + + with self.assertRaisesRegex(TypeError, no_subscripted_generics): issubclass(P, PG[T]) - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, no_subscripted_generics): issubclass(PG, PG[int]) + only_classes_allowed = r"issubclass\(\) arg 1 must be a class" + + with self.assertRaisesRegex(TypeError, only_classes_allowed): + issubclass(1, P) + with self.assertRaisesRegex(TypeError, only_classes_allowed): + issubclass(1, PG) + with self.assertRaisesRegex(TypeError, only_classes_allowed): + issubclass(1, BadP) + with self.assertRaisesRegex(TypeError, only_classes_allowed): + issubclass(1, BadPG) + + def test_implicit_issubclass_between_two_protocols(self): + @runtime_checkable + class CallableMembersProto(Protocol): + def meth(self): ... + + # All the below protocols should be considered "subclasses" + # of CallableMembersProto at runtime, + # even though none of them explicitly subclass CallableMembersProto + + class IdenticalProto(Protocol): + def meth(self): ... + + class SupersetProto(Protocol): + def meth(self): ... + def meth2(self): ... + + class NonCallableMembersProto(Protocol): + meth: Callable[[], None] + + class NonCallableMembersSupersetProto(Protocol): + meth: Callable[[], None] + meth2: Callable[[str, int], bool] + + class MixedMembersProto1(Protocol): + meth: Callable[[], None] + def meth2(self): ... + + class MixedMembersProto2(Protocol): + def meth(self): ... + meth2: Callable[[str, int], bool] + + for proto in ( + IdenticalProto, SupersetProto, NonCallableMembersProto, + NonCallableMembersSupersetProto, MixedMembersProto1, MixedMembersProto2 + ): + with self.subTest(proto=proto.__name__): + self.assertIsSubclass(proto, CallableMembersProto) + + # These two shouldn't be considered subclasses of CallableMembersProto, however, + # since they don't have the `meth` protocol member + + class EmptyProtocol(Protocol): ... + class UnrelatedProtocol(Protocol): + def wut(self): ... + + self.assertNotIsSubclass(EmptyProtocol, CallableMembersProto) + self.assertNotIsSubclass(UnrelatedProtocol, CallableMembersProto) + + # These aren't protocols at all (despite having annotations), + # so they should only be considered subclasses of CallableMembersProto + # if they *actually have an attribute* matching the `meth` member + # (just having an annotation is insufficient) + + class AnnotatedButNotAProtocol: + meth: Callable[[], None] + + class NotAProtocolButAnImplicitSubclass: + def meth(self): pass + + class NotAProtocolButAnImplicitSubclass2: + meth: Callable[[], None] + def meth(self): pass + + class NotAProtocolButAnImplicitSubclass3: + meth: Callable[[], None] + meth2: Callable[[int, str], bool] + def meth(self): pass + def meth2(self, x, y): return True + + self.assertNotIsSubclass(AnnotatedButNotAProtocol, CallableMembersProto) + self.assertIsSubclass(NotAProtocolButAnImplicitSubclass, CallableMembersProto) + self.assertIsSubclass(NotAProtocolButAnImplicitSubclass2, CallableMembersProto) + self.assertIsSubclass(NotAProtocolButAnImplicitSubclass3, CallableMembersProto) + + def test_isinstance_checks_not_at_whim_of_gc(self): + self.addCleanup(gc.enable) + gc.disable() + + with self.assertRaisesRegex( + TypeError, + "Protocols can only inherit from other protocols" + ): + class Foo(collections.abc.Mapping, Protocol): + pass + + self.assertNotIsInstance([], collections.abc.Mapping) + + def test_issubclass_and_isinstance_on_Protocol_itself(self): + class C: + def x(self): pass + + self.assertNotIsSubclass(object, Protocol) + self.assertNotIsInstance(object(), Protocol) + + self.assertNotIsSubclass(str, Protocol) + self.assertNotIsInstance('foo', Protocol) + + self.assertNotIsSubclass(C, Protocol) + self.assertNotIsInstance(C(), Protocol) + + only_classes_allowed = r"issubclass\(\) arg 1 must be a class" + + with self.assertRaisesRegex(TypeError, only_classes_allowed): + issubclass(1, Protocol) + with self.assertRaisesRegex(TypeError, only_classes_allowed): + issubclass('foo', Protocol) + with self.assertRaisesRegex(TypeError, only_classes_allowed): + issubclass(C(), Protocol) + + T = TypeVar('T') + + @runtime_checkable + class EmptyProtocol(Protocol): pass + + @runtime_checkable + class SupportsStartsWith(Protocol): + def startswith(self, x: str) -> bool: ... + + @runtime_checkable + class SupportsX(Protocol[T]): + def x(self): ... + + for proto in EmptyProtocol, SupportsStartsWith, SupportsX: + with self.subTest(proto=proto.__name__): + self.assertIsSubclass(proto, Protocol) + + # gh-105237 / PR #105239: + # check that the presence of Protocol subclasses + # where `issubclass(X, )` evaluates to True + # doesn't influence the result of `issubclass(X, Protocol)` + + self.assertIsSubclass(object, EmptyProtocol) + self.assertIsInstance(object(), EmptyProtocol) + self.assertNotIsSubclass(object, Protocol) + self.assertNotIsInstance(object(), Protocol) + + self.assertIsSubclass(str, SupportsStartsWith) + self.assertIsInstance('foo', SupportsStartsWith) + self.assertNotIsSubclass(str, Protocol) + self.assertNotIsInstance('foo', Protocol) + + self.assertIsSubclass(C, SupportsX) + self.assertIsInstance(C(), SupportsX) + self.assertNotIsSubclass(C, Protocol) + self.assertNotIsInstance(C(), Protocol) + # TODO: RUSTPYTHON @unittest.expectedFailure def test_protocols_issubclass_non_callable(self): @@ -1127,55 +3267,359 @@ class C: class PNonCall(Protocol): x = 1 - with self.assertRaises(TypeError): + non_callable_members_illegal = ( + "Protocols with non-method members don't support issubclass()" + ) + + with self.assertRaisesRegex(TypeError, non_callable_members_illegal): issubclass(C, PNonCall) + self.assertIsInstance(C(), PNonCall) PNonCall.register(C) - with self.assertRaises(TypeError): + + with self.assertRaisesRegex(TypeError, non_callable_members_illegal): issubclass(C, PNonCall) + self.assertIsInstance(C(), PNonCall) - # check that non-protocol subclasses are not affected - class D(PNonCall): ... + # check that non-protocol subclasses are not affected + class D(PNonCall): ... + + self.assertNotIsSubclass(C, D) + self.assertNotIsInstance(C(), D) + D.register(C) + self.assertIsSubclass(C, D) + self.assertIsInstance(C(), D) + + with self.assertRaisesRegex(TypeError, non_callable_members_illegal): + issubclass(D, PNonCall) + + def test_no_weird_caching_with_issubclass_after_isinstance(self): + @runtime_checkable + class Spam(Protocol): + x: int + + class Eggs: + def __init__(self) -> None: + self.x = 42 + + self.assertIsInstance(Eggs(), Spam) + + # gh-104555: If we didn't override ABCMeta.__subclasscheck__ in _ProtocolMeta, + # TypeError wouldn't be raised here, + # as the cached result of the isinstance() check immediately above + # would mean the issubclass() call would short-circuit + # before we got to the "raise TypeError" line + with self.assertRaisesRegex( + TypeError, + "Protocols with non-method members don't support issubclass()" + ): + issubclass(Eggs, Spam) + + def test_no_weird_caching_with_issubclass_after_isinstance_2(self): + @runtime_checkable + class Spam(Protocol): + x: int + + class Eggs: ... + + self.assertNotIsInstance(Eggs(), Spam) + + # gh-104555: If we didn't override ABCMeta.__subclasscheck__ in _ProtocolMeta, + # TypeError wouldn't be raised here, + # as the cached result of the isinstance() check immediately above + # would mean the issubclass() call would short-circuit + # before we got to the "raise TypeError" line + with self.assertRaisesRegex( + TypeError, + "Protocols with non-method members don't support issubclass()" + ): + issubclass(Eggs, Spam) + + def test_no_weird_caching_with_issubclass_after_isinstance_3(self): + @runtime_checkable + class Spam(Protocol): + x: int + + class Eggs: + def __getattr__(self, attr): + if attr == "x": + return 42 + raise AttributeError(attr) + + self.assertNotIsInstance(Eggs(), Spam) + + # gh-104555: If we didn't override ABCMeta.__subclasscheck__ in _ProtocolMeta, + # TypeError wouldn't be raised here, + # as the cached result of the isinstance() check immediately above + # would mean the issubclass() call would short-circuit + # before we got to the "raise TypeError" line + with self.assertRaisesRegex( + TypeError, + "Protocols with non-method members don't support issubclass()" + ): + issubclass(Eggs, Spam) + + def test_no_weird_caching_with_issubclass_after_isinstance_pep695(self): + @runtime_checkable + class Spam[T](Protocol): + x: T + + class Eggs[T]: + def __init__(self, x: T) -> None: + self.x = x + + self.assertIsInstance(Eggs(42), Spam) + + # gh-104555: If we didn't override ABCMeta.__subclasscheck__ in _ProtocolMeta, + # TypeError wouldn't be raised here, + # as the cached result of the isinstance() check immediately above + # would mean the issubclass() call would short-circuit + # before we got to the "raise TypeError" line + with self.assertRaisesRegex( + TypeError, + "Protocols with non-method members don't support issubclass()" + ): + issubclass(Eggs, Spam) + + def test_protocols_isinstance(self): + T = TypeVar('T') + + @runtime_checkable + class P(Protocol): + def meth(x): ... + + @runtime_checkable + class PG(Protocol[T]): + def meth(x): ... + + @runtime_checkable + class WeirdProto(Protocol): + meth = str.maketrans + + @runtime_checkable + class WeirdProto2(Protocol): + meth = lambda *args, **kwargs: None + + class CustomCallable: + def __call__(self, *args, **kwargs): + pass + + @runtime_checkable + class WeirderProto(Protocol): + meth = CustomCallable() + + class BadP(Protocol): + def meth(x): ... + + class BadPG(Protocol[T]): + def meth(x): ... + + class C: + def meth(x): ... + + class C2: + def __init__(self): + self.meth = lambda: None + + for klass in C, C2: + for proto in P, PG, WeirdProto, WeirdProto2, WeirderProto: + with self.subTest(klass=klass.__name__, proto=proto.__name__): + self.assertIsInstance(klass(), proto) + + no_subscripted_generics = "Subscripted generics cannot be used with class and instance checks" + + with self.assertRaisesRegex(TypeError, no_subscripted_generics): + isinstance(C(), PG[T]) + with self.assertRaisesRegex(TypeError, no_subscripted_generics): + isinstance(C(), PG[C]) + + only_runtime_checkable_msg = ( + "Instance and class checks can only be used " + "with @runtime_checkable protocols" + ) + + with self.assertRaisesRegex(TypeError, only_runtime_checkable_msg): + isinstance(C(), BadP) + with self.assertRaisesRegex(TypeError, only_runtime_checkable_msg): + isinstance(C(), BadPG) + + def test_protocols_isinstance_properties_and_descriptors(self): + class C: + @property + def attr(self): + return 42 + + class CustomDescriptor: + def __get__(self, obj, objtype=None): + return 42 + + class D: + attr = CustomDescriptor() + + # Check that properties set on superclasses + # are still found by the isinstance() logic + class E(C): ... + class F(D): ... + + class Empty: ... + + T = TypeVar('T') + + @runtime_checkable + class P(Protocol): + @property + def attr(self): ... + + @runtime_checkable + class P1(Protocol): + attr: int + + @runtime_checkable + class PG(Protocol[T]): + @property + def attr(self): ... + + @runtime_checkable + class PG1(Protocol[T]): + attr: T + + @runtime_checkable + class MethodP(Protocol): + def attr(self): ... + + @runtime_checkable + class MethodPG(Protocol[T]): + def attr(self) -> T: ... + + for protocol_class in P, P1, PG, PG1, MethodP, MethodPG: + for klass in C, D, E, F: + with self.subTest( + klass=klass.__name__, + protocol_class=protocol_class.__name__ + ): + self.assertIsInstance(klass(), protocol_class) + + with self.subTest(klass="Empty", protocol_class=protocol_class.__name__): + self.assertNotIsInstance(Empty(), protocol_class) + + class BadP(Protocol): + @property + def attr(self): ... + + class BadP1(Protocol): + attr: int + + class BadPG(Protocol[T]): + @property + def attr(self): ... + + class BadPG1(Protocol[T]): + attr: T + + cases = ( + PG[T], PG[C], PG1[T], PG1[C], MethodPG[T], + MethodPG[C], BadP, BadP1, BadPG, BadPG1 + ) + + for obj in cases: + for klass in C, D, E, F, Empty: + with self.subTest(klass=klass.__name__, obj=obj): + with self.assertRaises(TypeError): + isinstance(klass(), obj) + + def test_protocols_isinstance_not_fooled_by_custom_dir(self): + @runtime_checkable + class HasX(Protocol): + x: int + + class CustomDirWithX: + x = 10 + def __dir__(self): + return [] + + class CustomDirWithoutX: + def __dir__(self): + return ["x"] + + self.assertIsInstance(CustomDirWithX(), HasX) + self.assertNotIsInstance(CustomDirWithoutX(), HasX) + + def test_protocols_isinstance_attribute_access_with_side_effects(self): + class C: + @property + def attr(self): + raise AttributeError('no') + + class CustomDescriptor: + def __get__(self, obj, objtype=None): + raise RuntimeError("NO") - self.assertNotIsSubclass(C, D) - self.assertNotIsInstance(C(), D) - D.register(C) - self.assertIsSubclass(C, D) - self.assertIsInstance(C(), D) - with self.assertRaises(TypeError): - issubclass(D, PNonCall) + class D: + attr = CustomDescriptor() + + # Check that properties set on superclasses + # are still found by the isinstance() logic + class E(C): ... + class F(D): ... + + class WhyWouldYouDoThis: + def __getattr__(self, name): + raise RuntimeError("wut") - def test_protocols_isinstance(self): T = TypeVar('T') @runtime_checkable class P(Protocol): - def meth(x): ... + @property + def attr(self): ... + + @runtime_checkable + class P1(Protocol): + attr: int @runtime_checkable class PG(Protocol[T]): - def meth(x): ... + @property + def attr(self): ... - class BadP(Protocol): - def meth(x): ... + @runtime_checkable + class PG1(Protocol[T]): + attr: T - class BadPG(Protocol[T]): - def meth(x): ... + @runtime_checkable + class MethodP(Protocol): + def attr(self): ... - class C: - def meth(x): ... + @runtime_checkable + class MethodPG(Protocol[T]): + def attr(self) -> T: ... + + for protocol_class in P, P1, PG, PG1, MethodP, MethodPG: + for klass in C, D, E, F: + with self.subTest( + klass=klass.__name__, + protocol_class=protocol_class.__name__ + ): + self.assertIsInstance(klass(), protocol_class) - self.assertIsInstance(C(), P) - self.assertIsInstance(C(), PG) - with self.assertRaises(TypeError): - isinstance(C(), PG[T]) - with self.assertRaises(TypeError): - isinstance(C(), PG[C]) - with self.assertRaises(TypeError): - isinstance(C(), BadP) - with self.assertRaises(TypeError): - isinstance(C(), BadPG) + with self.subTest( + klass="WhyWouldYouDoThis", + protocol_class=protocol_class.__name__ + ): + self.assertNotIsInstance(WhyWouldYouDoThis(), protocol_class) + + def test_protocols_isinstance___slots__(self): + # As per the consensus in https://github.com/python/typing/issues/1367, + # this is desirable behaviour + @runtime_checkable + class HasX(Protocol): + x: int + + class HasNothingButSlots: + __slots__ = ("x",) + + self.assertIsInstance(HasNothingButSlots(), HasX) # TODO: RUSTPYTHON @unittest.expectedFailure @@ -1234,6 +3678,20 @@ def __init__(self, x): self.assertIsInstance(C(1), P) self.assertIsInstance(C(1), PG) + def test_protocols_isinstance_monkeypatching(self): + @runtime_checkable + class HasX(Protocol): + x: int + + class Foo: ... + + f = Foo() + self.assertNotIsInstance(f, HasX) + f.x = 42 + self.assertIsInstance(f, HasX) + del f.x + self.assertNotIsInstance(f, HasX) + def test_protocol_checks_after_subscript(self): class P(Protocol[T]): pass class C(P[T]): pass @@ -1323,10 +3781,10 @@ class NonP(P): class NonPR(PR): pass - class C: + class C(metaclass=abc.ABCMeta): x = 1 - class D: + class D(metaclass=abc.ABCMeta): def meth(self): pass self.assertNotIsInstance(C(), NonP) @@ -1336,6 +3794,32 @@ def meth(self): pass self.assertIsInstance(NonPR(), PR) self.assertIsSubclass(NonPR, PR) + self.assertNotIn("__protocol_attrs__", vars(NonP)) + self.assertNotIn("__protocol_attrs__", vars(NonPR)) + self.assertNotIn("__non_callable_proto_members__", vars(NonP)) + self.assertNotIn("__non_callable_proto_members__", vars(NonPR)) + + self.assertEqual(get_protocol_members(P), {"x"}) + self.assertEqual(get_protocol_members(PR), {"meth"}) + + # the returned object should be immutable, + # and should be a different object to the original attribute + # to prevent users from (accidentally or deliberately) + # mutating the attribute on the original class + self.assertIsInstance(get_protocol_members(P), frozenset) + self.assertIsNot(get_protocol_members(P), P.__protocol_attrs__) + self.assertIsInstance(get_protocol_members(PR), frozenset) + self.assertIsNot(get_protocol_members(PR), P.__protocol_attrs__) + + acceptable_extra_attrs = { + '_is_protocol', '_is_runtime_protocol', '__parameters__', + '__init__', '__annotations__', '__subclasshook__', + } + self.assertLessEqual(vars(NonP).keys(), vars(C).keys() | acceptable_extra_attrs) + self.assertLessEqual( + vars(NonPR).keys(), vars(D).keys() | acceptable_extra_attrs + ) + def test_custom_subclasshook(self): class P(Protocol): x = 1 @@ -1355,15 +3839,68 @@ def __subclasshook__(cls, other): self.assertIsSubclass(OKClass, C) self.assertNotIsSubclass(BadClass, C) + def test_custom_subclasshook_2(self): + @runtime_checkable + class HasX(Protocol): + # The presence of a non-callable member + # would mean issubclass() checks would fail with TypeError + # if it weren't for the custom `__subclasshook__` method + x = 1 + + @classmethod + def __subclasshook__(cls, other): + return hasattr(other, 'x') + + class Empty: pass + + class ImplementsHasX: + x = 1 + + self.assertIsInstance(ImplementsHasX(), HasX) + self.assertNotIsInstance(Empty(), HasX) + self.assertIsSubclass(ImplementsHasX, HasX) + self.assertNotIsSubclass(Empty, HasX) + + # isinstance() and issubclass() checks against this still raise TypeError, + # despite the presence of the custom __subclasshook__ method, + # as it's not decorated with @runtime_checkable + class NotRuntimeCheckable(Protocol): + @classmethod + def __subclasshook__(cls, other): + return hasattr(other, 'x') + + must_be_runtime_checkable = ( + "Instance and class checks can only be used " + "with @runtime_checkable protocols" + ) + + with self.assertRaisesRegex(TypeError, must_be_runtime_checkable): + issubclass(object, NotRuntimeCheckable) + with self.assertRaisesRegex(TypeError, must_be_runtime_checkable): + isinstance(object(), NotRuntimeCheckable) + def test_issubclass_fails_correctly(self): @runtime_checkable - class P(Protocol): + class NonCallableMembers(Protocol): x = 1 + class NotRuntimeCheckable(Protocol): + def callable_member(self) -> int: ... + + @runtime_checkable + class RuntimeCheckable(Protocol): + def callable_member(self) -> int: ... + class C: pass - with self.assertRaises(TypeError): - issubclass(C(), P) + # These three all exercise different code paths, + # but should result in the same error message: + for protocol in NonCallableMembers, NotRuntimeCheckable, RuntimeCheckable: + with self.subTest(proto_name=protocol.__name__): + with self.assertRaisesRegex( + TypeError, r"issubclass\(\) arg 1 must be a class" + ): + issubclass(C(), protocol) def test_defining_generic_protocols(self): T = TypeVar('T') @@ -1422,6 +3959,24 @@ def bar(self, x: str) -> str: self.assertIsInstance(Test(), PSub) + def test_pep695_generic_protocol_callable_members(self): + @runtime_checkable + class Foo[T](Protocol): + def meth(self, x: T) -> None: ... + + class Bar[T]: + def meth(self, x: T) -> None: ... + + self.assertIsInstance(Bar(), Foo) + self.assertIsSubclass(Bar, Foo) + + @runtime_checkable + class SupportsTrunc[T](Protocol): + def __trunc__(self) -> T: ... + + self.assertIsInstance(0.0, SupportsTrunc) + self.assertIsSubclass(float, SupportsTrunc) + def test_init_called(self): T = TypeVar('T') @@ -1476,11 +4031,11 @@ def test_protocols_bad_subscripts(self): with self.assertRaises(TypeError): class P(Protocol[T, T]): pass with self.assertRaises(TypeError): - class P(Protocol[int]): pass + class Q(Protocol[int]): pass with self.assertRaises(TypeError): - class P(Protocol[T], Protocol[S]): pass + class R(Protocol[T], Protocol[S]): pass with self.assertRaises(TypeError): - class P(typing.Mapping[T, S], Protocol[T]): pass + class S(typing.Mapping[T, S], Protocol[T]): pass def test_generic_protocols_repr(self): T = TypeVar('T') @@ -1583,8 +4138,8 @@ class DI: def __init__(self): self.x = None - self.assertIsInstance(C(), P) - self.assertIsInstance(D(), P) + self.assertIsInstance(CI(), P) + self.assertIsInstance(DI(), P) def test_protocols_in_unions(self): class P(Protocol): @@ -1615,7 +4170,7 @@ class CP(P[int]): self.assertEqual(x.bar, 'abc') self.assertEqual(x.x, 1) self.assertEqual(x.__dict__, {'foo': 42, 'bar': 'abc'}) - s = pickle.dumps(P) + s = pickle.dumps(P, proto) D = pickle.loads(s) class E: @@ -1625,6 +4180,39 @@ class E: # TODO: RUSTPYTHON @unittest.expectedFailure + def test_runtime_checkable_with_match_args(self): + @runtime_checkable + class P_regular(Protocol): + x: int + y: int + + @runtime_checkable + class P_match(Protocol): + __match_args__ = ('x', 'y') + x: int + y: int + + class Regular: + def __init__(self, x: int, y: int): + self.x = x + self.y = y + + class WithMatch: + __match_args__ = ('x', 'y', 'z') + def __init__(self, x: int, y: int, z: int): + self.x = x + self.y = y + self.z = z + + class Nope: ... + + self.assertIsInstance(Regular(1, 2), P_regular) + self.assertIsInstance(Regular(1, 2), P_match) + self.assertIsInstance(WithMatch(1, 2, 3), P_regular) + self.assertIsInstance(WithMatch(1, 2, 3), P_match) + self.assertNotIsInstance(Nope(), P_regular) + self.assertNotIsInstance(Nope(), P_match) + def test_supports_int(self): self.assertIsSubclass(int, typing.SupportsInt) self.assertNotIsSubclass(str, typing.SupportsInt) @@ -1639,11 +4227,11 @@ def test_supports_float(self): @unittest.expectedFailure def test_supports_complex(self): - # Note: complex itself doesn't have __complex__. class C: def __complex__(self): return 0j + self.assertIsSubclass(complex, typing.SupportsComplex) self.assertIsSubclass(C, typing.SupportsComplex) self.assertNotIsSubclass(str, typing.SupportsComplex) @@ -1651,11 +4239,11 @@ def __complex__(self): @unittest.expectedFailure def test_supports_bytes(self): - # Note: bytes itself doesn't have __bytes__. class B: def __bytes__(self): return b'' + self.assertIsSubclass(bytes, typing.SupportsBytes) self.assertIsSubclass(B, typing.SupportsBytes) self.assertNotIsSubclass(str, typing.SupportsBytes) @@ -1713,6 +4301,22 @@ def close(self): self.assertIsSubclass(B, Custom) self.assertNotIsSubclass(A, Custom) + @runtime_checkable + class ReleasableBuffer(collections.abc.Buffer, Protocol): + def __release_buffer__(self, mv: memoryview) -> None: ... + + class C: pass + class D: + def __buffer__(self, flags: int) -> memoryview: + return memoryview(b'') + def __release_buffer__(self, mv: memoryview) -> None: + pass + + self.assertIsSubclass(D, ReleasableBuffer) + self.assertIsInstance(D(), ReleasableBuffer) + self.assertNotIsSubclass(C, ReleasableBuffer) + self.assertNotIsInstance(C(), ReleasableBuffer) + def test_builtin_protocol_allowlist(self): with self.assertRaises(TypeError): class CustomProtocol(TestCase, Protocol): @@ -1721,6 +4325,9 @@ class CustomProtocol(TestCase, Protocol): class CustomContextManager(typing.ContextManager, Protocol): pass + class CustomAsyncIterator(typing.AsyncIterator, Protocol): + pass + def test_non_runtime_protocol_isinstance_check(self): class P(Protocol): x: int @@ -1738,6 +4345,196 @@ def __init__(self): Foo() # Previously triggered RecursionError + def test_get_protocol_members(self): + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(object) + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(object()) + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(Protocol) + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(Generic) + + class P(Protocol): + a: int + def b(self) -> str: ... + @property + def c(self) -> int: ... + + self.assertEqual(get_protocol_members(P), {'a', 'b', 'c'}) + self.assertIsInstance(get_protocol_members(P), frozenset) + self.assertIsNot(get_protocol_members(P), P.__protocol_attrs__) + + class Concrete: + a: int + def b(self) -> str: return "capybara" + @property + def c(self) -> int: return 5 + + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(Concrete) + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(Concrete()) + + class ConcreteInherit(P): + a: int = 42 + def b(self) -> str: return "capybara" + @property + def c(self) -> int: return 5 + + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(ConcreteInherit) + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(ConcreteInherit()) + + def test_is_protocol(self): + self.assertTrue(is_protocol(Proto)) + self.assertTrue(is_protocol(Point)) + self.assertFalse(is_protocol(Concrete)) + self.assertFalse(is_protocol(Concrete())) + self.assertFalse(is_protocol(Generic)) + self.assertFalse(is_protocol(object)) + + # Protocol is not itself a protocol + self.assertFalse(is_protocol(Protocol)) + + def test_interaction_with_isinstance_checks_on_superclasses_with_ABCMeta(self): + # Ensure the cache is empty, or this test won't work correctly + collections.abc.Sized._abc_registry_clear() + + class Foo(collections.abc.Sized, Protocol): pass + + # gh-105144: this previously raised TypeError + # if a Protocol subclass of Sized had been created + # before any isinstance() checks against Sized + self.assertNotIsInstance(1, collections.abc.Sized) + + def test_interaction_with_isinstance_checks_on_superclasses_with_ABCMeta_2(self): + # Ensure the cache is empty, or this test won't work correctly + collections.abc.Sized._abc_registry_clear() + + class Foo(typing.Sized, Protocol): pass + + # gh-105144: this previously raised TypeError + # if a Protocol subclass of Sized had been created + # before any isinstance() checks against Sized + self.assertNotIsInstance(1, typing.Sized) + + def test_empty_protocol_decorated_with_final(self): + @final + @runtime_checkable + class EmptyProtocol(Protocol): ... + + self.assertIsSubclass(object, EmptyProtocol) + self.assertIsInstance(object(), EmptyProtocol) + + def test_protocol_decorated_with_final_callable_members(self): + @final + @runtime_checkable + class ProtocolWithMethod(Protocol): + def startswith(self, string: str) -> bool: ... + + self.assertIsSubclass(str, ProtocolWithMethod) + self.assertNotIsSubclass(int, ProtocolWithMethod) + self.assertIsInstance('foo', ProtocolWithMethod) + self.assertNotIsInstance(42, ProtocolWithMethod) + + def test_protocol_decorated_with_final_noncallable_members(self): + @final + @runtime_checkable + class ProtocolWithNonCallableMember(Protocol): + x: int + + class Foo: + x = 42 + + only_callable_members_please = ( + r"Protocols with non-method members don't support issubclass()" + ) + + with self.assertRaisesRegex(TypeError, only_callable_members_please): + issubclass(Foo, ProtocolWithNonCallableMember) + + with self.assertRaisesRegex(TypeError, only_callable_members_please): + issubclass(int, ProtocolWithNonCallableMember) + + self.assertIsInstance(Foo(), ProtocolWithNonCallableMember) + self.assertNotIsInstance(42, ProtocolWithNonCallableMember) + + def test_protocol_decorated_with_final_mixed_members(self): + @final + @runtime_checkable + class ProtocolWithMixedMembers(Protocol): + x: int + def method(self) -> None: ... + + class Foo: + x = 42 + def method(self) -> None: ... + + only_callable_members_please = ( + r"Protocols with non-method members don't support issubclass()" + ) + + with self.assertRaisesRegex(TypeError, only_callable_members_please): + issubclass(Foo, ProtocolWithMixedMembers) + + with self.assertRaisesRegex(TypeError, only_callable_members_please): + issubclass(int, ProtocolWithMixedMembers) + + self.assertIsInstance(Foo(), ProtocolWithMixedMembers) + self.assertNotIsInstance(42, ProtocolWithMixedMembers) + + def test_protocol_issubclass_error_message(self): + @runtime_checkable + class Vec2D(Protocol): + x: float + y: float + + def square_norm(self) -> float: + return self.x ** 2 + self.y ** 2 + + self.assertEqual(Vec2D.__protocol_attrs__, {'x', 'y', 'square_norm'}) + expected_error_message = ( + "Protocols with non-method members don't support issubclass()." + " Non-method members: 'x', 'y'." + ) + with self.assertRaisesRegex(TypeError, re.escape(expected_error_message)): + issubclass(int, Vec2D) + + def test_nonruntime_protocol_interaction_with_evil_classproperty(self): + class classproperty: + def __get__(self, instance, type): + raise RuntimeError("NO") + + class Commentable(Protocol): + evil = classproperty() + + # recognised as a protocol attr, + # but not actually accessed by the protocol metaclass + # (which would raise RuntimeError) for non-runtime protocols. + # See gh-113320 + self.assertEqual(get_protocol_members(Commentable), {"evil"}) + + def test_runtime_protocol_interaction_with_evil_classproperty(self): + class CustomError(Exception): pass + + class classproperty: + def __get__(self, instance, type): + raise CustomError + + with self.assertRaises(TypeError) as cm: + @runtime_checkable + class Commentable(Protocol): + evil = classproperty() + + exc = cm.exception + self.assertEqual( + exc.args[0], + "Failed to determine whether protocol member 'evil' is a method member" + ) + self.assertIs(type(exc.__cause__), CustomError) + class GenericTests(BaseTestCase): @@ -1778,7 +4575,28 @@ class NewGeneric(Generic): ... with self.assertRaises(TypeError): class MyGeneric(Generic[T], Generic[S]): ... with self.assertRaises(TypeError): - class MyGeneric(List[T], Generic[S]): ... + class MyGeneric2(List[T], Generic[S]): ... + with self.assertRaises(TypeError): + Generic[()] + class D(Generic[T]): pass + with self.assertRaises(TypeError): + D[()] + + def test_generic_subclass_checks(self): + for typ in [list[int], List[int], + tuple[int, str], Tuple[int, str], + typing.Callable[..., None], + collections.abc.Callable[..., None]]: + with self.subTest(typ=typ): + self.assertRaises(TypeError, issubclass, typ, object) + self.assertRaises(TypeError, issubclass, typ, type) + self.assertRaises(TypeError, issubclass, typ, typ) + self.assertRaises(TypeError, issubclass, object, typ) + + # isinstance is fine: + self.assertTrue(isinstance(typ, object)) + # but, not when the right arg is also a generic: + self.assertRaises(TypeError, isinstance, typ, typ) def test_init(self): T = TypeVar('T') @@ -1886,6 +4704,16 @@ class C(B[int]): c.bar = 'abc' self.assertEqual(c.__dict__, {'bar': 'abc'}) + def test_setattr_exceptions(self): + class Immutable[T]: + def __setattr__(self, key, value): + raise RuntimeError("immutable") + + # gh-115165: This used to cause RuntimeError to be raised + # when we tried to set `__orig_class__` on the `Immutable` instance + # returned by the `Immutable[int]()` call + self.assertIsInstance(Immutable[int](), Immutable) + def test_subscripted_generics_as_proxies(self): T = TypeVar('T') class C(Generic[T]): @@ -2054,13 +4882,12 @@ class A(Generic[T]): self.assertNotEqual(typing.FrozenSet[A[str]], typing.FrozenSet[mod_generics_cache.B.A[str]]) - if sys.version_info[:2] > (3, 2): - self.assertTrue(repr(Tuple[A[str]]).endswith('.A[str]]')) - self.assertTrue(repr(Tuple[B.A[str]]).endswith('.B.A[str]]')) - self.assertTrue(repr(Tuple[mod_generics_cache.A[str]]) - .endswith('mod_generics_cache.A[str]]')) - self.assertTrue(repr(Tuple[mod_generics_cache.B.A[str]]) - .endswith('mod_generics_cache.B.A[str]]')) + self.assertTrue(repr(Tuple[A[str]]).endswith('.A[str]]')) + self.assertTrue(repr(Tuple[B.A[str]]).endswith('.B.A[str]]')) + self.assertTrue(repr(Tuple[mod_generics_cache.A[str]]) + .endswith('mod_generics_cache.A[str]]')) + self.assertTrue(repr(Tuple[mod_generics_cache.B.A[str]]) + .endswith('mod_generics_cache.B.A[str]]')) def test_extended_generic_rules_eq(self): T = TypeVar('T') @@ -2075,9 +4902,6 @@ def test_extended_generic_rules_eq(self): class Base: ... class Derived(Base): ... self.assertEqual(Union[T, Base][Union[Base, Derived]], Union[Base, Derived]) - with self.assertRaises(TypeError): - Union[T, int][1] - self.assertEqual(Callable[[T], T][KT], Callable[[KT], KT]) self.assertEqual(Callable[..., List[T]][int], Callable[..., List[int]]) @@ -2118,6 +4942,127 @@ def barfoo(x: AT): ... def barfoo2(x: CT): ... self.assertIs(get_type_hints(barfoo2, globals(), locals())['x'], CT) + def test_generic_pep585_forward_ref(self): + # See https://bugs.python.org/issue41370 + + class C1: + a: list['C1'] + self.assertEqual( + get_type_hints(C1, globals(), locals()), + {'a': list[C1]} + ) + + class C2: + a: dict['C1', list[List[list['C2']]]] + self.assertEqual( + get_type_hints(C2, globals(), locals()), + {'a': dict[C1, list[List[list[C2]]]]} + ) + + # Test stringified annotations + scope = {} + exec(textwrap.dedent(''' + from __future__ import annotations + class C3: + a: List[list["C2"]] + '''), scope) + C3 = scope['C3'] + self.assertEqual(C3.__annotations__['a'], "List[list['C2']]") + self.assertEqual( + get_type_hints(C3, globals(), locals()), + {'a': List[list[C2]]} + ) + + # Test recursive types + X = list["X"] + def f(x: X): ... + self.assertEqual( + get_type_hints(f, globals(), locals()), + {'x': list[list[ForwardRef('X')]]} + ) + + def test_pep695_generic_class_with_future_annotations(self): + original_globals = dict(ann_module695.__dict__) + + hints_for_A = get_type_hints(ann_module695.A) + A_type_params = ann_module695.A.__type_params__ + self.assertIs(hints_for_A["x"], A_type_params[0]) + self.assertEqual(hints_for_A["y"].__args__[0], Unpack[A_type_params[1]]) + self.assertIs(hints_for_A["z"].__args__[0], A_type_params[2]) + + # should not have changed as a result of the get_type_hints() calls! + self.assertEqual(ann_module695.__dict__, original_globals) + + def test_pep695_generic_class_with_future_annotations_and_local_shadowing(self): + hints_for_B = get_type_hints(ann_module695.B) + self.assertEqual(hints_for_B, {"x": int, "y": str, "z": bytes}) + + def test_pep695_generic_class_with_future_annotations_name_clash_with_global_vars(self): + hints_for_C = get_type_hints(ann_module695.C) + self.assertEqual( + set(hints_for_C.values()), + set(ann_module695.C.__type_params__) + ) + + def test_pep_695_generic_function_with_future_annotations(self): + hints_for_generic_function = get_type_hints(ann_module695.generic_function) + func_t_params = ann_module695.generic_function.__type_params__ + self.assertEqual( + hints_for_generic_function.keys(), {"x", "y", "z", "zz", "return"} + ) + self.assertIs(hints_for_generic_function["x"], func_t_params[0]) + self.assertEqual(hints_for_generic_function["y"], Unpack[func_t_params[1]]) + self.assertIs(hints_for_generic_function["z"].__origin__, func_t_params[2]) + self.assertIs(hints_for_generic_function["zz"].__origin__, func_t_params[2]) + + def test_pep_695_generic_function_with_future_annotations_name_clash_with_global_vars(self): + self.assertEqual( + set(get_type_hints(ann_module695.generic_function_2).values()), + set(ann_module695.generic_function_2.__type_params__) + ) + + def test_pep_695_generic_method_with_future_annotations(self): + hints_for_generic_method = get_type_hints(ann_module695.D.generic_method) + params = { + param.__name__: param + for param in ann_module695.D.generic_method.__type_params__ + } + self.assertEqual( + hints_for_generic_method, + {"x": params["Foo"], "y": params["Bar"], "return": types.NoneType} + ) + + def test_pep_695_generic_method_with_future_annotations_name_clash_with_global_vars(self): + self.assertEqual( + set(get_type_hints(ann_module695.D.generic_method_2).values()), + set(ann_module695.D.generic_method_2.__type_params__) + ) + + def test_pep_695_generics_with_future_annotations_nested_in_function(self): + results = ann_module695.nested() + + self.assertEqual( + set(results.hints_for_E.values()), + set(results.E.__type_params__) + ) + self.assertEqual( + set(results.hints_for_E_meth.values()), + set(results.E.generic_method.__type_params__) + ) + self.assertNotEqual( + set(results.hints_for_E_meth.values()), + set(results.E.__type_params__) + ) + self.assertEqual( + set(results.hints_for_E_meth.values()).intersection(results.E.__type_params__), + set() + ) + + self.assertEqual( + set(results.hints_for_generic_func.values()), + set(results.generic_func.__type_params__) + ) + def test_extended_generic_rules_subclassing(self): class T1(Tuple[T, KT]): ... class T2(Tuple[T, ...]): ... @@ -2152,8 +5097,6 @@ def test_fail_with_bare_union(self): List[Union] with self.assertRaises(TypeError): Tuple[Optional] - with self.assertRaises(TypeError): - ClassVar[ClassVar] with self.assertRaises(TypeError): List[ClassVar[int]] @@ -2179,18 +5122,19 @@ class MyDict(typing.Dict[T, T]): ... class MyDef(typing.DefaultDict[str, T]): ... self.assertIs(MyDef[int]().__class__, MyDef) self.assertEqual(MyDef[int]().__orig_class__, MyDef[int]) - # ChainMap was added in 3.3 - if sys.version_info >= (3, 3): - class MyChain(typing.ChainMap[str, T]): ... - self.assertIs(MyChain[int]().__class__, MyChain) - self.assertEqual(MyChain[int]().__orig_class__, MyChain[int]) + class MyChain(typing.ChainMap[str, T]): ... + self.assertIs(MyChain[int]().__class__, MyChain) + self.assertEqual(MyChain[int]().__orig_class__, MyChain[int]) def test_all_repr_eq_any(self): objs = (getattr(typing, el) for el in typing.__all__) for obj in objs: self.assertNotEqual(repr(obj), '') self.assertEqual(obj, obj) - if getattr(obj, '__parameters__', None) and len(obj.__parameters__) == 1: + if (getattr(obj, '__parameters__', None) + and not isinstance(obj, typing.TypeVar) + and isinstance(obj.__parameters__, tuple) + and len(obj.__parameters__) == 1): self.assertEqual(obj[Any].__args__, (Any,)) if isinstance(obj, type): for base in obj.__mro__: @@ -2217,7 +5161,8 @@ class C(B[int]): self.assertEqual(x.bar, 'abc') self.assertEqual(x.__dict__, {'foo': 42, 'bar': 'abc'}) samples = [Any, Union, Tuple, Callable, ClassVar, - Union[int, str], ClassVar[List], Tuple[int, ...], Callable[[str], bytes], + Union[int, str], ClassVar[List], Tuple[int, ...], Tuple[()], + Callable[[str], bytes], typing.DefaultDict, typing.FrozenSet[int]] for s in samples: for proto in range(pickle.HIGHEST_PROTOCOL + 1): @@ -2235,7 +5180,8 @@ class C(B[int]): def test_copy_and_deepcopy(self): T = TypeVar('T') class Node(Generic[T]): ... - things = [Union[T, int], Tuple[T, int], Callable[..., T], Callable[[int], int], + things = [Union[T, int], Tuple[T, int], Tuple[()], + Callable[..., T], Callable[[int], int], Tuple[Any, Any], Node[T], Node[int], Node[Any], typing.Iterable[T], typing.Iterable[Any], typing.Iterable[int], typing.Dict[int, str], typing.Dict[T, Any], ClassVar[int], ClassVar[List[T]], Tuple['T', 'T'], @@ -2247,22 +5193,30 @@ class Node(Generic[T]): ... def test_immutability_by_copy_and_pickle(self): # Special forms like Union, Any, etc., generic aliases to containers like List, # Mapping, etc., and type variabcles are considered immutable by copy and pickle. - global TP, TPB, TPV # for pickle + global TP, TPB, TPV, PP # for pickle TP = TypeVar('TP') TPB = TypeVar('TPB', bound=int) TPV = TypeVar('TPV', bytes, str) - for X in [TP, TPB, TPV, List, typing.Mapping, ClassVar, typing.Iterable, + PP = ParamSpec('PP') + for X in [TP, TPB, TPV, PP, + List, typing.Mapping, ClassVar, typing.Iterable, Union, Any, Tuple, Callable]: - self.assertIs(copy(X), X) - self.assertIs(deepcopy(X), X) - self.assertIs(pickle.loads(pickle.dumps(X)), X) + with self.subTest(thing=X): + self.assertIs(copy(X), X) + self.assertIs(deepcopy(X), X) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + self.assertIs(pickle.loads(pickle.dumps(X, proto)), X) + del TP, TPB, TPV, PP + # Check that local type variables are copyable. TL = TypeVar('TL') TLB = TypeVar('TLB', bound=int) TLV = TypeVar('TLV', bytes, str) - for X in [TL, TLB, TLV]: - self.assertIs(copy(X), X) - self.assertIs(deepcopy(X), X) + PL = ParamSpec('PL') + for X in [TL, TLB, TLV, PL]: + with self.subTest(thing=X): + self.assertIs(copy(X), X) + self.assertIs(deepcopy(X), X) def test_copy_generic_instances(self): T = TypeVar('T') @@ -2292,7 +5246,7 @@ def test_weakref_all(self): T = TypeVar('T') things = [Any, Union[T, int], Callable[..., T], Tuple[Any, Any], Optional[List[int]], typing.Mapping[int, str], - typing.re.Match[bytes], typing.Iterable['whatever']] + typing.Match[bytes], typing.Iterable['whatever']] for t in things: self.assertEqual(weakref.ref(t)(), t) @@ -2359,6 +5313,51 @@ class Y(C[int]): self.assertEqual(Y.__qualname__, 'GenericTests.test_repr_2..Y') + def test_repr_3(self): + T = TypeVar('T') + T1 = TypeVar('T1') + P = ParamSpec('P') + P2 = ParamSpec('P2') + Ts = TypeVarTuple('Ts') + + class MyCallable(Generic[P, T]): + pass + + class DoubleSpec(Generic[P, P2, T]): + pass + + class TsP(Generic[*Ts, P]): + pass + + object_to_expected_repr = { + MyCallable[P, T]: "MyCallable[~P, ~T]", + MyCallable[Concatenate[T1, P], T]: "MyCallable[typing.Concatenate[~T1, ~P], ~T]", + MyCallable[[], bool]: "MyCallable[[], bool]", + MyCallable[[int], bool]: "MyCallable[[int], bool]", + MyCallable[[int, str], bool]: "MyCallable[[int, str], bool]", + MyCallable[[int, list[int]], bool]: "MyCallable[[int, list[int]], bool]", + MyCallable[Concatenate[*Ts, P], T]: "MyCallable[typing.Concatenate[typing.Unpack[Ts], ~P], ~T]", + + DoubleSpec[P2, P, T]: "DoubleSpec[~P2, ~P, ~T]", + DoubleSpec[[int], [str], bool]: "DoubleSpec[[int], [str], bool]", + DoubleSpec[[int, int], [str, str], bool]: "DoubleSpec[[int, int], [str, str], bool]", + + TsP[*Ts, P]: "TsP[typing.Unpack[Ts], ~P]", + TsP[int, str, list[int], []]: "TsP[int, str, list[int], []]", + TsP[int, [str, list[int]]]: "TsP[int, [str, list[int]]]", + + # These lines are just too long to fit: + MyCallable[Concatenate[*Ts, P], int][int, str, [bool, float]]: + "MyCallable[[int, str, bool, float], int]", + } + + for obj, expected_repr in object_to_expected_repr.items(): + with self.subTest(obj=obj, expected_repr=expected_repr): + self.assertRegex( + repr(obj), + fr"^{re.escape(MyCallable.__module__)}.*\.{re.escape(expected_repr)}$", + ) + def test_eq_1(self): self.assertEqual(Generic, Generic) self.assertEqual(Generic[T], Generic[T]) @@ -2379,22 +5378,91 @@ class B(Generic[T]): def test_multiple_inheritance(self): - class A(Generic[T, VT]): - pass + class A(Generic[T, VT]): + pass + + class B(Generic[KT, T]): + pass + + class C(A[T, VT], Generic[VT, T, KT], B[KT, T]): + pass + + self.assertEqual(C.__parameters__, (VT, T, KT)) + + def test_multiple_inheritance_special(self): + S = TypeVar('S') + class B(Generic[S]): ... + class C(List[int], B): ... + self.assertEqual(C.__mro__, (C, list, B, Generic, object)) + + def test_multiple_inheritance_non_type_with___mro_entries__(self): + class GoodEntries: + def __mro_entries__(self, bases): + return (object,) + + class A(List[int], GoodEntries()): ... + + self.assertEqual(A.__mro__, (A, list, Generic, object)) + + def test_multiple_inheritance_non_type_without___mro_entries__(self): + # Error should be from the type machinery, not from typing.py + with self.assertRaisesRegex(TypeError, r"^bases must be types"): + class A(List[int], object()): ... + + def test_multiple_inheritance_non_type_bad___mro_entries__(self): + class BadEntries: + def __mro_entries__(self, bases): + return None + + # Error should be from the type machinery, not from typing.py + with self.assertRaisesRegex( + TypeError, + r"^__mro_entries__ must return a tuple", + ): + class A(List[int], BadEntries()): ... + + def test_multiple_inheritance___mro_entries___returns_non_type(self): + class BadEntries: + def __mro_entries__(self, bases): + return (object(),) + + # Error should be from the type machinery, not from typing.py + with self.assertRaisesRegex( + TypeError, + r"^bases must be types", + ): + class A(List[int], BadEntries()): ... + + def test_multiple_inheritance_with_genericalias(self): + class A(typing.Sized, list[int]): ... - class B(Generic[KT, T]): - pass + self.assertEqual( + A.__mro__, + (A, collections.abc.Sized, Generic, list, object), + ) - class C(A[T, VT], Generic[VT, T, KT], B[KT, T]): - pass + def test_multiple_inheritance_with_genericalias_2(self): + T = TypeVar("T") - self.assertEqual(C.__parameters__, (VT, T, KT)) + class BaseSeq(typing.Sequence[T]): ... + class MySeq(List[T], BaseSeq[T]): ... - def test_multiple_inheritance_special(self): - S = TypeVar('S') - class B(Generic[S]): ... - class C(List[int], B): ... - self.assertEqual(C.__mro__, (C, list, B, Generic, object)) + self.assertEqual( + MySeq.__mro__, + ( + MySeq, + list, + BaseSeq, + collections.abc.Sequence, + collections.abc.Reversible, + collections.abc.Collection, + collections.abc.Sized, + collections.abc.Iterable, + collections.abc.Container, + Generic, + object, + ), + ) def test_init_subclass_super_called(self): class FinalException(Exception): @@ -2412,7 +5480,7 @@ class Test(Generic[T], Final): class Subclass(Test): pass with self.assertRaises(FinalException): - class Subclass(Test[int]): + class Subclass2(Test[int]): pass def test_nested(self): @@ -2480,11 +5548,11 @@ class D(C): self.assertEqual(D.__parameters__, ()) - with self.assertRaises(Exception): + with self.assertRaises(TypeError): D[int] - with self.assertRaises(Exception): + with self.assertRaises(TypeError): D[Any] - with self.assertRaises(Exception): + with self.assertRaises(TypeError): D[T] def test_new_with_args(self): @@ -2567,6 +5635,7 @@ def test_subclass_special_form(self): Literal[1, 2], Concatenate[int, ParamSpec("P")], TypeGuard[int], + TypeIs[range], ): with self.subTest(msg=obj): with self.assertRaisesRegex( @@ -2575,11 +5644,66 @@ def test_subclass_special_form(self): class Foo(obj): pass + def test_complex_subclasses(self): + T_co = TypeVar("T_co", covariant=True) + + class Base(Generic[T_co]): + ... + + T = TypeVar("T") + + # see gh-94607: this fails in that bug + class Sub(Base, Generic[T]): + ... + + def test_parameter_detection(self): + self.assertEqual(List[T].__parameters__, (T,)) + self.assertEqual(List[List[T]].__parameters__, (T,)) + class A: + __parameters__ = (T,) + # Bare classes should be skipped + for a in (List, list): + for b in (A, int, TypeVar, TypeVarTuple, ParamSpec, types.GenericAlias, types.UnionType): + with self.subTest(generic=a, sub=b): + with self.assertRaisesRegex(TypeError, '.* is not a generic class'): + a[b][str] + # Duck-typing anything that looks like it has __parameters__. + # These tests are optional and failure is okay. + self.assertEqual(List[A()].__parameters__, (T,)) + # C version of GenericAlias + self.assertEqual(list[A()].__parameters__, (T,)) + + def test_non_generic_subscript(self): + T = TypeVar('T') + class G(Generic[T]): + pass + class A: + __parameters__ = (T,) + + for s in (int, G, A, List, list, + TypeVar, TypeVarTuple, ParamSpec, + types.GenericAlias, types.UnionType): + + for t in Tuple, tuple: + with self.subTest(tuple=t, sub=s): + self.assertEqual(t[s, T][int], t[s, int]) + self.assertEqual(t[T, s][int], t[int, s]) + a = t[s] + with self.assertRaises(TypeError): + a[int] + + for c in Callable, collections.abc.Callable: + with self.subTest(callable=c, sub=s): + self.assertEqual(c[[s], T][int], c[[s], int]) + self.assertEqual(c[[T], s][int], c[[int], s]) + a = c[[s], s] + with self.assertRaises(TypeError): + a[int] + + class ClassVarTests(BaseTestCase): def test_basics(self): - with self.assertRaises(TypeError): - ClassVar[1] with self.assertRaises(TypeError): ClassVar[int, str] with self.assertRaises(TypeError): @@ -2593,11 +5717,19 @@ def test_repr(self): self.assertEqual(repr(cv), 'typing.ClassVar[%s.Employee]' % __name__) def test_cannot_subclass(self): - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): class C(type(ClassVar)): pass - with self.assertRaises(TypeError): - class C(type(ClassVar[int])): + with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): + class D(type(ClassVar[int])): + pass + with self.assertRaisesRegex(TypeError, + r'Cannot subclass typing\.ClassVar'): + class E(ClassVar): + pass + with self.assertRaisesRegex(TypeError, + r'Cannot subclass typing\.ClassVar\[int\]'): + class F(ClassVar[int]): pass def test_cannot_init(self): @@ -2618,8 +5750,6 @@ class FinalTests(BaseTestCase): def test_basics(self): Final[int] # OK - with self.assertRaises(TypeError): - Final[1] with self.assertRaises(TypeError): Final[int, str] with self.assertRaises(TypeError): @@ -2637,11 +5767,19 @@ def test_repr(self): self.assertEqual(repr(cv), 'typing.Final[tuple[int]]') def test_cannot_subclass(self): - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): class C(type(Final)): pass - with self.assertRaises(TypeError): - class C(type(Final[int])): + with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): + class D(type(Final[int])): + pass + with self.assertRaisesRegex(TypeError, + r'Cannot subclass typing\.Final'): + class E(Final): + pass + with self.assertRaisesRegex(TypeError, + r'Cannot subclass typing\.Final\[int\]'): + class F(Final[int]): pass def test_cannot_init(self): @@ -2658,10 +5796,206 @@ def test_no_isinstance(self): with self.assertRaises(TypeError): issubclass(int, Final) + +class FinalDecoratorTests(BaseTestCase): def test_final_unmodified(self): def func(x): ... self.assertIs(func, final(func)) + def test_dunder_final(self): + @final + def func(): ... + @final + class Cls: ... + self.assertIs(True, func.__final__) + self.assertIs(True, Cls.__final__) + + class Wrapper: + __slots__ = ("func",) + def __init__(self, func): + self.func = func + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) + + # Check that no error is thrown if the attribute + # is not writable. + @final + @Wrapper + def wrapped(): ... + self.assertIsInstance(wrapped, Wrapper) + self.assertIs(False, hasattr(wrapped, "__final__")) + + class Meta(type): + @property + def __final__(self): return "can't set me" + @final + class WithMeta(metaclass=Meta): ... + self.assertEqual(WithMeta.__final__, "can't set me") + + # Builtin classes throw TypeError if you try to set an + # attribute. + final(int) + self.assertIs(False, hasattr(int, "__final__")) + + # Make sure it works with common builtin decorators + class Methods: + @final + @classmethod + def clsmethod(cls): ... + + @final + @staticmethod + def stmethod(): ... + + # The other order doesn't work because property objects + # don't allow attribute assignment. + @property + @final + def prop(self): ... + + @final + @lru_cache() + def cached(self): ... + + # Use getattr_static because the descriptor returns the + # underlying function, which doesn't have __final__. + self.assertIs( + True, + inspect.getattr_static(Methods, "clsmethod").__final__ + ) + self.assertIs( + True, + inspect.getattr_static(Methods, "stmethod").__final__ + ) + self.assertIs(True, Methods.prop.fget.__final__) + self.assertIs(True, Methods.cached.__final__) + + +class OverrideDecoratorTests(BaseTestCase): + def test_override(self): + class Base: + def normal_method(self): ... + @classmethod + def class_method_good_order(cls): ... + @classmethod + def class_method_bad_order(cls): ... + @staticmethod + def static_method_good_order(): ... + @staticmethod + def static_method_bad_order(): ... + + class Derived(Base): + @override + def normal_method(self): + return 42 + + @classmethod + @override + def class_method_good_order(cls): + return 42 + @override + @classmethod + def class_method_bad_order(cls): + return 42 + + @staticmethod + @override + def static_method_good_order(): + return 42 + @override + @staticmethod + def static_method_bad_order(): + return 42 + + self.assertIsSubclass(Derived, Base) + instance = Derived() + self.assertEqual(instance.normal_method(), 42) + self.assertIs(True, Derived.normal_method.__override__) + self.assertIs(True, instance.normal_method.__override__) + + self.assertEqual(Derived.class_method_good_order(), 42) + self.assertIs(True, Derived.class_method_good_order.__override__) + self.assertEqual(Derived.class_method_bad_order(), 42) + self.assertIs(False, hasattr(Derived.class_method_bad_order, "__override__")) + + self.assertEqual(Derived.static_method_good_order(), 42) + self.assertIs(True, Derived.static_method_good_order.__override__) + self.assertEqual(Derived.static_method_bad_order(), 42) + self.assertIs(False, hasattr(Derived.static_method_bad_order, "__override__")) + + # Base object is not changed: + self.assertIs(False, hasattr(Base.normal_method, "__override__")) + self.assertIs(False, hasattr(Base.class_method_good_order, "__override__")) + self.assertIs(False, hasattr(Base.class_method_bad_order, "__override__")) + self.assertIs(False, hasattr(Base.static_method_good_order, "__override__")) + self.assertIs(False, hasattr(Base.static_method_bad_order, "__override__")) + + def test_property(self): + class Base: + @property + def correct(self) -> int: + return 1 + @property + def wrong(self) -> int: + return 1 + + class Child(Base): + @property + @override + def correct(self) -> int: + return 2 + @override + @property + def wrong(self) -> int: + return 2 + + instance = Child() + self.assertEqual(instance.correct, 2) + self.assertTrue(Child.correct.fget.__override__) + self.assertEqual(instance.wrong, 2) + self.assertFalse(hasattr(Child.wrong, "__override__")) + self.assertFalse(hasattr(Child.wrong.fset, "__override__")) + + def test_silent_failure(self): + class CustomProp: + __slots__ = ('fget',) + def __init__(self, fget): + self.fget = fget + def __get__(self, obj, objtype=None): + return self.fget(obj) + + class WithOverride: + @override # must not fail on object with `__slots__` + @CustomProp + def some(self): + return 1 + + self.assertEqual(WithOverride.some, 1) + self.assertFalse(hasattr(WithOverride.some, "__override__")) + + def test_multiple_decorators(self): + def with_wraps(f): # similar to `lru_cache` definition + @wraps(f) + def wrapper(*args, **kwargs): + return f(*args, **kwargs) + return wrapper + + class WithOverride: + @override + @with_wraps + def on_top(self, a: int) -> int: + return a + 1 + @with_wraps + @override + def on_bottom(self, a: int) -> int: + return a + 2 + + instance = WithOverride() + self.assertEqual(instance.on_top(1), 2) + self.assertTrue(instance.on_top.__override__) + self.assertEqual(instance.on_bottom(1), 3) + self.assertTrue(instance.on_bottom.__override__) + class CastTests(BaseTestCase): @@ -2681,6 +6015,34 @@ def test_errors(self): cast('hello', 42) +class AssertTypeTests(BaseTestCase): + + def test_basics(self): + arg = 42 + self.assertIs(assert_type(arg, int), arg) + self.assertIs(assert_type(arg, str | float), arg) + self.assertIs(assert_type(arg, AnyStr), arg) + self.assertIs(assert_type(arg, None), arg) + + def test_errors(self): + # Bogus calls are not expected to fail. + arg = 42 + self.assertIs(assert_type(arg, 42), arg) + self.assertIs(assert_type(arg, 'hello'), arg) + + +# We need this to make sure that `@no_type_check` respects `__module__` attr: +from test.typinganndata import ann_module8 + +@no_type_check +class NoTypeCheck_Outer: + Inner = ann_module8.NoTypeCheck_Outer.Inner + +@no_type_check +class NoTypeCheck_WithFunction: + NoTypeCheck_function = ann_module8.NoTypeCheck_function + + class ForwardRefTests(BaseTestCase): def test_basics(self): @@ -2708,16 +6070,15 @@ def add_right(self, node: 'Node[T]' = None): t = Node[int] both_hints = get_type_hints(t.add_both, globals(), locals()) self.assertEqual(both_hints['left'], Optional[Node[T]]) - self.assertEqual(both_hints['right'], Optional[Node[T]]) - self.assertEqual(both_hints['left'], both_hints['right']) - self.assertEqual(both_hints['stuff'], Optional[int]) + self.assertEqual(both_hints['right'], Node[T]) + self.assertEqual(both_hints['stuff'], int) self.assertNotIn('blah', both_hints) left_hints = get_type_hints(t.add_left, globals(), locals()) self.assertEqual(left_hints['node'], Optional[Node[T]]) right_hints = get_type_hints(t.add_right, globals(), locals()) - self.assertEqual(right_hints['node'], Optional[Node[T]]) + self.assertEqual(right_hints['node'], Node[T]) def test_forwardref_instance_type_error(self): fr = typing.ForwardRef('int') @@ -2811,6 +6172,8 @@ def fun(x: a): def test_forward_repr(self): self.assertEqual(repr(List['int']), "typing.List[ForwardRef('int')]") + self.assertEqual(repr(List[ForwardRef('int', module='mod')]), + "typing.List[ForwardRef('int', module='mod')]") def test_union_forward(self): @@ -2866,10 +6229,11 @@ def fun(x: a): pass def cmp(o1, o2): return o1 == o2 - r1 = namespace1() - r2 = namespace2() - self.assertIsNot(r1, r2) - self.assertRaises(RecursionError, cmp, r1, r2) + with infinite_recursion(25): + r1 = namespace1() + r2 = namespace2() + self.assertIsNot(r1, r2) + self.assertRaises(RecursionError, cmp, r1, r2) def test_union_forward_recursion(self): ValueList = List['Value'] @@ -2924,12 +6288,16 @@ def test_special_forms_forward(self): class C: a: Annotated['ClassVar[int]', (3, 5)] = 4 b: Annotated['Final[int]', "const"] = 4 + x: 'ClassVar' = 4 + y: 'Final' = 4 class CF: b: List['Final[int]'] = 4 self.assertEqual(get_type_hints(C, globals())['a'], ClassVar[int]) self.assertEqual(get_type_hints(C, globals())['b'], Final[int]) + self.assertEqual(get_type_hints(C, globals())['x'], ClassVar) + self.assertEqual(get_type_hints(C, globals())['y'], Final) with self.assertRaises(TypeError): get_type_hints(CF, globals()), @@ -2946,13 +6314,11 @@ def foo(a: 'Node[T'): with self.assertRaises(SyntaxError): get_type_hints(foo) - def test_type_error(self): - - def foo(a: Tuple['42']): - pass - - with self.assertRaises(TypeError): - get_type_hints(foo) + def test_syntax_error_empty_string(self): + for form in [typing.List, typing.Set, typing.Type, typing.Deque]: + with self.subTest(form=form): + with self.assertRaises(SyntaxError): + form[''] def test_name_error(self): @@ -2989,9 +6355,98 @@ def meth(self, x: int): ... @no_type_check class D(C): c = C + # verify that @no_type_check never affects bases self.assertEqual(get_type_hints(C.meth), {'x': int}) + # and never child classes: + class Child(D): + def foo(self, x: int): ... + + self.assertEqual(get_type_hints(Child.foo), {'x': int}) + + def test_no_type_check_nested_types(self): + # See https://bugs.python.org/issue46571 + class Other: + o: int + class B: # Has the same `__name__`` as `A.B` and different `__qualname__` + o: int + @no_type_check + class A: + a: int + class B: + b: int + class C: + c: int + class D: + d: int + + Other = Other + + for klass in [A, A.B, A.B.C, A.D]: + with self.subTest(klass=klass): + self.assertTrue(klass.__no_type_check__) + self.assertEqual(get_type_hints(klass), {}) + + for not_modified in [Other, B]: + with self.subTest(not_modified=not_modified): + with self.assertRaises(AttributeError): + not_modified.__no_type_check__ + self.assertNotEqual(get_type_hints(not_modified), {}) + + def test_no_type_check_class_and_static_methods(self): + @no_type_check + class Some: + @staticmethod + def st(x: int) -> int: ... + @classmethod + def cl(cls, y: int) -> int: ... + + self.assertTrue(Some.st.__no_type_check__) + self.assertEqual(get_type_hints(Some.st), {}) + self.assertTrue(Some.cl.__no_type_check__) + self.assertEqual(get_type_hints(Some.cl), {}) + + def test_no_type_check_other_module(self): + self.assertTrue(NoTypeCheck_Outer.__no_type_check__) + with self.assertRaises(AttributeError): + ann_module8.NoTypeCheck_Outer.__no_type_check__ + with self.assertRaises(AttributeError): + ann_module8.NoTypeCheck_Outer.Inner.__no_type_check__ + + self.assertTrue(NoTypeCheck_WithFunction.__no_type_check__) + with self.assertRaises(AttributeError): + ann_module8.NoTypeCheck_function.__no_type_check__ + + def test_no_type_check_foreign_functions(self): + # We should not modify this function: + def some(*args: int) -> int: + ... + + @no_type_check + class A: + some_alias = some + some_class = classmethod(some) + some_static = staticmethod(some) + + with self.assertRaises(AttributeError): + some.__no_type_check__ + self.assertEqual(get_type_hints(some), {'args': int, 'return': int}) + + def test_no_type_check_lambda(self): + @no_type_check + class A: + # Corner case: `lambda` is both an assignment and a function: + bar: Callable[[int], int] = lambda arg: arg + + self.assertTrue(A.bar.__no_type_check__) + self.assertEqual(get_type_hints(A.bar), {}) + + def test_no_type_check_TypeError(self): + # This simply should not fail with + # `TypeError: can't set attributes of built-in/extension type 'dict'` + no_type_check(dict) + def test_no_type_check_forward_ref_as_string(self): class C: foo: typing.ClassVar[int] = 7 @@ -3006,21 +6461,15 @@ class F: for clazz in [C, D, E, F]: self.assertEqual(get_type_hints(clazz), expected_result) - def test_nested_classvar_fails_forward_ref_check(self): - class E: - foo: 'typing.ClassVar[typing.ClassVar[int]]' = 7 - class F: - foo: ClassVar['ClassVar[int]'] = 7 - - for clazz in [E, F]: - with self.assertRaises(TypeError): - get_type_hints(clazz) - def test_meta_no_type_check(self): - - @no_type_check_decorator - def magic_decorator(func): - return func + depr_msg = ( + "'typing.no_type_check_decorator' is deprecated " + "and slated for removal in Python 3.15" + ) + with self.assertWarnsRegex(DeprecationWarning, depr_msg): + @no_type_check_decorator + def magic_decorator(func): + return func self.assertEqual(magic_decorator.__name__, 'magic_decorator') @@ -3057,13 +6506,66 @@ def test_final_forward_ref(self): self.assertNotEqual(gth(Loop, globals())['attr'], Final[int]) self.assertNotEqual(gth(Loop, globals())['attr'], Final) + def test_or(self): + X = ForwardRef('X') + # __or__/__ror__ itself + self.assertEqual(X | "x", Union[X, "x"]) + self.assertEqual("x" | X, Union["x", X]) + + +class InternalsTests(BaseTestCase): + def test_deprecation_for_no_type_params_passed_to__evaluate(self): + with self.assertWarnsRegex( + DeprecationWarning, + ( + "Failing to pass a value to the 'type_params' parameter " + "of 'typing._eval_type' is deprecated" + ) + ) as cm: + self.assertEqual(typing._eval_type(list["int"], globals(), {}), list[int]) + + self.assertEqual(cm.filename, __file__) + + f = ForwardRef("int") + + with self.assertWarnsRegex( + DeprecationWarning, + ( + "Failing to pass a value to the 'type_params' parameter " + "of 'typing.ForwardRef._evaluate' is deprecated" + ) + ) as cm: + self.assertIs(f._evaluate(globals(), {}, recursive_guard=frozenset()), int) + + self.assertEqual(cm.filename, __file__) + + def test_collect_parameters(self): + typing = import_helper.import_fresh_module("typing") + with self.assertWarnsRegex( + DeprecationWarning, + "The private _collect_parameters function is deprecated" + ) as cm: + typing._collect_parameters + self.assertEqual(cm.filename, __file__) + + +@lru_cache() +def cached_func(x, y): + return 3 * x + y + + +class MethodHolder: + @classmethod + def clsmethod(cls): ... + @staticmethod + def stmethod(): ... + def method(self): ... + class OverloadTests(BaseTestCase): def test_overload_fails(self): - from typing import overload - - with self.assertRaises(RuntimeError): + with self.assertRaises(NotImplementedError): @overload def blah(): @@ -3072,8 +6574,6 @@ def blah(): blah() def test_overload_succeeds(self): - from typing import overload - @overload def blah(): pass @@ -3083,9 +6583,80 @@ def blah(): blah() + @cpython_only # gh-98713 + def test_overload_on_compiled_functions(self): + with patch("typing._overload_registry", + defaultdict(lambda: defaultdict(dict))): + # The registry starts out empty: + self.assertEqual(typing._overload_registry, {}) + + # This should just not fail: + overload(sum) + overload(print) + + # No overloads are recorded (but, it still has a side-effect): + self.assertEqual(typing.get_overloads(sum), []) + self.assertEqual(typing.get_overloads(print), []) + + def set_up_overloads(self): + def blah(): + pass + + overload1 = blah + overload(blah) + + def blah(): + pass + + overload2 = blah + overload(blah) + + def blah(): + pass + + return blah, [overload1, overload2] + + # Make sure we don't clear the global overload registry + @patch("typing._overload_registry", + defaultdict(lambda: defaultdict(dict))) + def test_overload_registry(self): + # The registry starts out empty + self.assertEqual(typing._overload_registry, {}) -ASYNCIO_TESTS = """ -import asyncio + impl, overloads = self.set_up_overloads() + self.assertNotEqual(typing._overload_registry, {}) + self.assertEqual(list(get_overloads(impl)), overloads) + + def some_other_func(): pass + overload(some_other_func) + other_overload = some_other_func + def some_other_func(): pass + self.assertEqual(list(get_overloads(some_other_func)), [other_overload]) + # Unrelated function still has no overloads: + def not_overloaded(): pass + self.assertEqual(list(get_overloads(not_overloaded)), []) + + # Make sure that after we clear all overloads, the registry is + # completely empty. + clear_overloads() + self.assertEqual(typing._overload_registry, {}) + self.assertEqual(get_overloads(impl), []) + + # Querying a function with no overloads shouldn't change the registry. + def the_only_one(): pass + self.assertEqual(get_overloads(the_only_one), []) + self.assertEqual(typing._overload_registry, {}) + + def test_overload_registry_repeated(self): + for _ in range(2): + impl, overloads = self.set_up_overloads() + + self.assertEqual(list(get_overloads(impl)), overloads) + + +from test.typinganndata import ( + ann_module, ann_module2, ann_module3, ann_module5, ann_module6, +) T_a = TypeVar('T_a') @@ -3118,19 +6689,6 @@ async def __aenter__(self) -> int: return 42 async def __aexit__(self, etype, eval, tb): return None -""" - -try: - exec(ASYNCIO_TESTS) -except ImportError: - ASYNCIO = False # multithreading is not enabled -else: - ASYNCIO = True - -# Definitions needed for features introduced in Python 3.6 - -from test import ann_module, ann_module2, ann_module3, ann_module5, ann_module6 -from typing import AsyncContextManager class A: y: float @@ -3177,20 +6735,59 @@ class Point2D(TypedDict): x: int y: int +class Point2DGeneric(Generic[T], TypedDict): + a: T + b: T + class Bar(_typed_dict_helper.Foo, total=False): b: int +class BarGeneric(_typed_dict_helper.FooGeneric[T], total=False): + b: int + class LabelPoint2D(Point2D, Label): ... class Options(TypedDict, total=False): log_level: int log_path: str +class TotalMovie(TypedDict): + title: str + year: NotRequired[int] + +class NontotalMovie(TypedDict, total=False): + title: Required[str] + year: int + +class ParentNontotalMovie(TypedDict, total=False): + title: Required[str] + +class ChildTotalMovie(ParentNontotalMovie): + year: NotRequired[int] + +class ParentDeeplyAnnotatedMovie(TypedDict): + title: Annotated[Annotated[Required[str], "foobar"], "another level"] + +class ChildDeeplyAnnotatedMovie(ParentDeeplyAnnotatedMovie): + year: NotRequired[Annotated[int, 2000]] + +class AnnotatedMovie(TypedDict): + title: Annotated[Required[str], "foobar"] + year: NotRequired[Annotated[int, 2000]] + +class DeeplyAnnotatedMovie(TypedDict): + title: Annotated[Annotated[Required[str], "foobar"], "another level"] + year: NotRequired[Annotated[int, 2000]] + +class WeirdlyQuotedMovie(TypedDict): + title: Annotated['Annotated[Required[str], "foobar"]', "another level"] + year: NotRequired['Annotated[int, 2000]'] + class HasForeignBaseClass(mod_generics_cache.A): some_xrepr: 'XRepr' other_a: 'mod_generics_cache.A' -async def g_with(am: AsyncContextManager[int]): +async def g_with(am: typing.AsyncContextManager[int]): x: int async with am as x: return x @@ -3349,7 +6946,7 @@ def foobar(x: list[ForwardRef('X')]): ... BA = Tuple[Annotated[T, (1, 0)], ...] def barfoo(x: BA): ... self.assertEqual(get_type_hints(barfoo, globals(), locals())['x'], Tuple[T, ...]) - self.assertIs( + self.assertEqual( get_type_hints(barfoo, globals(), locals(), include_extras=True)['x'], BA ) @@ -3357,7 +6954,7 @@ def barfoo(x: BA): ... BA = tuple[Annotated[T, (1, 0)], ...] def barfoo(x: BA): ... self.assertEqual(get_type_hints(barfoo, globals(), locals())['x'], tuple[T, ...]) - self.assertIs( + self.assertEqual( get_type_hints(barfoo, globals(), locals(), include_extras=True)['x'], BA ) @@ -3422,6 +7019,18 @@ def __iand__(self, other: Const["MySet[T]"]) -> "MySet[T]": {'other': MySet[T], 'return': MySet[T]} ) + def test_get_type_hints_annotated_with_none_default(self): + # See: https://bugs.python.org/issue46195 + def annotated_with_none_default(x: Annotated[int, 'data'] = None): ... + self.assertEqual( + get_type_hints(annotated_with_none_default), + {'x': int}, + ) + self.assertEqual( + get_type_hints(annotated_with_none_default, include_extras=True), + {'x': Annotated[int, 'data']}, + ) + def test_get_type_hints_classes_str_annotations(self): class Foo: y = str @@ -3447,10 +7056,79 @@ class BadType(BadBase): self.assertNotIn('bad', sys.modules) self.assertEqual(get_type_hints(BadType), {'foo': tuple, 'bar': list}) + def test_forward_ref_and_final(self): + # https://bugs.python.org/issue45166 + hints = get_type_hints(ann_module5) + self.assertEqual(hints, {'name': Final[str]}) + + hints = get_type_hints(ann_module5.MyClass) + self.assertEqual(hints, {'value': Final}) + + def test_top_level_class_var(self): + # https://bugs.python.org/issue45166 + with self.assertRaisesRegex( + TypeError, + r'typing.ClassVar\[int\] is not valid as type argument', + ): + get_type_hints(ann_module6) + + def test_get_type_hints_typeddict(self): + self.assertEqual(get_type_hints(TotalMovie), {'title': str, 'year': int}) + self.assertEqual(get_type_hints(TotalMovie, include_extras=True), { + 'title': str, + 'year': NotRequired[int], + }) + + self.assertEqual(get_type_hints(AnnotatedMovie), {'title': str, 'year': int}) + self.assertEqual(get_type_hints(AnnotatedMovie, include_extras=True), { + 'title': Annotated[Required[str], "foobar"], + 'year': NotRequired[Annotated[int, 2000]], + }) + + self.assertEqual(get_type_hints(DeeplyAnnotatedMovie), {'title': str, 'year': int}) + self.assertEqual(get_type_hints(DeeplyAnnotatedMovie, include_extras=True), { + 'title': Annotated[Required[str], "foobar", "another level"], + 'year': NotRequired[Annotated[int, 2000]], + }) + + self.assertEqual(get_type_hints(WeirdlyQuotedMovie), {'title': str, 'year': int}) + self.assertEqual(get_type_hints(WeirdlyQuotedMovie, include_extras=True), { + 'title': Annotated[Required[str], "foobar", "another level"], + 'year': NotRequired[Annotated[int, 2000]], + }) + + self.assertEqual(get_type_hints(_typed_dict_helper.VeryAnnotated), {'a': int}) + self.assertEqual(get_type_hints(_typed_dict_helper.VeryAnnotated, include_extras=True), { + 'a': Annotated[Required[int], "a", "b", "c"] + }) + + self.assertEqual(get_type_hints(ChildTotalMovie), {"title": str, "year": int}) + self.assertEqual(get_type_hints(ChildTotalMovie, include_extras=True), { + "title": Required[str], "year": NotRequired[int] + }) + + self.assertEqual(get_type_hints(ChildDeeplyAnnotatedMovie), {"title": str, "year": int}) + self.assertEqual(get_type_hints(ChildDeeplyAnnotatedMovie, include_extras=True), { + "title": Annotated[Required[str], "foobar", "another level"], + "year": NotRequired[Annotated[int, 2000]] + }) + + def test_get_type_hints_collections_abc_callable(self): + # https://github.com/python/cpython/issues/91621 + P = ParamSpec('P') + def f(x: collections.abc.Callable[[int], int]): ... + def g(x: collections.abc.Callable[..., int]): ... + def h(x: collections.abc.Callable[P, int]): ... + + self.assertEqual(get_type_hints(f), {'x': collections.abc.Callable[[int], int]}) + self.assertEqual(get_type_hints(g), {'x': collections.abc.Callable[..., int]}) + self.assertEqual(get_type_hints(h), {'x': collections.abc.Callable[P, int]}) + class GetUtilitiesTestCase(TestCase): def test_get_origin(self): T = TypeVar('T') + Ts = TypeVarTuple('Ts') P = ParamSpec('P') class C(Generic[T]): pass self.assertIs(get_origin(C[int]), C) @@ -3472,27 +7150,44 @@ class C(Generic[T]): pass self.assertIs(get_origin(list | str), types.UnionType) self.assertIs(get_origin(P.args), P) self.assertIs(get_origin(P.kwargs), P) + self.assertIs(get_origin(Required[int]), Required) + self.assertIs(get_origin(NotRequired[int]), NotRequired) + self.assertIs(get_origin((*Ts,)[0]), Unpack) + self.assertIs(get_origin(Unpack[Ts]), Unpack) + self.assertIs(get_origin((*tuple[*Ts],)[0]), tuple) + self.assertIs(get_origin(Unpack[Tuple[Unpack[Ts]]]), Unpack) def test_get_args(self): T = TypeVar('T') class C(Generic[T]): pass self.assertEqual(get_args(C[int]), (int,)) self.assertEqual(get_args(C[T]), (T,)) + self.assertEqual(get_args(typing.SupportsAbs[int]), (int,)) # Protocol + self.assertEqual(get_args(typing.SupportsAbs[T]), (T,)) + self.assertEqual(get_args(Point2DGeneric[int]), (int,)) # TypedDict + self.assertEqual(get_args(Point2DGeneric[T]), (T,)) + self.assertEqual(get_args(T), ()) self.assertEqual(get_args(int), ()) + self.assertEqual(get_args(Any), ()) + self.assertEqual(get_args(Self), ()) + self.assertEqual(get_args(LiteralString), ()) self.assertEqual(get_args(ClassVar[int]), (int,)) self.assertEqual(get_args(Union[int, str]), (int, str)) self.assertEqual(get_args(Literal[42, 43]), (42, 43)) self.assertEqual(get_args(Final[List[int]]), (List[int],)) + self.assertEqual(get_args(Optional[int]), (int, type(None))) + self.assertEqual(get_args(Union[int, None]), (int, type(None))) self.assertEqual(get_args(Union[int, Tuple[T, int]][str]), (int, Tuple[str, int])) self.assertEqual(get_args(typing.Dict[int, Tuple[T, T]][Optional[int]]), (int, Tuple[Optional[int], Optional[int]])) self.assertEqual(get_args(Callable[[], T][int]), ([], int)) self.assertEqual(get_args(Callable[..., int]), (..., int)) + self.assertEqual(get_args(Callable[[int], str]), ([int], str)) self.assertEqual(get_args(Union[int, Callable[[Tuple[T, ...]], str]]), (int, Callable[[Tuple[T, ...]], str])) self.assertEqual(get_args(Tuple[int, ...]), (int, ...)) - self.assertEqual(get_args(Tuple[()]), ((),)) + self.assertEqual(get_args(Tuple[()]), ()) self.assertEqual(get_args(Annotated[T, 'one', 2, ['three']]), (T, 'one', 2, ['three'])) self.assertEqual(get_args(List), ()) self.assertEqual(get_args(Tuple), ()) @@ -3505,26 +7200,30 @@ class C(Generic[T]): pass self.assertEqual(get_args(collections.abc.Callable[[int], str]), get_args(Callable[[int], str])) P = ParamSpec('P') + self.assertEqual(get_args(P), ()) + self.assertEqual(get_args(P.args), ()) + self.assertEqual(get_args(P.kwargs), ()) self.assertEqual(get_args(Callable[P, int]), (P, int)) + self.assertEqual(get_args(collections.abc.Callable[P, int]), (P, int)) self.assertEqual(get_args(Callable[Concatenate[int, P], int]), (Concatenate[int, P], int)) + self.assertEqual(get_args(collections.abc.Callable[Concatenate[int, P], int]), + (Concatenate[int, P], int)) + self.assertEqual(get_args(Concatenate[int, str, P]), (int, str, P)) self.assertEqual(get_args(list | str), (list, str)) - - def test_forward_ref_and_final(self): - # https://bugs.python.org/issue45166 - hints = get_type_hints(ann_module5) - self.assertEqual(hints, {'name': Final[str]}) - - hints = get_type_hints(ann_module5.MyClass) - self.assertEqual(hints, {'value': Final}) - - def test_top_level_class_var(self): - # https://bugs.python.org/issue45166 - with self.assertRaisesRegex( - TypeError, - r'typing.ClassVar\[int\] is not valid as type argument', - ): - get_type_hints(ann_module6) + self.assertEqual(get_args(Required[int]), (int,)) + self.assertEqual(get_args(NotRequired[int]), (int,)) + self.assertEqual(get_args(TypeAlias), ()) + self.assertEqual(get_args(TypeGuard[int]), (int,)) + self.assertEqual(get_args(TypeIs[range]), (range,)) + Ts = TypeVarTuple('Ts') + self.assertEqual(get_args(Ts), ()) + self.assertEqual(get_args((*Ts,)[0]), (Ts,)) + self.assertEqual(get_args(Unpack[Ts]), (Ts,)) + self.assertEqual(get_args(tuple[*Ts]), (*Ts,)) + self.assertEqual(get_args(tuple[Unpack[Ts]]), (Unpack[Ts],)) + self.assertEqual(get_args((*tuple[*Ts],)[0]), (*Ts,)) + self.assertEqual(get_args(Unpack[tuple[Unpack[Ts]]]), (tuple[Unpack[Ts]],)) class CollectionsAbcTests(BaseTestCase): @@ -3549,27 +7248,17 @@ def test_iterator(self): self.assertIsInstance(it, typing.Iterator) self.assertNotIsInstance(42, typing.Iterator) - @skipUnless(ASYNCIO, 'Python 3.5 and multithreading required') def test_awaitable(self): - ns = {} - exec( - "async def foo() -> typing.Awaitable[int]:\n" - " return await AwaitableWrapper(42)\n", - globals(), ns) - foo = ns['foo'] + async def foo() -> typing.Awaitable[int]: + return await AwaitableWrapper(42) g = foo() self.assertIsInstance(g, typing.Awaitable) self.assertNotIsInstance(foo, typing.Awaitable) g.send(None) # Run foo() till completion, to avoid warning. - @skipUnless(ASYNCIO, 'Python 3.5 and multithreading required') def test_coroutine(self): - ns = {} - exec( - "async def foo():\n" - " return\n", - globals(), ns) - foo = ns['foo'] + async def foo(): + return g = foo() self.assertIsInstance(g, typing.Coroutine) with self.assertRaises(TypeError): @@ -3580,7 +7269,6 @@ def test_coroutine(self): except StopIteration: pass - @skipUnless(ASYNCIO, 'Python 3.5 and multithreading required') def test_async_iterable(self): base_it = range(10) # type: Iterator[int] it = AsyncIteratorWrapper(base_it) @@ -3588,7 +7276,6 @@ def test_async_iterable(self): self.assertIsInstance(it, typing.AsyncIterable) self.assertNotIsInstance(42, typing.AsyncIterable) - @skipUnless(ASYNCIO, 'Python 3.5 and multithreading required') def test_async_iterator(self): base_it = range(10) # type: Iterator[int] it = AsyncIteratorWrapper(base_it) @@ -3634,8 +7321,14 @@ def test_mutablesequence(self): self.assertNotIsInstance((), typing.MutableSequence) def test_bytestring(self): - self.assertIsInstance(b'', typing.ByteString) - self.assertIsInstance(bytearray(b''), typing.ByteString) + with self.assertWarns(DeprecationWarning): + self.assertIsInstance(b'', typing.ByteString) + with self.assertWarns(DeprecationWarning): + self.assertIsInstance(bytearray(b''), typing.ByteString) + with self.assertWarns(DeprecationWarning): + class Foo(typing.ByteString): ... + with self.assertWarns(DeprecationWarning): + class Bar(typing.ByteString, typing.Awaitable): ... def test_list(self): self.assertIsSubclass(list, typing.List) @@ -3742,7 +7435,6 @@ class MyOrdDict(typing.OrderedDict[str, int]): self.assertIsSubclass(MyOrdDict, collections.OrderedDict) self.assertNotIsSubclass(collections.OrderedDict, MyOrdDict) - @skipUnless(sys.version_info >= (3, 3), 'ChainMap was added in 3.3') def test_chainmap_instantiation(self): self.assertIs(type(typing.ChainMap()), collections.ChainMap) self.assertIs(type(typing.ChainMap[KT, VT]()), collections.ChainMap) @@ -3750,7 +7442,6 @@ def test_chainmap_instantiation(self): class CM(typing.ChainMap[KT, VT]): ... self.assertIs(type(CM[int, str]()), CM) - @skipUnless(sys.version_info >= (3, 3), 'ChainMap was added in 3.3') def test_chainmap_subclass(self): class MyChainMap(typing.ChainMap[str, int]): @@ -3832,6 +7523,17 @@ def foo(): g = foo() self.assertIsSubclass(type(g), typing.Generator) + def test_generator_default(self): + g1 = typing.Generator[int] + g2 = typing.Generator[int, None, None] + self.assertEqual(get_args(g1), (int, type(None), type(None))) + self.assertEqual(get_args(g1), get_args(g2)) + + g3 = typing.Generator[int, float] + g4 = typing.Generator[int, float, None] + self.assertEqual(get_args(g3), (int, float, type(None))) + self.assertEqual(get_args(g3), get_args(g4)) + def test_no_generator_instantiation(self): with self.assertRaises(TypeError): typing.Generator() @@ -3841,10 +7543,9 @@ def test_no_generator_instantiation(self): typing.Generator[int, int, int]() def test_async_generator(self): - ns = {} - exec("async def f():\n" - " yield 42\n", globals(), ns) - g = ns['f']() + async def f(): + yield 42 + g = f() self.assertIsSubclass(type(g), typing.AsyncGenerator) def test_no_async_generator_instantiation(self): @@ -3878,7 +7579,7 @@ def __len__(self): return 0 self.assertEqual(len(MMC()), 0) - assert callable(MMC.update) + self.assertTrue(callable(MMC.update)) self.assertIsInstance(MMC(), typing.Mapping) class MMB(typing.MutableMapping[KT, VT]): @@ -3933,9 +7634,8 @@ def asend(self, value): def athrow(self, typ, val=None, tb=None): pass - ns = {} - exec('async def g(): yield 0', globals(), ns) - g = ns['g'] + async def g(): yield 0 + self.assertIsSubclass(G, typing.AsyncGenerator) self.assertIsSubclass(G, typing.AsyncIterable) self.assertIsSubclass(G, collections.abc.AsyncGenerator) @@ -4020,7 +7720,15 @@ def manager(): self.assertIsInstance(cm, typing.ContextManager) self.assertNotIsInstance(42, typing.ContextManager) - @skipUnless(ASYNCIO, 'Python 3.5 required') + def test_contextmanager_type_params(self): + cm1 = typing.ContextManager[int] + self.assertEqual(get_args(cm1), (int, bool | None)) + cm2 = typing.ContextManager[int, None] + self.assertEqual(get_args(cm2), (int, types.NoneType)) + + type gen_cm[T1, T2] = typing.ContextManager[T1, T2] + self.assertEqual(get_args(gen_cm.__value__[int, None]), (int, types.NoneType)) + def test_async_contextmanager(self): class NotACM: pass @@ -4032,11 +7740,17 @@ def manager(): cm = manager() self.assertNotIsInstance(cm, typing.AsyncContextManager) - self.assertEqual(typing.AsyncContextManager[int].__args__, (int,)) + self.assertEqual(typing.AsyncContextManager[int].__args__, (int, bool | None)) with self.assertRaises(TypeError): isinstance(42, typing.AsyncContextManager[int]) with self.assertRaises(TypeError): - typing.AsyncContextManager[int, str] + typing.AsyncContextManager[int, str, float] + + def test_asynccontextmanager_type_params(self): + cm1 = typing.AsyncContextManager[int] + self.assertEqual(get_args(cm1), (int, bool | None)) + cm2 = typing.AsyncContextManager[int, None] + self.assertEqual(get_args(cm2), (int, types.NoneType)) class TypeTests(BaseTestCase): @@ -4074,16 +7788,24 @@ def foo(a: A) -> Optional[BaseException]: else: return a() - assert isinstance(foo(KeyboardInterrupt), KeyboardInterrupt) - assert foo(None) is None + self.assertIsInstance(foo(KeyboardInterrupt), KeyboardInterrupt) + self.assertIsNone(foo(None)) + + +class TestModules(TestCase): + func_names = ['_idfunc'] + + def test_c_functions(self): + for fname in self.func_names: + self.assertEqual(getattr(typing, fname).__module__, '_typing') class NewTypeTests(BaseTestCase): @classmethod def setUpClass(cls): global UserId - UserId = NewType('UserId', int) - cls.UserName = NewType(cls.__qualname__ + '.UserName', str) + UserId = typing.NewType('UserId', int) + cls.UserName = typing.NewType(cls.__qualname__ + '.UserName', str) @classmethod def tearDownClass(cls): @@ -4091,9 +7813,6 @@ def tearDownClass(cls): del UserId del cls.UserName - def tearDown(self): - self.clear_caches() - def test_basic(self): self.assertIsInstance(UserId(5), int) self.assertIsInstance(self.UserName('Joe'), str) @@ -4109,11 +7828,11 @@ class D(UserId): def test_or(self): for cls in (int, self.UserName): with self.subTest(cls=cls): - self.assertEqual(UserId | cls, Union[UserId, cls]) - self.assertEqual(cls | UserId, Union[cls, UserId]) + self.assertEqual(UserId | cls, typing.Union[UserId, cls]) + self.assertEqual(cls | UserId, typing.Union[cls, UserId]) - self.assertEqual(get_args(UserId | cls), (UserId, cls)) - self.assertEqual(get_args(cls | UserId), (cls, UserId)) + self.assertEqual(typing.get_args(UserId | cls), (UserId, cls)) + self.assertEqual(typing.get_args(cls | UserId), (cls, UserId)) def test_special_attrs(self): self.assertEqual(UserId.__name__, 'UserId') @@ -4134,7 +7853,7 @@ def test_repr(self): f'{__name__}.{self.__class__.__qualname__}.UserName') def test_pickle(self): - UserAge = NewType('UserAge', float) + UserAge = typing.NewType('UserAge', float) for proto in range(pickle.HIGHEST_PROTOCOL + 1): with self.subTest(proto=proto): pickled = pickle.dumps(UserId, proto) @@ -4154,6 +7873,17 @@ def test_missing__name__(self): ) exec(code, {}) + def test_error_message_when_subclassing(self): + with self.assertRaisesRegex( + TypeError, + re.escape( + "Cannot subclass an instance of NewType. Perhaps you were looking for: " + "`ProUserId = NewType('ProUserId', UserId)`" + ) + ): + class ProUserId(UserId): + ... + class NamedTupleTests(BaseTestCase): class NestedEmployee(NamedTuple): @@ -4176,14 +7906,6 @@ def test_basics(self): self.assertEqual(Emp.__annotations__, collections.OrderedDict([('name', str), ('id', int)])) - def test_namedtuple_pyversion(self): - if sys.version_info[:2] < (3, 6): - with self.assertRaises(TypeError): - NamedTuple('Name', one=int, other=str) - with self.assertRaises(TypeError): - class NotYet(NamedTuple): - whatever = 0 - def test_annotation_usage(self): tim = CoolEmployee('Tim', 9000) self.assertIsInstance(tim, CoolEmployee) @@ -4239,22 +7961,120 @@ class A: with self.assertRaises(TypeError): class X(NamedTuple, A): x: int + with self.assertRaises(TypeError): + class Y(NamedTuple, tuple): + x: int + with self.assertRaises(TypeError): + class Z(NamedTuple, NamedTuple): + x: int + class B(NamedTuple): + x: int + with self.assertRaises(TypeError): + class C(NamedTuple, B): + y: str + + def test_generic(self): + class X(NamedTuple, Generic[T]): + x: T + self.assertEqual(X.__bases__, (tuple, Generic)) + self.assertEqual(X.__orig_bases__, (NamedTuple, Generic[T])) + self.assertEqual(X.__mro__, (X, tuple, Generic, object)) + + class Y(Generic[T], NamedTuple): + x: T + self.assertEqual(Y.__bases__, (Generic, tuple)) + self.assertEqual(Y.__orig_bases__, (Generic[T], NamedTuple)) + self.assertEqual(Y.__mro__, (Y, Generic, tuple, object)) + + for G in X, Y: + with self.subTest(type=G): + self.assertEqual(G.__parameters__, (T,)) + self.assertEqual(G[T].__args__, (T,)) + self.assertEqual(get_args(G[T]), (T,)) + A = G[int] + self.assertIs(A.__origin__, G) + self.assertEqual(A.__args__, (int,)) + self.assertEqual(get_args(A), (int,)) + self.assertEqual(A.__parameters__, ()) + + a = A(3) + self.assertIs(type(a), G) + self.assertEqual(a.x, 3) + + with self.assertRaises(TypeError): + G[int, str] + + def test_generic_pep695(self): + class X[T](NamedTuple): + x: T + T, = X.__type_params__ + self.assertIsInstance(T, TypeVar) + self.assertEqual(T.__name__, 'T') + self.assertEqual(X.__bases__, (tuple, Generic)) + self.assertEqual(X.__orig_bases__, (NamedTuple, Generic[T])) + self.assertEqual(X.__mro__, (X, tuple, Generic, object)) + self.assertEqual(X.__parameters__, (T,)) + self.assertEqual(X[str].__args__, (str,)) + self.assertEqual(X[str].__parameters__, ()) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_non_generic_subscript(self): + # For backward compatibility, subscription works + # on arbitrary NamedTuple types. + class Group(NamedTuple): + key: T + group: list[T] + A = Group[int] + self.assertEqual(A.__origin__, Group) + self.assertEqual(A.__parameters__, ()) + self.assertEqual(A.__args__, (int,)) + a = A(1, [2]) + self.assertIs(type(a), Group) + self.assertEqual(a, (1, [2])) def test_namedtuple_keyword_usage(self): - LocalEmployee = NamedTuple("LocalEmployee", name=str, age=int) + with self.assertWarnsRegex( + DeprecationWarning, + "Creating NamedTuple classes using keyword arguments is deprecated" + ): + LocalEmployee = NamedTuple("LocalEmployee", name=str, age=int) + nick = LocalEmployee('Nick', 25) self.assertIsInstance(nick, tuple) self.assertEqual(nick.name, 'Nick') self.assertEqual(LocalEmployee.__name__, 'LocalEmployee') self.assertEqual(LocalEmployee._fields, ('name', 'age')) self.assertEqual(LocalEmployee.__annotations__, dict(name=str, age=int)) - with self.assertRaises(TypeError): + + with self.assertRaisesRegex( + TypeError, + "Either list of fields or keywords can be provided to NamedTuple, not both" + ): NamedTuple('Name', [('x', int)], y=str) - with self.assertRaises(TypeError): - NamedTuple('Name', x=1, y='a') + + with self.assertRaisesRegex( + TypeError, + "Either list of fields or keywords can be provided to NamedTuple, not both" + ): + NamedTuple('Name', [], y=str) + + with self.assertRaisesRegex( + TypeError, + ( + r"Cannot pass `None` as the 'fields' parameter " + r"and also specify fields using keyword arguments" + ) + ): + NamedTuple('Name', None, x=int) def test_namedtuple_special_keyword_names(self): - NT = NamedTuple("NT", cls=type, self=object, typename=str, fields=list) + with self.assertWarnsRegex( + DeprecationWarning, + "Creating NamedTuple classes using keyword arguments is deprecated" + ): + NT = NamedTuple("NT", cls=type, self=object, typename=str, fields=list) + self.assertEqual(NT.__name__, 'NT') self.assertEqual(NT._fields, ('cls', 'self', 'typename', 'fields')) a = NT(cls=str, self=42, typename='foo', fields=[('bar', tuple)]) @@ -4264,12 +8084,32 @@ def test_namedtuple_special_keyword_names(self): self.assertEqual(a.fields, [('bar', tuple)]) def test_empty_namedtuple(self): - NT = NamedTuple('NT') + expected_warning = re.escape( + "Failing to pass a value for the 'fields' parameter is deprecated " + "and will be disallowed in Python 3.15. " + "To create a NamedTuple class with 0 fields " + "using the functional syntax, " + "pass an empty list, e.g. `NT1 = NamedTuple('NT1', [])`." + ) + with self.assertWarnsRegex(DeprecationWarning, fr"^{expected_warning}$"): + NT1 = NamedTuple('NT1') + + expected_warning = re.escape( + "Passing `None` as the 'fields' parameter is deprecated " + "and will be disallowed in Python 3.15. " + "To create a NamedTuple class with 0 fields " + "using the functional syntax, " + "pass an empty list, e.g. `NT2 = NamedTuple('NT2', [])`." + ) + with self.assertWarnsRegex(DeprecationWarning, fr"^{expected_warning}$"): + NT2 = NamedTuple('NT2', None) + + NT3 = NamedTuple('NT2', []) class CNT(NamedTuple): pass # empty body - for struct in [NT, CNT]: + for struct in NT1, NT2, NT3, CNT: with self.subTest(struct=struct): self.assertEqual(struct._fields, ()) self.assertEqual(struct._field_defaults, {}) @@ -4279,16 +8119,30 @@ class CNT(NamedTuple): def test_namedtuple_errors(self): with self.assertRaises(TypeError): NamedTuple.__new__() - with self.assertRaises(TypeError): + + with self.assertRaisesRegex( + TypeError, + "missing 1 required positional argument" + ): NamedTuple() - with self.assertRaises(TypeError): + + with self.assertRaisesRegex( + TypeError, + "takes from 1 to 2 positional arguments but 3 were given" + ): NamedTuple('Emp', [('name', str)], None) - with self.assertRaises(ValueError): + + with self.assertRaisesRegex( + ValueError, + "Field names cannot start with an underscore" + ): NamedTuple('Emp', [('_name', str)]) - with self.assertRaises(TypeError): + + with self.assertRaisesRegex( + TypeError, + "missing 1 required positional argument: 'typename'" + ): NamedTuple(typename='Emp', name=str, id=int) - with self.assertRaises(TypeError): - NamedTuple('Emp', fields=[('name', str), ('id', int)]) def test_copy_and_pickle(self): global Emp # pickle wants to reference the class by name @@ -4310,6 +8164,99 @@ def test_copy_and_pickle(self): self.assertEqual(jane2, jane) self.assertIsInstance(jane2, cls) + def test_orig_bases(self): + T = TypeVar('T') + + class SimpleNamedTuple(NamedTuple): + pass + + class GenericNamedTuple(NamedTuple, Generic[T]): + pass + + self.assertEqual(SimpleNamedTuple.__orig_bases__, (NamedTuple,)) + self.assertEqual(GenericNamedTuple.__orig_bases__, (NamedTuple, Generic[T])) + + CallNamedTuple = NamedTuple('CallNamedTuple', []) + + self.assertEqual(CallNamedTuple.__orig_bases__, (NamedTuple,)) + + def test_setname_called_on_values_in_class_dictionary(self): + class Vanilla: + def __set_name__(self, owner, name): + self.name = name + + class Foo(NamedTuple): + attr = Vanilla() + + foo = Foo() + self.assertEqual(len(foo), 0) + self.assertNotIn('attr', Foo._fields) + self.assertIsInstance(foo.attr, Vanilla) + self.assertEqual(foo.attr.name, "attr") + + class Bar(NamedTuple): + attr: Vanilla = Vanilla() + + bar = Bar() + self.assertEqual(len(bar), 1) + self.assertIn('attr', Bar._fields) + self.assertIsInstance(bar.attr, Vanilla) + self.assertEqual(bar.attr.name, "attr") + + def test_setname_raises_the_same_as_on_other_classes(self): + class CustomException(BaseException): pass + + class Annoying: + def __set_name__(self, owner, name): + raise CustomException + + annoying = Annoying() + + with self.assertRaises(CustomException) as cm: + class NormalClass: + attr = annoying + normal_exception = cm.exception + + with self.assertRaises(CustomException) as cm: + class NamedTupleClass(NamedTuple): + attr = annoying + namedtuple_exception = cm.exception + + self.assertIs(type(namedtuple_exception), CustomException) + self.assertIs(type(namedtuple_exception), type(normal_exception)) + + self.assertEqual(len(namedtuple_exception.__notes__), 1) + self.assertEqual( + len(namedtuple_exception.__notes__), len(normal_exception.__notes__) + ) + + expected_note = ( + "Error calling __set_name__ on 'Annoying' instance " + "'attr' in 'NamedTupleClass'" + ) + self.assertEqual(namedtuple_exception.__notes__[0], expected_note) + self.assertEqual( + namedtuple_exception.__notes__[0], + normal_exception.__notes__[0].replace("NormalClass", "NamedTupleClass") + ) + + def test_strange_errors_when_accessing_set_name_itself(self): + class CustomException(Exception): pass + + class Meta(type): + def __getattribute__(self, attr): + if attr == "__set_name__": + raise CustomException + return object.__getattribute__(self, attr) + + class VeryAnnoying(metaclass=Meta): pass + + very_annoying = VeryAnnoying() + + with self.assertRaises(CustomException): + class Foo(NamedTuple): + attr = very_annoying + class TypedDictTests(BaseTestCase): def test_basics_functional_syntax(self): @@ -4326,33 +8273,10 @@ def test_basics_functional_syntax(self): self.assertEqual(Emp.__bases__, (dict,)) self.assertEqual(Emp.__annotations__, {'name': str, 'id': int}) self.assertEqual(Emp.__total__, True) - - def test_basics_keywords_syntax(self): - Emp = TypedDict('Emp', name=str, id=int) - self.assertIsSubclass(Emp, dict) - self.assertIsSubclass(Emp, typing.MutableMapping) - self.assertNotIsSubclass(Emp, collections.abc.Sequence) - jim = Emp(name='Jim', id=1) - self.assertIs(type(jim), dict) - self.assertEqual(jim['name'], 'Jim') - self.assertEqual(jim['id'], 1) - self.assertEqual(Emp.__name__, 'Emp') - self.assertEqual(Emp.__module__, __name__) - self.assertEqual(Emp.__bases__, (dict,)) - self.assertEqual(Emp.__annotations__, {'name': str, 'id': int}) - self.assertEqual(Emp.__total__, True) - - def test_typeddict_special_keyword_names(self): - TD = TypedDict("TD", cls=type, self=object, typename=str, _typename=int, fields=list, _fields=dict) - self.assertEqual(TD.__name__, 'TD') - self.assertEqual(TD.__annotations__, {'cls': type, 'self': object, 'typename': str, '_typename': int, 'fields': list, '_fields': dict}) - a = TD(cls=str, self=42, typename='foo', _typename=53, fields=[('bar', tuple)], _fields={'baz', set}) - self.assertEqual(a['cls'], str) - self.assertEqual(a['self'], 42) - self.assertEqual(a['typename'], 'foo') - self.assertEqual(a['_typename'], 53) - self.assertEqual(a['fields'], [('bar', tuple)]) - self.assertEqual(a['_fields'], {'baz', set}) + self.assertEqual(Emp.__required_keys__, {'name', 'id'}) + self.assertIsInstance(Emp.__required_keys__, frozenset) + self.assertEqual(Emp.__optional_keys__, set()) + self.assertIsInstance(Emp.__optional_keys__, frozenset) def test_typeddict_create_errors(self): with self.assertRaises(TypeError): @@ -4361,11 +8285,10 @@ def test_typeddict_create_errors(self): TypedDict() with self.assertRaises(TypeError): TypedDict('Emp', [('name', str)], None) - with self.assertRaises(TypeError): - TypedDict(_typename='Emp', name=str, id=int) + TypedDict(_typename='Emp') with self.assertRaises(TypeError): - TypedDict('Emp', _fields={'name': str, 'id': int}) + TypedDict('Emp', name=str, id=int) def test_typeddict_errors(self): Emp = TypedDict('Emp', {'name': str, 'id': int}) @@ -4377,10 +8300,6 @@ def test_typeddict_errors(self): isinstance(jim, Emp) with self.assertRaises(TypeError): issubclass(dict, Emp) - with self.assertRaises(TypeError): - TypedDict('Hi', x=1) - with self.assertRaises(TypeError): - TypedDict('Hi', [('x', int), ('y', 1)]) with self.assertRaises(TypeError): TypedDict('Hi', [('x', int)], y=int) @@ -4399,7 +8318,7 @@ def test_py36_class_syntax_usage(self): def test_pickle(self): global EmpD # pickle wants to reference the class by name - EmpD = TypedDict('EmpD', name=str, id=int) + EmpD = TypedDict('EmpD', {'name': str, 'id': int}) jane = EmpD({'name': 'jane', 'id': 37}) for proto in range(pickle.HIGHEST_PROTOCOL + 1): z = pickle.dumps(jane, proto) @@ -4410,8 +8329,19 @@ def test_pickle(self): EmpDnew = pickle.loads(ZZ) self.assertEqual(EmpDnew({'name': 'jane', 'id': 37}), jane) + def test_pickle_generic(self): + point = Point2DGeneric(a=5.0, b=3.0) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + z = pickle.dumps(point, proto) + point2 = pickle.loads(z) + self.assertEqual(point2, point) + self.assertEqual(point2, {'a': 5.0, 'b': 3.0}) + ZZ = pickle.dumps(Point2DGeneric, proto) + Point2DGenericNew = pickle.loads(ZZ) + self.assertEqual(Point2DGenericNew({'a': 5.0, 'b': 3.0}), point) + def test_optional(self): - EmpD = TypedDict('EmpD', name=str, id=int) + EmpD = TypedDict('EmpD', {'name': str, 'id': int}) self.assertEqual(typing.Optional[EmpD], typing.Union[None, EmpD]) self.assertNotEqual(typing.List[EmpD], typing.Tuple[EmpD]) @@ -4422,7 +8352,9 @@ def test_total(self): self.assertEqual(D(x=1), {'x': 1}) self.assertEqual(D.__total__, False) self.assertEqual(D.__required_keys__, frozenset()) + self.assertIsInstance(D.__required_keys__, frozenset) self.assertEqual(D.__optional_keys__, {'x'}) + self.assertIsInstance(D.__optional_keys__, frozenset) self.assertEqual(Options(), {}) self.assertEqual(Options(log_level=2), {'log_level': 2}) @@ -4430,12 +8362,25 @@ def test_total(self): self.assertEqual(Options.__required_keys__, frozenset()) self.assertEqual(Options.__optional_keys__, {'log_level', 'log_path'}) + def test_total_inherits_non_total(self): + class TD1(TypedDict, total=False): + a: int + + self.assertIs(TD1.__total__, False) + + class TD2(TD1): + b: str + + self.assertIs(TD2.__total__, True) + def test_optional_keys(self): class Point2Dor3D(Point2D, total=False): z: int - assert Point2Dor3D.__required_keys__ == frozenset(['x', 'y']) - assert Point2Dor3D.__optional_keys__ == frozenset(['z']) + self.assertEqual(Point2Dor3D.__required_keys__, frozenset(['x', 'y'])) + self.assertIsInstance(Point2Dor3D.__required_keys__, frozenset) + self.assertEqual(Point2Dor3D.__optional_keys__, frozenset(['z'])) + self.assertIsInstance(Point2Dor3D.__optional_keys__, frozenset) def test_keys_inheritance(self): class BaseAnimal(TypedDict): @@ -4448,26 +8393,102 @@ class Animal(BaseAnimal, total=False): class Cat(Animal): fur_color: str - assert BaseAnimal.__required_keys__ == frozenset(['name']) - assert BaseAnimal.__optional_keys__ == frozenset([]) - assert BaseAnimal.__annotations__ == {'name': str} + self.assertEqual(BaseAnimal.__required_keys__, frozenset(['name'])) + self.assertEqual(BaseAnimal.__optional_keys__, frozenset([])) + self.assertEqual(BaseAnimal.__annotations__, {'name': str}) - assert Animal.__required_keys__ == frozenset(['name']) - assert Animal.__optional_keys__ == frozenset(['tail', 'voice']) - assert Animal.__annotations__ == { + self.assertEqual(Animal.__required_keys__, frozenset(['name'])) + self.assertEqual(Animal.__optional_keys__, frozenset(['tail', 'voice'])) + self.assertEqual(Animal.__annotations__, { 'name': str, 'tail': bool, 'voice': str, - } + }) - assert Cat.__required_keys__ == frozenset(['name', 'fur_color']) - assert Cat.__optional_keys__ == frozenset(['tail', 'voice']) - assert Cat.__annotations__ == { + self.assertEqual(Cat.__required_keys__, frozenset(['name', 'fur_color'])) + self.assertEqual(Cat.__optional_keys__, frozenset(['tail', 'voice'])) + self.assertEqual(Cat.__annotations__, { 'fur_color': str, 'name': str, 'tail': bool, 'voice': str, - } + }) + + def test_keys_inheritance_with_same_name(self): + class NotTotal(TypedDict, total=False): + a: int + + class Total(NotTotal): + a: int + + self.assertEqual(NotTotal.__required_keys__, frozenset()) + self.assertEqual(NotTotal.__optional_keys__, frozenset(['a'])) + self.assertEqual(Total.__required_keys__, frozenset(['a'])) + self.assertEqual(Total.__optional_keys__, frozenset()) + + class Base(TypedDict): + a: NotRequired[int] + b: Required[int] + + class Child(Base): + a: Required[int] + b: NotRequired[int] + + self.assertEqual(Base.__required_keys__, frozenset(['b'])) + self.assertEqual(Base.__optional_keys__, frozenset(['a'])) + self.assertEqual(Child.__required_keys__, frozenset(['a'])) + self.assertEqual(Child.__optional_keys__, frozenset(['b'])) + + def test_multiple_inheritance_with_same_key(self): + class Base1(TypedDict): + a: NotRequired[int] + + class Base2(TypedDict): + a: Required[str] + + class Child(Base1, Base2): + pass + + # Last base wins + self.assertEqual(Child.__annotations__, {'a': Required[str]}) + self.assertEqual(Child.__required_keys__, frozenset(['a'])) + self.assertEqual(Child.__optional_keys__, frozenset()) + + def test_required_notrequired_keys(self): + self.assertEqual(NontotalMovie.__required_keys__, + frozenset({"title"})) + self.assertEqual(NontotalMovie.__optional_keys__, + frozenset({"year"})) + + self.assertEqual(TotalMovie.__required_keys__, + frozenset({"title"})) + self.assertEqual(TotalMovie.__optional_keys__, + frozenset({"year"})) + + self.assertEqual(_typed_dict_helper.VeryAnnotated.__required_keys__, + frozenset()) + self.assertEqual(_typed_dict_helper.VeryAnnotated.__optional_keys__, + frozenset({"a"})) + + self.assertEqual(AnnotatedMovie.__required_keys__, + frozenset({"title"})) + self.assertEqual(AnnotatedMovie.__optional_keys__, + frozenset({"year"})) + + self.assertEqual(WeirdlyQuotedMovie.__required_keys__, + frozenset({"title"})) + self.assertEqual(WeirdlyQuotedMovie.__optional_keys__, + frozenset({"year"})) + + self.assertEqual(ChildTotalMovie.__required_keys__, + frozenset({"title"})) + self.assertEqual(ChildTotalMovie.__optional_keys__, + frozenset({"year"})) + + self.assertEqual(ChildDeeplyAnnotatedMovie.__required_keys__, + frozenset({"title"})) + self.assertEqual(ChildDeeplyAnnotatedMovie.__optional_keys__, + frozenset({"year"})) def test_multiple_inheritance(self): class One(TypedDict): @@ -4556,32 +8577,419 @@ class ChildWithInlineAndOptional(Untotal, Inline): class Wrong(*bases): pass - def test_is_typeddict(self): - assert is_typeddict(Point2D) is True - assert is_typeddict(Union[str, int]) is False - # classes, not instances - assert is_typeddict(Point2D()) is False + def test_is_typeddict(self): + self.assertIs(is_typeddict(Point2D), True) + self.assertIs(is_typeddict(Union[str, int]), False) + # classes, not instances + self.assertIs(is_typeddict(Point2D()), False) + call_based = TypedDict('call_based', {'a': int}) + self.assertIs(is_typeddict(call_based), True) + self.assertIs(is_typeddict(call_based()), False) + + T = TypeVar("T") + class BarGeneric(TypedDict, Generic[T]): + a: T + self.assertIs(is_typeddict(BarGeneric), True) + self.assertIs(is_typeddict(BarGeneric[int]), False) + self.assertIs(is_typeddict(BarGeneric()), False) + + class NewGeneric[T](TypedDict): + a: T + self.assertIs(is_typeddict(NewGeneric), True) + self.assertIs(is_typeddict(NewGeneric[int]), False) + self.assertIs(is_typeddict(NewGeneric()), False) + + # The TypedDict constructor is not itself a TypedDict + self.assertIs(is_typeddict(TypedDict), False) + + def test_get_type_hints(self): + self.assertEqual( + get_type_hints(Bar), + {'a': typing.Optional[int], 'b': int} + ) + + def test_get_type_hints_generic(self): + self.assertEqual( + get_type_hints(BarGeneric), + {'a': typing.Optional[T], 'b': int} + ) + + class FooBarGeneric(BarGeneric[int]): + c: str + + self.assertEqual( + get_type_hints(FooBarGeneric), + {'a': typing.Optional[T], 'b': int, 'c': str} + ) + + def test_pep695_generic_typeddict(self): + class A[T](TypedDict): + a: T + + T, = A.__type_params__ + self.assertIsInstance(T, TypeVar) + self.assertEqual(T.__name__, 'T') + self.assertEqual(A.__bases__, (Generic, dict)) + self.assertEqual(A.__orig_bases__, (TypedDict, Generic[T])) + self.assertEqual(A.__mro__, (A, Generic, dict, object)) + self.assertEqual(A.__parameters__, (T,)) + self.assertEqual(A[str].__parameters__, ()) + self.assertEqual(A[str].__args__, (str,)) + + def test_generic_inheritance(self): + class A(TypedDict, Generic[T]): + a: T + + self.assertEqual(A.__bases__, (Generic, dict)) + self.assertEqual(A.__orig_bases__, (TypedDict, Generic[T])) + self.assertEqual(A.__mro__, (A, Generic, dict, object)) + self.assertEqual(A.__parameters__, (T,)) + self.assertEqual(A[str].__parameters__, ()) + self.assertEqual(A[str].__args__, (str,)) + + class A2(Generic[T], TypedDict): + a: T + + self.assertEqual(A2.__bases__, (Generic, dict)) + self.assertEqual(A2.__orig_bases__, (Generic[T], TypedDict)) + self.assertEqual(A2.__mro__, (A2, Generic, dict, object)) + self.assertEqual(A2.__parameters__, (T,)) + self.assertEqual(A2[str].__parameters__, ()) + self.assertEqual(A2[str].__args__, (str,)) + + class B(A[KT], total=False): + b: KT + + self.assertEqual(B.__bases__, (Generic, dict)) + self.assertEqual(B.__orig_bases__, (A[KT],)) + self.assertEqual(B.__mro__, (B, Generic, dict, object)) + self.assertEqual(B.__parameters__, (KT,)) + self.assertEqual(B.__total__, False) + self.assertEqual(B.__optional_keys__, frozenset(['b'])) + self.assertEqual(B.__required_keys__, frozenset(['a'])) + + self.assertEqual(B[str].__parameters__, ()) + self.assertEqual(B[str].__args__, (str,)) + self.assertEqual(B[str].__origin__, B) + + class C(B[int]): + c: int + + self.assertEqual(C.__bases__, (Generic, dict)) + self.assertEqual(C.__orig_bases__, (B[int],)) + self.assertEqual(C.__mro__, (C, Generic, dict, object)) + self.assertEqual(C.__parameters__, ()) + self.assertEqual(C.__total__, True) + self.assertEqual(C.__optional_keys__, frozenset(['b'])) + self.assertEqual(C.__required_keys__, frozenset(['a', 'c'])) + self.assertEqual(C.__annotations__, { + 'a': T, + 'b': KT, + 'c': int, + }) + with self.assertRaises(TypeError): + C[str] + + + class Point3D(Point2DGeneric[T], Generic[T, KT]): + c: KT + + self.assertEqual(Point3D.__bases__, (Generic, dict)) + self.assertEqual(Point3D.__orig_bases__, (Point2DGeneric[T], Generic[T, KT])) + self.assertEqual(Point3D.__mro__, (Point3D, Generic, dict, object)) + self.assertEqual(Point3D.__parameters__, (T, KT)) + self.assertEqual(Point3D.__total__, True) + self.assertEqual(Point3D.__optional_keys__, frozenset()) + self.assertEqual(Point3D.__required_keys__, frozenset(['a', 'b', 'c'])) + self.assertEqual(Point3D.__annotations__, { + 'a': T, + 'b': T, + 'c': KT, + }) + self.assertEqual(Point3D[int, str].__origin__, Point3D) + + with self.assertRaises(TypeError): + Point3D[int] + + with self.assertRaises(TypeError): + class Point3D(Point2DGeneric[T], Generic[KT]): + c: KT + + def test_implicit_any_inheritance(self): + class A(TypedDict, Generic[T]): + a: T + + class B(A[KT], total=False): + b: KT + + class WithImplicitAny(B): + c: int + + self.assertEqual(WithImplicitAny.__bases__, (Generic, dict,)) + self.assertEqual(WithImplicitAny.__mro__, (WithImplicitAny, Generic, dict, object)) + # Consistent with GenericTests.test_implicit_any + self.assertEqual(WithImplicitAny.__parameters__, ()) + self.assertEqual(WithImplicitAny.__total__, True) + self.assertEqual(WithImplicitAny.__optional_keys__, frozenset(['b'])) + self.assertEqual(WithImplicitAny.__required_keys__, frozenset(['a', 'c'])) + self.assertEqual(WithImplicitAny.__annotations__, { + 'a': T, + 'b': KT, + 'c': int, + }) + with self.assertRaises(TypeError): + WithImplicitAny[str] + + def test_non_generic_subscript(self): + # For backward compatibility, subscription works + # on arbitrary TypedDict types. + class TD(TypedDict): + a: T + A = TD[int] + self.assertEqual(A.__origin__, TD) + self.assertEqual(A.__parameters__, ()) + self.assertEqual(A.__args__, (int,)) + a = A(a = 1) + self.assertIs(type(a), dict) + self.assertEqual(a, {'a': 1}) + + def test_orig_bases(self): + T = TypeVar('T') + + class Parent(TypedDict): + pass + + class Child(Parent): + pass + + class OtherChild(Parent): + pass + + class MixedChild(Child, OtherChild, Parent): + pass + + class GenericParent(TypedDict, Generic[T]): + pass + + class GenericChild(GenericParent[int]): + pass + + class OtherGenericChild(GenericParent[str]): + pass + + class MixedGenericChild(GenericChild, OtherGenericChild, GenericParent[float]): + pass + + class MultipleGenericBases(GenericParent[int], GenericParent[float]): + pass + + CallTypedDict = TypedDict('CallTypedDict', {}) + + self.assertEqual(Parent.__orig_bases__, (TypedDict,)) + self.assertEqual(Child.__orig_bases__, (Parent,)) + self.assertEqual(OtherChild.__orig_bases__, (Parent,)) + self.assertEqual(MixedChild.__orig_bases__, (Child, OtherChild, Parent,)) + self.assertEqual(GenericParent.__orig_bases__, (TypedDict, Generic[T])) + self.assertEqual(GenericChild.__orig_bases__, (GenericParent[int],)) + self.assertEqual(OtherGenericChild.__orig_bases__, (GenericParent[str],)) + self.assertEqual(MixedGenericChild.__orig_bases__, (GenericChild, OtherGenericChild, GenericParent[float])) + self.assertEqual(MultipleGenericBases.__orig_bases__, (GenericParent[int], GenericParent[float])) + self.assertEqual(CallTypedDict.__orig_bases__, (TypedDict,)) + + def test_zero_fields_typeddicts(self): + T1 = TypedDict("T1", {}) + class T2(TypedDict): pass + class T3[tvar](TypedDict): pass + S = TypeVar("S") + class T4(TypedDict, Generic[S]): pass + + expected_warning = re.escape( + "Failing to pass a value for the 'fields' parameter is deprecated " + "and will be disallowed in Python 3.15. " + "To create a TypedDict class with 0 fields " + "using the functional syntax, " + "pass an empty dictionary, e.g. `T5 = TypedDict('T5', {})`." + ) + with self.assertWarnsRegex(DeprecationWarning, fr"^{expected_warning}$"): + T5 = TypedDict('T5') + + expected_warning = re.escape( + "Passing `None` as the 'fields' parameter is deprecated " + "and will be disallowed in Python 3.15. " + "To create a TypedDict class with 0 fields " + "using the functional syntax, " + "pass an empty dictionary, e.g. `T6 = TypedDict('T6', {})`." + ) + with self.assertWarnsRegex(DeprecationWarning, fr"^{expected_warning}$"): + T6 = TypedDict('T6', None) + + for klass in T1, T2, T3, T4, T5, T6: + with self.subTest(klass=klass.__name__): + self.assertEqual(klass.__annotations__, {}) + self.assertEqual(klass.__required_keys__, set()) + self.assertEqual(klass.__optional_keys__, set()) + self.assertIsInstance(klass(), dict) + + def test_readonly_inheritance(self): + class Base1(TypedDict): + a: ReadOnly[int] + + class Child1(Base1): + b: str + + self.assertEqual(Child1.__readonly_keys__, frozenset({'a'})) + self.assertEqual(Child1.__mutable_keys__, frozenset({'b'})) + + class Base2(TypedDict): + a: ReadOnly[int] + + class Child2(Base2): + b: str + + self.assertEqual(Child1.__readonly_keys__, frozenset({'a'})) + self.assertEqual(Child1.__mutable_keys__, frozenset({'b'})) + + def test_cannot_make_mutable_key_readonly(self): + class Base(TypedDict): + a: int + + with self.assertRaises(TypeError): + class Child(Base): + a: ReadOnly[int] + + def test_can_make_readonly_key_mutable(self): + class Base(TypedDict): + a: ReadOnly[int] + + class Child(Base): + a: int + + self.assertEqual(Child.__readonly_keys__, frozenset()) + self.assertEqual(Child.__mutable_keys__, frozenset({'a'})) + + def test_combine_qualifiers(self): + class AllTheThings(TypedDict): + a: Annotated[Required[ReadOnly[int]], "why not"] + b: Required[Annotated[ReadOnly[int], "why not"]] + c: ReadOnly[NotRequired[Annotated[int, "why not"]]] + d: NotRequired[Annotated[int, "why not"]] + + self.assertEqual(AllTheThings.__required_keys__, frozenset({'a', 'b'})) + self.assertEqual(AllTheThings.__optional_keys__, frozenset({'c', 'd'})) + self.assertEqual(AllTheThings.__readonly_keys__, frozenset({'a', 'b', 'c'})) + self.assertEqual(AllTheThings.__mutable_keys__, frozenset({'d'})) + + self.assertEqual( + get_type_hints(AllTheThings, include_extras=False), + {'a': int, 'b': int, 'c': int, 'd': int}, + ) + self.assertEqual( + get_type_hints(AllTheThings, include_extras=True), + { + 'a': Annotated[Required[ReadOnly[int]], 'why not'], + 'b': Required[Annotated[ReadOnly[int], 'why not']], + 'c': ReadOnly[NotRequired[Annotated[int, 'why not']]], + 'd': NotRequired[Annotated[int, 'why not']], + }, + ) + + +class RequiredTests(BaseTestCase): + + def test_basics(self): + with self.assertRaises(TypeError): + Required[NotRequired] + with self.assertRaises(TypeError): + Required[int, str] + with self.assertRaises(TypeError): + Required[int][str] + + def test_repr(self): + self.assertEqual(repr(Required), 'typing.Required') + cv = Required[int] + self.assertEqual(repr(cv), 'typing.Required[int]') + cv = Required[Employee] + self.assertEqual(repr(cv), f'typing.Required[{__name__}.Employee]') + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_cannot_subclass(self): + with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): + class C(type(Required)): + pass + with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): + class D(type(Required[int])): + pass + with self.assertRaisesRegex(TypeError, + r'Cannot subclass typing\.Required'): + class E(Required): + pass + with self.assertRaisesRegex(TypeError, + r'Cannot subclass typing\.Required\[int\]'): + class F(Required[int]): + pass + + def test_cannot_init(self): + with self.assertRaises(TypeError): + Required() + with self.assertRaises(TypeError): + type(Required)() + with self.assertRaises(TypeError): + type(Required[Optional[int]])() + + def test_no_isinstance(self): + with self.assertRaises(TypeError): + isinstance(1, Required[int]) + with self.assertRaises(TypeError): + issubclass(int, Required) + + +class NotRequiredTests(BaseTestCase): + + def test_basics(self): + with self.assertRaises(TypeError): + NotRequired[Required] + with self.assertRaises(TypeError): + NotRequired[int, str] + with self.assertRaises(TypeError): + NotRequired[int][str] + + def test_repr(self): + self.assertEqual(repr(NotRequired), 'typing.NotRequired') + cv = NotRequired[int] + self.assertEqual(repr(cv), 'typing.NotRequired[int]') + cv = NotRequired[Employee] + self.assertEqual(repr(cv), f'typing.NotRequired[{__name__}.Employee]') - def test_get_type_hints(self): - self.assertEqual( - get_type_hints(Bar), - {'a': typing.Optional[int], 'b': int} - ) + def test_cannot_subclass(self): + with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): + class C(type(NotRequired)): + pass + with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): + class D(type(NotRequired[int])): + pass + with self.assertRaisesRegex(TypeError, + r'Cannot subclass typing\.NotRequired'): + class E(NotRequired): + pass + with self.assertRaisesRegex(TypeError, + r'Cannot subclass typing\.NotRequired\[int\]'): + class F(NotRequired[int]): + pass - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_non_generic_subscript(self): - # For backward compatibility, subscription works - # on arbitrary TypedDict types. - class TD(TypedDict): - a: T - A = TD[int] - self.assertEqual(A.__origin__, TD) - self.assertEqual(A.__parameters__, ()) - self.assertEqual(A.__args__, (int,)) - a = A(a = 1) - self.assertIs(type(a), dict) - self.assertEqual(a, {'a': 1}) + def test_cannot_init(self): + with self.assertRaises(TypeError): + NotRequired() + with self.assertRaises(TypeError): + type(NotRequired)() + with self.assertRaises(TypeError): + type(NotRequired[Optional[int]])() + + def test_no_isinstance(self): + with self.assertRaises(TypeError): + isinstance(1, NotRequired[int]) + with self.assertRaises(TypeError): + issubclass(int, NotRequired) class IOTests(BaseTestCase): @@ -4610,14 +9018,6 @@ def stuff(a: BinaryIO) -> bytes: a = stuff.__annotations__['a'] self.assertEqual(a.__parameters__, ()) - def test_io_submodule(self): - from typing.io import IO, TextIO, BinaryIO, __all__, __name__ - self.assertIs(IO, typing.IO) - self.assertIs(TextIO, typing.TextIO) - self.assertIs(BinaryIO, typing.BinaryIO) - self.assertEqual(set(__all__), set(['IO', 'TextIO', 'BinaryIO'])) - self.assertEqual(__name__, 'typing.io') - class RETests(BaseTestCase): # Much of this is really testing _TypeAlias. @@ -4662,31 +9062,26 @@ def test_repr(self): self.assertEqual(repr(Match[str]), 'typing.Match[str]') self.assertEqual(repr(Match[bytes]), 'typing.Match[bytes]') - def test_re_submodule(self): - from typing.re import Match, Pattern, __all__, __name__ - self.assertIs(Match, typing.Match) - self.assertIs(Pattern, typing.Pattern) - self.assertEqual(set(__all__), set(['Match', 'Pattern'])) - self.assertEqual(__name__, 'typing.re') - - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_cannot_subclass(self): - with self.assertRaises(TypeError) as ex: - + with self.assertRaisesRegex( + TypeError, + r"type 're\.Match' is not an acceptable base type", + ): class A(typing.Match): pass - - self.assertEqual(str(ex.exception), - "type 're.Match' is not an acceptable base type") + with self.assertRaisesRegex( + TypeError, + r"type 're\.Pattern' is not an acceptable base type", + ): + class B(typing.Pattern): + pass class AnnotatedTests(BaseTestCase): def test_new(self): with self.assertRaisesRegex( - TypeError, - 'Type Annotated cannot be instantiated', + TypeError, 'Cannot instantiate typing.Annotated', ): Annotated() @@ -4700,12 +9095,91 @@ def test_repr(self): "typing.Annotated[typing.List[int], 4, 5]" ) + def test_dir(self): + dir_items = set(dir(Annotated[int, 4])) + for required_item in [ + '__args__', '__parameters__', '__origin__', + '__metadata__', + ]: + with self.subTest(required_item=required_item): + self.assertIn(required_item, dir_items) + def test_flatten(self): A = Annotated[Annotated[int, 4], 5] self.assertEqual(A, Annotated[int, 4, 5]) self.assertEqual(A.__metadata__, (4, 5)) self.assertEqual(A.__origin__, int) + def test_deduplicate_from_union(self): + # Regular: + self.assertEqual(get_args(Annotated[int, 1] | int), + (Annotated[int, 1], int)) + self.assertEqual(get_args(Union[Annotated[int, 1], int]), + (Annotated[int, 1], int)) + self.assertEqual(get_args(Annotated[int, 1] | Annotated[int, 2] | int), + (Annotated[int, 1], Annotated[int, 2], int)) + self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[int, 2], int]), + (Annotated[int, 1], Annotated[int, 2], int)) + self.assertEqual(get_args(Annotated[int, 1] | Annotated[str, 1] | int), + (Annotated[int, 1], Annotated[str, 1], int)) + self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[str, 1], int]), + (Annotated[int, 1], Annotated[str, 1], int)) + + # Duplicates: + self.assertEqual(Annotated[int, 1] | Annotated[int, 1] | int, + Annotated[int, 1] | int) + self.assertEqual(Union[Annotated[int, 1], Annotated[int, 1], int], + Union[Annotated[int, 1], int]) + + # Unhashable metadata: + self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[int, set()] | int), + (str, Annotated[int, {}], Annotated[int, set()], int)) + self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[int, set()], int]), + (str, Annotated[int, {}], Annotated[int, set()], int)) + self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[str, {}] | int), + (str, Annotated[int, {}], Annotated[str, {}], int)) + self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[str, {}], int]), + (str, Annotated[int, {}], Annotated[str, {}], int)) + + self.assertEqual(get_args(Annotated[int, 1] | str | Annotated[str, {}] | int), + (Annotated[int, 1], str, Annotated[str, {}], int)) + self.assertEqual(get_args(Union[Annotated[int, 1], str, Annotated[str, {}], int]), + (Annotated[int, 1], str, Annotated[str, {}], int)) + + import dataclasses + @dataclasses.dataclass + class ValueRange: + lo: int + hi: int + v = ValueRange(1, 2) + self.assertEqual(get_args(Annotated[int, v] | None), + (Annotated[int, v], types.NoneType)) + self.assertEqual(get_args(Union[Annotated[int, v], None]), + (Annotated[int, v], types.NoneType)) + self.assertEqual(get_args(Optional[Annotated[int, v]]), + (Annotated[int, v], types.NoneType)) + + # Unhashable metadata duplicated: + self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int, + Annotated[int, {}] | int) + self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int, + int | Annotated[int, {}]) + self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int], + Union[Annotated[int, {}], int]) + self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int], + Union[int, Annotated[int, {}]]) + + def test_order_in_union(self): + expr1 = Annotated[int, 1] | str | Annotated[str, {}] | int + for args in itertools.permutations(get_args(expr1)): + with self.subTest(args=args): + self.assertEqual(expr1, reduce(operator.or_, args)) + + expr2 = Union[Annotated[int, 1], str, Annotated[str, {}], int] + for args in itertools.permutations(get_args(expr2)): + with self.subTest(args=args): + self.assertEqual(expr2, Union[args]) + def test_specialize(self): L = Annotated[List[T], "my decoration"] LI = Annotated[List[int], "my decoration"] @@ -4726,6 +9200,16 @@ def test_hash_eq(self): {Annotated[int, 4, 5], Annotated[int, 4, 5], Annotated[T, 4, 5]}, {Annotated[int, 4, 5], Annotated[T, 4, 5]} ) + # Unhashable `metadata` raises `TypeError`: + a1 = Annotated[int, []] + with self.assertRaises(TypeError): + hash(a1) + + class A: + __hash__ = None + a2 = Annotated[int, A()] + with self.assertRaises(TypeError): + hash(a2) def test_instantiate(self): class C: @@ -4751,6 +9235,17 @@ def test_instantiate_generic(self): self.assertEqual(MyCount([4, 4, 5]), {4: 2, 5: 1}) self.assertEqual(MyCount[int]([4, 4, 5]), {4: 2, 5: 1}) + def test_instantiate_immutable(self): + class C: + def __setattr__(self, key, value): + raise Exception("should be ignored") + + A = Annotated[C, "a decoration"] + # gh-115165: This used to cause RuntimeError to be raised + # when we tried to set `__orig_class__` on the `C` instance + # returned by the `A()` call + self.assertIsInstance(A(), C) + def test_cannot_instantiate_forward(self): A = Annotated["int", (5, 6)] with self.assertRaises(TypeError): @@ -4782,15 +9277,33 @@ class C: self.assertEqual(get_type_hints(C, globals())['classvar'], ClassVar[int]) self.assertEqual(get_type_hints(C, globals())['const'], Final[int]) - def test_hash_eq(self): - self.assertEqual(len({Annotated[int, 4, 5], Annotated[int, 4, 5]}), 1) - self.assertNotEqual(Annotated[int, 4, 5], Annotated[int, 5, 4]) - self.assertNotEqual(Annotated[int, 4, 5], Annotated[str, 4, 5]) - self.assertNotEqual(Annotated[int, 4], Annotated[int, 4, 4]) - self.assertEqual( - {Annotated[int, 4, 5], Annotated[int, 4, 5], Annotated[T, 4, 5]}, - {Annotated[int, 4, 5], Annotated[T, 4, 5]} - ) + def test_special_forms_nesting(self): + # These are uncommon types and are to ensure runtime + # is lax on validation. See gh-89547 for more context. + class CF: + x: ClassVar[Final[int]] + + class FC: + x: Final[ClassVar[int]] + + class ACF: + x: Annotated[ClassVar[Final[int]], "a decoration"] + + class CAF: + x: ClassVar[Annotated[Final[int], "a decoration"]] + + class AFC: + x: Annotated[Final[ClassVar[int]], "a decoration"] + + class FAC: + x: Final[Annotated[ClassVar[int], "a decoration"]] + + self.assertEqual(get_type_hints(CF, globals())['x'], ClassVar[Final[int]]) + self.assertEqual(get_type_hints(FC, globals())['x'], Final[ClassVar[int]]) + self.assertEqual(get_type_hints(ACF, globals())['x'], ClassVar[Final[int]]) + self.assertEqual(get_type_hints(CAF, globals())['x'], ClassVar[Final[int]]) + self.assertEqual(get_type_hints(AFC, globals())['x'], Final[ClassVar[int]]) + self.assertEqual(get_type_hints(FAC, globals())['x'], Final[ClassVar[int]]) def test_cannot_subclass(self): with self.assertRaisesRegex(TypeError, "Cannot subclass .*Annotated"): @@ -4868,6 +9381,120 @@ def test_subst(self): with self.assertRaises(TypeError): LI[None] + def test_typevar_subst(self): + dec = "a decoration" + Ts = TypeVarTuple('Ts') + T = TypeVar('T') + T1 = TypeVar('T1') + T2 = TypeVar('T2') + + A = Annotated[tuple[*Ts], dec] + self.assertEqual(A[int], Annotated[tuple[int], dec]) + self.assertEqual(A[str, int], Annotated[tuple[str, int], dec]) + with self.assertRaises(TypeError): + Annotated[*Ts, dec] + + B = Annotated[Tuple[Unpack[Ts]], dec] + self.assertEqual(B[int], Annotated[Tuple[int], dec]) + self.assertEqual(B[str, int], Annotated[Tuple[str, int], dec]) + with self.assertRaises(TypeError): + Annotated[Unpack[Ts], dec] + + C = Annotated[tuple[T, *Ts], dec] + self.assertEqual(C[int], Annotated[tuple[int], dec]) + self.assertEqual(C[int, str], Annotated[tuple[int, str], dec]) + self.assertEqual( + C[int, str, float], + Annotated[tuple[int, str, float], dec] + ) + with self.assertRaises(TypeError): + C[()] + + D = Annotated[Tuple[T, Unpack[Ts]], dec] + self.assertEqual(D[int], Annotated[Tuple[int], dec]) + self.assertEqual(D[int, str], Annotated[Tuple[int, str], dec]) + self.assertEqual( + D[int, str, float], + Annotated[Tuple[int, str, float], dec] + ) + with self.assertRaises(TypeError): + D[()] + + E = Annotated[tuple[*Ts, T], dec] + self.assertEqual(E[int], Annotated[tuple[int], dec]) + self.assertEqual(E[int, str], Annotated[tuple[int, str], dec]) + self.assertEqual( + E[int, str, float], + Annotated[tuple[int, str, float], dec] + ) + with self.assertRaises(TypeError): + E[()] + + F = Annotated[Tuple[Unpack[Ts], T], dec] + self.assertEqual(F[int], Annotated[Tuple[int], dec]) + self.assertEqual(F[int, str], Annotated[Tuple[int, str], dec]) + self.assertEqual( + F[int, str, float], + Annotated[Tuple[int, str, float], dec] + ) + with self.assertRaises(TypeError): + F[()] + + G = Annotated[tuple[T1, *Ts, T2], dec] + self.assertEqual(G[int, str], Annotated[tuple[int, str], dec]) + self.assertEqual( + G[int, str, float], + Annotated[tuple[int, str, float], dec] + ) + self.assertEqual( + G[int, str, bool, float], + Annotated[tuple[int, str, bool, float], dec] + ) + with self.assertRaises(TypeError): + G[int] + + H = Annotated[Tuple[T1, Unpack[Ts], T2], dec] + self.assertEqual(H[int, str], Annotated[Tuple[int, str], dec]) + self.assertEqual( + H[int, str, float], + Annotated[Tuple[int, str, float], dec] + ) + self.assertEqual( + H[int, str, bool, float], + Annotated[Tuple[int, str, bool, float], dec] + ) + with self.assertRaises(TypeError): + H[int] + + # Now let's try creating an alias from an alias. + + Ts2 = TypeVarTuple('Ts2') + T3 = TypeVar('T3') + T4 = TypeVar('T4') + + # G is Annotated[tuple[T1, *Ts, T2], dec]. + I = G[T3, *Ts2, T4] + J = G[T3, Unpack[Ts2], T4] + + for x, y in [ + (I, Annotated[tuple[T3, *Ts2, T4], dec]), + (J, Annotated[tuple[T3, Unpack[Ts2], T4], dec]), + (I[int, str], Annotated[tuple[int, str], dec]), + (J[int, str], Annotated[tuple[int, str], dec]), + (I[int, str, float], Annotated[tuple[int, str, float], dec]), + (J[int, str, float], Annotated[tuple[int, str, float], dec]), + (I[int, str, bool, float], + Annotated[tuple[int, str, bool, float], dec]), + (J[int, str, bool, float], + Annotated[tuple[int, str, bool, float], dec]), + ]: + self.assertEqual(x, y) + + with self.assertRaises(TypeError): + I[int] + with self.assertRaises(TypeError): + J[int] + def test_annotated_in_other_types(self): X = List[Annotated[T, 5]] self.assertEqual(X[int], List[Annotated[int, 5]]) @@ -4877,6 +9504,40 @@ class X(Annotated[int, (1, 10)]): ... self.assertEqual(X.__mro__, (X, int, object), "Annotated should be transparent.") + def test_annotated_cached_with_types(self): + class A(str): ... + class B(str): ... + + field_a1 = Annotated[str, A("X")] + field_a2 = Annotated[str, B("X")] + a1_metadata = field_a1.__metadata__[0] + a2_metadata = field_a2.__metadata__[0] + + self.assertIs(type(a1_metadata), A) + self.assertEqual(a1_metadata, A("X")) + self.assertIs(type(a2_metadata), B) + self.assertEqual(a2_metadata, B("X")) + self.assertIsNot(type(a1_metadata), type(a2_metadata)) + + field_b1 = Annotated[str, A("Y")] + field_b2 = Annotated[str, B("Y")] + b1_metadata = field_b1.__metadata__[0] + b2_metadata = field_b2.__metadata__[0] + + self.assertIs(type(b1_metadata), A) + self.assertEqual(b1_metadata, A("Y")) + self.assertIs(type(b2_metadata), B) + self.assertEqual(b2_metadata, B("Y")) + self.assertIsNot(type(b1_metadata), type(b2_metadata)) + + field_c1 = Annotated[int, 1] + field_c2 = Annotated[int, 1.0] + field_c3 = Annotated[int, True] + + self.assertIs(type(field_c1.__metadata__[0]), int) + self.assertIs(type(field_c2.__metadata__[0]), float) + self.assertIs(type(field_c3.__metadata__[0]), bool) + class TypeAliasTests(BaseTestCase): def test_canonical_usage_with_variable_annotation(self): @@ -4906,12 +9567,13 @@ def test_no_issubclass(self): issubclass(TypeAlias, Employee) def test_cannot_subclass(self): - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, + r'Cannot subclass typing\.TypeAlias'): class C(TypeAlias): pass with self.assertRaises(TypeError): - class C(type(TypeAlias)): + class D(type(TypeAlias)): pass def test_repr(self): @@ -4928,6 +9590,16 @@ def test_basic_plain(self): P = ParamSpec('P') self.assertEqual(P, P) self.assertIsInstance(P, ParamSpec) + self.assertEqual(P.__name__, 'P') + self.assertEqual(P.__module__, __name__) + + def test_basic_with_exec(self): + ns = {} + exec('from typing import ParamSpec; P = ParamSpec("P")', ns, ns) + P = ns['P'] + self.assertIsInstance(P, ParamSpec) + self.assertEqual(P.__name__, 'P') + self.assertIs(P.__module__, None) # TODO: RUSTPYTHON @unittest.expectedFailure @@ -5047,32 +9719,208 @@ class X(Generic[P, P2]): self.assertEqual(G1.__args__, ((int, str), (bytes,))) self.assertEqual(G2.__args__, ((int,), (str, bytes))) + def test_typevartuple_and_paramspecs_in_user_generics(self): + Ts = TypeVarTuple("Ts") + P = ParamSpec("P") + + class X(Generic[*Ts, P]): + f: Callable[P, int] + g: Tuple[*Ts] + + G1 = X[int, [bytes]] + self.assertEqual(G1.__args__, (int, (bytes,))) + G2 = X[int, str, [bytes]] + self.assertEqual(G2.__args__, (int, str, (bytes,))) + G3 = X[[bytes]] + self.assertEqual(G3.__args__, ((bytes,),)) + G4 = X[[]] + self.assertEqual(G4.__args__, ((),)) + with self.assertRaises(TypeError): + X[()] + + class Y(Generic[P, *Ts]): + f: Callable[P, int] + g: Tuple[*Ts] + + G1 = Y[[bytes], int] + self.assertEqual(G1.__args__, ((bytes,), int)) + G2 = Y[[bytes], int, str] + self.assertEqual(G2.__args__, ((bytes,), int, str)) + G3 = Y[[bytes]] + self.assertEqual(G3.__args__, ((bytes,),)) + G4 = Y[[]] + self.assertEqual(G4.__args__, ((),)) + with self.assertRaises(TypeError): + Y[()] + + def test_typevartuple_and_paramspecs_in_generic_aliases(self): + P = ParamSpec('P') + T = TypeVar('T') + Ts = TypeVarTuple('Ts') + + for C in Callable, collections.abc.Callable: + with self.subTest(generic=C): + A = C[P, Tuple[*Ts]] + B = A[[int, str], bytes, float] + self.assertEqual(B.__args__, (int, str, Tuple[bytes, float])) + + class X(Generic[T, P]): + pass + + A = X[Tuple[*Ts], P] + B = A[bytes, float, [int, str]] + self.assertEqual(B.__args__, (Tuple[bytes, float], (int, str,))) + + class Y(Generic[P, T]): + pass + + A = Y[P, Tuple[*Ts]] + B = A[[int, str], bytes, float] + self.assertEqual(B.__args__, ((int, str,), Tuple[bytes, float])) + + def test_var_substitution(self): + P = ParamSpec("P") + subst = P.__typing_subst__ + self.assertEqual(subst((int, str)), (int, str)) + self.assertEqual(subst([int, str]), (int, str)) + self.assertEqual(subst([None]), (type(None),)) + self.assertIs(subst(...), ...) + self.assertIs(subst(P), P) + self.assertEqual(subst(Concatenate[int, P]), Concatenate[int, P]) + def test_bad_var_substitution(self): T = TypeVar('T') P = ParamSpec('P') bad_args = (42, int, None, T, int|str, Union[int, str]) for arg in bad_args: with self.subTest(arg=arg): + with self.assertRaises(TypeError): + P.__typing_subst__(arg) with self.assertRaises(TypeError): typing.Callable[P, T][arg, str] with self.assertRaises(TypeError): collections.abc.Callable[P, T][arg, str] - def test_no_paramspec_in__parameters__(self): - # ParamSpec should not be found in __parameters__ - # of generics. Usages outside Callable, Concatenate - # and Generic are invalid. - T = TypeVar("T") - P = ParamSpec("P") - self.assertNotIn(P, List[P].__parameters__) - self.assertIn(T, Tuple[T, P].__parameters__) + def test_type_var_subst_for_other_type_vars(self): + T = TypeVar('T') + T2 = TypeVar('T2') + P = ParamSpec('P') + P2 = ParamSpec('P2') + Ts = TypeVarTuple('Ts') + + class Base(Generic[P]): + pass + + A1 = Base[T] + self.assertEqual(A1.__parameters__, (T,)) + self.assertEqual(A1.__args__, ((T,),)) + self.assertEqual(A1[int], Base[int]) + + A2 = Base[[T]] + self.assertEqual(A2.__parameters__, (T,)) + self.assertEqual(A2.__args__, ((T,),)) + self.assertEqual(A2[int], Base[int]) + + A3 = Base[[int, T]] + self.assertEqual(A3.__parameters__, (T,)) + self.assertEqual(A3.__args__, ((int, T),)) + self.assertEqual(A3[str], Base[[int, str]]) + + A4 = Base[[T, int, T2]] + self.assertEqual(A4.__parameters__, (T, T2)) + self.assertEqual(A4.__args__, ((T, int, T2),)) + self.assertEqual(A4[str, bool], Base[[str, int, bool]]) + + A5 = Base[[*Ts, int]] + self.assertEqual(A5.__parameters__, (Ts,)) + self.assertEqual(A5.__args__, ((*Ts, int),)) + self.assertEqual(A5[str, bool], Base[[str, bool, int]]) + + A5_2 = Base[[int, *Ts]] + self.assertEqual(A5_2.__parameters__, (Ts,)) + self.assertEqual(A5_2.__args__, ((int, *Ts),)) + self.assertEqual(A5_2[str, bool], Base[[int, str, bool]]) + + A6 = Base[[T, *Ts]] + self.assertEqual(A6.__parameters__, (T, Ts)) + self.assertEqual(A6.__args__, ((T, *Ts),)) + self.assertEqual(A6[int, str, bool], Base[[int, str, bool]]) + + A7 = Base[[T, T]] + self.assertEqual(A7.__parameters__, (T,)) + self.assertEqual(A7.__args__, ((T, T),)) + self.assertEqual(A7[int], Base[[int, int]]) + + A8 = Base[[T, list[T]]] + self.assertEqual(A8.__parameters__, (T,)) + self.assertEqual(A8.__args__, ((T, list[T]),)) + self.assertEqual(A8[int], Base[[int, list[int]]]) + + A9 = Base[[Tuple[*Ts], *Ts]] + self.assertEqual(A9.__parameters__, (Ts,)) + self.assertEqual(A9.__args__, ((Tuple[*Ts], *Ts),)) + self.assertEqual(A9[int, str], Base[Tuple[int, str], int, str]) + + A10 = Base[P2] + self.assertEqual(A10.__parameters__, (P2,)) + self.assertEqual(A10.__args__, (P2,)) + self.assertEqual(A10[[int, str]], Base[[int, str]]) + + class DoubleP(Generic[P, P2]): + pass + + B1 = DoubleP[P, P2] + self.assertEqual(B1.__parameters__, (P, P2)) + self.assertEqual(B1.__args__, (P, P2)) + self.assertEqual(B1[[int, str], [bool]], DoubleP[[int, str], [bool]]) + self.assertEqual(B1[[], []], DoubleP[[], []]) + + B2 = DoubleP[[int, str], P2] + self.assertEqual(B2.__parameters__, (P2,)) + self.assertEqual(B2.__args__, ((int, str), P2)) + self.assertEqual(B2[[bool, bool]], DoubleP[[int, str], [bool, bool]]) + self.assertEqual(B2[[]], DoubleP[[int, str], []]) + + B3 = DoubleP[P, [bool, bool]] + self.assertEqual(B3.__parameters__, (P,)) + self.assertEqual(B3.__args__, (P, (bool, bool))) + self.assertEqual(B3[[int, str]], DoubleP[[int, str], [bool, bool]]) + self.assertEqual(B3[[]], DoubleP[[], [bool, bool]]) + + B4 = DoubleP[[T, int], [bool, T2]] + self.assertEqual(B4.__parameters__, (T, T2)) + self.assertEqual(B4.__args__, ((T, int), (bool, T2))) + self.assertEqual(B4[str, float], DoubleP[[str, int], [bool, float]]) + + B5 = DoubleP[[*Ts, int], [bool, T2]] + self.assertEqual(B5.__parameters__, (Ts, T2)) + self.assertEqual(B5.__args__, ((*Ts, int), (bool, T2))) + self.assertEqual(B5[str, bytes, float], + DoubleP[[str, bytes, int], [bool, float]]) + + B6 = DoubleP[[T, int], [bool, *Ts]] + self.assertEqual(B6.__parameters__, (T, Ts)) + self.assertEqual(B6.__args__, ((T, int), (bool, *Ts))) + self.assertEqual(B6[str, bytes, float], + DoubleP[[str, int], [bool, bytes, float]]) + + class PandT(Generic[P, T]): + pass - # Test for consistency with builtin generics. - self.assertNotIn(P, list[P].__parameters__) - self.assertIn(T, tuple[T, P].__parameters__) + C1 = PandT[P, T] + self.assertEqual(C1.__parameters__, (P, T)) + self.assertEqual(C1.__args__, (P, T)) + self.assertEqual(C1[[int, str], bool], PandT[[int, str], bool]) - self.assertNotIn(P, (list[P] | int).__parameters__) - self.assertIn(T, (tuple[T, P] | int).__parameters__) + C2 = PandT[[int, T], T] + self.assertEqual(C2.__parameters__, (T,)) + self.assertEqual(C2.__args__, ((int, T), T)) + self.assertEqual(C2[str], PandT[[int, str], str]) + + C3 = PandT[[int, *Ts], T] + self.assertEqual(C3.__parameters__, (Ts, T)) + self.assertEqual(C3.__args__, ((int, *Ts), T)) + self.assertEqual(C3[str, bool, bytes], PandT[[int, str, bool], bytes]) def test_paramspec_in_nested_generics(self): # Although ParamSpec should not be found in __parameters__ of most @@ -5113,6 +9961,24 @@ def test_paramspec_gets_copied(self): self.assertEqual(C2[Concatenate[str, P2]].__parameters__, (P2,)) self.assertEqual(C2[Concatenate[T, P2]].__parameters__, (T, P2)) + def test_cannot_subclass(self): + with self.assertRaisesRegex(TypeError, NOT_A_BASE_TYPE % 'ParamSpec'): + class C(ParamSpec): pass + with self.assertRaisesRegex(TypeError, NOT_A_BASE_TYPE % 'ParamSpecArgs'): + class D(ParamSpecArgs): pass + with self.assertRaisesRegex(TypeError, NOT_A_BASE_TYPE % 'ParamSpecKwargs'): + class E(ParamSpecKwargs): pass + P = ParamSpec('P') + with self.assertRaisesRegex(TypeError, + CANNOT_SUBCLASS_INSTANCE % 'ParamSpec'): + class F(P): pass + with self.assertRaisesRegex(TypeError, + CANNOT_SUBCLASS_INSTANCE % 'ParamSpecArgs'): + class G(P.args): pass + with self.assertRaisesRegex(TypeError, + CANNOT_SUBCLASS_INSTANCE % 'ParamSpecKwargs'): + class H(P.kwargs): pass + class ConcatenateTests(BaseTestCase): def test_basics(self): @@ -5121,6 +9987,15 @@ class MyClass: ... c = Concatenate[MyClass, P] self.assertNotEqual(c, Concatenate) + def test_dir(self): + P = ParamSpec('P') + dir_items = set(dir(Concatenate[int, P])) + for required_item in [ + '__args__', '__parameters__', '__origin__', + ]: + with self.subTest(required_item=required_item): + self.assertIn(required_item, dir_items) + def test_valid_uses(self): P = ParamSpec('P') T = TypeVar('T') @@ -5149,8 +10024,7 @@ def test_var_substitution(self): self.assertEqual(C[int, []], (int,)) self.assertEqual(C[int, Concatenate[str, P2]], Concatenate[int, str, P2]) - with self.assertRaises(TypeError): - C[int, ...] + self.assertEqual(C[int, ...], Concatenate[int, ...]) C = Concatenate[int, P] self.assertEqual(C[P2], Concatenate[int, P2]) @@ -5158,8 +10032,7 @@ def test_var_substitution(self): self.assertEqual(C[str, float], (int, str, float)) self.assertEqual(C[[]], (int,)) self.assertEqual(C[Concatenate[str, P2]], Concatenate[int, str, P2]) - with self.assertRaises(TypeError): - C[...] + self.assertEqual(C[...], Concatenate[int, ...]) class TypeGuardTests(BaseTestCase): def test_basics(self): @@ -5168,6 +10041,9 @@ def test_basics(self): def foo(arg) -> TypeGuard[int]: ... self.assertEqual(gth(foo), {'return': TypeGuard[int]}) + with self.assertRaises(TypeError): + TypeGuard[int, str] + def test_repr(self): self.assertEqual(repr(TypeGuard), 'typing.TypeGuard') cv = TypeGuard[int] @@ -5178,11 +10054,19 @@ def test_repr(self): self.assertEqual(repr(cv), 'typing.TypeGuard[tuple[int]]') def test_cannot_subclass(self): - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): class C(type(TypeGuard)): pass - with self.assertRaises(TypeError): - class C(type(TypeGuard[int])): + with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): + class D(type(TypeGuard[int])): + pass + with self.assertRaisesRegex(TypeError, + r'Cannot subclass typing\.TypeGuard'): + class E(TypeGuard): + pass + with self.assertRaisesRegex(TypeError, + r'Cannot subclass typing\.TypeGuard\[int\]'): + class F(TypeGuard[int]): pass def test_cannot_init(self): @@ -5200,6 +10084,56 @@ def test_no_isinstance(self): issubclass(int, TypeGuard) +class TypeIsTests(BaseTestCase): + def test_basics(self): + TypeIs[int] # OK + + def foo(arg) -> TypeIs[int]: ... + self.assertEqual(gth(foo), {'return': TypeIs[int]}) + + with self.assertRaises(TypeError): + TypeIs[int, str] + + def test_repr(self): + self.assertEqual(repr(TypeIs), 'typing.TypeIs') + cv = TypeIs[int] + self.assertEqual(repr(cv), 'typing.TypeIs[int]') + cv = TypeIs[Employee] + self.assertEqual(repr(cv), 'typing.TypeIs[%s.Employee]' % __name__) + cv = TypeIs[tuple[int]] + self.assertEqual(repr(cv), 'typing.TypeIs[tuple[int]]') + + def test_cannot_subclass(self): + with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): + class C(type(TypeIs)): + pass + with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): + class D(type(TypeIs[int])): + pass + with self.assertRaisesRegex(TypeError, + r'Cannot subclass typing\.TypeIs'): + class E(TypeIs): + pass + with self.assertRaisesRegex(TypeError, + r'Cannot subclass typing\.TypeIs\[int\]'): + class F(TypeIs[int]): + pass + + def test_cannot_init(self): + with self.assertRaises(TypeError): + TypeIs() + with self.assertRaises(TypeError): + type(TypeIs)() + with self.assertRaises(TypeError): + type(TypeIs[Optional[int]])() + + def test_no_isinstance(self): + with self.assertRaises(TypeError): + isinstance(1, TypeIs[int]) + with self.assertRaises(TypeError): + issubclass(int, TypeIs) + + SpecialAttrsP = typing.ParamSpec('SpecialAttrsP') SpecialAttrsT = typing.TypeVar('SpecialAttrsT', int, float, complex) @@ -5251,7 +10185,7 @@ def test_special_attrs(self): typing.ValuesView: 'ValuesView', # Subscribed ABC classes typing.AbstractSet[Any]: 'AbstractSet', - typing.AsyncContextManager[Any]: 'AsyncContextManager', + typing.AsyncContextManager[Any, Any]: 'AsyncContextManager', typing.AsyncGenerator[Any, Any]: 'AsyncGenerator', typing.AsyncIterable[Any]: 'AsyncIterable', typing.AsyncIterator[Any]: 'AsyncIterator', @@ -5261,7 +10195,7 @@ def test_special_attrs(self): typing.ChainMap[Any, Any]: 'ChainMap', typing.Collection[Any]: 'Collection', typing.Container[Any]: 'Container', - typing.ContextManager[Any]: 'ContextManager', + typing.ContextManager[Any, Any]: 'ContextManager', typing.Coroutine[Any, Any, Any]: 'Coroutine', typing.Counter[Any]: 'Counter', typing.DefaultDict[Any, Any]: 'DefaultDict', @@ -5297,13 +10231,17 @@ def test_special_attrs(self): typing.Literal: 'Literal', typing.NewType: 'NewType', typing.NoReturn: 'NoReturn', + typing.Never: 'Never', typing.Optional: 'Optional', typing.TypeAlias: 'TypeAlias', typing.TypeGuard: 'TypeGuard', + typing.TypeIs: 'TypeIs', typing.TypeVar: 'TypeVar', typing.Union: 'Union', + typing.Self: 'Self', # Subscribed special forms typing.Annotated[Any, "Annotation"]: 'Annotated', + typing.Annotated[int, 'Annotation']: 'Annotated', typing.ClassVar[Any]: 'ClassVar', typing.Concatenate[Any, SpecialAttrsP]: 'Concatenate', typing.Final[Any]: 'Final', @@ -5312,6 +10250,7 @@ def test_special_attrs(self): typing.Literal[True, 2]: 'Literal', typing.Optional[Any]: 'Optional', typing.TypeGuard[Any]: 'TypeGuard', + typing.TypeIs[Any]: 'TypeIs', typing.Union[Any]: 'Any', typing.Union[int, float]: 'Union', # Incompatible special forms (tested in test_special_attrs2) @@ -5345,7 +10284,7 @@ def test_special_attrs2(self): self.assertEqual(fr.__module__, 'typing') # Forward refs are currently unpicklable. for proto in range(pickle.HIGHEST_PROTOCOL + 1): - with self.assertRaises(TypeError) as exc: + with self.assertRaises(TypeError): pickle.dumps(fr, proto) self.assertEqual(SpecialAttrsTests.TypeName.__name__, 'TypeName') @@ -5390,10 +10329,153 @@ class Foo(Generic[T]): def bar(self): pass baz = 3 + __magic__ = 4 + # The class attributes of the original class should be visible even # in dir() of the GenericAlias. See bpo-45755. - self.assertIn('bar', dir(Foo[int])) - self.assertIn('baz', dir(Foo[int])) + dir_items = set(dir(Foo[int])) + for required_item in [ + 'bar', 'baz', + '__args__', '__parameters__', '__origin__', + ]: + with self.subTest(required_item=required_item): + self.assertIn(required_item, dir_items) + self.assertNotIn('__magic__', dir_items) + + +class RevealTypeTests(BaseTestCase): + def test_reveal_type(self): + obj = object() + with captured_stderr() as stderr: + self.assertIs(obj, reveal_type(obj)) + self.assertEqual(stderr.getvalue(), "Runtime type is 'object'\n") + + +class DataclassTransformTests(BaseTestCase): + def test_decorator(self): + def create_model(*, frozen: bool = False, kw_only: bool = True): + return lambda cls: cls + + decorated = dataclass_transform(kw_only_default=True, order_default=False)(create_model) + + class CustomerModel: + id: int + + self.assertIs(decorated, create_model) + self.assertEqual( + decorated.__dataclass_transform__, + { + "eq_default": True, + "order_default": False, + "kw_only_default": True, + "frozen_default": False, + "field_specifiers": (), + "kwargs": {}, + } + ) + self.assertIs( + decorated(frozen=True, kw_only=False)(CustomerModel), + CustomerModel + ) + + def test_base_class(self): + class ModelBase: + def __init_subclass__(cls, *, frozen: bool = False): ... + + Decorated = dataclass_transform( + eq_default=True, + order_default=True, + # Arbitrary unrecognized kwargs are accepted at runtime. + make_everything_awesome=True, + )(ModelBase) + + class CustomerModel(Decorated, frozen=True): + id: int + + self.assertIs(Decorated, ModelBase) + self.assertEqual( + Decorated.__dataclass_transform__, + { + "eq_default": True, + "order_default": True, + "kw_only_default": False, + "frozen_default": False, + "field_specifiers": (), + "kwargs": {"make_everything_awesome": True}, + } + ) + self.assertIsSubclass(CustomerModel, Decorated) + + def test_metaclass(self): + class Field: ... + + class ModelMeta(type): + def __new__( + cls, name, bases, namespace, *, init: bool = True, + ): + return super().__new__(cls, name, bases, namespace) + + Decorated = dataclass_transform( + order_default=True, frozen_default=True, field_specifiers=(Field,) + )(ModelMeta) + + class ModelBase(metaclass=Decorated): ... + + class CustomerModel(ModelBase, init=False): + id: int + + self.assertIs(Decorated, ModelMeta) + self.assertEqual( + Decorated.__dataclass_transform__, + { + "eq_default": True, + "order_default": True, + "kw_only_default": False, + "frozen_default": True, + "field_specifiers": (Field,), + "kwargs": {}, + } + ) + self.assertIsInstance(CustomerModel, Decorated) + + +class NoDefaultTests(BaseTestCase): + def test_pickling(self): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + s = pickle.dumps(NoDefault, proto) + loaded = pickle.loads(s) + self.assertIs(NoDefault, loaded) + + def test_constructor(self): + self.assertIs(NoDefault, type(NoDefault)()) + with self.assertRaises(TypeError): + type(NoDefault)(1) + + def test_repr(self): + self.assertEqual(repr(NoDefault), 'typing.NoDefault') + + @requires_docstrings + def test_doc(self): + self.assertIsInstance(NoDefault.__doc__, str) + + def test_class(self): + self.assertIs(NoDefault.__class__, type(NoDefault)) + + def test_no_call(self): + with self.assertRaises(TypeError): + NoDefault() + + def test_no_attributes(self): + with self.assertRaises(AttributeError): + NoDefault.foo = 3 + with self.assertRaises(AttributeError): + NoDefault.foo + + # TypeError is consistent with the behavior of NoneType + with self.assertRaises(TypeError): + type(NoDefault).foo = 3 + with self.assertRaises(AttributeError): + type(NoDefault).foo class AllTests(BaseTestCase): @@ -5409,7 +10491,7 @@ def test_all(self): # Context managers. self.assertIn('ContextManager', a) self.assertIn('AsyncContextManager', a) - # Check that io and re are not exported. + # Check that former namespaces io and re are not exported. self.assertNotIn('io', a) self.assertNotIn('re', a) # Spot-check that stdlib modules aren't exported. @@ -5422,6 +10504,10 @@ def test_all(self): self.assertIn('SupportsComplex', a) def test_all_exported_names(self): + # ensure all dynamically created objects are actualised + for name in typing.__all__: + getattr(typing, name) + actual_all = set(typing.__all__) computed_all = { k for k, v in vars(typing).items() @@ -5429,10 +10515,6 @@ def test_all_exported_names(self): if k in actual_all or ( # avoid private names not k.startswith('_') and - # avoid things in the io / re typing submodules - k not in typing.io.__all__ and - k not in typing.re.__all__ and - k not in {'io', 're'} and # there's a few types and metaclasses that aren't exported not k.endswith(('Meta', '_contra', '_co')) and not k.upper() == k and @@ -5443,6 +10525,43 @@ def test_all_exported_names(self): self.assertSetEqual(computed_all, actual_all) +class TypeIterationTests(BaseTestCase): + _UNITERABLE_TYPES = ( + Any, + Union, + Union[str, int], + Union[str, T], + List, + Tuple, + Callable, + Callable[..., T], + Callable[[T], str], + Annotated, + Annotated[T, ''], + ) + + def test_cannot_iterate(self): + expected_error_regex = "object is not iterable" + for test_type in self._UNITERABLE_TYPES: + with self.subTest(type=test_type): + with self.assertRaisesRegex(TypeError, expected_error_regex): + iter(test_type) + with self.assertRaisesRegex(TypeError, expected_error_regex): + list(test_type) + with self.assertRaisesRegex(TypeError, expected_error_regex): + for _ in test_type: + pass + + def test_is_not_instance_of_iterable(self): + for type_to_test in self._UNITERABLE_TYPES: + self.assertNotIsInstance(type_to_test, collections.abc.Iterable) + + +def load_tests(loader, tests, pattern): + import doctest + tests.addTests(doctest.DocTestSuite(typing)) + return tests + if __name__ == '__main__': main() diff --git a/Lib/typing.py b/Lib/typing.py index 086d0f3f95..f5c904271d 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -1,35 +1,46 @@ """ -The typing module: Support for gradual typing as defined by PEP 484. - -At large scale, the structure of the module is following: -* Imports and exports, all public names should be explicitly added to __all__. -* Internal helper functions: these should never be used in code outside this module. -* _SpecialForm and its instances (special forms): - Any, NoReturn, ClassVar, Union, Optional, Concatenate -* Classes whose instances can be type arguments in addition to types: - ForwardRef, TypeVar and ParamSpec -* The core of internal generics API: _GenericAlias and _VariadicGenericAlias, the latter is - currently only used by Tuple and Callable. All subscripted types like X[int], Union[int, str], - etc., are instances of either of these classes. -* The public counterpart of the generics API consists of two classes: Generic and Protocol. -* Public helper functions: get_type_hints, overload, cast, no_type_check, - no_type_check_decorator. -* Generic aliases for collections.abc ABCs and few additional protocols. +The typing module: Support for gradual typing as defined by PEP 484 and subsequent PEPs. + +Among other things, the module includes the following: +* Generic, Protocol, and internal machinery to support generic aliases. + All subscripted types like X[int], Union[int, str] are generic aliases. +* Various "special forms" that have unique meanings in type annotations: + NoReturn, Never, ClassVar, Self, Concatenate, Unpack, and others. +* Classes whose instances can be type arguments to generic classes and functions: + TypeVar, ParamSpec, TypeVarTuple. +* Public helper functions: get_type_hints, overload, cast, final, and others. +* Several protocols to support duck-typing: + SupportsFloat, SupportsIndex, SupportsAbs, and others. * Special types: NewType, NamedTuple, TypedDict. -* Wrapper submodules for re and io related types. +* Deprecated aliases for builtin types and collections.abc ABCs. + +Any name not present in __all__ is an implementation detail +that may be changed without notice. Use at your own risk! """ from abc import abstractmethod, ABCMeta import collections +from collections import defaultdict import collections.abc -import contextlib +import copyreg import functools import operator -import re as stdlib_re # Avoid confusion with the re we export. import sys import types from types import WrapperDescriptorType, MethodWrapperType, MethodDescriptorType, GenericAlias +from _typing import ( + _idfunc, + TypeVar, + ParamSpec, + TypeVarTuple, + ParamSpecArgs, + ParamSpecKwargs, + TypeAliasType, + Generic, + NoDefault, +) + # Please keep __all__ alphabetized within each category. __all__ = [ # Super-special typing primitives. @@ -48,6 +59,7 @@ 'Tuple', 'Type', 'TypeVar', + 'TypeVarTuple', 'Union', # ABCs (from collections.abc). @@ -109,30 +121,45 @@ # One-off things. 'AnyStr', + 'assert_type', + 'assert_never', 'cast', + 'clear_overloads', + 'dataclass_transform', 'final', 'get_args', 'get_origin', + 'get_overloads', + 'get_protocol_members', 'get_type_hints', + 'is_protocol', 'is_typeddict', + 'LiteralString', + 'Never', 'NewType', 'no_type_check', 'no_type_check_decorator', + 'NoDefault', 'NoReturn', + 'NotRequired', 'overload', + 'override', 'ParamSpecArgs', 'ParamSpecKwargs', + 'ReadOnly', + 'Required', + 'reveal_type', 'runtime_checkable', + 'Self', 'Text', 'TYPE_CHECKING', 'TypeAlias', 'TypeGuard', + 'TypeIs', + 'TypeAliasType', + 'Unpack', ] -# The pseudo-submodules 're' and 'io' are part of the public -# namespace, but excluded from __all__ because they might stomp on -# legitimate imports of those modules. - def _type_convert(arg, module=None, *, allow_special_forms=False): """For converting None to type(None), and strings to ForwardRef.""" @@ -149,7 +176,7 @@ def _type_check(arg, msg, is_argument=True, module=None, *, allow_special_forms= As a special case, accept None and return type(None) instead. Also wrap strings into ForwardRef instances. Consider several corner cases, for example plain special forms like Union are not valid, while Union[int, str] is OK, etc. - The msg argument is a human-readable error message, e.g:: + The msg argument is a human-readable error message, e.g.:: "Union[arg, ...]: arg should be a type." @@ -165,14 +192,13 @@ def _type_check(arg, msg, is_argument=True, module=None, *, allow_special_forms= if (isinstance(arg, _GenericAlias) and arg.__origin__ in invalid_generic_forms): raise TypeError(f"{arg} is not valid as type argument") - if arg in (Any, NoReturn, Final, TypeAlias): + if arg in (Any, LiteralString, NoReturn, Never, Self, TypeAlias): + return arg + if allow_special_forms and arg in (ClassVar, Final): return arg if isinstance(arg, _SpecialForm) or arg in (Generic, Protocol): raise TypeError(f"Plain {arg} is not valid as type argument") - if isinstance(arg, (type, TypeVar, ForwardRef, types.UnionType, ParamSpec, - ParamSpecArgs, ParamSpecKwargs)): - return arg - if not callable(arg): + if type(arg) is tuple: raise TypeError(f"{msg} Got {arg!r:.100}.") return arg @@ -182,6 +208,30 @@ def _is_param_expr(arg): (tuple, list, ParamSpec, _ConcatenateGenericAlias)) +def _should_unflatten_callable_args(typ, args): + """Internal helper for munging collections.abc.Callable's __args__. + + The canonical representation for a Callable's __args__ flattens the + argument types, see https://github.com/python/cpython/issues/86361. + + For example:: + + >>> import collections.abc + >>> P = ParamSpec('P') + >>> collections.abc.Callable[[int, int], str].__args__ == (int, int, str) + True + >>> collections.abc.Callable[P, str].__args__ == (P, str) + True + + As a result, if we need to reconstruct the Callable from its __args__, + we need to unflatten it. + """ + return ( + typ.__origin__ is collections.abc.Callable + and not (len(args) == 2 and _is_param_expr(args[0])) + ) + + def _type_repr(obj): """Return the repr() of an object, special-casing types (internal helper). @@ -190,99 +240,160 @@ def _type_repr(obj): typically enough to uniquely identify a type. For everything else, we fall back on repr(obj). """ - if isinstance(obj, types.GenericAlias): - return repr(obj) + # When changing this function, don't forget about + # `_collections_abc._type_repr`, which does the same thing + # and must be consistent with this one. if isinstance(obj, type): if obj.__module__ == 'builtins': return obj.__qualname__ return f'{obj.__module__}.{obj.__qualname__}' if obj is ...: - return('...') + return '...' if isinstance(obj, types.FunctionType): return obj.__name__ + if isinstance(obj, tuple): + # Special case for `repr` of types with `ParamSpec`: + return '[' + ', '.join(_type_repr(t) for t in obj) + ']' return repr(obj) -def _collect_type_vars(types_, typevar_types=None): - """Collect all type variable contained - in types in order of first appearance (lexicographic order). For example:: +def _collect_type_parameters(args, *, enforce_default_ordering: bool = True): + """Collect all type parameters in args + in order of first appearance (lexicographic order). - _collect_type_vars((T, List[S, T])) == (T, S) + For example:: + + >>> P = ParamSpec('P') + >>> T = TypeVar('T') + >>> _collect_type_parameters((T, Callable[P, T])) + (~T, ~P) """ - if typevar_types is None: - typevar_types = TypeVar - tvars = [] - for t in types_: - if isinstance(t, typevar_types) and t not in tvars: - tvars.append(t) - if isinstance(t, (_GenericAlias, GenericAlias, types.UnionType)): - tvars.extend([t for t in t.__parameters__ if t not in tvars]) - return tuple(tvars) + # required type parameter cannot appear after parameter with default + default_encountered = False + # or after TypeVarTuple + type_var_tuple_encountered = False + parameters = [] + for t in args: + if isinstance(t, type): + # We don't want __parameters__ descriptor of a bare Python class. + pass + elif isinstance(t, tuple): + # `t` might be a tuple, when `ParamSpec` is substituted with + # `[T, int]`, or `[int, *Ts]`, etc. + for x in t: + for collected in _collect_type_parameters([x]): + if collected not in parameters: + parameters.append(collected) + elif hasattr(t, '__typing_subst__'): + if t not in parameters: + if enforce_default_ordering: + if type_var_tuple_encountered and t.has_default(): + raise TypeError('Type parameter with a default' + ' follows TypeVarTuple') + + if t.has_default(): + default_encountered = True + elif default_encountered: + raise TypeError(f'Type parameter {t!r} without a default' + ' follows type parameter with a default') + + parameters.append(t) + else: + if _is_unpacked_typevartuple(t): + type_var_tuple_encountered = True + for x in getattr(t, '__parameters__', ()): + if x not in parameters: + parameters.append(x) + return tuple(parameters) -def _check_generic(cls, parameters, elen): +def _check_generic_specialization(cls, arguments): """Check correct count for parameters of a generic cls (internal helper). + This gives a nice error message in case of count mismatch. """ - if not elen: + expected_len = len(cls.__parameters__) + if not expected_len: raise TypeError(f"{cls} is not a generic class") - alen = len(parameters) - if alen != elen: - raise TypeError(f"Too {'many' if alen > elen else 'few'} arguments for {cls};" - f" actual {alen}, expected {elen}") - -def _prepare_paramspec_params(cls, params): - """Prepares the parameters for a Generic containing ParamSpec - variables (internal helper). - """ - # Special case where Z[[int, str, bool]] == Z[int, str, bool] in PEP 612. - if (len(cls.__parameters__) == 1 - and params and not _is_param_expr(params[0])): - assert isinstance(cls.__parameters__[0], ParamSpec) - return (params,) - else: - _check_generic(cls, params, len(cls.__parameters__)) - _params = [] - # Convert lists to tuples to help other libraries cache the results. - for p, tvar in zip(params, cls.__parameters__): - if isinstance(tvar, ParamSpec) and isinstance(p, list): - p = tuple(p) - _params.append(p) - return tuple(_params) - -def _deduplicate(params): - # Weed out strict duplicates, preserving the first of each occurrence. - all_params = set(params) - if len(all_params) < len(params): - new_params = [] - for t in params: - if t in all_params: - new_params.append(t) - all_params.remove(t) - params = new_params - assert not all_params, all_params - return params + actual_len = len(arguments) + if actual_len != expected_len: + # deal with defaults + if actual_len < expected_len: + # If the parameter at index `actual_len` in the parameters list + # has a default, then all parameters after it must also have + # one, because we validated as much in _collect_type_parameters(). + # That means that no error needs to be raised here, despite + # the number of arguments being passed not matching the number + # of parameters: all parameters that aren't explicitly + # specialized in this call are parameters with default values. + if cls.__parameters__[actual_len].has_default(): + return + + expected_len -= sum(p.has_default() for p in cls.__parameters__) + expect_val = f"at least {expected_len}" + else: + expect_val = expected_len + raise TypeError(f"Too {'many' if actual_len > expected_len else 'few'} arguments" + f" for {cls}; actual {actual_len}, expected {expect_val}") + + +def _unpack_args(*args): + newargs = [] + for arg in args: + subargs = getattr(arg, '__typing_unpacked_tuple_args__', None) + if subargs is not None and not (subargs and subargs[-1] is ...): + newargs.extend(subargs) + else: + newargs.append(arg) + return newargs + +def _deduplicate(params, *, unhashable_fallback=False): + # Weed out strict duplicates, preserving the first of each occurrence. + try: + return dict.fromkeys(params) + except TypeError: + if not unhashable_fallback: + raise + # Happens for cases like `Annotated[dict, {'x': IntValidator()}]` + return _deduplicate_unhashable(params) + +def _deduplicate_unhashable(unhashable_params): + new_unhashable = [] + for t in unhashable_params: + if t not in new_unhashable: + new_unhashable.append(t) + return new_unhashable + +def _compare_args_orderless(first_args, second_args): + first_unhashable = _deduplicate_unhashable(first_args) + second_unhashable = _deduplicate_unhashable(second_args) + t = list(second_unhashable) + try: + for elem in first_unhashable: + t.remove(elem) + except ValueError: + return False + return not t def _remove_dups_flatten(parameters): - """An internal helper for Union creation and substitution: flatten Unions - among parameters, then remove duplicates. + """Internal helper for Union creation and substitution. + + Flatten Unions among parameters, then remove duplicates. """ # Flatten out Union[Union[...], ...]. params = [] for p in parameters: if isinstance(p, (_UnionGenericAlias, types.UnionType)): params.extend(p.__args__) - elif isinstance(p, tuple) and len(p) > 0 and p[0] is Union: - params.extend(p[1:]) else: params.append(p) - return tuple(_deduplicate(params)) + return tuple(_deduplicate(params, unhashable_fallback=True)) def _flatten_literal_params(parameters): - """An internal helper for Literal creation: flatten Literals among parameters""" + """Internal helper for Literal creation: flatten Literals among parameters.""" params = [] for p in parameters: if isinstance(p, _LiteralGenericAlias): @@ -293,20 +404,29 @@ def _flatten_literal_params(parameters): _cleanups = [] +_caches = {} def _tp_cache(func=None, /, *, typed=False): - """Internal wrapper caching __getitem__ of generic types with a fallback to - original function for non-hashable arguments. + """Internal wrapper caching __getitem__ of generic types. + + For non-hashable arguments, the original function is used as a fallback. """ def decorator(func): - cached = functools.lru_cache(typed=typed)(func) - _cleanups.append(cached.cache_clear) + # The callback 'inner' references the newly created lru_cache + # indirectly by performing a lookup in the global '_caches' dictionary. + # This breaks a reference that can be problematic when combined with + # C API extensions that leak references to types. See GH-98253. + + cache = functools.lru_cache(typed=typed)(func) + _caches[func] = cache + _cleanups.append(cache.cache_clear) + del cache @functools.wraps(func) def inner(*args, **kwds): try: - return cached(*args, **kwds) + return _caches[func](*args, **kwds) except TypeError: pass # All real errors (not unhashable args) are raised below. return func(*args, **kwds) @@ -317,16 +437,61 @@ def inner(*args, **kwds): return decorator -def _eval_type(t, globalns, localns, recursive_guard=frozenset()): + +def _deprecation_warning_for_no_type_params_passed(funcname: str) -> None: + import warnings + + depr_message = ( + f"Failing to pass a value to the 'type_params' parameter " + f"of {funcname!r} is deprecated, as it leads to incorrect behaviour " + f"when calling {funcname} on a stringified annotation " + f"that references a PEP 695 type parameter. " + f"It will be disallowed in Python 3.15." + ) + warnings.warn(depr_message, category=DeprecationWarning, stacklevel=3) + + +class _Sentinel: + __slots__ = () + def __repr__(self): + return '' + + +_sentinel = _Sentinel() + + +def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=frozenset()): """Evaluate all forward references in the given type t. + For use of globalns and localns see the docstring for get_type_hints(). recursive_guard is used to prevent infinite recursion with a recursive ForwardRef. """ + if type_params is _sentinel: + _deprecation_warning_for_no_type_params_passed("typing._eval_type") + type_params = () if isinstance(t, ForwardRef): - return t._evaluate(globalns, localns, recursive_guard) + return t._evaluate(globalns, localns, type_params, recursive_guard=recursive_guard) if isinstance(t, (_GenericAlias, GenericAlias, types.UnionType)): - ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__) + if isinstance(t, GenericAlias): + args = tuple( + ForwardRef(arg) if isinstance(arg, str) else arg + for arg in t.__args__ + ) + is_unpacked = t.__unpacked__ + if _should_unflatten_callable_args(t, args): + t = t.__origin__[(args[:-1], args[-1])] + else: + t = t.__origin__[args] + if is_unpacked: + t = Unpack[t] + + ev_args = tuple( + _eval_type( + a, globalns, localns, type_params, recursive_guard=recursive_guard + ) + for a in t.__args__ + ) if ev_args == t.__args__: return t if isinstance(t, GenericAlias): @@ -339,28 +504,36 @@ def _eval_type(t, globalns, localns, recursive_guard=frozenset()): class _Final: - """Mixin to prohibit subclassing""" + """Mixin to prohibit subclassing.""" __slots__ = ('__weakref__',) - def __init_subclass__(self, /, *args, **kwds): + def __init_subclass__(cls, /, *args, **kwds): if '_root' not in kwds: raise TypeError("Cannot subclass special typing classes") -class _Immutable: - """Mixin to indicate that object should not be copied.""" - __slots__ = () - def __copy__(self): - return self +class _NotIterable: + """Mixin to prevent iteration, without being compatible with Iterable. + + That is, we could do:: - def __deepcopy__(self, memo): - return self + def __iter__(self): raise TypeError() + + But this would make users of this mixin duck type-compatible with + collections.abc.Iterable - isinstance(foo, Iterable) would be True. + + Luckily, we can instead prevent iteration by setting __iter__ to None, which + is treated specially. + """ + + __slots__ = () + __iter__ = None # Internal indicator of special typing constructs. # See __doc__ instance attribute for specific docs. -class _SpecialForm(_Final, _root=True): +class _SpecialForm(_Final, _NotIterable, _root=True): __slots__ = ('_name', '__doc__', '_getitem') def __init__(self, getitem): @@ -403,15 +576,26 @@ def __getitem__(self, parameters): return self._getitem(self, parameters) -class _LiteralSpecialForm(_SpecialForm, _root=True): +class _TypedCacheSpecialForm(_SpecialForm, _root=True): def __getitem__(self, parameters): if not isinstance(parameters, tuple): parameters = (parameters,) return self._getitem(self, *parameters) -@_SpecialForm -def Any(self, parameters): +class _AnyMeta(type): + def __instancecheck__(self, obj): + if self is Any: + raise TypeError("typing.Any cannot be used with isinstance()") + return super().__instancecheck__(obj) + + def __repr__(self): + if self is Any: + return "typing.Any" + return super().__repr__() # respect to subclasses + + +class Any(metaclass=_AnyMeta): """Special type indicating an unconstrained type. - Any is compatible with every type. @@ -420,43 +604,128 @@ def Any(self, parameters): Note that all the above statements are true from the point of view of static type checkers. At runtime, Any should not be used with instance - or class checks. + checks. """ - raise TypeError(f"{self} is not subscriptable") + + def __new__(cls, *args, **kwargs): + if cls is Any: + raise TypeError("Any cannot be instantiated") + return super().__new__(cls) + @_SpecialForm def NoReturn(self, parameters): """Special type indicating functions that never return. + Example:: - from typing import NoReturn + from typing import NoReturn + + def stop() -> NoReturn: + raise Exception('no way') + + NoReturn can also be used as a bottom type, a type that + has no values. Starting in Python 3.11, the Never type should + be used for this concept instead. Type checkers should treat the two + equivalently. + """ + raise TypeError(f"{self} is not subscriptable") + +# This is semantically identical to NoReturn, but it is implemented +# separately so that type checkers can distinguish between the two +# if they want. +@_SpecialForm +def Never(self, parameters): + """The bottom type, a type that has no members. + + This can be used to define a function that should never be + called, or a function that never returns:: - def stop() -> NoReturn: - raise Exception('no way') + from typing import Never + + def never_call_me(arg: Never) -> None: + pass - This type is invalid in other positions, e.g., ``List[NoReturn]`` - will fail in static type checkers. + def int_or_str(arg: int | str) -> None: + never_call_me(arg) # type checker error + match arg: + case int(): + print("It's an int") + case str(): + print("It's a str") + case _: + never_call_me(arg) # OK, arg is of type Never """ raise TypeError(f"{self} is not subscriptable") + +@_SpecialForm +def Self(self, parameters): + """Used to spell the type of "self" in classes. + + Example:: + + from typing import Self + + class Foo: + def return_self(self) -> Self: + ... + return self + + This is especially useful for: + - classmethods that are used as alternative constructors + - annotating an `__enter__` method which returns self + """ + raise TypeError(f"{self} is not subscriptable") + + +@_SpecialForm +def LiteralString(self, parameters): + """Represents an arbitrary literal string. + + Example:: + + from typing import LiteralString + + def run_query(sql: LiteralString) -> None: + ... + + def caller(arbitrary_string: str, literal_string: LiteralString) -> None: + run_query("SELECT * FROM students") # OK + run_query(literal_string) # OK + run_query("SELECT * FROM " + literal_string) # OK + run_query(arbitrary_string) # type checker error + run_query( # type checker error + f"SELECT * FROM students WHERE name = {arbitrary_string}" + ) + + Only string literals and other LiteralStrings are compatible + with LiteralString. This provides a tool to help prevent + security issues such as SQL injection. + """ + raise TypeError(f"{self} is not subscriptable") + + @_SpecialForm def ClassVar(self, parameters): """Special type construct to mark class variables. An annotation wrapped in ClassVar indicates that a given attribute is intended to be used as a class variable and - should not be set on instances of that class. Usage:: + should not be set on instances of that class. - class Starship: - stats: ClassVar[Dict[str, int]] = {} # class variable - damage: int = 10 # instance variable + Usage:: + + class Starship: + stats: ClassVar[dict[str, int]] = {} # class variable + damage: int = 10 # instance variable ClassVar accepts only types and cannot be further subscribed. Note that ClassVar is not a class itself, and should not be used with isinstance() or issubclass(). """ - item = _type_check(parameters, f'{self} accepts only single type.') + item = _type_check(parameters, f'{self} accepts only single type.', allow_special_forms=True) return _GenericAlias(self, (item,)) @_SpecialForm @@ -464,45 +733,50 @@ def Final(self, parameters): """Special typing construct to indicate final names to type checkers. A final name cannot be re-assigned or overridden in a subclass. - For example: - MAX_SIZE: Final = 9000 - MAX_SIZE += 1 # Error reported by type checker + For example:: - class Connection: - TIMEOUT: Final[int] = 10 + MAX_SIZE: Final = 9000 + MAX_SIZE += 1 # Error reported by type checker - class FastConnector(Connection): - TIMEOUT = 1 # Error reported by type checker + class Connection: + TIMEOUT: Final[int] = 10 + + class FastConnector(Connection): + TIMEOUT = 1 # Error reported by type checker There is no runtime checking of these properties. """ - item = _type_check(parameters, f'{self} accepts only single type.') + item = _type_check(parameters, f'{self} accepts only single type.', allow_special_forms=True) return _GenericAlias(self, (item,)) @_SpecialForm def Union(self, parameters): """Union type; Union[X, Y] means either X or Y. - To define a union, use e.g. Union[int, str]. Details: + On Python 3.10 and higher, the | operator + can also be used to denote unions; + X | Y means the same thing to the type checker as Union[X, Y]. + + To define a union, use e.g. Union[int, str]. Details: - The arguments must be types and there must be at least one. - None as an argument is a special case and is replaced by type(None). - Unions of unions are flattened, e.g.:: - Union[Union[int, str], float] == Union[int, str, float] + assert Union[Union[int, str], float] == Union[int, str, float] - Unions of a single argument vanish, e.g.:: - Union[int] == int # The constructor actually returns int + assert Union[int] == int # The constructor actually returns int - Redundant arguments are skipped, e.g.:: - Union[int, str, int] == Union[int, str] + assert Union[int, str, int] == Union[int, str] - When comparing unions, the argument order is ignored, e.g.:: - Union[int, str] == Union[str, int] + assert Union[int, str] == Union[str, int] - You cannot subclass or instantiate a union. - You can use Optional[X] as a shorthand for Union[X, None]. @@ -520,33 +794,39 @@ def Union(self, parameters): return _UnionGenericAlias(self, parameters, name="Optional") return _UnionGenericAlias(self, parameters) -@_SpecialForm -def Optional(self, parameters): - """Optional type. +def _make_union(left, right): + """Used from the C implementation of TypeVar. - Optional[X] is equivalent to Union[X, None]. + TypeVar.__or__ calls this instead of returning types.UnionType + because we want to allow unions between TypeVars and strings + (forward references). """ + return Union[left, right] + +@_SpecialForm +def Optional(self, parameters): + """Optional[X] is equivalent to Union[X, None].""" arg = _type_check(parameters, f"{self} requires a single type.") return Union[arg, type(None)] -@_LiteralSpecialForm +@_TypedCacheSpecialForm @_tp_cache(typed=True) def Literal(self, *parameters): """Special typing form to define literal types (a.k.a. value types). This form can be used to indicate to type checkers that the corresponding variable or function parameter has a value equivalent to the provided - literal (or one of several literals): + literal (or one of several literals):: - def validate_simple(data: Any) -> Literal[True]: # always returns True - ... + def validate_simple(data: Any) -> Literal[True]: # always returns True + ... - MODE = Literal['r', 'rb', 'w', 'wb'] - def open_helper(file: str, mode: MODE) -> str: - ... + MODE = Literal['r', 'rb', 'w', 'wb'] + def open_helper(file: str, mode: MODE) -> str: + ... - open_helper('/some/path', 'r') # Passes type check - open_helper('/other/path', 'typo') # Error in type checker + open_helper('/some/path', 'r') # Passes type check + open_helper('/other/path', 'typo') # Error in type checker Literal[...] cannot be subclassed. At runtime, an arbitrary value is allowed as type argument to Literal[...], but type checkers may @@ -566,7 +846,9 @@ def open_helper(file: str, mode: MODE) -> str: @_SpecialForm def TypeAlias(self, parameters): - """Special marker indicating that an assignment should + """Special form for marking type aliases. + + Use TypeAlias to indicate that an assignment should be recognized as a proper type alias definition by type checkers. @@ -581,13 +863,15 @@ def TypeAlias(self, parameters): @_SpecialForm def Concatenate(self, parameters): - """Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a - higher order function which adds, removes or transforms parameters of a - callable. + """Special form for annotating higher-order functions. + + ``Concatenate`` can be used in conjunction with ``ParamSpec`` and + ``Callable`` to represent a higher-order function which adds, removes or + transforms the parameters of a callable. For example:: - Callable[Concatenate[int, P], int] + Callable[Concatenate[int, P], int] See PEP 612 for detailed information. """ @@ -595,56 +879,62 @@ def Concatenate(self, parameters): raise TypeError("Cannot take a Concatenate of no types.") if not isinstance(parameters, tuple): parameters = (parameters,) - if not isinstance(parameters[-1], ParamSpec): + if not (parameters[-1] is ... or isinstance(parameters[-1], ParamSpec)): raise TypeError("The last parameter to Concatenate should be a " - "ParamSpec variable.") + "ParamSpec variable or ellipsis.") msg = "Concatenate[arg, ...]: each arg must be a type." parameters = (*(_type_check(p, msg) for p in parameters[:-1]), parameters[-1]) - return _ConcatenateGenericAlias(self, parameters, - _typevar_types=(TypeVar, ParamSpec), - _paramspec_tvars=True) + return _ConcatenateGenericAlias(self, parameters) @_SpecialForm def TypeGuard(self, parameters): - """Special typing form used to annotate the return type of a user-defined - type guard function. ``TypeGuard`` only accepts a single type argument. + """Special typing construct for marking user-defined type predicate functions. + + ``TypeGuard`` can be used to annotate the return type of a user-defined + type predicate function. ``TypeGuard`` only accepts a single type argument. At runtime, functions marked this way should return a boolean. ``TypeGuard`` aims to benefit *type narrowing* -- a technique used by static type checkers to determine a more precise type of an expression within a program's code flow. Usually type narrowing is done by analyzing conditional code flow and applying the narrowing to a block of code. The - conditional expression here is sometimes referred to as a "type guard". + conditional expression here is sometimes referred to as a "type predicate". Sometimes it would be convenient to use a user-defined boolean function - as a type guard. Such a function should use ``TypeGuard[...]`` as its - return type to alert static type checkers to this intention. + as a type predicate. Such a function should use ``TypeGuard[...]`` or + ``TypeIs[...]`` as its return type to alert static type checkers to + this intention. ``TypeGuard`` should be used over ``TypeIs`` when narrowing + from an incompatible type (e.g., ``list[object]`` to ``list[int]``) or when + the function does not return ``True`` for all instances of the narrowed type. - Using ``-> TypeGuard`` tells the static type checker that for a given - function: + Using ``-> TypeGuard[NarrowedType]`` tells the static type checker that + for a given function: 1. The return value is a boolean. 2. If the return value is ``True``, the type of its argument - is the type inside ``TypeGuard``. + is ``NarrowedType``. + + For example:: - For example:: + def is_str_list(val: list[object]) -> TypeGuard[list[str]]: + '''Determines whether all objects in the list are strings''' + return all(isinstance(x, str) for x in val) - def is_str(val: Union[str, float]): - # "isinstance" type guard - if isinstance(val, str): - # Type of ``val`` is narrowed to ``str`` - ... - else: - # Else, type of ``val`` is narrowed to ``float``. - ... + def func1(val: list[object]): + if is_str_list(val): + # Type of ``val`` is narrowed to ``list[str]``. + print(" ".join(val)) + else: + # Type of ``val`` remains as ``list[object]``. + print("Not a list of strings!") Strict type narrowing is not enforced -- ``TypeB`` need not be a narrower form of ``TypeA`` (it can even be a wider form) and this may lead to type-unsafe results. The main reason is to allow for things like - narrowing ``List[object]`` to ``List[str]`` even though the latter is not - a subtype of the former, since ``List`` is invariant. The responsibility of - writing type-safe type guards is left to the user. + narrowing ``list[object]`` to ``list[str]`` even though the latter is not + a subtype of the former, since ``list`` is invariant. The responsibility of + writing type-safe type predicates is left to the user. ``TypeGuard`` also works with type variables. For more information, see PEP 647 (User-Defined Type Guards). @@ -653,6 +943,75 @@ def is_str(val: Union[str, float]): return _GenericAlias(self, (item,)) +@_SpecialForm +def TypeIs(self, parameters): + """Special typing construct for marking user-defined type predicate functions. + + ``TypeIs`` can be used to annotate the return type of a user-defined + type predicate function. ``TypeIs`` only accepts a single type argument. + At runtime, functions marked this way should return a boolean and accept + at least one argument. + + ``TypeIs`` aims to benefit *type narrowing* -- a technique used by static + type checkers to determine a more precise type of an expression within a + program's code flow. Usually type narrowing is done by analyzing + conditional code flow and applying the narrowing to a block of code. The + conditional expression here is sometimes referred to as a "type predicate". + + Sometimes it would be convenient to use a user-defined boolean function + as a type predicate. Such a function should use ``TypeIs[...]`` or + ``TypeGuard[...]`` as its return type to alert static type checkers to + this intention. ``TypeIs`` usually has more intuitive behavior than + ``TypeGuard``, but it cannot be used when the input and output types + are incompatible (e.g., ``list[object]`` to ``list[int]``) or when the + function does not return ``True`` for all instances of the narrowed type. + + Using ``-> TypeIs[NarrowedType]`` tells the static type checker that for + a given function: + + 1. The return value is a boolean. + 2. If the return value is ``True``, the type of its argument + is the intersection of the argument's original type and + ``NarrowedType``. + 3. If the return value is ``False``, the type of its argument + is narrowed to exclude ``NarrowedType``. + + For example:: + + from typing import assert_type, final, TypeIs + + class Parent: pass + class Child(Parent): pass + @final + class Unrelated: pass + + def is_parent(val: object) -> TypeIs[Parent]: + return isinstance(val, Parent) + + def run(arg: Child | Unrelated): + if is_parent(arg): + # Type of ``arg`` is narrowed to the intersection + # of ``Parent`` and ``Child``, which is equivalent to + # ``Child``. + assert_type(arg, Child) + else: + # Type of ``arg`` is narrowed to exclude ``Parent``, + # so only ``Unrelated`` is left. + assert_type(arg, Unrelated) + + The type inside ``TypeIs`` must be consistent with the type of the + function's argument; if it is not, static type checkers will raise + an error. An incorrectly written ``TypeIs`` function can lead to + unsound behavior in the type system; it is the user's responsibility + to write such functions in a type-safe manner. + + ``TypeIs`` also works with type variables. For more information, see + PEP 742 (Narrowing types with ``TypeIs``). + """ + item = _type_check(parameters, f'{self} accepts only single type.') + return _GenericAlias(self, (item,)) + + class ForwardRef(_Final, _root=True): """Internal wrapper to hold a forward reference.""" @@ -664,10 +1023,19 @@ class ForwardRef(_Final, _root=True): def __init__(self, arg, is_argument=True, module=None, *, is_class=False): if not isinstance(arg, str): raise TypeError(f"Forward reference must be a string -- got {arg!r}") + + # If we do `def f(*args: *Ts)`, then we'll have `arg = '*Ts'`. + # Unfortunately, this isn't a valid expression on its own, so we + # do the unpacking manually. + if arg.startswith('*'): + arg_to_compile = f'({arg},)[0]' # E.g. (*Ts,)[0] or (*tuple[int, int],)[0] + else: + arg_to_compile = arg try: - code = compile(arg, '', 'eval') + code = compile(arg_to_compile, '', 'eval') except SyntaxError: raise SyntaxError(f"Forward reference must be an expression -- got {arg!r}") + self.__forward_arg__ = arg self.__forward_code__ = code self.__forward_evaluated__ = False @@ -676,7 +1044,10 @@ def __init__(self, arg, is_argument=True, module=None, *, is_class=False): self.__forward_is_class__ = is_class self.__forward_module__ = module - def _evaluate(self, globalns, localns, recursive_guard): + def _evaluate(self, globalns, localns, type_params=_sentinel, *, recursive_guard): + if type_params is _sentinel: + _deprecation_warning_for_no_type_params_passed("typing.ForwardRef._evaluate") + type_params = () if self.__forward_arg__ in recursive_guard: return self if not self.__forward_evaluated__ or localns is not globalns: @@ -690,6 +1061,22 @@ def _evaluate(self, globalns, localns, recursive_guard): globalns = getattr( sys.modules.get(self.__forward_module__, None), '__dict__', globalns ) + + # type parameters require some special handling, + # as they exist in their own scope + # but `eval()` does not have a dedicated parameter for that scope. + # For classes, names in type parameter scopes should override + # names in the global scope (which here are called `localns`!), + # but should in turn be overridden by names in the class scope + # (which here are called `globalns`!) + if type_params: + globalns, localns = dict(globalns), dict(localns) + for param in type_params: + param_name = param.__name__ + if not self.__forward_is_class__ or param_name not in globalns: + globalns[param_name] = param + localns.pop(param_name, None) + type_ = _type_check( eval(self.__forward_code__, globalns, localns), "Forward references must evaluate to types.", @@ -697,7 +1084,11 @@ def _evaluate(self, globalns, localns, recursive_guard): allow_special_forms=self.__forward_is_class__, ) self.__forward_value__ = _eval_type( - type_, globalns, localns, recursive_guard | {self.__forward_arg__} + type_, + globalns, + localns, + type_params, + recursive_guard=(recursive_guard | {self.__forward_arg__}), ) self.__forward_evaluated__ = True return self.__forward_value__ @@ -714,236 +1105,205 @@ def __eq__(self, other): def __hash__(self): return hash((self.__forward_arg__, self.__forward_module__)) - def __repr__(self): - return f'ForwardRef({self.__forward_arg__!r})' - -class _TypeVarLike: - """Mixin for TypeVar-like types (TypeVar and ParamSpec).""" - def __init__(self, bound, covariant, contravariant): - """Used to setup TypeVars and ParamSpec's bound, covariant and - contravariant attributes. - """ - if covariant and contravariant: - raise ValueError("Bivariant types are not supported.") - self.__covariant__ = bool(covariant) - self.__contravariant__ = bool(contravariant) - if bound: - self.__bound__ = _type_check(bound, "Bound must be a type.") - else: - self.__bound__ = None - - def __or__(self, right): - return Union[self, right] + def __or__(self, other): + return Union[self, other] - def __ror__(self, left): - return Union[left, self] + def __ror__(self, other): + return Union[other, self] def __repr__(self): - if self.__covariant__: - prefix = '+' - elif self.__contravariant__: - prefix = '-' + if self.__forward_module__ is None: + module_repr = '' else: - prefix = '~' - return prefix + self.__name__ + module_repr = f', module={self.__forward_module__!r}' + return f'ForwardRef({self.__forward_arg__!r}{module_repr})' - def __reduce__(self): - return self.__name__ +def _is_unpacked_typevartuple(x: Any) -> bool: + return ((not isinstance(x, type)) and + getattr(x, '__typing_is_unpacked_typevartuple__', False)) -class TypeVar( _Final, _Immutable, _TypeVarLike, _root=True): - """Type variable. - Usage:: - - T = TypeVar('T') # Can be anything - A = TypeVar('A', str, bytes) # Must be str or bytes - - Type variables exist primarily for the benefit of static type - checkers. They serve as the parameters for generic types as well - as for generic function definitions. See class Generic for more - information on generic types. Generic functions work as follows: - - def repeat(x: T, n: int) -> List[T]: - '''Return a list containing n references to x.''' - return [x]*n - - def longest(x: A, y: A) -> A: - '''Return the longest of two strings.''' - return x if len(x) >= len(y) else y - - The latter example's signature is essentially the overloading - of (str, str) -> str and (bytes, bytes) -> bytes. Also note - that if the arguments are instances of some subclass of str, - the return type is still plain str. - - At runtime, isinstance(x, T) and issubclass(C, T) will raise TypeError. - - Type variables defined with covariant=True or contravariant=True - can be used to declare covariant or contravariant generic types. - See PEP 484 for more details. By default generic types are invariant - in all type variables. - - Type variables can be introspected. e.g.: - - T.__name__ == 'T' - T.__constraints__ == () - T.__covariant__ == False - T.__contravariant__ = False - A.__constraints__ == (str, bytes) - - Note that only type variables defined in global scope can be pickled. - """ - - __slots__ = ('__name__', '__bound__', '__constraints__', - '__covariant__', '__contravariant__', '__dict__') - - def __init__(self, name, *constraints, bound=None, - covariant=False, contravariant=False): - self.__name__ = name - super().__init__(bound, covariant, contravariant) - if constraints and bound is not None: - raise TypeError("Constraints cannot be combined with bound=...") - if constraints and len(constraints) == 1: - raise TypeError("A single constraint is not allowed") - msg = "TypeVar(name, constraint, ...): constraints must be types." - self.__constraints__ = tuple(_type_check(t, msg) for t in constraints) - try: - def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') # for pickling - except (AttributeError, ValueError): - def_mod = None - if def_mod != 'typing': - self.__module__ = def_mod - - -class ParamSpecArgs(_Final, _Immutable, _root=True): - """The args for a ParamSpec object. - - Given a ParamSpec object P, P.args is an instance of ParamSpecArgs. +def _is_typevar_like(x: Any) -> bool: + return isinstance(x, (TypeVar, ParamSpec)) or _is_unpacked_typevartuple(x) - ParamSpecArgs objects have a reference back to their ParamSpec: - P.args.__origin__ is P - - This type is meant for runtime introspection and has no special meaning to - static type checkers. - """ - def __init__(self, origin): - self.__origin__ = origin +def _typevar_subst(self, arg): + msg = "Parameters to generic types must be types." + arg = _type_check(arg, msg, is_argument=True) + if ((isinstance(arg, _GenericAlias) and arg.__origin__ is Unpack) or + (isinstance(arg, GenericAlias) and getattr(arg, '__unpacked__', False))): + raise TypeError(f"{arg} is not valid as type argument") + return arg - def __repr__(self): - return f"{self.__origin__.__name__}.args" - def __eq__(self, other): - if not isinstance(other, ParamSpecArgs): - return NotImplemented - return self.__origin__ == other.__origin__ +def _typevartuple_prepare_subst(self, alias, args): + params = alias.__parameters__ + typevartuple_index = params.index(self) + for param in params[typevartuple_index + 1:]: + if isinstance(param, TypeVarTuple): + raise TypeError(f"More than one TypeVarTuple parameter in {alias}") + + alen = len(args) + plen = len(params) + left = typevartuple_index + right = plen - typevartuple_index - 1 + var_tuple_index = None + fillarg = None + for k, arg in enumerate(args): + if not isinstance(arg, type): + subargs = getattr(arg, '__typing_unpacked_tuple_args__', None) + if subargs and len(subargs) == 2 and subargs[-1] is ...: + if var_tuple_index is not None: + raise TypeError("More than one unpacked arbitrary-length tuple argument") + var_tuple_index = k + fillarg = subargs[0] + if var_tuple_index is not None: + left = min(left, var_tuple_index) + right = min(right, alen - var_tuple_index - 1) + elif left + right > alen: + raise TypeError(f"Too few arguments for {alias};" + f" actual {alen}, expected at least {plen-1}") + if left == alen - right and self.has_default(): + replacement = _unpack_args(self.__default__) + else: + replacement = args[left: alen - right] + + return ( + *args[:left], + *([fillarg]*(typevartuple_index - left)), + replacement, + *([fillarg]*(plen - right - left - typevartuple_index - 1)), + *args[alen - right:], + ) + + +def _paramspec_subst(self, arg): + if isinstance(arg, (list, tuple)): + arg = tuple(_type_check(a, "Expected a type.") for a in arg) + elif not _is_param_expr(arg): + raise TypeError(f"Expected a list of types, an ellipsis, " + f"ParamSpec, or Concatenate. Got {arg}") + return arg -class ParamSpecKwargs(_Final, _Immutable, _root=True): - """The kwargs for a ParamSpec object. +def _paramspec_prepare_subst(self, alias, args): + params = alias.__parameters__ + i = params.index(self) + if i == len(args) and self.has_default(): + args = [*args, self.__default__] + if i >= len(args): + raise TypeError(f"Too few arguments for {alias}") + # Special case where Z[[int, str, bool]] == Z[int, str, bool] in PEP 612. + if len(params) == 1 and not _is_param_expr(args[0]): + assert i == 0 + args = (args,) + # Convert lists to tuples to help other libraries cache the results. + elif isinstance(args[i], list): + args = (*args[:i], tuple(args[i]), *args[i+1:]) + return args - Given a ParamSpec object P, P.kwargs is an instance of ParamSpecKwargs. - ParamSpecKwargs objects have a reference back to their ParamSpec: +@_tp_cache +def _generic_class_getitem(cls, args): + """Parameterizes a generic class. - P.kwargs.__origin__ is P + At least, parameterizing a generic class is the *main* thing this method + does. For example, for some generic class `Foo`, this is called when we + do `Foo[int]` - there, with `cls=Foo` and `args=int`. - This type is meant for runtime introspection and has no special meaning to - static type checkers. + However, note that this method is also called when defining generic + classes in the first place with `class Foo(Generic[T]): ...`. """ - def __init__(self, origin): - self.__origin__ = origin - - def __repr__(self): - return f"{self.__origin__.__name__}.kwargs" - - def __eq__(self, other): - if not isinstance(other, ParamSpecKwargs): - return NotImplemented - return self.__origin__ == other.__origin__ - + if not isinstance(args, tuple): + args = (args,) -class ParamSpec(_Final, _Immutable, _TypeVarLike, _root=True): - """Parameter specification variable. - - Usage:: + args = tuple(_type_convert(p) for p in args) + is_generic_or_protocol = cls in (Generic, Protocol) - P = ParamSpec('P') - - Parameter specification variables exist primarily for the benefit of static - type checkers. They are used to forward the parameter types of one - callable to another callable, a pattern commonly found in higher order - functions and decorators. They are only valid when used in ``Concatenate``, - or as the first argument to ``Callable``, or as parameters for user-defined - Generics. See class Generic for more information on generic types. An - example for annotating a decorator:: - - T = TypeVar('T') - P = ParamSpec('P') - - def add_logging(f: Callable[P, T]) -> Callable[P, T]: - '''A type-safe decorator to add logging to a function.''' - def inner(*args: P.args, **kwargs: P.kwargs) -> T: - logging.info(f'{f.__name__} was called') - return f(*args, **kwargs) - return inner - - @add_logging - def add_two(x: float, y: float) -> float: - '''Add two numbers together.''' - return x + y - - Parameter specification variables defined with covariant=True or - contravariant=True can be used to declare covariant or contravariant - generic types. These keyword arguments are valid, but their actual semantics - are yet to be decided. See PEP 612 for details. - - Parameter specification variables can be introspected. e.g.: - - P.__name__ == 'T' - P.__bound__ == None - P.__covariant__ == False - P.__contravariant__ == False - - Note that only parameter specification variables defined in global scope can - be pickled. - """ + if is_generic_or_protocol: + # Generic and Protocol can only be subscripted with unique type variables. + if not args: + raise TypeError( + f"Parameter list to {cls.__qualname__}[...] cannot be empty" + ) + if not all(_is_typevar_like(p) for p in args): + raise TypeError( + f"Parameters to {cls.__name__}[...] must all be type variables " + f"or parameter specification variables.") + if len(set(args)) != len(args): + raise TypeError( + f"Parameters to {cls.__name__}[...] must all be unique") + else: + # Subscripting a regular Generic subclass. + for param in cls.__parameters__: + prepare = getattr(param, '__typing_prepare_subst__', None) + if prepare is not None: + args = prepare(cls, args) + _check_generic_specialization(cls, args) - __slots__ = ('__name__', '__bound__', '__covariant__', '__contravariant__', - '__dict__') + new_args = [] + for param, new_arg in zip(cls.__parameters__, args): + if isinstance(param, TypeVarTuple): + new_args.extend(new_arg) + else: + new_args.append(new_arg) + args = tuple(new_args) - @property - def args(self): - return ParamSpecArgs(self) + return _GenericAlias(cls, args) - @property - def kwargs(self): - return ParamSpecKwargs(self) - def __init__(self, name, *, bound=None, covariant=False, contravariant=False): - self.__name__ = name - super().__init__(bound, covariant, contravariant) - try: - def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): - def_mod = None - if def_mod != 'typing': - self.__module__ = def_mod +def _generic_init_subclass(cls, *args, **kwargs): + super(Generic, cls).__init_subclass__(*args, **kwargs) + tvars = [] + if '__orig_bases__' in cls.__dict__: + error = Generic in cls.__orig_bases__ + else: + error = (Generic in cls.__bases__ and + cls.__name__ != 'Protocol' and + type(cls) != _TypedDictMeta) + if error: + raise TypeError("Cannot inherit from plain Generic") + if '__orig_bases__' in cls.__dict__: + tvars = _collect_type_parameters(cls.__orig_bases__) + # Look for Generic[T1, ..., Tn]. + # If found, tvars must be a subset of it. + # If not found, tvars is it. + # Also check for and reject plain Generic, + # and reject multiple Generic[...]. + gvars = None + for base in cls.__orig_bases__: + if (isinstance(base, _GenericAlias) and + base.__origin__ is Generic): + if gvars is not None: + raise TypeError( + "Cannot inherit from Generic[...] multiple times.") + gvars = base.__parameters__ + if gvars is not None: + tvarset = set(tvars) + gvarset = set(gvars) + if not tvarset <= gvarset: + s_vars = ', '.join(str(t) for t in tvars if t not in gvarset) + s_args = ', '.join(str(g) for g in gvars) + raise TypeError(f"Some type variables ({s_vars}) are" + f" not listed in Generic[{s_args}]") + tvars = gvars + cls.__parameters__ = tuple(tvars) def _is_dunder(attr): return attr.startswith('__') and attr.endswith('__') class _BaseGenericAlias(_Final, _root=True): - """The central part of internal API. + """The central part of the internal API. This represents a generic version of type 'origin' with type arguments 'params'. There are two kind of these aliases: user defined and special. The special ones are wrappers around builtin collections and ABCs in collections.abc. These must - have 'name' always set. If 'inst' is False, then the alias can't be instantiated, + have 'name' always set. If 'inst' is False, then the alias can't be instantiated; this is used by e.g. typing.List and typing.Dict. """ + def __init__(self, origin, *, inst=True, name=None): self._inst = inst self._name = name @@ -957,7 +1317,9 @@ def __call__(self, *args, **kwargs): result = self.__origin__(*args, **kwargs) try: result.__orig_class__ = self - except AttributeError: + # Some objects raise TypeError (or something even more exotic) + # if you try to set attributes on them; we guard against that here + except Exception: pass return result @@ -965,9 +1327,29 @@ def __mro_entries__(self, bases): res = [] if self.__origin__ not in bases: res.append(self.__origin__) + + # Check if any base that occurs after us in `bases` is either itself a + # subclass of Generic, or something which will add a subclass of Generic + # to `__bases__` via its `__mro_entries__`. If not, add Generic + # ourselves. The goal is to ensure that Generic (or a subclass) will + # appear exactly once in the final bases tuple. If we let it appear + # multiple times, we risk "can't form a consistent MRO" errors. i = bases.index(self) for b in bases[i+1:]: - if isinstance(b, _BaseGenericAlias) or issubclass(b, Generic): + if isinstance(b, _BaseGenericAlias): + break + if not isinstance(b, type): + meth = getattr(b, "__mro_entries__", None) + new_bases = meth(bases) if meth else None + if ( + isinstance(new_bases, tuple) and + any( + isinstance(b2, type) and issubclass(b2, Generic) + for b2 in new_bases + ) + ): + break + elif issubclass(b, Generic): break else: res.append(Generic) @@ -984,8 +1366,7 @@ def __getattr__(self, attr): raise AttributeError(attr) def __setattr__(self, attr, val): - if _is_dunder(attr) or attr in {'_name', '_inst', '_nparams', - '_typevar_types', '_paramspec_tvars'}: + if _is_dunder(attr) or attr in {'_name', '_inst', '_nparams', '_defaults'}: super().__setattr__(attr, val) else: setattr(self.__origin__, attr, val) @@ -1001,6 +1382,7 @@ def __dir__(self): return list(set(super().__dir__() + [attr for attr in dir(self.__origin__) if not _is_dunder(attr)])) + # Special typing constructs Union, Optional, Generic, Callable and Tuple # use three special attributes for internal bookkeeping of generic types: # * __parameters__ is a tuple of unique free type parameters of a generic @@ -1013,18 +1395,42 @@ def __dir__(self): class _GenericAlias(_BaseGenericAlias, _root=True): - def __init__(self, origin, params, *, inst=True, name=None, - _typevar_types=TypeVar, - _paramspec_tvars=False): + # The type of parameterized generics. + # + # That is, for example, `type(List[int])` is `_GenericAlias`. + # + # Objects which are instances of this class include: + # * Parameterized container types, e.g. `Tuple[int]`, `List[int]`. + # * Note that native container types, e.g. `tuple`, `list`, use + # `types.GenericAlias` instead. + # * Parameterized classes: + # class C[T]: pass + # # C[int] is a _GenericAlias + # * `Callable` aliases, generic `Callable` aliases, and + # parameterized `Callable` aliases: + # T = TypeVar('T') + # # _CallableGenericAlias inherits from _GenericAlias. + # A = Callable[[], None] # _CallableGenericAlias + # B = Callable[[T], None] # _CallableGenericAlias + # C = B[int] # _CallableGenericAlias + # * Parameterized `Final`, `ClassVar`, `TypeGuard`, and `TypeIs`: + # # All _GenericAlias + # Final[int] + # ClassVar[float] + # TypeGuard[bool] + # TypeIs[range] + + def __init__(self, origin, args, *, inst=True, name=None): super().__init__(origin, inst=inst, name=name) - if not isinstance(params, tuple): - params = (params,) + if not isinstance(args, tuple): + args = (args,) self.__args__ = tuple(... if a is _TypingEllipsis else - () if a is _TypingEmpty else - a for a in params) - self.__parameters__ = _collect_type_vars(params, typevar_types=_typevar_types) - self._typevar_types = _typevar_types - self._paramspec_tvars = _paramspec_tvars + a for a in args) + enforce_default_ordering = origin in (Generic, Protocol) + self.__parameters__ = _collect_type_parameters( + args, + enforce_default_ordering=enforce_default_ordering, + ) if not name: self.__module__ = origin.__module__ @@ -1044,53 +1450,140 @@ def __ror__(self, left): return Union[left, self] @_tp_cache - def __getitem__(self, params): + def __getitem__(self, args): + # Parameterizes an already-parameterized object. + # + # For example, we arrive here doing something like: + # T1 = TypeVar('T1') + # T2 = TypeVar('T2') + # T3 = TypeVar('T3') + # class A(Generic[T1]): pass + # B = A[T2] # B is a _GenericAlias + # C = B[T3] # Invokes _GenericAlias.__getitem__ + # + # We also arrive here when parameterizing a generic `Callable` alias: + # T = TypeVar('T') + # C = Callable[[T], None] + # C[int] # Invokes _GenericAlias.__getitem__ + if self.__origin__ in (Generic, Protocol): # Can't subscript Generic[...] or Protocol[...]. raise TypeError(f"Cannot subscript already-subscripted {self}") - if not isinstance(params, tuple): - params = (params,) - params = tuple(_type_convert(p) for p in params) - if (self._paramspec_tvars - and any(isinstance(t, ParamSpec) for t in self.__parameters__)): - params = _prepare_paramspec_params(self, params) - else: - _check_generic(self, params, len(self.__parameters__)) + if not self.__parameters__: + raise TypeError(f"{self} is not a generic class") - subst = dict(zip(self.__parameters__, params)) + # Preprocess `args`. + if not isinstance(args, tuple): + args = (args,) + args = _unpack_args(*(_type_convert(p) for p in args)) + new_args = self._determine_new_args(args) + r = self.copy_with(new_args) + return r + + def _determine_new_args(self, args): + # Determines new __args__ for __getitem__. + # + # For example, suppose we had: + # T1 = TypeVar('T1') + # T2 = TypeVar('T2') + # class A(Generic[T1, T2]): pass + # T3 = TypeVar('T3') + # B = A[int, T3] + # C = B[str] + # `B.__args__` is `(int, T3)`, so `C.__args__` should be `(int, str)`. + # Unfortunately, this is harder than it looks, because if `T3` is + # anything more exotic than a plain `TypeVar`, we need to consider + # edge cases. + + params = self.__parameters__ + # In the example above, this would be {T3: str} + for param in params: + prepare = getattr(param, '__typing_prepare_subst__', None) + if prepare is not None: + args = prepare(self, args) + alen = len(args) + plen = len(params) + if alen != plen: + raise TypeError(f"Too {'many' if alen > plen else 'few'} arguments for {self};" + f" actual {alen}, expected {plen}") + new_arg_by_param = dict(zip(params, args)) + return tuple(self._make_substitution(self.__args__, new_arg_by_param)) + + def _make_substitution(self, args, new_arg_by_param): + """Create a list of new type arguments.""" new_args = [] - for arg in self.__args__: - if isinstance(arg, self._typevar_types): - if isinstance(arg, ParamSpec): - arg = subst[arg] - if not _is_param_expr(arg): - raise TypeError(f"Expected a list of types, an ellipsis, " - f"ParamSpec, or Concatenate. Got {arg}") + for old_arg in args: + if isinstance(old_arg, type): + new_args.append(old_arg) + continue + + substfunc = getattr(old_arg, '__typing_subst__', None) + if substfunc: + new_arg = substfunc(new_arg_by_param[old_arg]) + else: + subparams = getattr(old_arg, '__parameters__', ()) + if not subparams: + new_arg = old_arg else: - arg = subst[arg] - elif isinstance(arg, (_GenericAlias, GenericAlias, types.UnionType)): - subparams = arg.__parameters__ - if subparams: - subargs = tuple(subst[x] for x in subparams) - arg = arg[subargs] - # Required to flatten out the args for CallableGenericAlias - if self.__origin__ == collections.abc.Callable and isinstance(arg, tuple): - new_args.extend(arg) + subargs = [] + for x in subparams: + if isinstance(x, TypeVarTuple): + subargs.extend(new_arg_by_param[x]) + else: + subargs.append(new_arg_by_param[x]) + new_arg = old_arg[tuple(subargs)] + + if self.__origin__ == collections.abc.Callable and isinstance(new_arg, tuple): + # Consider the following `Callable`. + # C = Callable[[int], str] + # Here, `C.__args__` should be (int, str) - NOT ([int], str). + # That means that if we had something like... + # P = ParamSpec('P') + # T = TypeVar('T') + # C = Callable[P, T] + # D = C[[int, str], float] + # ...we need to be careful; `new_args` should end up as + # `(int, str, float)` rather than `([int, str], float)`. + new_args.extend(new_arg) + elif _is_unpacked_typevartuple(old_arg): + # Consider the following `_GenericAlias`, `B`: + # class A(Generic[*Ts]): ... + # B = A[T, *Ts] + # If we then do: + # B[float, int, str] + # The `new_arg` corresponding to `T` will be `float`, and the + # `new_arg` corresponding to `*Ts` will be `(int, str)`. We + # should join all these types together in a flat list + # `(float, int, str)` - so again, we should `extend`. + new_args.extend(new_arg) + elif isinstance(old_arg, tuple): + # Corner case: + # P = ParamSpec('P') + # T = TypeVar('T') + # class Base(Generic[P]): ... + # Can be substituted like this: + # X = Base[[int, T]] + # In this case, `old_arg` will be a tuple: + new_args.append( + tuple(self._make_substitution(old_arg, new_arg_by_param)), + ) else: - new_args.append(arg) - return self.copy_with(tuple(new_args)) + new_args.append(new_arg) + return new_args - def copy_with(self, params): - return self.__class__(self.__origin__, params, name=self._name, inst=self._inst, - _typevar_types=self._typevar_types, - _paramspec_tvars=self._paramspec_tvars) + def copy_with(self, args): + return self.__class__(self.__origin__, args, name=self._name, inst=self._inst) def __repr__(self): if self._name: name = 'typing.' + self._name else: name = _type_repr(self.__origin__) - args = ", ".join([_type_repr(a) for a in self.__args__]) + if self.__args__: + args = ", ".join([_type_repr(a) for a in self.__args__]) + else: + # To ensure the repr is eval-able. + args = "()" return f'{name}[{args}]' def __reduce__(self): @@ -1118,17 +1611,21 @@ def __mro_entries__(self, bases): return () return (self.__origin__,) + def __iter__(self): + yield Unpack[self] + # _nparams is the number of accepted parameters, e.g. 0 for Hashable, # 1 for List and 2 for Dict. It may be -1 if variable number of # parameters are accepted (needs custom __getitem__). -class _SpecialGenericAlias(_BaseGenericAlias, _root=True): - def __init__(self, origin, nparams, *, inst=True, name=None): +class _SpecialGenericAlias(_NotIterable, _BaseGenericAlias, _root=True): + def __init__(self, origin, nparams, *, inst=True, name=None, defaults=()): if name is None: name = origin.__name__ super().__init__(origin, inst=inst, name=name) self._nparams = nparams + self._defaults = defaults if origin.__module__ == 'builtins': self.__doc__ = f'A generic version of {origin.__qualname__}.' else: @@ -1140,7 +1637,22 @@ def __getitem__(self, params): params = (params,) msg = "Parameters to generic types must be types." params = tuple(_type_check(p, msg) for p in params) - _check_generic(self, params, self._nparams) + if (self._defaults + and len(params) < self._nparams + and len(params) + len(self._defaults) >= self._nparams + ): + params = (*params, *self._defaults[len(params) - self._nparams:]) + actual_len = len(params) + + if actual_len != self._nparams: + if self._defaults: + expected = f"at least {self._nparams - len(self._defaults)}" + else: + expected = str(self._nparams) + if not self._nparams: + raise TypeError(f"{self} is not a generic class") + raise TypeError(f"Too {'many' if actual_len > self._nparams else 'few'} arguments for {self};" + f" actual {actual_len}, expected {expected}") return self.copy_with(params) def copy_with(self, params): @@ -1166,7 +1678,23 @@ def __or__(self, right): def __ror__(self, left): return Union[left, self] -class _CallableGenericAlias(_GenericAlias, _root=True): + +class _DeprecatedGenericAlias(_SpecialGenericAlias, _root=True): + def __init__( + self, origin, nparams, *, removal_version, inst=True, name=None + ): + super().__init__(origin, nparams, inst=inst, name=name) + self._removal_version = removal_version + + def __instancecheck__(self, inst): + import warnings + warnings._deprecated( + f"{self.__module__}.{self._name}", remove=self._removal_version + ) + return super().__instancecheck__(inst) + + +class _CallableGenericAlias(_NotIterable, _GenericAlias, _root=True): def __repr__(self): assert self._name == 'Callable' args = self.__args__ @@ -1186,9 +1714,7 @@ def __reduce__(self): class _CallableType(_SpecialGenericAlias, _root=True): def copy_with(self, params): return _CallableGenericAlias(self.__origin__, params, - name=self._name, inst=self._inst, - _typevar_types=(TypeVar, ParamSpec), - _paramspec_tvars=True) + name=self._name, inst=self._inst) def __getitem__(self, params): if not isinstance(params, tuple) or len(params) != 2: @@ -1221,27 +1747,28 @@ def __getitem_inner__(self, params): class _TupleType(_SpecialGenericAlias, _root=True): @_tp_cache def __getitem__(self, params): - if params == (): - return self.copy_with((_TypingEmpty,)) if not isinstance(params, tuple): params = (params,) - if len(params) == 2 and params[1] is ...: + if len(params) >= 2 and params[-1] is ...: msg = "Tuple[t, ...]: t must be a type." - p = _type_check(params[0], msg) - return self.copy_with((p, _TypingEllipsis)) + params = tuple(_type_check(p, msg) for p in params[:-1]) + return self.copy_with((*params, _TypingEllipsis)) msg = "Tuple[t0, t1, ...]: each t must be a type." params = tuple(_type_check(p, msg) for p in params) return self.copy_with(params) -class _UnionGenericAlias(_GenericAlias, _root=True): +class _UnionGenericAlias(_NotIterable, _GenericAlias, _root=True): def copy_with(self, params): return Union[params] def __eq__(self, other): if not isinstance(other, (_UnionGenericAlias, types.UnionType)): return NotImplemented - return set(self.__args__) == set(other.__args__) + try: # fast path + return set(self.__args__) == set(other.__args__) + except TypeError: # not hashable, slow path + return _compare_args_orderless(self.__args__, other.__args__) def __hash__(self): return hash(frozenset(self.__args__)) @@ -1273,7 +1800,6 @@ def _value_and_type_iter(parameters): class _LiteralGenericAlias(_GenericAlias, _root=True): - def __eq__(self, other): if not isinstance(other, _LiteralGenericAlias): return NotImplemented @@ -1290,118 +1816,108 @@ def copy_with(self, params): return (*params[:-1], *params[-1]) if isinstance(params[-1], _ConcatenateGenericAlias): params = (*params[:-1], *params[-1].__args__) - elif not isinstance(params[-1], ParamSpec): - raise TypeError("The last parameter to Concatenate should be a " - "ParamSpec variable.") return super().copy_with(params) -class Generic: - """Abstract base class for generic types. +@_SpecialForm +def Unpack(self, parameters): + """Type unpack operator. - A generic type is typically declared by inheriting from - this class parameterized with one or more type variables. - For example, a generic mapping type might be defined as:: + The type unpack operator takes the child types from some container type, + such as `tuple[int, str]` or a `TypeVarTuple`, and 'pulls them out'. - class Mapping(Generic[KT, VT]): - def __getitem__(self, key: KT) -> VT: - ... - # Etc. + For example:: - This class can then be used as follows:: + # For some generic class `Foo`: + Foo[Unpack[tuple[int, str]]] # Equivalent to Foo[int, str] - def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: - try: - return mapping[key] - except KeyError: - return default - """ - __slots__ = () - _is_protocol = False + Ts = TypeVarTuple('Ts') + # Specifies that `Bar` is generic in an arbitrary number of types. + # (Think of `Ts` as a tuple of an arbitrary number of individual + # `TypeVar`s, which the `Unpack` is 'pulling out' directly into the + # `Generic[]`.) + class Bar(Generic[Unpack[Ts]]): ... + Bar[int] # Valid + Bar[int, str] # Also valid - @_tp_cache - def __class_getitem__(cls, params): - if not isinstance(params, tuple): - params = (params,) - if not params and cls is not Tuple: - raise TypeError( - f"Parameter list to {cls.__qualname__}[...] cannot be empty") - params = tuple(_type_convert(p) for p in params) - if cls in (Generic, Protocol): - # Generic and Protocol can only be subscripted with unique type variables. - if not all(isinstance(p, (TypeVar, ParamSpec)) for p in params): - raise TypeError( - f"Parameters to {cls.__name__}[...] must all be type variables " - f"or parameter specification variables.") - if len(set(params)) != len(params): - raise TypeError( - f"Parameters to {cls.__name__}[...] must all be unique") - else: - # Subscripting a regular Generic subclass. - if any(isinstance(t, ParamSpec) for t in cls.__parameters__): - params = _prepare_paramspec_params(cls, params) - else: - _check_generic(cls, params, len(cls.__parameters__)) - return _GenericAlias(cls, params, - _typevar_types=(TypeVar, ParamSpec), - _paramspec_tvars=True) + From Python 3.11, this can also be done using the `*` operator:: - def __init_subclass__(cls, *args, **kwargs): - super().__init_subclass__(*args, **kwargs) - tvars = [] - if '__orig_bases__' in cls.__dict__: - error = Generic in cls.__orig_bases__ - else: - error = Generic in cls.__bases__ and cls.__name__ != 'Protocol' - if error: - raise TypeError("Cannot inherit from plain Generic") - if '__orig_bases__' in cls.__dict__: - tvars = _collect_type_vars(cls.__orig_bases__, (TypeVar, ParamSpec)) - # Look for Generic[T1, ..., Tn]. - # If found, tvars must be a subset of it. - # If not found, tvars is it. - # Also check for and reject plain Generic, - # and reject multiple Generic[...]. - gvars = None - for base in cls.__orig_bases__: - if (isinstance(base, _GenericAlias) and - base.__origin__ is Generic): - if gvars is not None: - raise TypeError( - "Cannot inherit from Generic[...] multiple types.") - gvars = base.__parameters__ - if gvars is not None: - tvarset = set(tvars) - gvarset = set(gvars) - if not tvarset <= gvarset: - s_vars = ', '.join(str(t) for t in tvars if t not in gvarset) - s_args = ', '.join(str(g) for g in gvars) - raise TypeError(f"Some type variables ({s_vars}) are" - f" not listed in Generic[{s_args}]") - tvars = gvars - cls.__parameters__ = tuple(tvars) - - -class _TypingEmpty: - """Internal placeholder for () or []. Used by TupleMeta and CallableMeta - to allow empty list/tuple in specific places, without allowing them - to sneak in where prohibited. + Foo[*tuple[int, str]] + class Bar(Generic[*Ts]): ... + + And from Python 3.12, it can be done using built-in syntax for generics:: + + Foo[*tuple[int, str]] + class Bar[*Ts]: ... + + The operator can also be used along with a `TypedDict` to annotate + `**kwargs` in a function signature:: + + class Movie(TypedDict): + name: str + year: int + + # This function expects two keyword arguments - *name* of type `str` and + # *year* of type `int`. + def foo(**kwargs: Unpack[Movie]): ... + + Note that there is only some runtime checking of this operator. Not + everything the runtime allows may be accepted by static type checkers. + + For more information, see PEPs 646 and 692. """ + item = _type_check(parameters, f'{self} accepts only single type.') + return _UnpackGenericAlias(origin=self, args=(item,)) + + +class _UnpackGenericAlias(_GenericAlias, _root=True): + def __repr__(self): + # `Unpack` only takes one argument, so __args__ should contain only + # a single item. + return f'typing.Unpack[{_type_repr(self.__args__[0])}]' + + def __getitem__(self, args): + if self.__typing_is_unpacked_typevartuple__: + return args + return super().__getitem__(args) + + @property + def __typing_unpacked_tuple_args__(self): + assert self.__origin__ is Unpack + assert len(self.__args__) == 1 + arg, = self.__args__ + if isinstance(arg, (_GenericAlias, types.GenericAlias)): + if arg.__origin__ is not tuple: + raise TypeError("Unpack[...] must be used with a tuple type") + return arg.__args__ + return None + + @property + def __typing_is_unpacked_typevartuple__(self): + assert self.__origin__ is Unpack + assert len(self.__args__) == 1 + return isinstance(self.__args__[0], TypeVarTuple) class _TypingEllipsis: """Internal placeholder for ... (ellipsis).""" -_TYPING_INTERNALS = ['__parameters__', '__orig_bases__', '__orig_class__', - '_is_protocol', '_is_runtime_protocol'] +_TYPING_INTERNALS = frozenset({ + '__parameters__', '__orig_bases__', '__orig_class__', + '_is_protocol', '_is_runtime_protocol', '__protocol_attrs__', + '__non_callable_proto_members__', '__type_params__', +}) -_SPECIAL_NAMES = ['__abstractmethods__', '__annotations__', '__dict__', '__doc__', - '__init__', '__module__', '__new__', '__slots__', - '__subclasshook__', '__weakref__', '__class_getitem__'] +_SPECIAL_NAMES = frozenset({ + '__abstractmethods__', '__annotations__', '__dict__', '__doc__', + '__init__', '__module__', '__new__', '__slots__', + '__subclasshook__', '__weakref__', '__class_getitem__', + '__match_args__', '__static_attributes__', '__firstlineno__', +}) # These special attributes will be not collected as protocol members. -EXCLUDED_ATTRIBUTES = _TYPING_INTERNALS + _SPECIAL_NAMES + ['_MutableMapping__marker'] +EXCLUDED_ATTRIBUTES = _TYPING_INTERNALS | _SPECIAL_NAMES | {'_MutableMapping__marker'} def _get_protocol_attrs(cls): @@ -1412,20 +1928,15 @@ def _get_protocol_attrs(cls): """ attrs = set() for base in cls.__mro__[:-1]: # without object - if base.__name__ in ('Protocol', 'Generic'): + if base.__name__ in {'Protocol', 'Generic'}: continue annotations = getattr(base, '__annotations__', {}) - for attr in list(base.__dict__.keys()) + list(annotations.keys()): + for attr in (*base.__dict__, *annotations): if not attr.startswith('_abc_') and attr not in EXCLUDED_ATTRIBUTES: attrs.add(attr) return attrs -def _is_callable_members_only(cls): - # PEP 544 prohibits using issubclass() with protocols that have non-method members. - return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls)) - - def _no_init_or_replace_init(self, *args, **kwargs): cls = type(self) @@ -1456,59 +1967,192 @@ def _no_init_or_replace_init(self, *args, **kwargs): def _caller(depth=1, default='__main__'): + try: + return sys._getframemodulename(depth + 1) or default + except AttributeError: # For platforms without _getframemodulename() + pass try: return sys._getframe(depth + 1).f_globals.get('__name__', default) except (AttributeError, ValueError): # For platforms without _getframe() - return None - + pass + return None -def _allow_reckless_class_checks(depth=3): +def _allow_reckless_class_checks(depth=2): """Allow instance and class checks for special stdlib modules. The abc and functools modules indiscriminately call isinstance() and issubclass() on the whole MRO of a user class, which may contain protocols. """ - try: - return sys._getframe(depth).f_globals['__name__'] in ['abc', 'functools'] - except (AttributeError, ValueError): # For platforms without _getframe(). - return True + return _caller(depth) in {'abc', 'functools', None} _PROTO_ALLOWLIST = { 'collections.abc': [ 'Callable', 'Awaitable', 'Iterable', 'Iterator', 'AsyncIterable', - 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', + 'AsyncIterator', 'Hashable', 'Sized', 'Container', 'Collection', + 'Reversible', 'Buffer', ], 'contextlib': ['AbstractContextManager', 'AbstractAsyncContextManager'], } +@functools.cache +def _lazy_load_getattr_static(): + # Import getattr_static lazily so as not to slow down the import of typing.py + # Cache the result so we don't slow down _ProtocolMeta.__instancecheck__ unnecessarily + from inspect import getattr_static + return getattr_static + + +_cleanups.append(_lazy_load_getattr_static.cache_clear) + +def _pickle_psargs(psargs): + return ParamSpecArgs, (psargs.__origin__,) + +copyreg.pickle(ParamSpecArgs, _pickle_psargs) + +def _pickle_pskwargs(pskwargs): + return ParamSpecKwargs, (pskwargs.__origin__,) + +copyreg.pickle(ParamSpecKwargs, _pickle_pskwargs) + +del _pickle_psargs, _pickle_pskwargs + + +# Preload these once, as globals, as a micro-optimisation. +# This makes a significant difference to the time it takes +# to do `isinstance()`/`issubclass()` checks +# against runtime-checkable protocols with only one callable member. +_abc_instancecheck = ABCMeta.__instancecheck__ +_abc_subclasscheck = ABCMeta.__subclasscheck__ + + +def _type_check_issubclass_arg_1(arg): + """Raise TypeError if `arg` is not an instance of `type` + in `issubclass(arg, )`. + + In most cases, this is verified by type.__subclasscheck__. + Checking it again unnecessarily would slow down issubclass() checks, + so, we don't perform this check unless we absolutely have to. + + For various error paths, however, + we want to ensure that *this* error message is shown to the user + where relevant, rather than a typing.py-specific error message. + """ + if not isinstance(arg, type): + # Same error message as for issubclass(1, int). + raise TypeError('issubclass() arg 1 must be a class') + + class _ProtocolMeta(ABCMeta): - # This metaclass is really unfortunate and exists only because of - # the lack of __instancehook__. + # This metaclass is somewhat unfortunate, + # but is necessary for several reasons... + def __new__(mcls, name, bases, namespace, /, **kwargs): + if name == "Protocol" and bases == (Generic,): + pass + elif Protocol in bases: + for base in bases: + if not ( + base in {object, Generic} + or base.__name__ in _PROTO_ALLOWLIST.get(base.__module__, []) + or ( + issubclass(base, Generic) + and getattr(base, "_is_protocol", False) + ) + ): + raise TypeError( + f"Protocols can only inherit from other protocols, " + f"got {base!r}" + ) + return super().__new__(mcls, name, bases, namespace, **kwargs) + + def __init__(cls, *args, **kwargs): + super().__init__(*args, **kwargs) + if getattr(cls, "_is_protocol", False): + cls.__protocol_attrs__ = _get_protocol_attrs(cls) + + def __subclasscheck__(cls, other): + if cls is Protocol: + return type.__subclasscheck__(cls, other) + if ( + getattr(cls, '_is_protocol', False) + and not _allow_reckless_class_checks() + ): + if not getattr(cls, '_is_runtime_protocol', False): + _type_check_issubclass_arg_1(other) + raise TypeError( + "Instance and class checks can only be used with " + "@runtime_checkable protocols" + ) + if ( + # this attribute is set by @runtime_checkable: + cls.__non_callable_proto_members__ + and cls.__dict__.get("__subclasshook__") is _proto_hook + ): + _type_check_issubclass_arg_1(other) + non_method_attrs = sorted(cls.__non_callable_proto_members__) + raise TypeError( + "Protocols with non-method members don't support issubclass()." + f" Non-method members: {str(non_method_attrs)[1:-1]}." + ) + return _abc_subclasscheck(cls, other) + def __instancecheck__(cls, instance): # We need this method for situations where attributes are # assigned in __init__. + if cls is Protocol: + return type.__instancecheck__(cls, instance) + if not getattr(cls, "_is_protocol", False): + # i.e., it's a concrete subclass of a protocol + return _abc_instancecheck(cls, instance) + if ( - getattr(cls, '_is_protocol', False) and not getattr(cls, '_is_runtime_protocol', False) and - not _allow_reckless_class_checks(depth=2) + not _allow_reckless_class_checks() ): raise TypeError("Instance and class checks can only be used with" " @runtime_checkable protocols") - if ((not getattr(cls, '_is_protocol', False) or - _is_callable_members_only(cls)) and - issubclass(instance.__class__, cls)): + if _abc_instancecheck(cls, instance): + return True + + getattr_static = _lazy_load_getattr_static() + for attr in cls.__protocol_attrs__: + try: + val = getattr_static(instance, attr) + except AttributeError: + break + # this attribute is set by @runtime_checkable: + if val is None and attr not in cls.__non_callable_proto_members__: + break + else: return True - if cls._is_protocol: - if all(hasattr(instance, attr) and - # All *methods* can be blocked by setting them to None. - (not callable(getattr(cls, attr, None)) or - getattr(instance, attr) is not None) - for attr in _get_protocol_attrs(cls)): - return True - return super().__instancecheck__(instance) + + return False + + +@classmethod +def _proto_hook(cls, other): + if not cls.__dict__.get('_is_protocol', False): + return NotImplemented + + for attr in cls.__protocol_attrs__: + for base in other.__mro__: + # Check if the members appears in the class dictionary... + if attr in base.__dict__: + if base.__dict__[attr] is None: + return NotImplemented + break + + # ...or in annotations, if it is a sub-protocol. + annotations = getattr(base, '__annotations__', {}) + if (isinstance(annotations, collections.abc.Mapping) and + attr in annotations and + issubclass(other, Generic) and getattr(other, '_is_protocol', False)): + break + else: + return NotImplemented + return True class Protocol(Generic, metaclass=_ProtocolMeta): @@ -1521,7 +2165,9 @@ def meth(self) -> int: ... Such classes are primarily used with static type checkers that recognize - structural subtyping (static duck-typing), for example:: + structural subtyping (static duck-typing). + + For example:: class C: def meth(self) -> int: @@ -1537,10 +2183,11 @@ def func(x: Proto) -> int: only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as:: - class GenProto(Protocol[T]): + class GenProto[T](Protocol): def meth(self) -> T: ... """ + __slots__ = () _is_protocol = True _is_runtime_protocol = False @@ -1553,75 +2200,30 @@ def __init_subclass__(cls, *args, **kwargs): cls._is_protocol = any(b is Protocol for b in cls.__bases__) # Set (or override) the protocol subclass hook. - def _proto_hook(other): - if not cls.__dict__.get('_is_protocol', False): - return NotImplemented - - # First, perform various sanity checks. - if not getattr(cls, '_is_runtime_protocol', False): - if _allow_reckless_class_checks(): - return NotImplemented - raise TypeError("Instance and class checks can only be used with" - " @runtime_checkable protocols") - if not _is_callable_members_only(cls): - if _allow_reckless_class_checks(): - return NotImplemented - raise TypeError("Protocols with non-method members" - " don't support issubclass()") - if not isinstance(other, type): - # Same error message as for issubclass(1, int). - raise TypeError('issubclass() arg 1 must be a class') - - # Second, perform the actual structural compatibility check. - for attr in _get_protocol_attrs(cls): - for base in other.__mro__: - # Check if the members appears in the class dictionary... - if attr in base.__dict__: - if base.__dict__[attr] is None: - return NotImplemented - break - - # ...or in annotations, if it is a sub-protocol. - annotations = getattr(base, '__annotations__', {}) - if (isinstance(annotations, collections.abc.Mapping) and - attr in annotations and - issubclass(other, Generic) and other._is_protocol): - break - else: - return NotImplemented - return True - if '__subclasshook__' not in cls.__dict__: cls.__subclasshook__ = _proto_hook - # We have nothing more to do for non-protocols... - if not cls._is_protocol: - return - - # ... otherwise check consistency of bases, and prohibit instantiation. - for base in cls.__bases__: - if not (base in (object, Generic) or - base.__module__ in _PROTO_ALLOWLIST and - base.__name__ in _PROTO_ALLOWLIST[base.__module__] or - issubclass(base, Generic) and base._is_protocol): - raise TypeError('Protocols can only inherit from other' - ' protocols, got %r' % base) - cls.__init__ = _no_init_or_replace_init + # Prohibit instantiation for protocol classes + if cls._is_protocol and cls.__init__ is Protocol.__init__: + cls.__init__ = _no_init_or_replace_init -class _AnnotatedAlias(_GenericAlias, _root=True): +class _AnnotatedAlias(_NotIterable, _GenericAlias, _root=True): """Runtime representation of an annotated type. At its core 'Annotated[t, dec1, dec2, ...]' is an alias for the type 't' - with extra annotations. The alias behaves like a normal typing alias, - instantiating is the same as instantiating the underlying type, binding + with extra annotations. The alias behaves like a normal typing alias. + Instantiating is the same as instantiating the underlying type; binding it to types is also the same. + + The metadata itself is stored in a '__metadata__' attribute as a tuple. """ + def __init__(self, origin, metadata): if isinstance(origin, _AnnotatedAlias): metadata = origin.__metadata__ + metadata origin = origin.__origin__ - super().__init__(origin, origin) + super().__init__(origin, origin, name='Annotated') self.__metadata__ = metadata def copy_with(self, params): @@ -1654,9 +2256,14 @@ def __getattr__(self, attr): return 'Annotated' return super().__getattr__(attr) + def __mro_entries__(self, bases): + return (self.__origin__,) + -class Annotated: - """Add context specific metadata to a type. +@_TypedCacheSpecialForm +@_tp_cache(typed=True) +def Annotated(self, *params): + """Add context-specific metadata to a type. Example: Annotated[int, runtime_check.Unsigned] indicates to the hypothetical runtime_check module that this type is an unsigned int. @@ -1668,44 +2275,51 @@ class Annotated: Details: - It's an error to call `Annotated` with less than two arguments. - - Nested Annotated are flattened:: + - Access the metadata via the ``__metadata__`` attribute:: + + assert Annotated[int, '$'].__metadata__ == ('$',) + + - Nested Annotated types are flattened:: - Annotated[Annotated[T, Ann1, Ann2], Ann3] == Annotated[T, Ann1, Ann2, Ann3] + assert Annotated[Annotated[T, Ann1, Ann2], Ann3] == Annotated[T, Ann1, Ann2, Ann3] - Instantiating an annotated type is equivalent to instantiating the underlying type:: - Annotated[C, Ann1](5) == C(5) + assert Annotated[C, Ann1](5) == C(5) - Annotated can be used as a generic type alias:: - Optimized = Annotated[T, runtime.Optimize()] - Optimized[int] == Annotated[int, runtime.Optimize()] + type Optimized[T] = Annotated[T, runtime.Optimize()] + # type checker will treat Optimized[int] + # as equivalent to Annotated[int, runtime.Optimize()] - OptimizedList = Annotated[List[T], runtime.Optimize()] - OptimizedList[int] == Annotated[List[int], runtime.Optimize()] - """ + type OptimizedList[T] = Annotated[list[T], runtime.Optimize()] + # type checker will treat OptimizedList[int] + # as equivalent to Annotated[list[int], runtime.Optimize()] - __slots__ = () + - Annotated cannot be used with an unpacked TypeVarTuple:: - def __new__(cls, *args, **kwargs): - raise TypeError("Type Annotated cannot be instantiated.") + type Variadic[*Ts] = Annotated[*Ts, Ann1] # NOT valid - @_tp_cache - def __class_getitem__(cls, params): - if not isinstance(params, tuple) or len(params) < 2: - raise TypeError("Annotated[...] should be used " - "with at least two arguments (a type and an " - "annotation).") - msg = "Annotated[t, ...]: t must be a type." - origin = _type_check(params[0], msg, allow_special_forms=True) - metadata = tuple(params[1:]) - return _AnnotatedAlias(origin, metadata) + This would be equivalent to:: - def __init_subclass__(cls, *args, **kwargs): - raise TypeError( - "Cannot subclass {}.Annotated".format(cls.__module__) - ) + Annotated[T1, T2, T3, ..., Ann1] + + where T1, T2 etc. are TypeVars, which would be invalid, because + only one type should be passed to Annotated. + """ + if len(params) < 2: + raise TypeError("Annotated[...] should be used " + "with at least two arguments (a type and an " + "annotation).") + if _is_unpacked_typevartuple(params[0]): + raise TypeError("Annotated[...] should not be used with an " + "unpacked TypeVarTuple") + msg = "Annotated[t, ...]: t must be a type." + origin = _type_check(params[0], msg, allow_special_forms=True) + metadata = tuple(params[1:]) + return _AnnotatedAlias(origin, metadata) def runtime_checkable(cls): @@ -1715,6 +2329,7 @@ def runtime_checkable(cls): Raise TypeError if applied to a non-protocol class. This allows a simple-minded structural check very similar to one trick ponies in collections.abc such as Iterable. + For example:: @runtime_checkable @@ -1726,10 +2341,26 @@ def close(self): ... Warning: this will check only the presence of the required methods, not their type signatures! """ - if not issubclass(cls, Generic) or not cls._is_protocol: + if not issubclass(cls, Generic) or not getattr(cls, '_is_protocol', False): raise TypeError('@runtime_checkable can be only applied to protocol classes,' ' got %r' % cls) cls._is_runtime_protocol = True + # PEP 544 prohibits using issubclass() + # with protocols that have non-method members. + # See gh-113320 for why we compute this attribute here, + # rather than in `_ProtocolMeta.__init__` + cls.__non_callable_proto_members__ = set() + for attr in cls.__protocol_attrs__: + try: + is_callable = callable(getattr(cls, attr, None)) + except Exception as e: + raise TypeError( + f"Failed to determine whether protocol member {attr!r} " + "is a method member" + ) from e + else: + if not is_callable: + cls.__non_callable_proto_members__.add(attr) return cls @@ -1744,24 +2375,20 @@ def cast(typ, val): return val -def _get_defaults(func): - """Internal helper to extract the default arguments, by name.""" - try: - code = func.__code__ - except AttributeError: - # Some built-in functions don't have __code__, __defaults__, etc. - return {} - pos_count = code.co_argcount - arg_names = code.co_varnames - arg_names = arg_names[:pos_count] - defaults = func.__defaults__ or () - kwdefaults = func.__kwdefaults__ - res = dict(kwdefaults) if kwdefaults else {} - pos_offset = pos_count - len(defaults) - for name, value in zip(arg_names[pos_offset:], defaults): - assert name not in res - res[name] = value - return res +def assert_type(val, typ, /): + """Ask a static type checker to confirm that the value is of the given type. + + At runtime this does nothing: it returns the first argument unchanged with no + checks or side effects, no matter the actual type of the argument. + + When a static type checker encounters a call to assert_type(), it + emits an error if the value is not of the specified type:: + + def greet(name: str) -> None: + assert_type(name, str) # OK + assert_type(name, int) # type checker error + """ + return val _allowed_types = (types.FunctionType, types.BuiltinFunctionType, @@ -1773,8 +2400,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False): """Return type hints for an object. This is often the same as obj.__annotations__, but it handles - forward references encoded as string literals, adds Optional[t] if a - default value equal to None is set and recursively replaces all + forward references encoded as string literals and recursively replaces all 'Annotated[T, ...]' with 'T' (unless 'include_extras=True'). The argument may be a module, class, method, or function. The annotations @@ -1801,7 +2427,6 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False): - If two dict arguments are passed, they specify globals and locals, respectively. """ - if getattr(obj, '__no_type_check__', None): return {} # Classes require a special treatment. @@ -1829,7 +2454,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False): value = type(None) if isinstance(value, str): value = ForwardRef(value, is_argument=False, is_class=True) - value = _eval_type(value, base_globals, base_locals) + value = _eval_type(value, base_globals, base_locals, base.__type_params__) hints[name] = value return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()} @@ -1854,8 +2479,8 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False): else: raise TypeError('{!r} is not a module, class, method, ' 'or function.'.format(obj)) - defaults = _get_defaults(obj) hints = dict(hints) + type_params = getattr(obj, "__type_params__", ()) for name, value in hints.items(): if value is None: value = type(None) @@ -1867,18 +2492,16 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False): is_argument=not isinstance(obj, types.ModuleType), is_class=False, ) - value = _eval_type(value, globalns, localns) - if name in defaults and defaults[name] is None: - value = Optional[value] - hints[name] = value + hints[name] = _eval_type(value, globalns, localns, type_params) return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()} def _strip_annotations(t): - """Strips the annotations from a given type. - """ + """Strip the annotations from a given type.""" if isinstance(t, _AnnotatedAlias): return _strip_annotations(t.__origin__) + if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired, ReadOnly): + return _strip_annotations(t.__args__[0]) if isinstance(t, _GenericAlias): stripped_args = tuple(_strip_annotations(a) for a in t.__args__) if stripped_args == t.__args__: @@ -1901,17 +2524,20 @@ def _strip_annotations(t): def get_origin(tp): """Get the unsubscripted version of a type. - This supports generic types, Callable, Tuple, Union, Literal, Final, ClassVar - and Annotated. Return None for unsupported types. Examples:: - - get_origin(Literal[42]) is Literal - get_origin(int) is None - get_origin(ClassVar[int]) is ClassVar - get_origin(Generic) is Generic - get_origin(Generic[T]) is Generic - get_origin(Union[T, int]) is Union - get_origin(List[Tuple[T, T]][int]) == list - get_origin(P.args) is P + This supports generic types, Callable, Tuple, Union, Literal, Final, ClassVar, + Annotated, and others. Return None for unsupported types. + + Examples:: + + >>> P = ParamSpec('P') + >>> assert get_origin(Literal[42]) is Literal + >>> assert get_origin(int) is None + >>> assert get_origin(ClassVar[int]) is ClassVar + >>> assert get_origin(Generic) is Generic + >>> assert get_origin(Generic[T]) is Generic + >>> assert get_origin(Union[T, int]) is Union + >>> assert get_origin(List[Tuple[T, T]][int]) is list + >>> assert get_origin(P.args) is P """ if isinstance(tp, _AnnotatedAlias): return Annotated @@ -1929,19 +2555,21 @@ def get_args(tp): """Get type arguments with all substitutions performed. For unions, basic simplifications used by Union constructor are performed. + Examples:: - get_args(Dict[str, int]) == (str, int) - get_args(int) == () - get_args(Union[int, Union[T, int], str][int]) == (int, str) - get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int]) - get_args(Callable[[], T][int]) == ([], int) + + >>> T = TypeVar('T') + >>> assert get_args(Dict[str, int]) == (str, int) + >>> assert get_args(int) == () + >>> assert get_args(Union[int, Union[T, int], str][int]) == (int, str) + >>> assert get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int]) + >>> assert get_args(Callable[[], T][int]) == ([], int) """ if isinstance(tp, _AnnotatedAlias): return (tp.__origin__,) + tp.__metadata__ if isinstance(tp, (_GenericAlias, GenericAlias)): res = tp.__args__ - if (tp.__origin__ is collections.abc.Callable - and not (len(res) == 2 and _is_param_expr(res[0]))): + if _should_unflatten_callable_args(tp, res): res = (list(res[:-1]), res[-1]) return res if isinstance(tp, types.UnionType): @@ -1950,19 +2578,51 @@ def get_args(tp): def is_typeddict(tp): - """Check if an annotation is a TypedDict class + """Check if an annotation is a TypedDict class. For example:: - class Film(TypedDict): - title: str - year: int - is_typeddict(Film) # => True - is_typeddict(Union[list, str]) # => False + >>> from typing import TypedDict + >>> class Film(TypedDict): + ... title: str + ... year: int + ... + >>> is_typeddict(Film) + True + >>> is_typeddict(dict) + False """ return isinstance(tp, _TypedDictMeta) +_ASSERT_NEVER_REPR_MAX_LENGTH = 100 + + +def assert_never(arg: Never, /) -> Never: + """Statically assert that a line of code is unreachable. + + Example:: + + def int_or_str(arg: int | str) -> None: + match arg: + case int(): + print("It's an int") + case str(): + print("It's a str") + case _: + assert_never(arg) + + If a type checker finds that a call to assert_never() is + reachable, it will emit an error. + + At runtime, this throws an exception when called. + """ + value = repr(arg) + if len(value) > _ASSERT_NEVER_REPR_MAX_LENGTH: + value = value[:_ASSERT_NEVER_REPR_MAX_LENGTH] + '...' + raise AssertionError(f"Expected code to be unreachable, but got: {value}") + + def no_type_check(arg): """Decorator to indicate that annotations are not type hints. @@ -1973,13 +2633,23 @@ def no_type_check(arg): This mutates the function(s) or class(es) in place. """ if isinstance(arg, type): - arg_attrs = arg.__dict__.copy() - for attr, val in arg.__dict__.items(): - if val in arg.__bases__ + (arg,): - arg_attrs.pop(attr) - for obj in arg_attrs.values(): + for key in dir(arg): + obj = getattr(arg, key) + if ( + not hasattr(obj, '__qualname__') + or obj.__qualname__ != f'{arg.__qualname__}.{obj.__name__}' + or getattr(obj, '__module__', None) != arg.__module__ + ): + # We only modify objects that are defined in this type directly. + # If classes / methods are nested in multiple layers, + # we will modify them when processing their direct holders. + continue + # Instance, class, and static methods: if isinstance(obj, types.FunctionType): obj.__no_type_check__ = True + if isinstance(obj, types.MethodType): + obj.__func__.__no_type_check__ = True + # Nested types: if isinstance(obj, type): no_type_check(obj) try: @@ -1995,7 +2665,8 @@ def no_type_check_decorator(decorator): This wraps the decorator with something that wraps the decorated function in @no_type_check. """ - + import warnings + warnings._deprecated("typing.no_type_check_decorator", remove=(3, 15)) @functools.wraps(decorator) def wrapped_decorator(*args, **kwds): func = decorator(*args, **kwds) @@ -2014,63 +2685,107 @@ def _overload_dummy(*args, **kwds): "by an implementation that is not @overload-ed.") +# {module: {qualname: {firstlineno: func}}} +_overload_registry = defaultdict(functools.partial(defaultdict, dict)) + + def overload(func): """Decorator for overloaded functions/methods. In a stub file, place two or more stub definitions for the same - function in a row, each decorated with @overload. For example: + function in a row, each decorated with @overload. + + For example:: - @overload - def utf8(value: None) -> None: ... - @overload - def utf8(value: bytes) -> bytes: ... - @overload - def utf8(value: str) -> bytes: ... + @overload + def utf8(value: None) -> None: ... + @overload + def utf8(value: bytes) -> bytes: ... + @overload + def utf8(value: str) -> bytes: ... In a non-stub file (i.e. a regular .py file), do the same but follow it with an implementation. The implementation should *not* - be decorated with @overload. For example: - - @overload - def utf8(value: None) -> None: ... - @overload - def utf8(value: bytes) -> bytes: ... - @overload - def utf8(value: str) -> bytes: ... - def utf8(value): - # implementation goes here + be decorated with @overload:: + + @overload + def utf8(value: None) -> None: ... + @overload + def utf8(value: bytes) -> bytes: ... + @overload + def utf8(value: str) -> bytes: ... + def utf8(value): + ... # implementation goes here + + The overloads for a function can be retrieved at runtime using the + get_overloads() function. """ + # classmethod and staticmethod + f = getattr(func, "__func__", func) + try: + _overload_registry[f.__module__][f.__qualname__][f.__code__.co_firstlineno] = func + except AttributeError: + # Not a normal function; ignore. + pass return _overload_dummy +def get_overloads(func): + """Return all defined overloads for *func* as a sequence.""" + # classmethod and staticmethod + f = getattr(func, "__func__", func) + if f.__module__ not in _overload_registry: + return [] + mod_dict = _overload_registry[f.__module__] + if f.__qualname__ not in mod_dict: + return [] + return list(mod_dict[f.__qualname__].values()) + + +def clear_overloads(): + """Clear all overloads in the registry.""" + _overload_registry.clear() + + def final(f): - """A decorator to indicate final methods and final classes. + """Decorator to indicate final methods and final classes. Use this decorator to indicate to type checkers that the decorated method cannot be overridden, and decorated class cannot be subclassed. - For example: - - class Base: - @final - def done(self) -> None: - ... - class Sub(Base): - def done(self) -> None: # Error reported by type checker + + For example:: + + class Base: + @final + def done(self) -> None: + ... + class Sub(Base): + def done(self) -> None: # Error reported by type checker ... - @final - class Leaf: - ... - class Other(Leaf): # Error reported by type checker - ... + @final + class Leaf: + ... + class Other(Leaf): # Error reported by type checker + ... - There is no runtime checking of these properties. + There is no runtime checking of these properties. The decorator + attempts to set the ``__final__`` attribute to ``True`` on the decorated + object to allow runtime introspection. """ + try: + f.__final__ = True + except (AttributeError, TypeError): + # Skip the attribute silently if it is not writable. + # AttributeError happens if the object has __slots__ or a + # read-only property, TypeError if it's a builtin class. + pass return f -# Some unconstrained type variables. These are used by the container types. -# (These are not for export.) +# Some unconstrained type variables. These were initially used by the container types. +# They were never meant for export and are now unused, but we keep them around to +# avoid breaking compatibility with users who import them. T = TypeVar('T') # Any type. KT = TypeVar('KT') # Key type. VT = TypeVar('VT') # Value type. @@ -2081,6 +2796,7 @@ class Other(Leaf): # Error reported by type checker # Internal type variable used for Type[]. CT_co = TypeVar('CT_co', covariant=True, bound=type) + # A useful type variable with constraints. This represents string types. # (This one *is* for export!) AnyStr = TypeVar('AnyStr', bytes, str) @@ -2102,13 +2818,17 @@ class Other(Leaf): # Error reported by type checker Collection = _alias(collections.abc.Collection, 1) Callable = _CallableType(collections.abc.Callable, 2) Callable.__doc__ = \ - """Callable type; Callable[[int], str] is a function of (int) -> str. + """Deprecated alias to collections.abc.Callable. + + Callable[[int], str] signifies a function that takes a single + parameter of type int and returns a str. The subscription syntax must always be used with exactly two - values: the argument list and the return type. The argument list - must be a list of types or ellipsis; the return type must be a single type. + values: the argument list and the return type. + The argument list must be a list of types, a ParamSpec, + Concatenate or ellipsis. The return type must be a single type. - There is no syntax to indicate optional or keyword arguments, + There is no syntax to indicate optional or keyword arguments; such function types are rarely used as callback types. """ AbstractSet = _alias(collections.abc.Set, 1, name='AbstractSet') @@ -2118,11 +2838,15 @@ class Other(Leaf): # Error reported by type checker MutableMapping = _alias(collections.abc.MutableMapping, 2) Sequence = _alias(collections.abc.Sequence, 1) MutableSequence = _alias(collections.abc.MutableSequence, 1) -ByteString = _alias(collections.abc.ByteString, 0) # Not generic +ByteString = _DeprecatedGenericAlias( + collections.abc.ByteString, 0, removal_version=(3, 14) # Not generic. +) # Tuple accepts variable number of parameters. Tuple = _TupleType(tuple, -1, inst=False, name='Tuple') Tuple.__doc__ = \ - """Tuple type; Tuple[X, Y] is the cross-product type of X and Y. + """Deprecated alias to builtins.tuple. + + Tuple[X, Y] is the cross-product type of X and Y. Example: Tuple[T1, T2] is a tuple of two elements corresponding to type variables T1 and T2. Tuple[int, float, str] is a tuple @@ -2138,36 +2862,34 @@ class Other(Leaf): # Error reported by type checker KeysView = _alias(collections.abc.KeysView, 1) ItemsView = _alias(collections.abc.ItemsView, 2) ValuesView = _alias(collections.abc.ValuesView, 1) -ContextManager = _alias(contextlib.AbstractContextManager, 1, name='ContextManager') -AsyncContextManager = _alias(contextlib.AbstractAsyncContextManager, 1, name='AsyncContextManager') Dict = _alias(dict, 2, inst=False, name='Dict') DefaultDict = _alias(collections.defaultdict, 2, name='DefaultDict') OrderedDict = _alias(collections.OrderedDict, 2) Counter = _alias(collections.Counter, 1) ChainMap = _alias(collections.ChainMap, 2) -Generator = _alias(collections.abc.Generator, 3) -AsyncGenerator = _alias(collections.abc.AsyncGenerator, 2) +Generator = _alias(collections.abc.Generator, 3, defaults=(types.NoneType, types.NoneType)) +AsyncGenerator = _alias(collections.abc.AsyncGenerator, 2, defaults=(types.NoneType,)) Type = _alias(type, 1, inst=False, name='Type') Type.__doc__ = \ - """A special construct usable to annotate class objects. + """Deprecated alias to builtins.type. + builtins.type or typing.Type can be used to annotate class objects. For example, suppose we have the following classes:: - class User: ... # Abstract base for User classes - class BasicUser(User): ... - class ProUser(User): ... - class TeamUser(User): ... + class User: ... # Abstract base for User classes + class BasicUser(User): ... + class ProUser(User): ... + class TeamUser(User): ... And a function that takes a class argument that's a subclass of User and returns an instance of the corresponding class:: - U = TypeVar('U', bound=User) - def new_user(user_class: Type[U]) -> U: - user = user_class() - # (Here we could write the user object to a database) - return user + def new_user[U](user_class: Type[U]) -> U: + user = user_class() + # (Here we could write the user object to a database) + return user - joe = new_user(BasicUser) + joe = new_user(BasicUser) At this point the type checker knows that joe has type BasicUser. """ @@ -2176,6 +2898,7 @@ def new_user(user_class: Type[U]) -> U: @runtime_checkable class SupportsInt(Protocol): """An ABC with one abstract method __int__.""" + __slots__ = () @abstractmethod @@ -2186,6 +2909,7 @@ def __int__(self) -> int: @runtime_checkable class SupportsFloat(Protocol): """An ABC with one abstract method __float__.""" + __slots__ = () @abstractmethod @@ -2196,6 +2920,7 @@ def __float__(self) -> float: @runtime_checkable class SupportsComplex(Protocol): """An ABC with one abstract method __complex__.""" + __slots__ = () @abstractmethod @@ -2206,6 +2931,7 @@ def __complex__(self) -> complex: @runtime_checkable class SupportsBytes(Protocol): """An ABC with one abstract method __bytes__.""" + __slots__ = () @abstractmethod @@ -2216,6 +2942,7 @@ def __bytes__(self) -> bytes: @runtime_checkable class SupportsIndex(Protocol): """An ABC with one abstract method __index__.""" + __slots__ = () @abstractmethod @@ -2224,22 +2951,24 @@ def __index__(self) -> int: @runtime_checkable -class SupportsAbs(Protocol[T_co]): +class SupportsAbs[T](Protocol): """An ABC with one abstract method __abs__ that is covariant in its return type.""" + __slots__ = () @abstractmethod - def __abs__(self) -> T_co: + def __abs__(self) -> T: pass @runtime_checkable -class SupportsRound(Protocol[T_co]): +class SupportsRound[T](Protocol): """An ABC with one abstract method __round__ that is covariant in its return type.""" + __slots__ = () @abstractmethod - def __round__(self, ndigits: int = 0) -> T_co: + def __round__(self, ndigits: int = 0) -> T: pass @@ -2262,9 +2991,13 @@ def _make_nmtuple(name, types, module, defaults = ()): class NamedTupleMeta(type): - def __new__(cls, typename, bases, ns): - assert bases[0] is _NamedTuple + assert _NamedTuple in bases + for base in bases: + if base is not _NamedTuple and base is not Generic: + raise TypeError( + 'can only inherit from a NamedTuple type and Generic') + bases = tuple(tuple if base is _NamedTuple else base for base in bases) types = ns.get('__annotations__', {}) default_names = [] for field_name in types: @@ -2278,19 +3011,40 @@ def __new__(cls, typename, bases, ns): nm_tpl = _make_nmtuple(typename, types.items(), defaults=[ns[n] for n in default_names], module=ns['__module__']) + nm_tpl.__bases__ = bases + if Generic in bases: + class_getitem = _generic_class_getitem + nm_tpl.__class_getitem__ = classmethod(class_getitem) # update from user namespace without overriding special namedtuple attributes - for key in ns: + for key, val in ns.items(): if key in _prohibited: raise AttributeError("Cannot overwrite NamedTuple attribute " + key) - elif key not in _special and key not in nm_tpl._fields: - setattr(nm_tpl, key, ns[key]) + elif key not in _special: + if key not in nm_tpl._fields: + setattr(nm_tpl, key, val) + try: + set_name = type(val).__set_name__ + except AttributeError: + pass + else: + try: + set_name(val, nm_tpl, key) + except BaseException as e: + e.add_note( + f"Error calling __set_name__ on {type(val).__name__!r} " + f"instance {key!r} in {typename!r}" + ) + raise + + if Generic in bases: + nm_tpl.__init_subclass__() return nm_tpl -def NamedTuple(typename, fields=None, /, **kwargs): +def NamedTuple(typename, fields=_sentinel, /, **kwargs): """Typed version of namedtuple. - Usage in Python versions >= 3.6:: + Usage:: class Employee(NamedTuple): name: str @@ -2303,39 +3057,86 @@ class Employee(NamedTuple): The resulting class has an extra __annotations__ attribute, giving a dict that maps field names to types. (The field names are also in the _fields attribute, which is part of the namedtuple API.) - Alternative equivalent keyword syntax is also accepted:: - - Employee = NamedTuple('Employee', name=str, id=int) - - In Python versions <= 3.5 use:: + An alternative equivalent functional syntax is also accepted:: Employee = NamedTuple('Employee', [('name', str), ('id', int)]) """ - if fields is None: - fields = kwargs.items() + if fields is _sentinel: + if kwargs: + deprecated_thing = "Creating NamedTuple classes using keyword arguments" + deprecation_msg = ( + "{name} is deprecated and will be disallowed in Python {remove}. " + "Use the class-based or functional syntax instead." + ) + else: + deprecated_thing = "Failing to pass a value for the 'fields' parameter" + example = f"`{typename} = NamedTuple({typename!r}, [])`" + deprecation_msg = ( + "{name} is deprecated and will be disallowed in Python {remove}. " + "To create a NamedTuple class with 0 fields " + "using the functional syntax, " + "pass an empty list, e.g. " + ) + example + "." + elif fields is None: + if kwargs: + raise TypeError( + "Cannot pass `None` as the 'fields' parameter " + "and also specify fields using keyword arguments" + ) + else: + deprecated_thing = "Passing `None` as the 'fields' parameter" + example = f"`{typename} = NamedTuple({typename!r}, [])`" + deprecation_msg = ( + "{name} is deprecated and will be disallowed in Python {remove}. " + "To create a NamedTuple class with 0 fields " + "using the functional syntax, " + "pass an empty list, e.g. " + ) + example + "." elif kwargs: raise TypeError("Either list of fields or keywords" " can be provided to NamedTuple, not both") - try: - module = sys._getframe(1).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): - module = None - return _make_nmtuple(typename, fields, module=module) + if fields is _sentinel or fields is None: + import warnings + warnings._deprecated(deprecated_thing, message=deprecation_msg, remove=(3, 15)) + fields = kwargs.items() + nt = _make_nmtuple(typename, fields, module=_caller()) + nt.__orig_bases__ = (NamedTuple,) + return nt _NamedTuple = type.__new__(NamedTupleMeta, 'NamedTuple', (), {}) def _namedtuple_mro_entries(bases): - if len(bases) > 1: - raise TypeError("Multiple inheritance with NamedTuple is not supported") - assert bases[0] is NamedTuple + assert NamedTuple in bases return (_NamedTuple,) NamedTuple.__mro_entries__ = _namedtuple_mro_entries +def _get_typeddict_qualifiers(annotation_type): + while True: + annotation_origin = get_origin(annotation_type) + if annotation_origin is Annotated: + annotation_args = get_args(annotation_type) + if annotation_args: + annotation_type = annotation_args[0] + else: + break + elif annotation_origin is Required: + yield Required + (annotation_type,) = get_args(annotation_type) + elif annotation_origin is NotRequired: + yield NotRequired + (annotation_type,) = get_args(annotation_type) + elif annotation_origin is ReadOnly: + yield ReadOnly + (annotation_type,) = get_args(annotation_type) + else: + break + + class _TypedDictMeta(type): def __new__(cls, name, bases, ns, total=True): - """Create new typed dict class object. + """Create a new typed dict class object. This method is called when TypedDict is subclassed, or when TypedDict is instantiated. This way @@ -2343,14 +3144,22 @@ def __new__(cls, name, bases, ns, total=True): Subclasses and instances of TypedDict return actual dictionaries. """ for base in bases: - if type(base) is not _TypedDictMeta: + if type(base) is not _TypedDictMeta and base is not Generic: raise TypeError('cannot inherit from both a TypedDict type ' 'and a non-TypedDict base class') - tp_dict = type.__new__(_TypedDictMeta, name, (dict,), ns) + + if any(issubclass(b, Generic) for b in bases): + generic_base = (Generic,) + else: + generic_base = () + + tp_dict = type.__new__(_TypedDictMeta, name, (*generic_base, dict), ns) + + if not hasattr(tp_dict, '__orig_bases__'): + tp_dict.__orig_bases__ = bases annotations = {} own_annotations = ns.get('__annotations__', {}) - own_annotation_keys = set(own_annotations.keys()) msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type" own_annotations = { n: _type_check(tp, msg, module=tp_dict.__module__) @@ -2358,23 +3167,61 @@ def __new__(cls, name, bases, ns, total=True): } required_keys = set() optional_keys = set() + readonly_keys = set() + mutable_keys = set() for base in bases: annotations.update(base.__dict__.get('__annotations__', {})) - required_keys.update(base.__dict__.get('__required_keys__', ())) - optional_keys.update(base.__dict__.get('__optional_keys__', ())) + + base_required = base.__dict__.get('__required_keys__', set()) + required_keys |= base_required + optional_keys -= base_required + + base_optional = base.__dict__.get('__optional_keys__', set()) + required_keys -= base_optional + optional_keys |= base_optional + + readonly_keys.update(base.__dict__.get('__readonly_keys__', ())) + mutable_keys.update(base.__dict__.get('__mutable_keys__', ())) annotations.update(own_annotations) - if total: - required_keys.update(own_annotation_keys) - else: - optional_keys.update(own_annotation_keys) + for annotation_key, annotation_type in own_annotations.items(): + qualifiers = set(_get_typeddict_qualifiers(annotation_type)) + if Required in qualifiers: + is_required = True + elif NotRequired in qualifiers: + is_required = False + else: + is_required = total + if is_required: + required_keys.add(annotation_key) + optional_keys.discard(annotation_key) + else: + optional_keys.add(annotation_key) + required_keys.discard(annotation_key) + + if ReadOnly in qualifiers: + if annotation_key in mutable_keys: + raise TypeError( + f"Cannot override mutable key {annotation_key!r}" + " with read-only key" + ) + readonly_keys.add(annotation_key) + else: + mutable_keys.add(annotation_key) + readonly_keys.discard(annotation_key) + + assert required_keys.isdisjoint(optional_keys), ( + f"Required keys overlap with optional keys in {name}:" + f" {required_keys=}, {optional_keys=}" + ) tp_dict.__annotations__ = annotations tp_dict.__required_keys__ = frozenset(required_keys) tp_dict.__optional_keys__ = frozenset(optional_keys) - if not hasattr(tp_dict, '__total__'): - tp_dict.__total__ = total + tp_dict.__readonly_keys__ = frozenset(readonly_keys) + tp_dict.__mutable_keys__ = frozenset(mutable_keys) + tp_dict.__total__ = total return tp_dict __call__ = dict # static method @@ -2386,72 +3233,164 @@ def __subclasscheck__(cls, other): __instancecheck__ = __subclasscheck__ -def TypedDict(typename, fields=None, /, *, total=True, **kwargs): +def TypedDict(typename, fields=_sentinel, /, *, total=True): """A simple typed namespace. At runtime it is equivalent to a plain dict. - TypedDict creates a dictionary type that expects all of its + TypedDict creates a dictionary type such that a type checker will expect all instances to have a certain set of keys, where each key is associated with a value of a consistent type. This expectation - is not checked at runtime but is only enforced by type checkers. - Usage:: - - class Point2D(TypedDict): - x: int - y: int - label: str + is not checked at runtime. - a: Point2D = {'x': 1, 'y': 2, 'label': 'good'} # OK - b: Point2D = {'z': 3, 'label': 'bad'} # Fails type check + Usage:: - assert Point2D(x=1, y=2, label='first') == dict(x=1, y=2, label='first') + >>> class Point2D(TypedDict): + ... x: int + ... y: int + ... label: str + ... + >>> a: Point2D = {'x': 1, 'y': 2, 'label': 'good'} # OK + >>> b: Point2D = {'z': 3, 'label': 'bad'} # Fails type check + >>> Point2D(x=1, y=2, label='first') == dict(x=1, y=2, label='first') + True The type info can be accessed via the Point2D.__annotations__ dict, and the Point2D.__required_keys__ and Point2D.__optional_keys__ frozensets. - TypedDict supports two additional equivalent forms:: + TypedDict supports an additional equivalent form:: - Point2D = TypedDict('Point2D', x=int, y=int, label=str) Point2D = TypedDict('Point2D', {'x': int, 'y': int, 'label': str}) By default, all keys must be present in a TypedDict. It is possible - to override this by specifying totality. - Usage:: + to override this by specifying totality:: - class point2D(TypedDict, total=False): + class Point2D(TypedDict, total=False): x: int y: int - This means that a point2D TypedDict can have any of the keys omitted.A type + This means that a Point2D TypedDict can have any of the keys omitted. A type checker is only expected to support a literal False or True as the value of the total argument. True is the default, and makes all items defined in the class body be required. - The class syntax is only supported in Python 3.6+, while two other - syntax forms work for Python 2.7 and 3.2+ + The Required and NotRequired special forms can also be used to mark + individual keys as being required or not required:: + + class Point2D(TypedDict): + x: int # the "x" key must always be present (Required is the default) + y: NotRequired[int] # the "y" key can be omitted + + See PEP 655 for more details on Required and NotRequired. + + The ReadOnly special form can be used + to mark individual keys as immutable for type checkers:: + + class DatabaseUser(TypedDict): + id: ReadOnly[int] # the "id" key must not be modified + username: str # the "username" key can be changed + """ - if fields is None: - fields = kwargs - elif kwargs: - raise TypeError("TypedDict takes either a dict or keyword arguments," - " but not both") + if fields is _sentinel or fields is None: + import warnings + + if fields is _sentinel: + deprecated_thing = "Failing to pass a value for the 'fields' parameter" + else: + deprecated_thing = "Passing `None` as the 'fields' parameter" + + example = f"`{typename} = TypedDict({typename!r}, {{{{}}}})`" + deprecation_msg = ( + "{name} is deprecated and will be disallowed in Python {remove}. " + "To create a TypedDict class with 0 fields " + "using the functional syntax, " + "pass an empty dictionary, e.g. " + ) + example + "." + warnings._deprecated(deprecated_thing, message=deprecation_msg, remove=(3, 15)) + fields = {} ns = {'__annotations__': dict(fields)} - try: + module = _caller() + if module is not None: # Setting correct module is necessary to make typed dict classes pickleable. - ns['__module__'] = sys._getframe(1).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): - pass + ns['__module__'] = module - return _TypedDictMeta(typename, (), ns, total=total) + td = _TypedDictMeta(typename, (), ns, total=total) + td.__orig_bases__ = (TypedDict,) + return td _TypedDict = type.__new__(_TypedDictMeta, 'TypedDict', (), {}) TypedDict.__mro_entries__ = lambda bases: (_TypedDict,) +@_SpecialForm +def Required(self, parameters): + """Special typing construct to mark a TypedDict key as required. + + This is mainly useful for total=False TypedDicts. + + For example:: + + class Movie(TypedDict, total=False): + title: Required[str] + year: int + + m = Movie( + title='The Matrix', # typechecker error if key is omitted + year=1999, + ) + + There is no runtime checking that a required key is actually provided + when instantiating a related TypedDict. + """ + item = _type_check(parameters, f'{self._name} accepts only a single type.') + return _GenericAlias(self, (item,)) + + +@_SpecialForm +def NotRequired(self, parameters): + """Special typing construct to mark a TypedDict key as potentially missing. + + For example:: + + class Movie(TypedDict): + title: str + year: NotRequired[int] + + m = Movie( + title='The Matrix', # typechecker error if key is omitted + year=1999, + ) + """ + item = _type_check(parameters, f'{self._name} accepts only a single type.') + return _GenericAlias(self, (item,)) + + +@_SpecialForm +def ReadOnly(self, parameters): + """A special typing construct to mark an item of a TypedDict as read-only. + + For example:: + + class Movie(TypedDict): + title: ReadOnly[str] + year: int + + def mutate_movie(m: Movie) -> None: + m["year"] = 1992 # allowed + m["title"] = "The Matrix" # typechecker error + + There is no runtime checking for this property. + """ + item = _type_check(parameters, f'{self._name} accepts only a single type.') + return _GenericAlias(self, (item,)) + + class NewType: - """NewType creates simple unique types with almost zero - runtime overhead. NewType(name, tp) is considered a subtype of tp + """NewType creates simple unique types with almost zero runtime overhead. + + NewType(name, tp) is considered a subtype of tp by static type checkers. At runtime, NewType(name, tp) returns - a dummy callable that simply returns its argument. Usage:: + a dummy callable that simply returns its argument. + + Usage:: UserId = NewType('UserId', int) @@ -2466,6 +3405,8 @@ def name_by_id(user_id: UserId) -> str: num = UserId(5) + 1 # type: int """ + __call__ = _idfunc + def __init__(self, name, tp): self.__qualname__ = name if '.' in name: @@ -2476,12 +3417,24 @@ def __init__(self, name, tp): if def_mod != 'typing': self.__module__ = def_mod + def __mro_entries__(self, bases): + # We defined __mro_entries__ to get a better error message + # if a user attempts to subclass a NewType instance. bpo-46170 + superclass_name = self.__name__ + + class Dummy: + def __init_subclass__(cls): + subclass_name = cls.__name__ + raise TypeError( + f"Cannot subclass an instance of NewType. Perhaps you were looking for: " + f"`{subclass_name} = NewType({subclass_name!r}, {superclass_name})`" + ) + + return (Dummy,) + def __repr__(self): return f'{self.__module__}.{self.__qualname__}' - def __call__(self, x): - return x - def __reduce__(self): return self.__qualname__ @@ -2648,28 +3601,215 @@ def __enter__(self) -> 'TextIO': pass -class io: - """Wrapper namespace for IO generic classes.""" +def reveal_type[T](obj: T, /) -> T: + """Ask a static type checker to reveal the inferred type of an expression. + + When a static type checker encounters a call to ``reveal_type()``, + it will emit the inferred type of the argument:: + + x: int = 1 + reveal_type(x) + + Running a static type checker (e.g., mypy) on this example + will produce output similar to 'Revealed type is "builtins.int"'. + + At runtime, the function prints the runtime type of the + argument and returns the argument unchanged. + """ + print(f"Runtime type is {type(obj).__name__!r}", file=sys.stderr) + return obj + + +class _IdentityCallable(Protocol): + def __call__[T](self, arg: T, /) -> T: + ... + + +def dataclass_transform( + *, + eq_default: bool = True, + order_default: bool = False, + kw_only_default: bool = False, + frozen_default: bool = False, + field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = (), + **kwargs: Any, +) -> _IdentityCallable: + """Decorator to mark an object as providing dataclass-like behaviour. + + The decorator can be applied to a function, class, or metaclass. + + Example usage with a decorator function:: + + @dataclass_transform() + def create_model[T](cls: type[T]) -> type[T]: + ... + return cls + + @create_model + class CustomerModel: + id: int + name: str + + On a base class:: + + @dataclass_transform() + class ModelBase: ... + + class CustomerModel(ModelBase): + id: int + name: str + + On a metaclass:: + + @dataclass_transform() + class ModelMeta(type): ... + + class ModelBase(metaclass=ModelMeta): ... + + class CustomerModel(ModelBase): + id: int + name: str + + The ``CustomerModel`` classes defined above will + be treated by type checkers similarly to classes created with + ``@dataclasses.dataclass``. + For example, type checkers will assume these classes have + ``__init__`` methods that accept ``id`` and ``name``. + + The arguments to this decorator can be used to customize this behavior: + - ``eq_default`` indicates whether the ``eq`` parameter is assumed to be + ``True`` or ``False`` if it is omitted by the caller. + - ``order_default`` indicates whether the ``order`` parameter is + assumed to be True or False if it is omitted by the caller. + - ``kw_only_default`` indicates whether the ``kw_only`` parameter is + assumed to be True or False if it is omitted by the caller. + - ``frozen_default`` indicates whether the ``frozen`` parameter is + assumed to be True or False if it is omitted by the caller. + - ``field_specifiers`` specifies a static list of supported classes + or functions that describe fields, similar to ``dataclasses.field()``. + - Arbitrary other keyword arguments are accepted in order to allow for + possible future extensions. + + At runtime, this decorator records its arguments in the + ``__dataclass_transform__`` attribute on the decorated object. + It has no other runtime effect. + + See PEP 681 for more details. + """ + def decorator(cls_or_fn): + cls_or_fn.__dataclass_transform__ = { + "eq_default": eq_default, + "order_default": order_default, + "kw_only_default": kw_only_default, + "frozen_default": frozen_default, + "field_specifiers": field_specifiers, + "kwargs": kwargs, + } + return cls_or_fn + return decorator + + +type _Func = Callable[..., Any] + + +def override[F: _Func](method: F, /) -> F: + """Indicate that a method is intended to override a method in a base class. + + Usage:: + + class Base: + def method(self) -> None: + pass + + class Child(Base): + @override + def method(self) -> None: + super().method() + + When this decorator is applied to a method, the type checker will + validate that it overrides a method or attribute with the same name on a + base class. This helps prevent bugs that may occur when a base class is + changed without an equivalent change to a child class. + + There is no runtime checking of this property. The decorator attempts to + set the ``__override__`` attribute to ``True`` on the decorated object to + allow runtime introspection. + + See PEP 698 for details. + """ + try: + method.__override__ = True + except (AttributeError, TypeError): + # Skip the attribute silently if it is not writable. + # AttributeError happens if the object has __slots__ or a + # read-only property, TypeError if it's a builtin class. + pass + return method + + +def is_protocol(tp: type, /) -> bool: + """Return True if the given type is a Protocol. + + Example:: + + >>> from typing import Protocol, is_protocol + >>> class P(Protocol): + ... def a(self) -> str: ... + ... b: int + >>> is_protocol(P) + True + >>> is_protocol(int) + False + """ + return ( + isinstance(tp, type) + and getattr(tp, '_is_protocol', False) + and tp != Protocol + ) - __all__ = ['IO', 'TextIO', 'BinaryIO'] - IO = IO - TextIO = TextIO - BinaryIO = BinaryIO +def get_protocol_members(tp: type, /) -> frozenset[str]: + """Return the set of members defined in a Protocol. -io.__name__ = __name__ + '.io' -sys.modules[io.__name__] = io + Example:: -Pattern = _alias(stdlib_re.Pattern, 1) -Match = _alias(stdlib_re.Match, 1) + >>> from typing import Protocol, get_protocol_members + >>> class P(Protocol): + ... def a(self) -> str: ... + ... b: int + >>> get_protocol_members(P) == frozenset({'a', 'b'}) + True -class re: - """Wrapper namespace for re type aliases.""" + Raise a TypeError for arguments that are not Protocols. + """ + if not is_protocol(tp): + raise TypeError(f'{tp!r} is not a Protocol') + return frozenset(tp.__protocol_attrs__) - __all__ = ['Pattern', 'Match'] - Pattern = Pattern - Match = Match +def __getattr__(attr): + """Improve the import time of the typing module. -re.__name__ = __name__ + '.re' -sys.modules[re.__name__] = re + Soft-deprecated objects which are costly to create + are only created on-demand here. + """ + if attr in {"Pattern", "Match"}: + import re + obj = _alias(getattr(re, attr), 1) + elif attr in {"ContextManager", "AsyncContextManager"}: + import contextlib + obj = _alias(getattr(contextlib, f"Abstract{attr}"), 2, name=attr, defaults=(bool | None,)) + elif attr == "_collect_parameters": + import warnings + + depr_message = ( + "The private _collect_parameters function is deprecated and will be" + " removed in a future version of Python. Any use of private functions" + " is discouraged and may break in the future." + ) + warnings.warn(depr_message, category=DeprecationWarning, stacklevel=2) + obj = _collect_type_parameters + else: + raise AttributeError(f"module {__name__!r} has no attribute {attr!r}") + globals()[attr] = obj + return obj \ No newline at end of file diff --git a/vm/src/stdlib/typing.rs b/vm/src/stdlib/typing.rs index 29ce516b14..a6c518eace 100644 --- a/vm/src/stdlib/typing.rs +++ b/vm/src/stdlib/typing.rs @@ -1,11 +1,24 @@ -pub(crate) use _typing::make_module; +use crate::{ + builtins::{PyModule, PyTypeRef}, + types::{Constructor, Representable}, + PyRef, PyResult, VirtualMachine, +}; + +pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { + let module = _typing::make_module(vm); + module + .set_attr("NoDefault", vm.ctx.no_default.clone(), vm) + .unwrap(); + module +} #[pymodule] pub(crate) mod _typing { use crate::{ - builtins::{pystr::AsPyStr, PyGenericAlias, PyTupleRef, PyTypeRef}, + builtins::{pystr::AsPyStr, PyGenericAlias, PyStrRef, PyTupleRef, PyTypeRef}, function::IntoFuncArgs, - PyObjectRef, PyPayload, PyResult, VirtualMachine, + types::Representable, + Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, }; pub(crate) fn _call_typing_func_object<'a>( @@ -127,7 +140,7 @@ pub(crate) mod _typing { // compute_value: PyObjectRef, // module: PyObjectRef, } - #[pyclass(flags(BASETYPE))] + #[pyclass(with(Representable), flags(BASETYPE, IMMUTABLETYPE))] impl TypeAliasType { pub fn new( name: PyObjectRef, @@ -142,6 +155,20 @@ pub(crate) mod _typing { } } + impl Representable for TypeAliasType { + fn repr(zelf: &Py, vm: &VirtualMachine) -> PyResult { + zelf.name + .clone() + .downcast() + .map_err(|_| vm.new_type_error("TypeAliasType.__repr__ doesn't return str.".into())) + } + + #[cold] + fn repr_str(_zelf: &Py, _vm: &VirtualMachine) -> PyResult { + unreachable!("use repr instead") + } + } + #[pyattr] #[pyclass(name)] #[derive(Debug, PyPayload)] @@ -170,3 +197,30 @@ pub(crate) mod _typing { // } // } } + +#[pyclass(module = false, name)] +#[derive(Debug, PyPayload)] +#[allow(dead_code)] +pub struct NoDefaultType {} + +#[pyclass(with(Constructor, Representable))] +impl NoDefaultType { + #[pymethod(magic)] + fn reduce() -> String { + "NoDefault".to_string() + } +} + +impl Constructor for NoDefaultType { + type Args = (); + + fn py_new(_: PyTypeRef, _args: Self::Args, vm: &VirtualMachine) -> PyResult { + Ok(vm.ctx.no_default.clone().into()) + } +} + +impl Representable for NoDefaultType { + fn repr_str(_zelf: &crate::Py, _vm: &VirtualMachine) -> crate::PyResult { + Ok("typing.NoDefault".to_owned()) + } +} diff --git a/vm/src/types/zoo.rs b/vm/src/types/zoo.rs index 492b584e29..213758ac92 100644 --- a/vm/src/types/zoo.rs +++ b/vm/src/types/zoo.rs @@ -9,6 +9,7 @@ use crate::{ union_, weakproxy, weakref, zip, }, class::StaticType, + stdlib::typing, vm::Context, Py, }; @@ -92,6 +93,7 @@ pub struct TypeZoo { pub generic_alias_type: &'static Py, pub union_type: &'static Py, pub member_descriptor_type: &'static Py, + pub no_default_type: &'static Py, // RustPython-original types pub method_def: &'static Py, @@ -184,6 +186,8 @@ impl TypeZoo { union_type: union_::PyUnion::init_builtin_type(), member_descriptor_type: descriptor::PyMemberDescriptor::init_builtin_type(), + no_default_type: typing::NoDefaultType::init_builtin_type(), + method_def: crate::function::HeapMethodDef::init_builtin_type(), } } diff --git a/vm/src/vm/context.rs b/vm/src/vm/context.rs index f67035d0f3..f2f8503721 100644 --- a/vm/src/vm/context.rs +++ b/vm/src/vm/context.rs @@ -22,6 +22,7 @@ use crate::{ }, intern::{InternableString, MaybeInternedString, StringPool}, object::{Py, PyObjectPayload, PyObjectRef, PyPayload, PyRef}, + stdlib::typing::NoDefaultType, types::{PyTypeFlags, PyTypeSlots, TypeZoo}, PyResult, VirtualMachine, }; @@ -41,6 +42,7 @@ pub struct Context { pub empty_bytes: PyRef, pub ellipsis: PyRef, pub not_implemented: PyRef, + pub no_default: PyRef, pub types: TypeZoo, pub exceptions: exceptions::ExceptionZoo, @@ -303,6 +305,7 @@ impl Context { let empty_str = unsafe { string_pool.intern("", types.str_type.to_owned()) }; let empty_bytes = create_object(PyBytes::from(Vec::new()), types.bytes_type); + let no_default = create_object(NoDefaultType {}, types.no_default_type); Context { true_value, false_value, @@ -311,6 +314,7 @@ impl Context { empty_frozenset, empty_str, empty_bytes, + no_default, ellipsis, not_implemented,