10000 Include JSON Schema input core schema in function schemas (#11142) · pydantic/pydantic@a07c31e · GitHub
[go: up one dir, main page]

Skip to content

Commit a07c31e

Browse files
Include JSON Schema input core schema in function schemas (#11142)
Co-authored-by: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com>
1 parent 9166d55 commit a07c31e

File tree

7 files changed

+60
-80
lines changed

7 files changed

+60
-80
lines changed

pydantic/_internal/_core_metadata.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
from warnings import warn
55

66
if TYPE_CHECKING:
7-
from pydantic_core import CoreSchema
8-
97
from ..config import JsonDict, JsonSchemaExtraCallable
108
from ._schema_generation_shared import (
119
GetJsonSchemaFuncti 10000 on,
@@ -20,7 +18,6 @@ class CoreMetadata(TypedDict, total=False):
2018
pydantic_js_annotation_functions: List of JSON schema functions that don't resolve refs during application.
2119
pydantic_js_prefer_positional_arguments: Whether JSON schema generator will
2220
prefer positional over keyword arguments for an 'arguments' schema.
23-
pydantic_js_input_core_schema: Schema associated with the input value for the associated
2421
custom validation function. Only applies to before, plain, and wrap validators.
2522
pydantic_js_udpates: key / value pair updates to apply to the JSON schema for a type.
2623
pydantic_js_extra: WIP, either key/value pair updates to apply to the JSON schema, or a custom callable.
@@ -37,7 +34,6 @@ class CoreMetadata(TypedDict, total=False):
3734
pydantic_js_functions: list[GetJsonSchemaFunction]
3835
pydantic_js_annotation_functions: list[GetJsonSchemaFunction]
3936
pydantic_js_prefer_positional_arguments: bool
40-
pydantic_js_input_core_schema: CoreSchema
4137
pydantic_js_updates: JsonDict
4238
pydantic_js_extra: JsonDict | JsonSchemaExtraCallable
4339

pydantic/_internal/_core_utils.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,12 +280,38 @@ def handle_dict_schema(self, schema: core_schema.DictSchema, f: Walk) -> core_sc
280280
schema['values_schema'] = self.walk(values_schema, f)
281281
return schema
282282

283-
def handle_function_schema(self, schema: AnyFunctionSchema, f: Walk) -> core_schema.CoreSchema:
284-
if not is_function_with_inner_schema(schema):
285-
return schema
283+
def handle_function_after_schema(
284+
self, schema: core_schema.AfterValidatorFunctionSchema, f: Walk
285+
) -> core_schema.CoreSchema:
286286
schema['schema'] = self.walk(schema['schema'], f)
287287
return schema
288288

289+
def handle_function_before_schema(
290+
self, schema: core_schema.BeforeValidatorFunctionSchema, f: Walk
291+
) -> core_schema.CoreSchema:
292+
schema['schema'] = self.walk(schema['schema'], f)
293+
if 'json_schema_input_schema' in schema:
294+
schema['json_schema_input_schema'] = self.walk(schema['json_schema_input_schema'], f)
295+
return schema
296+
297+
# TODO duplicate schema types for serializers and validators, needs to be deduplicated:
298+
def handle_function_plain_schema(
299+
self, schema: core_schema.PlainValidatorFunctionSchema | core_schema.PlainSerializerFunctionSerSchema, f: Walk
300+
) -> core_schema.CoreSchema:
301+
if 'json_schema_input_schema' in schema:
302+
schema['json_schema_input_schema'] = self.walk(schema['json_schema_input_schema'], f)
303+
return schema # pyright: ignore[reportReturnType]
304+
305+
# TODO duplicate schema types for serializers and validators, needs to be deduplicated:
306+
def handle_function_wrap_schema(
307+
self, schema: core_schema.WrapValidatorFunctionSchema | core_schema.WrapSerializerFunctionSerSchema, f: Walk
308+
) -> core_schema.CoreSchema:
309+
if 'schema' in schema:
310+
schema['schema'] = self.walk(schema['schema'], f)
311+
if 'json_schema_input_schema' in schema:
312+
schema['json_schema_input_schema'] = self.walk(schema['json_schema_input_schema'], f)
313+
return schema # pyright: ignore[reportReturnType]
314+
289315
def handle_union_schema(self, schema: core_schema.UnionSchema, f: Walk) -> core_schema.CoreSchema:
290316
new_choices: list[CoreSchema | tuple[CoreSchema, str]] = []
291317
for v in schema['choices']:

pydantic/functional_validators.py

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,6 @@ def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaH
127127
if self.json_schema_input_type is PydanticUndefined
128128
else handler.generate_schema(self.json_schema_input_type)
129129
)
130-
# Try to resolve the original schema if required, because schema cleaning
131-
# won't inline references in metadata:
132-
if input_schema is not None:
133-
try:
134-
input_schema = handler.resolve_ref_schema(input_schema)
135-
except LookupError:
136-
pass
137-
metadata = {'pydantic_js_input_core_schema': input_schema} if input_schema is not None else {}
138130

139131
info_arg = _inspect_validator(self.func, 'before')
140132
if info_arg:
@@ -143,11 +135,13 @@ def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaH
143135
func,
144136
schema=schema,
145137
field_name=handler.field_name,
146-
metadata=metadata,
138+
json_schema_input_schema=input_schema,
147139
)
148140
else:
149141
func = cast(core_schema.NoInfoValidatorFunction, self.func)
150-
return core_schema.no_info_before_validator_function(func, schema=schema, metadata=metadata)
142+
return core_schema.no_info_before_validator_function(
143+
func, schema=schema, json_schema_input_schema=input_schema
144+
)
151145

152146
@classmethod
153147
def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self:
@@ -229,13 +223,6 @@ def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaH
229223
serialization = None
230224

231225
input_schema = handler.generate_schema(self.json_schema_input_type)
232-
# Try to resolve the original schema if required, because schema cleaning
233-
# won't inline references in metadata:
234-
try:
235-
input_schema = handler.resolve_ref_schema(input_schema)
236-
except LookupError:
237-
pass
238-
metadata = {'pydantic_js_input_core_schema': input_schema} if input_schema is not None else {}
239226

240227
info_arg = _inspect_validator(self.func, 'plain')
241228
if info_arg:
@@ -244,14 +231,14 @@ def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaH
244231
func,
245232
field_name=handler.field_name,
246233
serialization=serialization, # pyright: ignore[reportArgumentType]
247-
metadata=metadata,
234+
json_schema_input_schema=input_schema,
248235
)
249236
else:
250237
func = cast(core_schema.NoInfoValidatorFunction, self.func)
251238
return core_schema.no_info_plain_validator_function(
252239
func,
253240
serialization=serialization, # pyright: ignore[reportArgumentType]
254-
metadata=metadata,
241+
json_schema_input_schema=input_schema,
255242
)
256243

257244
@classmethod
@@ -312,14 +299,6 @@ def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaH
312299
if self.json_schema_input_type is PydanticUndefined
313300
else handler.generate_schema(self.json_schema_input_type)
314301
)
315-
# Try to resolve the original schema if required, because schema cleaning
316-
# won't inline references in metadata:
317-
if input_schema is not None:
318-
try:
319-
input_schema = handler.resolve_ref_schema(input_schema)
320-
except LookupError:
321-
pass
322-
metadata = {'pydantic_js_input_core_schema': input_schema} if input_schema is not None else {}
323302

324303
info_arg = _inspect_validator(self.func, 'wrap')
325304
if info_arg:
@@ -328,14 +307,14 @@ def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaH
328307
func,
329308
schema=schema,
330309
field_name=handler.field_name,
331-
metadata=metadata,
310+
json_schema_input_schema=input_schema,
332311
)
333312
else:
334313
func = cast(core_schema.NoInfoWrapValidatorFunction, self.func)
335314
return core_schema.no_info_wrap_validator_function(
336315
func,
337316
schema=schema,
338-
metadata=metadata,
317+
json_schema_input_schema=input_schema,
339318
)
340319

341320
@classmethod

pydantic/json_schema.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,9 +1053,7 @@ def function_before_schema(self, schema: core_schema.BeforeValidatorFunctionSche
10531053
Returns:
10541054
The generated JSON schema.
10551055
"""
1056-
if self._mode == 'validation' and (
1057-
input_schema := schema.get('metadata', {}).get('pydantic_js_input_core_schema')
1058-
):
1056+
if self._mode == 'validation' and (input_schema := schema.get('json_schema_input_schema')):
10591057
return self.generate_inner(input_schema)
10601058

10611059
return self.generate_inner(schema['schema'])
@@ -1080,9 +1078,7 @@ def function_plain_schema(self, schema: core_schema.PlainValidatorFunctionSchema
10801078
Returns:
10811079
The generated JSON schema.
10821080
"""
1083-
if self._mode == 'validation' and (
1084-
input_schema := schema.get('metadata', {}).get('pydantic_js_input_core_schema')
1085-
):
1081+
if self._mode == 'validation' and (input_schema := schema.get('json_schema_input_schema')):
10861082
return self.generate_inner(input_schema)
10871083

10881084
return self.handle_invalid_for_json_schema(
@@ -1098,9 +1094,7 @@ def function_wrap_schema(self, schema: core_schema.WrapValidatorFunctionSchema)
10981094
Returns:
10991095
The generated JSON schema.
11001096
"""
1101-
if self._mode == 'validation' and (
1102-
input_schema := schema.get('metadata', {}).get('pydantic_js_input_core_schema')
1103-
):
1097+
if self._mode == 'validation' and (input_schema := schema.get('json_schema_input_schema')):
11041098
return self.generate_inner(input_schema)
11051099

11061100
return self.generate_inner(schema['schema'])

tests/test_json_schema.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6531,31 +6531,39 @@ def validate_f(cls, value: Any) -> int: ...
65316531
@pytest.mark.parametrize(
65326532
'validator',
65336533
[
6534-
PlainValidator(lambda v: v, json_schema_input_type='Sub'),
6535-
BeforeValidator(lambda v: v, json_schema_input_type='Sub'),
6536-
WrapValidator(lambda v, h: h(v), json_schema_input_type='Sub'),
6534+
PlainValidator(lambda v: v, json_schema_input_type='Union[Sub1, Sub2]'),
6535+
BeforeValidator(lambda v: v, json_schema_input_type='Union[Sub1, Sub2]'),
6536+
WrapValidator(lambda v, h: h(v), json_schema_input_type='Union[Sub1, Sub2]'),
65376537
],
65386538
)
65396539
def test_json_schema_input_type_with_refs(validator) -> None:
6540-
"""Test that `'definition-ref` schemas for `json_schema_input_type` are inlined.
6540+
"""Test that `'definition-ref` schemas for `json_schema_input_type` are supported.
65416541
65426542
See: https://github.com/pydantic/pydantic/issues/10434.
6543+
See: https://github.com/pydantic/pydantic/issues/11033
65436544
"""
65446545

6545-
class Sub(BaseModel):
6546+
class Sub1(BaseModel):
6547+
pass
6548+
6549+
class Sub2(BaseModel):
65466550
pass
65476551

65486552
class Model(BaseModel):
65496553
sub: Annotated[
6550-
Sub,
6551-
PlainSerializer(lambda v: v, return_type=Sub),
6554+
Union[Sub1, Sub2],
6555+
PlainSerializer(lambda v: v, return_type=Union[Sub1, Sub2]),
65526556
validator,
65536557
]
65546558

65556559
json_schema = Model.model_json_schema()
65566560

6557-
assert 'Sub' in json_schema['$defs']
6558-
assert json_schema['properties']['sub']['$ref'] == '#/$defs/Sub'
6561+
assert 'Sub1' in json_schema['$defs']
6562+
assert 'Sub2' in json_schema['$defs']
6563+
assert json_schema['properties']['sub'] == {
6564+
'anyOf': [{'$ref': '#/$defs/Sub1'}, {'$ref': '#/$defs/Sub2'}],
6565+
'title': 'Sub',
6566+
}
65596567

65606568

65616569
@pytest.mark.parametrize(

tests/test_utils.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -637,29 +637,6 @@ def walk(s, recurse):
637637
}
638638

639639

640-
def test_handle_function_schema():
641-
schema = core_schema.with_info_before_validator_function(
642-
lambda v, _info: v, core_schema.float_schema(), field_name='field_name'
643-
)
644 D041 -
645-
def walk(s, recurse):
646-
# change type to str
647-
if s['type'] == 'float':
648-
s['type'] = 'str'
649-
return s
650-
651-
schema = _WalkCoreSchema().handle_function_schema(schema, walk)
652-
assert schema['type'] == 'function-before'
653-
assert schema['schema'] == {'type': 'str'}
654-
655-
def walk1(s, recurse):
656-
# this is here to make sure this function is not called
657-
assert False
658-
659-
schema = _WalkCoreSchema().handle_function_schema(core_schema.int_schema(), walk1)
660-
assert schema['type'] == 'int'
661-
662-
663640
def test_handle_call_schema():
664641
param_a = core_schema.arguments_parameter(name='a', schema=core_schema.str_schema(), mode='positional_only')
665642
args_schema = core_schema.arguments_schema([param_a])

uv.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)
0