8000 Backport parameter defaults for `(Async)Generator` and `(Async)Contex… · python/typing_extensions@12e901e · GitHub
[go: up one dir, main page]

Skip to content

Commit 12e901e

Browse files
authored
Backport parameter defaults for (Async)Generator and (Async)ContextManager (#382)
1 parent 781e996 commit 12e901e

File tree

4 files changed

+151
-19
lines changed

4 files changed

+151
-19
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@
2222
- At runtime, `assert_never` now includes the repr of the argument
2323
in the `AssertionError`. Patch by Hashem, backporting of the original
2424
fix https://github.com/python/cpython/pull/91720 by Jelle Zijlstra.
25+
- The second and third parameters of `typing_extensions.Generator`,
26+
and the second parameter of `typing_extensions.AsyncGenerator`,
27+
now default to `None`. This matches the behaviour of `typing.Generator`
28+
and `typing.AsyncGenerator` on Python 3.13+.
29+
- `typing.ContextManager` and `typing.AsyncContextManager` now have an
30+
optional second parameter, which defaults to `Optional[bool]`. The new
31+
parameter signifies the return type of the `__(a)exit__` method,
32+
matching `typing.ContextManager` and `typing.AsyncContextManager` on
33+
Python 3.13+.
2534

2635
# Release 4.11.0 (April 5, 2024)
2736

doc/index.rst

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -885,8 +885,8 @@ Annotation metadata
885885
Pure aliases
886886
~~~~~~~~~~~~
887887

888-
These are simply re-exported from the :mod:`typing` module on all supported
889-
versions of Python. They are listed here for completeness.
888+
Most of these are simply re-exported from the :mod:`typing` module on all supported
889+
versions of Python, but all are listed here for completeness.
890890

891891
.. class:: AbstractSet
892892

@@ -904,10 +904,19 @@ versions of Python. They are listed here for completeness.
904904

905905
See :py:class:`typing.AsyncContextManager`. In ``typing`` since 3.5.4 and 3.6.2.
906906

907+
.. versionchanged:: 4.12.0
908+
909+
``AsyncContextManager`` now has an optional second parameter, defaulting to
910+
``Optional[bool]``, signifying the return type of the ``__aexit__`` method.
911+
907912
.. class:: AsyncGenerator
908913

909914
See :py:class:`typing.AsyncGenerator`. In ``typing`` since 3.6.1.
910915

916+
.. versionchanged:: 4.12.0
917+
918+
The second type parameter is now optional (it defaults to ``None``).
919+
911920
.. class:: AsyncIterable
912921

913922
See :py:class:`typing.AsyncIterable`. In ``typing`` since 3.5.2.
@@ -956,6 +965,11 @@ versions of Python. They are listed here for completeness.
956965

957966
See :py:class:`typing.ContextManager`. In ``typing`` since 3.5.4.
958967

968+
.. versionchanged:: 4.12.0
969+
970+
``AsyncContextManager`` now has an optional second parameter, defaulting to
971+
``Optional[bool]``, signifying the return type of the ``__aexit__`` method.
972+
959973
.. class:: Coroutine
960974

961975
See :py:class:`typing.Coroutine`. In ``typing`` since 3.5.3.
@@ -996,6 +1010,11 @@ versions of Python. They are listed here for completeness.
9961010

9971011
.. versionadded:: 4.7.0
9981012

1013+
.. versionchanged:: 4.12.0
1014+
1015+
The second type and third type parameters are now optional
1016+
(they both default to ``None``).
1017+
9991018
.. class:: Generic
10001019

10011020
See :py:class:`typing.Generic`.

src/test_typing_extensions.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
from typing_extensions import Doc, NoDefault
4242
from _typed_dict_test_helper import Foo, FooGeneric, VeryAnnotated
4343

44+
NoneType = type(None)
45+
4446
# Flags used to mark tests that only apply after a specific
4547
# version of the typing module.
4648
TYPING_3_9_0 = sys.version_info[:3] >= (3, 9, 0)
@@ -1626,6 +1628,17 @@ async def g(): yield 0
16261628
self.assertNotIsInstance(type(g), G)
16271629
self.assertNotIsInstance(g, G)
16281630

1631+
def test_generator_default(self):
1632+
g1 = typing_extensions.Generator[int]
1633+
g2 = typing_extensions.Generator[int, None, None]
1634+
self.assertEqual(get_args(g1), (int, type(None), type(None)))
1635+
self.assertEqual(get_args(g1), get_args(g2))
1636+
1637+
g3 = typing_extensions.Generator[int, float]
1638+
g4 = typing_extensions.Generator[int, float, None]
1639+
self.assertEqual(get_args(g3), (int, float, type(None)))
1640+
self.assertEqual(get_args(g3), get_args(g4))
1641+
16291642

16301643
class OtherABCTests(BaseTestCase):
16311644

@@ -1638,6 +1651,12 @@ def manager():
16381651
self.assertIsInstance(cm, typing_extensions.ContextManager)
16391652
self.assertNotIsInstance(42, typing_extensions.ContextManager)
16401653

1654+
def test_contextmanager_type_params(self):
1655+
cm1 = typing_extensions.ContextManager[int]
1656+
self.assertEqual(get_args(cm1), (int, typing.Optional[bool]))
1657+
cm2 = typing_extensions.ContextManager[int, None]
1658+
self.assertEqual(get_args(cm2), (int, NoneType))
1659+
16411660
def test_async_contextmanager(self):
16421661
class NotACM:
16431662
pass
@@ -1649,11 +1668,20 @@ def manager():
16491668

16501669
cm = manager()
16511670
self.assertNotIsInstance(cm, typing_extensions.AsyncContextManager)
1652-
self.assertEqual(typing_extensions.AsyncContextManager[int].__args__, (int,))
1671+
self.assertEqual(
1672+
typing_extensions.AsyncContextManager[int].__args__,
1673+
(int, typing.Optional[bool])
1674+
)
16531675
with self.assertRaises(TypeError):
16541676
isinstance(42, typing_extensions.AsyncContextManager[int])
16551677
with self.assertRaises(TypeError):
1656-
typing_extensions.AsyncContextManager[int, str]
1678+
typing_extensions.AsyncContextManager[int, str, float]
1679+
1680+
def test_asynccontextmanager_type_params(self):
1681+
cm1 = typing_extensions.AsyncContextManager[int]
1682+
self.assertEqual(get_args(cm1), (int, typing.Optional[bool]))
1683+
cm2 = typing_extensions.AsyncContextManager[int, None]
1684+
self.assertEqual(get_args(cm2), (int, NoneType))
16571685

16581686

16591687
class TypeTests(BaseTestCase):
@@ -5533,28 +5561,25 @@ def test_all_names_in___all__(self):
55335561
self.assertLessEqual(exclude, actual_names)
55345562

55355563
def te 10000 st_typing_extensions_defers_when_possible(self):
5536-
exclude = {
5537-
'dataclass_transform',
5538-
'overload',
5539-
'ParamSpec',
5540-
'TypeVar',
5541-
'TypeVarTuple',
5542-
'get_type_hints',
5543-
}
5564+
exclude = set()
55445565
if sys.version_info < (3, 10):
55455566
exclude |= {'get_args', 'get_origin'}
55465567
if sys.version_info < (3, 10, 1):
55475568
exclude |= {"Literal"}
55485569
if sys.version_info < (3, 11):
5549-
exclude |= {'final', 'Any', 'NewType'}
5570+
exclude |= {'final', 'Any', 'NewType', 'overload'}
55505571
if sys.version_info < (3, 12):
55515572
exclude |= {
55525573
'SupportsAbs', 'SupportsBytes',
55535574
'SupportsComplex', 'SupportsFloat', 'SupportsIndex', 'SupportsInt',
5554-
'SupportsRound', 'Unpack',
5575+
'SupportsRound', 'Unpack', 'dataclass_transform',
55555576
}
55565577
if sys.version_info < (3, 13):
5557-
exclude |= {'NamedTuple', 'Protocol', 'runtime_checkable'}
5578+
exclude |= {
5579+
'NamedTuple', 'Protocol', 'runtime_checkable', 'Generator',
5580+
'AsyncGenerator', 'ContextManager', 'AsyncContextManager',
5581+
'ParamSpec', 'TypeVar', 'TypeVarTuple', 'get_type_hints',
5582+
}
55585583
if not typing_extensions._PEP_728_IMPLEMENTED:
55595584
exclude |= {'TypedDict', 'is_typeddict'}
55605585
for item in typing_extensions.__all__:

src/typing_extensions.py

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import abc
22
import collections
33
import collections.abc
4+
import contextlib
45
import functools
56
import inspect
67
import operator
@@ -408,17 +409,96 @@ def clear_overloads():
408409
AsyncIterable = typing.AsyncIterable
409410
AsyncIterator = typing.AsyncIterator
410411
Deque = typing.Deque
411-
ContextManager = typing.ContextManager
412-
AsyncContextManager = typing.AsyncContextManager
413412
DefaultDict = typing.DefaultDict
414413
OrderedDict = typing.OrderedDict
415414
Counter = typing.Counter
416415
ChainMap = typing.ChainMap
417-
AsyncGenerator = typing.AsyncGenerator
418416
Text = typing.Text
419417
TYPE_CHECKING = typing.TYPE_CHECKING
420418

421419

420+
if sys.version_info >= (3, 13, 0, "beta"):
421+
from typing import ContextManager, AsyncContextManager, Generator, AsyncGenerator
422+
else:
423+
def _is_dunder(attr):
424+
return attr.startswith('__') and attr.endswith('__')
425+
426+
# Python <3.9 doesn't have typing._SpecialGenericAlias
427+
_special_generic_alias_base = getattr(
428+
typing, "_SpecialGenericAlias", typing._GenericAlias
429+
)
430+
431+
class _SpecialGenericAlias(_special_generic_alias_base, _root=True):
432+
def __init__(self, origin, nparams, *, inst=True, name=None, defaults=()):
433+
if _special_generic_alias_base is typing._GenericAlias:
434+
# Python <3.9
435+
self.__origin__ = origin
436+
self._nparams = nparams
437+
super().__init__(origin, nparams, special=True, inst=inst, name=name)
438+
else:
439+
# Python >= 3.9
440+
super().__init__(origin, nparams, inst=inst, name=name)
441+
self._defaults = defaults
442+
443+
def __setattr__(self, attr, val):
444+
allowed_attrs = {'_name', '_inst', '_nparams', '_defaults'}
445+
if _special_generic_alias_base is typing._GenericAlias:
446+
# Python <3.9
447+
allowed_attrs.add("__origin__")
448+
if _is_dunder(attr) or attr in allowed_attrs:
449+
object.__setattr__(self, attr, val)
450+
else:
451+
setattr(self.__origin__, attr, val)
452+
453+
@typing._tp_cache
454+
def __getitem__(self, params):
455+
if not isinstance(params, tuple):
456+
params = (params,)
457+
msg = "Parameters to generic types must be types."
458+
params = tuple(typing._type_check(p, msg) for p in params)
459+
if (
460+
self._defaults
461+
and len(params) < self._nparams
462+
and len(params) + len(self._defaults) >= self._nparams
463+
):
464+
params = (*params, *self._defaults[len(params) - self._nparams:])
465+
actual_len = len(params)
466+
467+
if actual_len != self._nparams:
468+
if self._defaults:
469+
expected = f"at least {self._nparams - len(self._defaults)}"
470+
else:
471+
expected = str(self._nparams)
472+
if not self._nparams:
473+
raise TypeError(f"{self} is not a generic class")
474+
raise TypeError(
475+
f"Too {'many' if actual_len > self._nparams else 'few'}"
476+
f" arguments for {self};"
477+
f" actual {actual_len}, expected {expected}"
478+
)
479+
return self.copy_with(params)
480+
481+
_NoneType = type(None)
482+
Generator = _SpecialGenericAlias(
483+
collections.abc.Generator, 3, defaults=(_NoneType, _NoneType)
484+
)
485+
AsyncGenerator = _SpecialGenericAlias(
486+
collections.abc.AsyncGenerator, 2, defaults=(_NoneType,)
487+
)
488+
ContextManager = _SpecialGenericAlias(
489+
contextlib.AbstractContextManager,
490+
2,
491+
name="ContextManager",
492+
defaults=(typing.Optional[bool],)
493+
)
494+
AsyncContextManager = _SpecialGenericAlias(
495+
contextlib.AbstractAsyncContextManager,
496+
2,
497+
name="AsyncContextManager",
498+
defaults=(typing.Optional[bool],)
499+
)
500+
501+
422502
_PROTO_ALLOWLIST = {
423503
'collections.abc': [
424504
'Callable', 'Awaitable', 'Iterable', 'Iterator', 'AsyncIterable',
@@ -3344,7 +3424,6 @@ def __eq__(self, other: object) -> bool:
33443424
Dict = typing.Dict
33453425
ForwardRef = typing.ForwardRef
33463426
FrozenSet = typing.FrozenSet
3347-
Generator = typing.Generator
33483427
Generic = typing.Generic
33493428
Hashable = typing.Hashable
33503429
IO = typing.IO

0 commit comments

Comments
 (0)
0