8000 Fix `validate_call` schema generation and doc error for `TypedDict` (… · pydantic/pydantic@61cdaea · GitHub
[go: up one dir, main page]

Skip to content

Commit 61cdaea

Browse files
authored
Fix validate_call schema generation and doc error for TypedDict (#6370)
1 parent 1d0f405 commit 61cdaea

File tree

3 files changed

+43
-6
lines changed

3 files changed

+43
-6
lines changed

docs/usage/errors.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,8 @@ except PydanticUserError as exc_info:
290290

291291
## `TypedDict` version {#typed-dict-version}
292292

293-
This error is raised when you use `typing_extensions.TypedDict`
294-
instead of `typing.TypedDict` on Python < 3.11.
293+
This error is raised when you use `typing.TypedDict`
294+
instead of `typing_extensions.TypedDict` on Python < 3.12.
295295

296296
## Model parent field overridden {#model-field-overridden}
297297

pydantic/_internal/_validate_call.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,35 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
6161
namespace = _typing_extra.add_module_globals(function, None)
6262
config_wrapper = ConfigWrapper(config)
6363
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
64-
self.__pydantic_core_schema__ = schema = gen_schema.generate_schema(function)
64+
self.__pydantic_core_schema__ = schema = gen_schema.collect_definitions(gen_schema.generate_schema(function))
6565
core_config = config_wrapper.core_config(self)
6666
schema = _discriminated_union.apply_discriminators(flatten_schema_defs(schema))
6767
simplified_schema = inline_schema_defs(schema)
6868
self.__pydantic_validator__ = pydantic_core.SchemaValidator(simplified_schema, core_config)
6969

70+
if self._validate_return:
71+
return_type = (
72+
self.__signature__.return_annotation
73+
if self.__signature__.return_annotation is not self.__signature__.empty
74+
else Any
75+
)
76+
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
77+
self.__return_pydantic_core_schema__ = schema = gen_schema.collect_definitions(
78+
gen_schema.generate_schema(return_type)
79+
)
80+
core_config = config_wrapper.core_config(self)
81+
schema = _discriminated_union.apply_discriminators(flatten_schema_defs(schema))
82+
simplified_schema = inline_schema_defs(schema)
83+
self.__return_pydantic_validator__ = pydantic_core.SchemaValidator(simplified_schema, core_config)
84+
else:
85+
self.__return_pydantic_core_schema__ = None
86+
self.__return_pydantic_validator__ = None
87+
7088
def __call__(self, *args: Any, **kwargs: Any) -> Any:
71-
return self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
89+
res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
90+
if self.__return_pydantic_validator__:
91+
return self.__return_pydantic_validator__.validate_python(res)
92+
return res
7293

7394
def __get__(self, obj: Any, objtype: type[Any] | None = None) -> ValidateCallWrapper:
7495
"""Bind the raw function and return another ValidateCallWrapper wrapping that."""

tests/test_validate_call.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import sys
44
from datetime import datetime
55
from functools import partial
6-
from typing import List
6+
from typing import List, Tuple
77

88
import pytest
99
from pydantic_core import ArgsKwargs
10-
from typing_extensions import Annotated
10+
from typing_extensions import Annotated, TypedDict
1111

1212
from pydantic import Field, TypeAdapter, ValidationError, validate_call
13+
from pydantic.main import BaseModel
1314

1415
skip_pre_38 = pytest.mark.skipif(sys.version_info < (3, 8), reason='testing >= 3.8 behaviour only')
1516
skip_pre_39 = pytest.mark.skipif(sys.version_info < (3, 9), reason='testing >= 3.9 behaviour only')
@@ -621,3 +622,18 @@ def f(a: int, /, **kwargs):
621622
"""
622623
)
623624
assert module.f(1, a=2) == (1, {'a': 2})
625+
626+
627+
def test_model_as_arg() -> None:
628+
class Model1(TypedDict):
629+
x: int
630+
631+
class Model2(BaseModel):
632+
y: int
633+
634+
@validate_call(validate_return=True)
635+
def f1(m1: Model1, m2: Model2) -> Tuple[Model1, Model2]:
636+
return (m1, m2.model_dump()) # type: ignore
637+
638+
res = f1({'x': '1'}, {'y': '2'}) # type: ignore
639+
assert res == ({'x': 1}, Model2(y=2))

0 commit comments

Comments
 (0)
0