8000 Backport parameter defaults for `(Async)Generator` and `(Async)ContextManager` by AlexWaygood · Pull Request #382 · python/typing_extensions · GitHub
[go: up one dir, main page]

Skip to content

Backport parameter defaults for (Async)Generator and (Async)ContextManager 8000 #382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
8000
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@
- At runtime, `assert_never` now includes the repr of the argument
in the `AssertionError`. Patch by Hashem, backporting of the original
fix https://github.com/python/cpython/pull/91720 by Jelle Zijlstra.
- The second and third parameters of `typing_extensions.Generator`,
and the second parameter of `typing_extensions.AsyncGenerator`,
now default to `None`. This matches the behaviour of `typing.Generator`
and `typing.AsyncGenerator` on Python 3.13+.
- `typing.ContextManager` and `typing.AsyncContextManager` now have an
optional second parameter, which defaults to `Optional[bool]`. The new
parameter signifies the return type of the `__(a)exit__` method,
matching `typing.ContextManager` and `typing.AsyncContextManager` on
Python 3.13+.

# Release 4.11.0 (April 5, 2024)

Expand Down
23 changes: 21 additions & 2 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -885,8 +885,8 @@ Annotation metadata
Pure aliases
~~~~~~~~~~~~

These are simply re-exported from the :mod:`typing` module on all supported
versions of Python. They are listed here for completeness.
Most of these are simply re-exported from the :mod:`typing` module on all supported
versions of Python, but all are listed here for completeness.

.. class:: AbstractSet

Expand All @@ -904,10 +904,19 @@ versions of Python. They are listed here for completeness.

See :py:class:`typing.AsyncContextManager`. In ``typing`` since 3.5.4 and 3.6.2.

.. versionchanged:: 4.12.0

``AsyncContextManager`` now has an optional second parameter, defaulting to
``Optional[bool]``, signifying the return type of the ``__aexit__`` method.

.. class:: AsyncGenerator

See :py:class:`typing.AsyncGenerator`. In ``typing`` since 3.6.1.

.. versionchanged:: 4.12.0

The second type parameter is now optional (it defaults to ``None``).

.. class:: AsyncIterable

See :py:class:`typing.AsyncIterable`. In ``typing`` since 3.5.2.
Expand Down Expand Up @@ -956,6 +965,11 @@ versions of Python. They are listed here for completeness.

See :py:class:`typing.ContextManager`. In ``typing`` since 3.5.4.

.. versionchanged:: 4.12.0

``AsyncContextManager`` now has an optional second parameter, defaulting to
``Optional[bool]``, signifying the return type of the ``__aexit__`` method.

.. class:: Coroutine

See :py:class:`typing.Coroutine`. In ``typing`` since 3.5.3.
Expand Down Expand Up @@ -996,6 +1010,11 @@ versions of Python. They are listed here for completeness.

.. versionadded:: 4.7.0

.. versionchanged:: 4.12.0

The second type and third type parameters are now optional
(they both default to ``None``).

.. class:: Generic

See :py:class:`typing.Generic`.
Expand Down
51 changes: 38 additions & 10000 13 deletions src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from typing_extensions import Doc, NoDefault
from _typed_dict_test_helper import Foo, FooGeneric, VeryAnnotated

NoneType = type(None)

# Flags used to mark tests that only apply after a specific
# version of the typing module.
TYPING_3_9_0 = sys.version_info[:3] >= (3, 9, 0)
Expand Down Expand Up @@ -1626,6 +1628,17 @@ async def g(): yield 0
self.assertNotIsInstance(type(g), G)
self.assertNotIsInstance(g, G)

def test_generator_default(self):
g1 = typing_extensions.Generator[int]
g2 = typing_extensions.Generator[int, None, None]
self.assertEqual(get_args(g1), (int, type(None), type(None)))
self.assertEqual(get_args(g1), get_args(g2))

g3 = typing_extensions.Generator[int, float]
g4 = typing_extensions.Generator[int, float, None]
self.assertEqual(get_args(g3), (int, float, type(None)))
self.assertEqual(get_args(g3), get_args(g4))


class OtherABCTests(BaseTestCase):

Expand All @@ -1638,6 +1651,12 @@ def manager():
self.assertIsInstance(cm, typing_extensions.ContextManager)
self.assertNotIsInstance(42, typing_extensions.ContextManager)

def test_contextmanager_type_params(self):
cm1 = typing_extensions.ContextManager[int]
self.assertEqual(get_args(cm1), (int, typing.Optional[bool]))
cm2 = typing_extensions.ContextManager[int, None]
self.assertEqual(get_args(cm2), (int, NoneType))

def test_async_contextmanager(self):
class NotACM:
pass
Expand All @@ -1649,11 +1668,20 @@ def manager():

cm = manager()
self.assertNotIsInstance(cm, typing_extensions.AsyncContextManager)
self.assertEqual(typing_extensions.AsyncContextManager[int].__args__, (int,))
self.assertEqual(
typing_extensions.AsyncContextManager[int].__args__,
(int, typing.Optional[bool])
)
with self.assertRaises(TypeError):
isinstance(42, typing_extensions.AsyncContextManager[int])
with self.assertRaises(TypeError):
typing_extensions.AsyncContextManager[int, str]
typing_extensions.AsyncContextManager[int, str, float]

def test_asynccontextmanager_type_params(self):
cm1 = typing_extensions.AsyncContextManager[int]
self.assertEqual(get_args(cm1), (int, typing.Optional[bool]))
cm2 = typing_extensions.AsyncContextManager[int, None]
self.assertEqual(get_args(cm2), (int, NoneType))


class TypeTests(BaseTestCase):
Expand Down Expand Up @@ -5533,28 +5561,25 @@ def test_all_names_in___all__(self):
self.assertLessEqual(exclude, actual_names)

def test_typing_extensions_defers_when_possible(self):
exclude = {
'dataclass_transform',
'overload',
'ParamSpec',
'TypeVar',
'TypeVarTuple',
'get_type_hints',
}
exclude = set()
if sys.version_info < (3, 10):
exclude |= {'get_args', 'get_origin'}
if sys.version_info < (3, 10, 1):
exclude |= {"Literal"}
if sys.version_info < (3, 11):
exclude |= {'final', 'Any', 'NewType'}
exclude |= {'final', 'Any', 'NewType', 'overload'}
if sys.version_info < (3, 12):
exclude |= {
'SupportsAbs', 'SupportsBytes',
'SupportsComplex', 'SupportsFloat', 'SupportsIndex', 'SupportsInt',
'SupportsRound', 'Unpack',
'SupportsRound', 'Unpack', 'dataclass_transform',
}
if sys.version_info < (3, 13):
exclude |= {'NamedTuple', 'Protocol', 'runtime_checkable'}
exclude |= {
'NamedTuple', 'Protocol', 'runtime_checkable', 'Generator',
'AsyncGenerator', 'ContextManager', 'AsyncContextManager',
'ParamSpec', 'TypeVar', 'TypeVarTuple', 'get_type_hints',
}
if not typing_extensions._PEP_728_IMPLEMENTED:
exclude |= {'TypedDict', 'is_typeddict'}
for item in typing_extensions.__all__:
Expand Down
87 changes: 83 additions & 4 deletions src/typing_extensions.py
23D3
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import collections
import collections.abc
import contextlib
import functools
import inspect
import operator
Expand Down Expand Up @@ -408,17 +409,96 @@ def clear_overloads():
AsyncIterable = typing.AsyncIterable
AsyncIterator = typing.AsyncIterator
Deque = typing.Deque
ContextManager = typing.ContextManager
AsyncContextManager = typing.AsyncContextManager
DefaultDict = typing.DefaultDict
OrderedDict = typing.OrderedDict
Counter = typing.Counter
ChainMap = typing.ChainMap
AsyncGenerator = typing.AsyncGenerator
Text = typing.Text
TYPE_CHECKING = typing.TYPE_CHECKING


if sys.version_info >= (3, 13, 0, "beta"):
from typing import ContextManager, AsyncContextManager, Generator, AsyncGenerator
else:
def _is_dunder(attr):
return attr.startswith('__') and attr.endswith('__')

# Python <3.9 doesn't have typing._SpecialGenericAlias
_special_generic_alias_base = getattr(
typing, "_SpecialGenericAlias", typing._GenericAlias
)

class _SpecialGenericAlias(_special_generic_alias_base, _root=True):
def __init__(self, origin, nparams, *, inst=True, name=None, defaults=()):
if _special_generic_alias_base is typing._GenericAlias:
# Python <3.9
self.__origin__ = origin
self._nparams = nparams
super().__init__(origin, nparams, special=True, inst=inst, name=name)
else:
# Python >= 3.9
super().__init__(origin, nparams, inst=inst, name=name)
self._defaults = defaults

def __setattr__(self, attr, val):
allowed_attrs = {'_name', '_inst', '_nparams', '_defaults'}
if _special_generic_alias_base is typing._GenericAlias:
# Python <3.9
allowed_attrs.add("__origin__")
if _is_dunder(attr) or attr in allowed_attrs:
object.__setattr__(self, attr, val)
else:
setattr(self.__origin__, attr, val)

@typing._tp_cache
def __getitem__(self, params):
if not isinstance(params, tuple):
params = (params,)
msg = "Parameters to generic types must be types."
params = tuple(typing._type_check(p, msg) for p in params)
if (
self._defaults
and len(params) < self._nparams
and len(params) + len(self._defaults) >= self._nparams
):
params = (*params, *self._defaults[len(params) - self._nparams:])
actual_len = len(params)

if actual_len != self._nparams:
if self._defaults:
expected = f"at least {self._nparams - len(self._defaults)}"
else:
expected = str(self._nparams)
if not self._nparams:
raise TypeError(f"{self} is not a generic class")
raise TypeError(
f"Too {'many' if actual_len > self._nparams else 'few'}"
f" arguments for {self};"
f" actual {actual_len}, expected {expected}"
)
return self.copy_with(params)

_NoneType = type(None)
Generator = _SpecialGenericAlias(
collections.abc.Generator, 3, defaults=(_NoneType, _NoneType)
)
AsyncGenerator = _SpecialGenericAlias(
collections.abc.AsyncGenerator, 2, defaults=(_NoneType,)
)
ContextManager = _SpecialGenericAlias(
contextlib.AbstractContextManager,
2,
name="ContextManager",
defaults=(typing.Optional[bool],)
)
AsyncContextManager = _SpecialGenericAlias(
contextlib.AbstractAsyncContextManager,
2,
name="AsyncContextManager",
defaults=(typing.Optional[bool],)
)


_PROTO_ALLOWLIST = {
'collections.abc': [
'Callable', 'Awaitable', 'Iterable', 'Iterator', 'AsyncIterable',
Expand Down Expand Up @@ -3344,7 +3424,6 @@ def __eq__(self, other: object) -> bool:
Dict = typing.Dict
ForwardRef = typing.ForwardRef
FrozenSet = typing.FrozenSet
Generator = typing.Generator
Generic = typing.Generic
Hashable = typing.Hashable
IO = typing.IO
Expand Down
Loading
0