8000 Properly support type variable defaults by Viicos · Pull Request #11332 · pydantic/pydantic · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pydantic/_internal/_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1989,7 +1989,7 @@ def _unsubstituted_typevar_schema(self, typevar: typing.TypeVar) -> core_schema.
try:
has_default = typevar.has_default()
except AttributeError:
# Happens if using `typing.TypeVar` on Python < 3.13
# Happens if using `typing.TypeVar` (and not `typing_extensions`) on Python < 3.13
pass
else:
if has_default:
Expand Down
59 changes: 47 additions & 12 deletions pydantic/_internal/_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import Iterator, Mapping, MutableMapping
from contextlib import contextmanager
from contextvars import ContextVar
from itertools import zip_longest
from types import prepare_class
from typing import TYPE_CHECKING, Any, TypeVar
from weakref import WeakValueDictionary
Expand Down Expand Up @@ -382,21 +383,55 @@ def has_instance_in_type(type_: Any, isinstance_target: Any) -> bool:
return False


def check_parameters_count(cls: type[BaseModel], parameters: tuple[Any, ...]) -> None:
"""Check the generic model parameters count is equal.
def map_generic_model_arguments(cls: type[BaseModel], args: tuple[Any, ...]) -> dict[TypeVar, Any]:
"""Return a mapping between the arguments of a generic model and the provided arguments during parametrization.

Args:
cls: The generic model.
parameters: A tuple of passed parameters to the generic model.
If the number of arguments does not match the parameters (e.g. if providing too few or too many arguments),
a `TypeError` is raised.

Raises:
TypeError: If the passed parameters count is not equal to generic model parameters count.
Example:
```python {test="skip" lint="skip"}
class Model[T, U, V = int](BaseModel): ...

map_generic_model_arguments(Model, (str, bytes))
#> {T: str, U: bytes, V: int}

map_generic_model_arguments(Model, (str,))
#> TypeError: Too few arguments for <class '__main__.Model'>; actual 1, expected at least 2

map_generic_model_argumenst(Model, (str, bytes, int, complex))
#> TypeError: Too many arguments for <class '__main__.Model'>; actual 4, expected 3
```

Note:
This function is analogous to the private `typing._check_generic_specialization` function.
"""
actual = len(parameters)
expected = len(cls.__pydantic_generic_metadata__['parameters'])
if actual != expected:
description = 'many' if actual > expected else 'few'
raise TypeError(f'Too {description} parameters for {cls}; actual {actual}, expected {expected}')
parameters = cls.__pydantic_generic_metadata__['parameters']
expected_len = len(parameters)
typevars_map: dict[TypeVar, Any] = {}

_missing = object()
for parameter, argument in zip_longest(parameters, args, fillvalue=_missing):
if parameter is _missing:
raise TypeError(f'Too many arguments for {cls}; actual {len(args)}, expected {expected_len}')

if argument is _missing:
param = typing.cast(TypeVar, parameter)
try:
has_default = param.has_default()
except AttributeError:
# Happens if using `typing.TypeVar` (and not `typing_extensions`) on Python < 3.13.
has_default = False
if has_default:
typevars_map[param] = param.__default__
else:
expected_len -= sum(hasattr(p, 'has_default') and p.has_default() for p in parameters)
raise TypeError(f'Too few arguments for {cls}; actual {len(args)}, expected at least {expected_len}')
else:
param = typing.cast(TypeVar, parameter)
typevars_map[param] = argument

return typevars_map


_generic_recursion_cache: ContextVar[set[str] | None] = ContextVar('_generic_recursion_cache', default=None)
Expand Down
8 changes: 3 additions & 5 deletions pydantic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,12 +766,10 @@ def __class_getitem__(

if not isinstance(typevar_values, tuple):
typevar_values = (typevar_values,)
_generics.check_parameters_count(cls, typevar_values)

# Build map from generic typevars to passed params
typevars_map: dict[TypeVar, type[Any]] = dict(
zip(cls.__pydantic_generic_metadata__['parameters'], typevar_values)
)
typevars_map = _generics.map_generic_model_arguments(cls, typevar_values)
# In case type variables have defaults and a type wasn't provided, use the defaults:
typevar_values = tuple(v for v in typevars_map.values())

if _utils.all_identical(typevars_map.keys(), typevars_map.values()) and typevars_map:
submodel = cls # if arguments are equal to parameters it's the same object
Expand Down
24 changes: 23 additions & 1 deletion tests/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,28 @@ class Model(BaseModel, Generic[T, S]):
)


def test_arguments_count_validation() -> None:
T = TypeVar('T')
S = TypeVar('S')
U = TypingExtensionsTypeVar('U', default=int)

class Model(BaseModel, Generic[T, S, U]):
t: T
s: S
u: U

model_repr = repr(Model)

with pytest.raises(TypeError, match=f'Too many arguments for {model_repr}; actual 4, expected 3'):
Model[int, int, int, int]

with pytest.raises(TypeError, match=f'Too few arguments for {model_repr}; actual 1, expected at least 2'):
Model[int]

assert Model[int, int].__pydantic_generic_metadata__['args'] == (int, int, int)
assert Model[int, int, str].__pydantic_generic_metadata__['args'] == (int, int, str)


def test_cover_cache(clean_cache):
cache_size = len(_GENERIC_TYPES_CACHE)
T = TypeVar('T')
Expand Down Expand Up @@ -2778,7 +2800,7 @@ class MyErrorDetails(ErrorDetails):
ids=['default', 'constraint'],
)
def test_serialize_unsubstituted_typevars_variants(
type_var: type[BaseModel],
type_var: TypeVar,
) -> None:
class ErrorDetails(BaseModel):
foo: str
Expand Down
Loading
0