|
7 | 7 | from collections.abc import Iterator, Mapping, MutableMapping |
8 | 8 | from contextlib import contextmanager |
9 | 9 | from contextvars import ContextVar |
| 10 | +from itertools import zip_longest |
10 | 11 | from types import prepare_class |
11 | 12 | from typing import TYPE_CHECKING, Any, TypeVar |
12 | 13 | from weakref import WeakValueDictionary |
@@ -382,21 +383,55 @@ def has_instance_in_type(type_: Any, isinstance_target: Any) -> bool: |
382 | 383 | return False |
383 | 384 |
|
384 | 385 |
|
385 | | -def check_parameters_count(cls: type[BaseModel], parameters: tuple[Any, ...]) -> None: |
386 | | - """Check the generic model parameters count is equal. |
| 386 | +def map_generic_model_arguments(cls: type[BaseModel], args: tuple[Any, ...]) -> dict[TypeVar, Any]: |
| 387 | + """Return a mapping between the arguments of a generic model and the provided arguments during parametrization. |
387 | 388 |
|
388 | | - Args: |
389 | | - cls: The generic model. |
390 | | - parameters: A tuple of passed parameters to the generic model. |
| 389 | + If the number of arguments does not match the parameters (e.g. if providing too few or too many arguments), |
| 390 | + a `TypeError` is raised. |
391 | 391 |
|
392 | | - Raises: |
393 | | - TypeError: If the passed parameters count is not equal to generic model parameters count. |
| 392 | + Example: |
| 393 | + ```python {test="skip" lint="skip"} |
| 394 | + class Model[T, U, V = int](BaseModel): ... |
| 395 | + |
| 396 | + map_generic_model_arguments(Model, (str, bytes)) |
| 397 | + #> {T: str, U: bytes, V: int} |
| 398 | + |
| 399 | + map_generic_model_arguments(Model, (str,)) |
| 400 | + #> TypeError: Too few arguments for <class '__main__.Model'>; actual 1, expected at least 2 |
| 401 | + |
| 402 | + map_generic_model_argumenst(Model, (str, bytes, int, complex)) |
| 403 | + #> TypeError: Too many arguments for <class '__main__.Model'>; actual 4, expected 3 |
| 404 | + ``` |
| 405 | + |
| 406 | + Note: |
| 407 | + This function is analogous to the private `typing._check_generic_specialization` function. |
394 | 408 | """ |
395 | | - actual = len(parameters) |
396 | | - expected = len(cls.__pydantic_generic_metadata__['parameters']) |
397 | | - if actual != expected: |
398 | | - description = 'many' if actual > expected else 'few' |
399 | | - raise TypeError(f'Too {description} parameters for {cls}; actual {actual}, expected {expected}') |
| 409 | + parameters = cls.__pydantic_generic_metadata__['parameters'] |
| 410 | + expected_len = len(parameters) |
| 411 | + typevars_map: dict[TypeVar, Any] = {} |
| 412 | + |
| 413 | + _missing = object() |
| 414 | + for parameter, argument in zip_longest(parameters, args, fillvalue=_missing): |
| 415 | + if parameter is _missing: |
| 416 | + raise TypeError(f'Too many arguments for {cls}; actual {len(args)}, expected {expected_len}') |
| 417 | + |
| 418 | + if argument is _missing: |
| 419 | + param = typing.cast(TypeVar, parameter) |
| 420 | + try: |
| 421 | + has_default = param.has_default() |
| 422 | + except AttributeError: |
| 423 | + # Happens if using `typing.TypeVar` (and not `typing_extensions`) on Python < 3.13. |
| 424 | + has_default = False |
| 425 | + if has_default: |
| 426 | + typevars_map[param] = param.__default__ |
| 427 | + else: |
| 428 | + expected_len -= sum(hasattr(p, 'has_default') and p.has_default() for p in parameters) |
| 429 | + raise TypeError(f'Too few arguments for {cls}; actual {len(args)}, expected at least {expected_len}') |
| 430 | + else: |
| 431 | + param = typing.cast(TypeVar, parameter) |
| 432 | + typevars_map[param] = argument |
| 433 | + |
| 434 | + return typevars_map |
400 | 435 |
|
401 | 436 |
|
402 | 437 | _generic_recursion_cache: ContextVar[set[str] | None] = ContextVar('_generic_recursion_cache', default=None) |
|
0 commit comments