8000 Eagerly evaluate type aliases if annotated · pydantic/pydantic@da649cd · GitHub
[go: up one dir, main page]

Skip to content

Commit da649cd

Browse files
committed
Eagerly evaluate type aliases if annotated
1 parent c393317 commit da649cd

File tree

1 file changed

+131
-50
lines changed

1 file changed

+131
-50
lines changed

pydantic/fields.py

Lines changed: 131 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,90 @@
4646
Deprecated: TypeAlias = deprecated
4747

4848

49+
def _unpack_annotated(annotation) -> tuple[Any, list[Any]]:
50+
"""Unpack the annotation if it is wrapped with the `Annotated` type qualifier.
51+
52+
This function also unpacks PEP 695 type aliases if necessary (and also generic
53+
aliases with a PEP 695 type alias origin). However, it does *not* try to evaluate
54+
forward references, so users should make sure the type alias' `__value__` does not
55+
contain unresolvable forward references.
56+
57+
Example:
58+
```python {test="skip" lint="skip"}
59+
from typing import Annotated
60+
61+
type InnerList[T] = Annotated[list[T], 'meta_1']
62+
type MyList[T] = Annotated[InnerList[T], 'meta_2']
63+
type MyIntList = MyList[int]
64+
65+
_unpack_annotated(MyList)
66+
#> (list[T], ['meta_1', 'meta_2'])
67+
_unpack_annotated(MyList[int])
68+
#> (list[int], ['meta_1', 'meta_2'])
69+
_unpack_annotated(MyIntList)
70+
#> (list[int], ['meta_1', 'meta_2'])
71+
```
72+
73+
Returns:
74+
A two-tuple, the first element is the annotated type and the second element
75+
is a list containing the annotated metadata. If the annotation wasn't
76+
wrapped with `Annotated` in the first place, it is returned as is and the
77+
metadata list is empty.
78+
"""
79+
if _typing_extra.is_annotated(annotation):
80+
typ, *metadata = typing_extensions.get_args(annotation)
81+
# The annotated type might be a PEP 695 type alias, so we need to recursively
82+
# unpack it. Note that we could make an optimization here: the following next
83+
# call to `_unpack_annotated` could omit the `is_annotated` check, because Python
84+
# already flattens `Annotated[Annotated[<type>, ...], ...]` forms. However, we would
85+
# need to "re-enable" the check for further recursive calls.
86+
typ, sub_meta = _unpack_annotated(typ)
87+
metadata = sub_meta + metadata
88+
return typ, metadata
89+
elif _typing_extra.is_type_alias_type(annotation):
90+
try:
91+
value = annotation.__value__
92+
except NameError:
93+
# The type alias value contains an unresolvable reference. Note that even if it
94+
# resolves successfully, it might contain string annotations, and because of design
95+
# limitations we don't evaluate the type (we don't have access to a `NsResolver` instance).
96+
pass
97+
else:
98+
typ, metadata = _unpack_annotated(value)
99+
if metadata:
100+
# Having metadata means the type alias' `__value__` was an `Annotated` form
101+
# (or, recursively, a type alias to an `Annotated` form). It is important to
102+
# check for this as we don't want to unpack "normal" type aliases (e.g. `type MyInt = int`).
103+
return typ, metadata
104+
return annotation, []
105+
elif _typing_extra.is_generic_alias(annotation):
106+
# When parametrized, a PEP 695 type alias becomes a generic alias
107+
# (e.g. with `type MyList[T] = Annotated[list[T], ...]`, `MyList[int]`
108+
# is a generic alias).
109+
origin = typing_extensions.get_origin(annotation)
110+
if _typing_extra.is_type_alias_type(origin):
111+
try:
112+
value = origin.__value__
113+
except NameError:
114+
pass
115+
else:
116+
# While Python already handles type variable replacement for simple `Annotated` forms,
117+
# we need to manually apply the same logic for PEP 695 type aliases:
118+
# - With `MyList = Annotated[list[T], ...]`, `MyList[int] == Annotated[list[int], ...]`
119+
# - With `type MyList = Annotated[list[T], ...]`, `MyList[int].__value__ == Annotated[list[T], ...]`.
120+
value = _generics.replace_types(value, _generics.get_standard_typevars_map(annotation))
121+
typ, metadata = _unpack_annotated(value)
122+
if metadata:
123+
return typ, metadata
124+
return annotation, []
125+
126+
return annotation, []
127+
128+
49129
class _FromFieldInfoInputs(typing_extensions.TypedDict, total=False):
50130
"""This class exists solely to add type checking for the `**kwargs` in `FieldInfo.from_field`."""
51131

132+
# TODO PEP 747: use TypeForm:
52133
annotation: type[Any] | None
53134
default_factory: Callable[[], Any] | Callable[[dict[str, Any]], Any] | None
54135
alias: str | None
@@ -207,7 +288,7 @@ def __init__(self, **kwargs: Unpack[_FieldInfoInputs]) -> None:
207288
"""
208289
self._attributes_set = {k: v for k, v in kwargs.items() if v is not _Unset}
209290
kwargs = {k: _DefaultValues.get(k) if v is _Unset else v for k, v in kwargs.items()} # type: ignore
210-
self.annotation, annotation_metadata = self._extract_metadata(kwargs.get('annotation'))
291+
self.annotation = kwargs.get('annotation')
211292
self.evaluated = False
212293

213294
default = kwargs.pop('default', PydanticUndefined)
@@ -247,7 +328,7 @@ def __init__(self, **kwargs: Unpack[_FieldInfoInputs]) -> None:
247328
self.init_var = kwargs.pop('init_var', None)
248329
self.kw_only = kwargs.pop('kw_only', None)
249330

250-
self.metadata = self._collect_metadata(kwargs) + annotation_metadata # type: ignore
331+
self.metadata = self._collect_metadata(kwargs) # type: ignore
251332

252333
@staticmethod
253334
def from_field(default: Any = PydanticUndefined, **kwargs: Unpack[_FromFieldInfoInputs]) -> FieldInfo:
@@ -310,34 +391,50 @@ class MyModel(pydantic.BaseModel):
310391
Returns:
311392
An instance of the field metadata.
312393
"""
313-
final = False
314-
if _typing_extra.is_finalvar(annotation):
315-
final = True
316-
if annotation is not typing_extensions.Final:
394+
# 1. Check if the annotation is the `Final` type qualifier:
395+
final = _typing_extra.is_finalvar(annotation)
396+
if final:
397+
if _typing_extra.is_generic_alias(annotation):
398+
# The annotation is a parametrized `Final`, e.g. `Final[int]`.
399+
# In this case, `annotation` will be `int`:
317400
annotation = typing_extensions.get_args(annotation)[0]
318-
319-
if _typing_extra.is_annotated(annotation):
320-
first_arg, *extra_args = typing_extensions.get_args(annotation)
321-
if _typing_extra.is_finalvar(first_arg):
322-
final = True
323-
field_info_annotations = [a for a in extra_args if isinstance(a, FieldInfo)]
324-
field_info = FieldInfo.merge_field_infos(*field_info_annotations, annotation=first_arg)
401+
else:
402+
# The annotation is a bare `Final`. Use `Any` as a type annotation:
403+
return FieldInfo(annotation=Any, frozen=True) # pyright: ignore[reportArgumentType] (PEP 747)
404+
405+
# 2. Check if the annotation is an `Annotated` form.
406+
# In this case, `annotation` will be the annotated type:
407+
annotation, metadata = _unpack_annotated(annotation)
408+
409+
# 3. If we have metadata, `annotation` was the annotated type:
410+
if metadata:
411+
# 3.1. Check if the annotated type is the `Final` type qualifier.
412+
# (i.e. `Annotated[Final[...], ...]`). Note that we only do
413+
# so if `final` isn't `True` already, because we don't want to
414+
# support the invalid `Final[Annotated[Final, ...]]` form.
415+
if not final:
416+
final = _typing_extra.is_finalvar(annotation)
417+
if final and _typing_extra.is_generic_alias(annotation):
418+
annotation = typing_extensions.get_args(annotation)[0]
419+
420+
field_info_annotations = [a for a in metadata if isinstance(a, FieldInfo)]
421+
field_info = FieldInfo.merge_field_infos(*field_info_annotations, annotation=annotation)
325422
if field_info:
326423
new_field_info = copy(field_info)
327-
new_field_info.annotation = first_arg
424+
new_field_info.annotation = annotation
328425
new_field_info.frozen = final or field_info.frozen
329-
metadata: list[Any] = []
330-
for a in extra_args:
426+
field_metadata: list[Any] = []
427+
for a in metadata:
331428
if _typing_extra.is_deprecated_instance(a):
332429
new_field_info.deprecated = a.message
333430
elif not isinstance(a, FieldInfo):
334-
metadata.append(a)
431+
field_metadata.append(a)
335432
else:
336-
metadata.extend(a.metadata)
337-
new_field_info.metadata = metadata
433+
field_metadata.extend(a.metadata)
434+
new_field_info.metadata = field_metadata
338435
return new_field_info
339436

340-
return FieldInfo(annotation=annotation, frozen=final or None) # pyright: ignore[reportArgumentType]
437+
return FieldInfo(annotation=annotation, frozen=final or None) # pyright: ignore[reportArgumentType] (PEP 747)
341438

342439
@staticmethod
343440
def from_annotated_attribute(annotation: type[Any], default: Any) -> FieldInfo:
@@ -367,16 +464,16 @@ class MyModel(pydantic.BaseModel):
367464
if annotation is default:
368465
raise PydanticUserError(
369466
'Error when building FieldInfo from annotated attribute. '
370-
"Make sure you don't have any field name clashing with a type annotation ",
467+
"Make sure you don't have any field name clashing with a type annotation.",
371468
code='unevaluable-type-annotation',
372469
)
373470

374471
final = _typing_extra.is_finalvar(annotation)
375-
if final and annotation is not typing_extensions.Final:
472+
if final and _typing_extra.is_generic_alias(annotation):
376473
annotation = typing_extensions.get_args(annotation)[0]
377474

378475
if isinstance(default, FieldInfo):
379-
default.annotation, annotation_metadata = FieldInfo._extract_metadata(annotation) # pyright: ignore[reportArgumentType]
476+
default.annotation, annotation_metadata = _unpack_annotated(annotation)
380477
default.metadata += annotation_metadata
381478
default = default.merge_field_infos(
382479
*[x for x in annotation_metadata if isinstance(x, FieldInfo)], default, annotation=default.annotation
@@ -394,7 +491,7 @@ class MyModel(pydantic.BaseModel):
394491
annotation = annotation.type
395492

396493
pydantic_field = FieldInfo._from_dataclass_field(default)
397-
pydantic_field.annotation, annotation_metadata = FieldInfo._extract_metadata(annotation) # pyright: ignore[reportArgumentType]
494+
pydantic_field.annotation, annotation_metadata = _unpack_annotated(annotation)
398495
pydantic_field.metadata += annotation_metadata
399496
pydantic_field = pydantic_field.merge_field_infos(
400497
*[x for x in annotation_metadata if isinstance(x, FieldInfo)],
@@ -407,19 +504,20 @@ class MyModel(pydantic.BaseModel):
407504
pydantic_field.kw_only = getattr(default, 'kw_only', None)
408505
return pydantic_field
409506

410-
if _typing_extra.is_annotated(annotation):
411-
first_arg, *extra_args = typing_extensions.get_args(annotation)
412-
field_infos = [a for a in extra_args if isinstance(a, FieldInfo)]
413-
field_info = FieldInfo.merge_field_infos(*field_infos, annotation=first_arg, default=default)
414-
metadata: list[Any] = []
415-
for a in extra_args:
507+
annotation, metadata = _unpack_annotated(annotation)
508+
509+
if metadata:
510+
field_infos = [a for a in metadata if isinstance(a, FieldInfo)]
511+
field_info = FieldInfo.merge_field_infos(*field_infos, annotation=annotation, default=default)
512+
field_metadata: list[Any] = []
513+
for a in metadata:
416514
if _typing_extra.is_deprecated_instance(a):
417515
field_info.deprecated = a.message
418516
elif not isinstance(a, FieldInfo):
419-
metadata.append(a)
517+
field_metadata.append(a)
420518
else:
421-
metadata.extend(a.metadata)
422-
field_info.metadata = metadata
519+
field_metadata.extend(a.metadata)
520+
field_info.metadata = field_metadata
423521
return field_info
424522

425523
return FieldInfo(annotation=annotation, default=default, frozen=final or None) # pyright: ignore[reportArgumentType]
@@ -516,23 +614,6 @@ def _from_dataclass_field(dc_field: DataclassField[Any]) -> FieldInfo:
516614
dc_field_metadata = {k: v for k, v in dc_field.metadata.items() if k in _FIELD_ARG_NAMES}
517615
return Field(default=default, default_factory=default_factory, repr=dc_field.repr, **dc_field_metadata) # pyright: ignore[reportCallIssue]
518616

519-
@staticmethod
520-
def _extract_metadata(annotation: type[Any] | None) -> tuple[type[Any] | None, list[Any]]:
521-
"""Tries to extract metadata/constraints from an annotation if it uses `Annotated`.
522-
523-
Args:
524-
annotation: The type hint annotation for which metadata has to be extracted.
525-
526-
Returns:
527-
A tuple containing the extracted metadata type and the list of extra arguments.
528-
"""
529-
if annotation is not None:
530-
if _typing_extra.is_annotated(annotation):
531-
first_arg, *extra_args = typing_extensions.get_args(annotation)
532-
return first_arg, list(extra_args)
533-
534-
return annotation, []
535-
536617
@staticmethod
537618
def _collect_metadata(kwargs: dict[str, Any]) -> list[Any]:
538619
"""Collect annotations from kwargs.

0 commit comments

Comments
 (0)
0