8000 Add basic co/contravariance feature to generics. · python/typing@9a434e6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9a434e6

Browse files
author
Guido van Rossum
committed
Add basic co/contravariance feature to generics.
1 parent 097dcf0 commit 9a434e6

File tree

2 files changed

+135
-23
lines changed

2 files changed

+135
-23
lines changed

prototyping/test_typing.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ def test_repr(self):
139139
self.assertEqual(repr(KT), '~KT')
140140
self.assertEqual(repr(VT), '~VT')
141141
self.assertEqual(repr(AnyStr), '~AnyStr')
142+
self.assertEqual(repr(typing.T_in), '+T_in')
143+
T_out = TypeVar('T_out', kind='out')
144+
self.assertEqual(repr(T_out), '-T_out')
142145

143146
def test_no_redefinition(self):
144147
self.assertNotEqual(TypeVar('T'), TypeVar('T'))
@@ -843,6 +846,47 @@ def test_type_alias(self):
843846
Undefined(typing.re.Pattern[bytes])
844847
Undefined(typing.re.Pattern[Any])
845848

849+
def test_invariance(self):
850+
# Because of invariance, List[subclass of X] is not a subclass
851+
# of List[X], and ditto for MutableSequence.
852+
assert not issubclass(typing.List[Manager], typing.List[Employee])
853+
assert not issubclass(typing.MutableSequence[Manager],
854+
typing.MutableSequence[Employee])
855+
# It's still reflexive.
856+
assert issubclass(typing.List[Employee], typing.List[Employee])
857+
assert issubclass(typing.MutableSequence[Employee],
858+
typing.MutableSequence[Employee])
859+
860+
def test_covariance_tuple(self):
861+
# Check covariace for Tuple (which are really special cases).
862+
assert issubclass(Tuple[Manager], Tuple[Employee])
863+
assert not issubclass(Tuple[Employee], Tuple[Manager])
864+
# And pairwise.
865+
assert issubclass(Tuple[Manager, Manager], Tuple[Employee, Employee])
866+
assert not issubclass(Tuple[Employee, Employee],
867+
Tuple[Manager, Employee])
868+
# And using ellipsis.
869+
assert issubclass(Tuple[Manager, ...], Tuple[Employee, ...])
870+
assert not issubclass(Tuple[Employee, ...], Tuple[Manager, ...])
871+
872+
def test_covariance_sequence(self):
873+
# Check covariance for Sequence (which is just a generic class
874+
# for this purpose, but using a covariant type variable).
875+
assert issubclass(typing.Sequence[Manager], typing.Sequence[Employee])
876+
assert not issubclass(typing.Sequence[Employee],
877+
typing.Sequence[Manager])
878+
879+
def test_covariance_mapping(self):
880+
# Ditto for Mapping (a generic class with two parameters).
881+
assert issubclass(typing.Mapping[Employee, Manager],
882+
typing.Mapping[Employee, Employee])
883+
assert issubclass(typing.Mapping[Manager, Employee],
884+
typing.Mapping[Employee, Employee])
885+
assert not issubclass(typing.Mapping[Employee, Manager],
886+
typing.Mapping[Manager, Manager])
887+
assert not issubclass(typing.Mapping[Manager, Employee],
888+
typing.Mapping[Manager, Manager])
889+
846890

847891
class CastTest(TestCase):
848892

prototyping/typing.py

Lines changed: 91 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,8 @@ class TypeVar(TypingMeta, metaclass=TypingMeta, _root=True):
338338
T1 = TypeVar('T1') # Unconstrained
339339
T2 = TypeVar('T2', t1, t2, ...) # Constrained to any of (t1, t2, ...)
340340
341+
(TODO: The rules below are inconsistent. See issue 62.)
342+
341343
For an unconstrained type variable T, isinstance(x, T) is false
342344
for all x, and similar for issubclass(cls, T). Example::
343345
@@ -388,9 +390,15 @@ class MyStr(str):
388390
389391
"""
390392

391-
def __new__(cls, name, *constraints):
393+
def __new__(cls, name, *constraints, kind=None):
392394
self = super().__new__(cls, name, (Final,), {}, _root=True)
395+
if kind not in (None, 'in', 'out'):
396+
raise ValueError("kind must be 'in', 'out' or None; got %r." %
397+
(kind,))
398+
if kind is not None and constraints:
399+
raise TypeError("kind does not combine with constraints")
393400
msg = "TypeVar(name, constraint, ...): constraints must be types."
401+
self.__kind__ = kind
394402
self.__constraints__ = tuple(_type_check(t, msg) for t in constraints)
395403
self.__binding__ = None
396404
return self
@@ -399,7 +407,13 @@ def _has_type_var(self):
399407
return True
400408

401409
def __repr__(self):
402-
return '~' + self.__name__
410+
if self.__kind__ == 'in':
411+
prefix = '+'
412+
elif self.__kind__ == 'out':
413+
prefix = '-'
414+
else:
415+
prefix = '~'
416+
return prefix + self.__name__
403417

404418
def __instancecheck__(self, instance):
405419
if self.__binding__ is not None:
@@ -500,6 +514,9 @@ def __exit__(self, *args):
500514
T = TypeVar('T') # Any type.
501515
KT = TypeVar('KT') # Key type.
502516
VT = TypeVar('VT') # Value type.
517+
T_in = TypeVar('T_in', kind='in') # Any type, for covariant containers.
518+
KT_in = TypeVar('KT_in', kind='in') # Key type, for covariant containers.
519+
VT_in = TypeVar('VT_in', kind='in') # Value type, for covariant containers.
503520

504521
# A useful type variable with constraints. This represents string types.
505522
# TODO: What about bytearray, memoryview?
@@ -947,9 +964,15 @@ class GenericMeta(TypingMeta, abc.ABCMeta):
947964
# TODO: Constrain more how Generic is used; only a few
948965
# standard patterns should be allowed.
949966

967+
# TODO: Use a more precise rule than matching __name__ to decide
968+
# whether two classes are the same. Also, save the formal
969+
# parameters. (These things are related! A solution lies in
970+
# using origin.)
971+
950972
__extra__ = None
951973

952-
def __new__(cls, name, bases, namespace, parameters=None, extra=None):
974+
def __new__(cls, name, bases, namespace,
975+
parameters=None, origin=None, extra=None):
953976
if parameters is None:
954977
# Extract parameters from direct base classes. Only
955978
# direct bases are considered and only those that are
@@ -983,6 +1006,7 @@ def __new__(cls, name, bases, namespace, parameters=None, extra=None):
9831006
self.__extra__ = extra
9841007
# Else __extra__ is inherited, eventually from the
9851008
# (meta-)class default above.
1009+
self.__origin__ = origin
9861010
return self
9871011

9881012
def _has_type_var(self):
@@ -1035,14 +1059,49 @@ def __getitem__(self, params):
10351059

10361060
return self.__class__(self.__name__, self.__bases__,
10371061
dict(self.__dict__),
1038-
parameters=params, extra=self.__extra__)
1062+
parameters=params,
1063+
origin=self,
1064+
extra=self.__extra__)
10391065

10401066
def __subclasscheck__(self, cls):
10411067
if cls is Any:
10421068
return True
1069+
if isinstance(cls, GenericMeta):
1070+
# For a class C(Generic[T]) where T is co-variant,
1071+
# C[X] is a subclass of C[Y] iff X is a subclass of Y.
1072+
origin = self.__origin__
1073+
if origin is not None and origin is cls.__origin__:
1074+
assert len(self.__parameters__) == len(origin.__parameters__)
1075+
assert len(cls.__parameters__) == len(origin.__parameters__)
1076+
for p_self, p_cls, p_origin in zip(self.__parameters__,
1077+
cls.__parameters__,
1078+
origin.__parameters__):
1079+
if isinstance(p_origin, TypeVar):
1080+
if p_origin.__kind__ is None:
1081+
# Invariant -- p_cls and p_self must equal.
1082+
if p_self != p_cls:
1083+
break
1084+
elif p_origin.__kind__ == 'in':
1085+
# Covariant -- p_cls must be a subclass of p_self.
1086+
if not issubclass(p_cls, p_self):
1087+
break
1088+
elif p_origin.__kind__ == 'out':
1089+
# Contravariant. I think it's the opposite. :-)
1090+
if not issubclass(p_self, p_cls):
1091+
break
1092+
else:
1093+
assert False, p_origin.__kind__
1094+
else:
1095+
# If the origin's parameter is not a typevar,
1096+
# insist on invariance.
1097+
if p_self != p_cls:
1098+
break
1099+
else:
1100+
return True
1101+
# If we break out of the loop, the superclass gets a chance.
10431102
if super().__subclasscheck__(cls):
10441103
return True
1045-
if self.__extra__ is None:
1104+
if self.__extra__ is None or isinstance(cls, GenericMeta):
10461105
return False
10471106
return issubclass(cls, self.__extra__)
10481107

@@ -1233,18 +1292,14 @@ def overload(func):
12331292
raise RuntimeError("Overloading is only supported in library stubs")
12341293

12351294

1236-
class _Protocol(Generic):
1237-
"""Internal base class for protocol classes.
1295+
class _ProtocolMeta(GenericMeta):
1296+
"""Internal metaclass for _Protocol.
12381297
1239-
This implements a simple-minded structural isinstance check
1240-
(similar but more general than the one-offs in collections.abc
1241-
such as Hashable).
1298+
This exists so _Protocol classes can be generic without deriving
1299+
from Generic.
12421300
"""
12431301

1244-
_is_protocol = True
1245-
1246-
@classmethod
1247-
def __subclasshook__(self, cls):
1302+
def __subclasscheck__(self, cls):
12481303
if not self._is_protocol:
12491304
# No structural checks since this isn't a protocol.
12501305
return NotImplemented
@@ -1258,10 +1313,9 @@ def __subclasshook__(self, cls):
12581313

12591314
for attr in attrs:
12601315
if not any(attr in d.__dict__ for d in cls.__mro__):
1261-
return NotImplemented
1316+
return False
12621317
return True
12631318

1264-
@classmethod
12651319
def _get_protocol_attrs(self):
12661320
# Get all Protocol base classes.
12671321
protocol_bases = []
@@ -1284,19 +1338,32 @@ def _get_protocol_attrs(self):
12841338
attr != '_is_protocol' and
12851339
attr != '__dict__' and
12861340
attr != '_get_protocol_attrs' and
1341+
attr != '__parameters__' and
1342+
attr != '__origin__' and
12871343
attr != '__module__'):
12881344
attrs.add(attr)
12891345

12901346
return attrs
12911347

12921348

1349+
class _Protocol(metaclass=_ProtocolMeta):
1350+
"""Internal base class for protocol classes.
1351+
1352+
This implements a simple-minded structural isinstance check
1353+
(similar but more general than the one-offs in collections.abc
1354+
such as Hashable).
1355+
"""
1356+
1357+
_is_protocol = True
1358+
1359+
12931360
# Various ABCs mimicking those in collections.abc.
12941361
# A few are simply re-exported for completeness.
12951362

12961363
Hashable = collections_abc.Hashable # Not generic.
12971364

12981365

1299-
class Iterable(Generic[T], extra=collections_abc.Iterable):
1366+
class Iterable(Generic[T_in], extra=collections_abc.Iterable):
13001367
pass
13011368

13021369

@@ -1356,7 +1423,7 @@ def __reversed__(self) -> 'Iterator[T]':
13561423
Sized = collections_abc.Sized # Not generic.
13571424

13581425

1359-
class Container(Generic[T], extra=collections_abc.Container):
1426+
class Container(Generic[T_in], extra=collections_abc.Container):
13601427
pass
13611428

13621429

@@ -1367,24 +1434,24 @@ class AbstractSet(Sized, Iterable, Container, extra=collections_abc.Set):
13671434
pass
13681435

13691436

1370-
class MutableSet(AbstractSet, extra=collections_abc.MutableSet):
1437+
class MutableSet(AbstractSet[T], extra=collections_abc.MutableSet):
13711438
pass
13721439

13731440

1374-
class Mapping(Sized, Iterable[KT], Container[KT], Generic[KT, VT],
1441+
class Mapping(Sized, Iterable[KT_in], Container[KT_in], Generic[KT_in, VT_in],
13751442
extra=collections_abc.Mapping):
13761443
pass
13771444

13781445

1379-
class MutableMapping(Mapping, extra=collections_abc.MutableMapping):
1446+
class MutableMapping(Mapping[KT, VT], extra=collections_abc.MutableMapping):
13801447
pass
13811448

13821449

13831450
class Sequence(Sized, Iterable, Container, extra=collections_abc.Sequence):
13841451
pass
13851452

13861453

1387-
class MutableSequence(Sequence, extra=collections_abc.MutableSequence):
1454+
class MutableSequence(Sequence[T], extra=collections_abc.MutableSequence):
13881455
pass
13891456

13901457

@@ -1459,7 +1526,8 @@ class KeysView(MappingView, Set[KT], extra=collections_abc.KeysView):
14591526

14601527

14611528
# TODO: Enable Set[Tuple[KT, VT]] instead of Generic[KT, VT].
1462-
class ItemsView(MappingView, Generic[KT, VT], extra=collections_abc.ItemsView):
1529+
class ItemsView(MappingView, Generic[KT_in, VT_in],
1530+
extra=collections_abc.ItemsView):
14631531
pass
14641532

14651533

0 commit comments

Comments
 (0)
0