8000 Properly support type variable defaults · pydantic/pydantic@e362df2 · GitHub
[go: up one dir, main page]

Skip to content

Commit e362df2

Browse files
committed
Properly support type variable defaults
If a type argument isn't provided, use the default value.
1 parent 14d14b0 commit e362df2

File tree

4 files changed

+74
-19
lines changed

4 files changed

+74
-19
lines changed

pydantic/_internal/_generate_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1989,7 +1989,7 @@ def _unsubstituted_typevar_schema(self, typevar: typing.TypeVar) -> core_schema.
19891989
try:
19901990
has_default = typevar.has_default()
19911991
except AttributeError:
1992-
# Happens if using `typing.TypeVar` on Python < 3.13
1992+
# Happens if using `typing.TypeVar` (and not `typing_extensions`) on Python < 3.13
19931993
pass
19941994
else:
19951995
if has_default:

pydantic/_internal/_generics.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from collections.abc import Iterator, Mapping, MutableMapping
88
from contextlib import contextmanager
99
from contextvars import ContextVar
10+
from itertools import zip_longest
1011
from types import prepare_class
1112
from typing import TYPE_CHECKING, Any, TypeVar
1213
from weakref import WeakValueDictionary
@@ -382,21 +383,55 @@ def has_instance_in_type(type_: Any, isinstance_target: Any) -> bool:
382383
return False
383384

384385

385-
def check_parameters_count(cls: type[BaseModel], parameters: tuple[Any, ...]) -> None:
386-
"""Check the generic model parameters count is equal.
386+
def map_generic_model_arguments(cls: type[BaseModel], args: tuple[Any, ...]) -> dict[TypeVar, Any]:
387+
"""Return a mapping between the arguments of a generic model and the provided arguments during parametrization.
387388

388-
Args:
389-
cls: The generic model.
390-
parameters: A tuple of passed parameters to the generic model.
389+
If the number of arguments does not match the parameters (e.g. if providing too few or too many arguments),
390+
a `TypeError` is raised.
391391

392-
Raises:
393-
TypeError: If the passed parameters count is not equal to generic model parameters count.
392+
Example:
393+
```python {test="skip" lint="skip"}
394+
class Model[T, U, V = int](BaseModel): ...
395+
396+
map_generic_model_arguments(Model, (str, bytes))
397+
#> {T: str, U: bytes, V: int}
398+
399+
map_generic_model_arguments(Model, (str,))
400+
#> TypeError: Too few arguments for <class '__main__.Model'>; actual 1, expected at least 2
401+
402+
map_generic_model_argumenst(Model, (str, bytes, int, complex))
403+
#> TypeError: Too many arguments for <class '__main__.Model'>; actual 4, expected 3
404+
```
405+
406+
Note:
407+
This function is analogous to the private `typing._check_generic_specialization` function.
394408
"""
395-
actual = len(parameters)
396-
expected = len(cls.__pydantic_generic_metadata__['parameters'])
397-
if actual != expected:
398-
description = 'many' if actual > expected else 'few'
399-
raise TypeError(f'Too {description} parameters for {cls}; actual {actual}, expected {expected}')
409+
parameters = cls.__pydantic_generic_metadata__['parameters']
410+
expected_len = len(parameters)
411+
typevars_map: dict[TypeVar, Any] = {}
412+
413+
_missing = object()
414+
for parameter, argument in zip_longest(parameters, args, fillvalue=_missing):
415+
if parameter is _missing:
416+
raise TypeError(f'Too many arguments for {cls}; actual {len(args)}, expected {expected_len}')
417+
418+
if argument is _missing:
419+
param = typing.cast(TypeVar, parameter)
420+
try:
421+
has_default = param.has_default()
422+
except AttributeError:
423+
# Happens if using `typing.TypeVar` (and not `typing_extensions`) on Python < 3.13.
424+
has_default = False
425+
if has_default:
426+
typevars_map[param] = param.__default__
427+
else:
428+
expected_len -= sum(hasattr(p, 'has_default') and p.has_default() for p in parameters)
429+
raise TypeError(f'Too few arguments for {cls}; actual {len(args)}, expected at least {expected_len}')
430+
else:
431+
param = typing.cast(TypeVar, parameter)
432+
typevars_map[param] = argument
433+
434+
return typevars_map
400435

401436

402437
_generic_recursion_cache: ContextVar[set[str] | None] = ContextVar('_generic_recursion_cache', default=None)

pydantic/main.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -766,12 +766,10 @@ def __class_getitem__(
766766

767767
if not isinstance(typevar_values, tuple):
768768
typevar_values = (typevar_values,)
769-
_generics.check_parameters_count(cls, typevar_values)
770769

771-
# Build map from generic typevars to passed params
772-
typevars_map: dict[TypeVar, type[Any]] = dict(
773-
zip(cls.__pydantic_generic_metadata__['parameters'], typevar_values)
774-
)
770+
typevars_map = _generics.map_generic_model_arguments(cls, typevar_values)
771+
# In case type variables have defaults and a type wasn't provided, use the defaults:
772+
typevar_values = tuple(v for v in typevars_map.values())
775773

776774
if _utils.all_identical(typevars_map.keys(), typevars_map.values()) and typevars_map:
777775
submodel = cls # if arguments are equal to parameters it's the same object

tests/test_generics.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,28 @@ class Model(BaseModel, Generic[T, S]):
306306
)
307307

308308

309+
def test_arguments_count_validation() -> None:
310+
T = TypeVar('T')
311+
S = TypeVar('S')
312+
U = TypingExtensionsTypeVar('U', default=int)
313+
314+
class Model(BaseModel, Generic[T, S, U]):
315+
t: T
316+
s: S
317+
u: U
318+
319+
model_repr = repr(Model)
320+
321+
with pytest.raises(TypeError, match=f'Too many arguments for {model_repr}; actual 4, expected 3'):
322+
Model[int, int, int, int]
323+
324+
with pytest.raises(TypeError, match=f'Too few arguments for {model_repr}; actual 1, expected at least 2'):
325+
Model[int]
326+
327+
assert Model[int, int].__pydantic_generic_metadata__['args'] == (int, int, int)
328+
assert Model[int, int, str].__pydantic_generic_metadata__['args'] == (int, int, str)
329+
330+
309331
def test_cover_cache(clean_cache):
310332
cache_size = len(_GENERIC_TYPES_CACHE)
311333
T = TypeVar('T')
@@ -2778,7 +2800,7 @@ class MyErrorDetails(ErrorDetails):
27782800
ids=['default', 'constraint'],
27792801
)
27802802
def test_serialize_unsubstituted_typevars_variants(
2781-
type_var: type[BaseModel],
2803+
type_var: TypeVar,
27822804
) -> None:
27832805
class ErrorDetails(BaseModel):
27842806
foo: str

0 commit comments

Comments
 (0)
0