@@ -61,14 +61,35 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
61
61
namespace = _typing_extra .add_module_globals (function , None )
62
62
config_wrapper = ConfigWrapper (config )
63
63
gen_schema = _generate_schema .GenerateSchema (config_wrapper , namespace )
64
- self .__pydantic_core_schema__ = schema = gen_schema .generate_schema (function )
64
+ self .__pydantic_core_schema__ = schema = gen_schema .collect_definitions ( gen_schema . generate_schema (function ) )
65
65
core_config = config_wrapper .core_config (self )
66
66
schema = _discriminated_union .apply_discriminators (flatten_schema_defs (schema ))
67
67
simplified_schema = inline_schema_defs (schema )
68
68
self .__pydantic_validator__ = pydantic_core .SchemaValidator (simplified_schema , core_config )
69
69
70
+ if self ._validate_return :
71
+ return_type = (
72
+ self .__signature__ .return_annotation
73
+ if self .__signature__ .return_annotation is not self .__signature__ .empty
74
+ else Any
75
+ )
76
+ gen_schema = _generate_schema .GenerateSchema (config_wrapper , namespace )
77
+ self .__return_pydantic_core_schema__ = schema = gen_schema .collect_definitions (
78
+ gen_schema .generate_schema (return_type )
79
+ )
80
+ core_config = config_wrapper .core_config (self )
81
+ schema = _discriminated_union .apply_discriminators (flatten_schema_defs (schema ))
82
+ simplified_schema = inline_schema_defs (schema )
83
+ self .__return_pydantic_validator__ = pydantic_core .SchemaValidator (simplified_schema , core_config )
84
+ else :
85
+ self .__return_pydantic_core_schema__ = None
86
+ self .__return_pydantic_validator__ = None
87
+
70
88
def __call__ (self , * args : Any , ** kwargs : Any ) -> Any :
71
- return self .__pydantic_validator__ .validate_python (pydantic_core .ArgsKwargs (args , kwargs ))
89
+ res = self .__pydantic_validator__ .validate_python (pydantic_core .ArgsKwargs (args , kwargs ))
90
+ if self .__return_pydantic_validator__ :
91
+ return self .__return_pydantic_validator__ .validate_python (res )
92
+ return res
72
93
73
94
def __get__ (self , obj : Any , objtype : type [Any ] | None = None ) -> ValidateCallWrapper :
74
95
"""Bind the raw function and return another ValidateCallWrapper wrapping that."""
0 commit comments