8000 Support usage of `type` with `typing.Self` and type aliases by kc0506 · Pull Request #10621 · pydantic/pydantic · GitHub
[go: up one dir, main page]

Skip to content

Support usage of type with typing.Self and type aliases #10621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions pydantic/_internal/_generate_schema.py
< 10000 /tr>
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,12 @@ def _unpack_refs_defs(self, schema: CoreSchema) -> CoreSchema:
return schema['schema']
return schema

def _resolve_self_type(self, obj: Any) -> Any:
obj = self.model_type_stack.get()
if obj is None:
raise PydanticUserError('`typing.Self` is invalid in this context', code='invalid-self-type')
return obj

def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.CoreSchema | None:
"""Try to generate schema from either the `__get_pydantic_core_schema__` function or
`__pydantic_core_schema__` property.
Expand All @@ -776,9 +782,7 @@ def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.C
"""
# avoid calling `__get_pydantic_core_schema__` if we've already visited this object
if is_self_type(obj):
obj = self.model_type_stack.get()
if obj is None:
raise PydanticUserError('`typing.Self` is invalid in this context', code='invalid-self-type')
obj = self._resolve_self_type(obj)
with self.defs.get_schema_or_ref(obj) as (_, maybe_schema):
if maybe_schema is not None:
return maybe_schema
Expand Down Expand Up @@ -858,27 +862,33 @@ def _resolve_forward_ref(self, obj: Any) -> Any:
return obj

@overload
def _get_args_resolving_forward_refs(self, obj: Any, required: Literal[True]) -> tuple[Any, ...]: ...
def _get_args_resolving_forward_refs(
self, obj: Any, *, required: Literal[True], eval_str: bool = True
) -> tuple[Any, ...]: ...

@overload
def _get_args_resolving_forward_refs(self, obj: Any) -> tuple[Any, ...] | None: ...
def _get_args_resolving_forward_refs(self, obj: Any, *, eval_str: bool = True) -> tuple[Any, ...] | None: ...

def _get_args_resolving_forward_refs(self, obj: Any, required: bool = False) -> tuple[Any, ...] | None:
def _get_args_resolving_forward_refs(
self, obj: Any, *, required: bool = False, eval_str: bool = True
) -> tuple[Any, ...] | None:
args = get_args(obj)
if args:
if eval_str:
args = [_typing_extra._make_forward_ref(a) if isinstance(a, str) else a for a in args]
args = tuple([self._resolve_forward_ref(a) if isinstance(a, ForwardRef) else a for a in args])
elif required: # pragma: no cover
raise TypeError(f'Expected {obj} to have generic parameters but it had none')
return args

def _get_first_arg_or_any(self, obj: Any) -> Any:
args = self._get_args_resolving_forward_refs(obj)
def _get_first_arg_or_any(self, obj: Any, eval_str: bool = True) -> Any:
args = self._get_args_resolving_forward_refs(obj, eval_str=eval_str)
if not args:
return Any
return args[0]

def _get_first_two_args_or_any(self, obj: Any) -> tuple[Any, Any]:
args = self._get_args_resolving_forward_refs(obj)
def _get_first_two_args_or_any(self, obj: Any, eval_str: bool = True) -> tuple[Any, Any]:
args = self._get_args_resolving_forward_refs(obj, eval_str=eval_str)
if not args:
return (Any, Any)
if len(args) < 2:
Expand Down Expand Up @@ -1659,10 +1669,14 @@ def _union_is_subclass_schema(self, union_type: Any) -> core_schema.CoreSchema:
def _subclass_schema(self, type_: Any) -> core_schema.CoreSchema:
"""Generate schema for a Type, e.g. `Type[int]`."""
type_param = self._get_first_arg_or_any(type_)

# Assume `type[Annotated[<typ>, ...]]` is equivalent to `type[<typ>]`:
type_param = _typing_extra.annotated_type(type_param) or type_param

if type_param == Any:
return self._type_schema()
elif isinstance(type_param, TypeAliasType):
return self.generate_schema(typing.Type[type_param.__value__])
elif isinstance(type_param, typing.TypeVar):
if type_param.__bound__:
if _typing_extra.origin_is_union(get_origin(type_param.__bound__)):
Expand All @@ -1677,6 +1691,11 @@ def _subclass_schema(self, type_: Any) -> core_schema.CoreSchema:
elif _typing_extra.origin_is_union(get_origin(type_param)):
return self._union_is_subclass_schema(type_param)
else:
if is_self_type(type_param):
type_param = self._resolve_self_type(type_param)

if not inspect.isclass(type_param):
raise TypeError(f'Expected a class, got {type_param!r}')
return core_schema.is_subclass_schema(type_param)

def _sequence_schema(self, items_type: Any) -> core_schema.CoreSchema:
Expand Down Expand Up @@ -2020,10 +2039,14 @@ def _annotated_schema(self, annotated_type: Any) -> core_schema.CoreSchema:
"""Generate schema for an Annotated type, e.g. `Annotated[int, Field(...)]` or `Annotated[int, Gt(0)]`."""
FieldInfo = import_cached_field_info()

# We don't want to eval string annotations as type,
# because there may be cases like `Annotated[int, 'not a type']`
source_type, *annotations = self._get_args_resolving_forward_refs(
annotated_type,
required=True,
eval_str=False,
)

schema = self._apply_annotations(source_type, annotations)
# put the default validator last so that TypeAdapter.get_default_value() works
# even if there are function validators involved
Expand Down
46 changes: 45 additions & 1 deletion tests/test_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import pytest
from dirty_equals import HasRepr, IsStr
from pydantic_core import ErrorDetails, InitErrorDetails, PydanticSerializationError, core_schema
from typing_extensions import Annotated, Literal, TypedDict, get_args
from typing_extensions import Annotated, Literal, TypeAliasType, TypedDict, get_args

from pydantic import (
BaseModel,
Expand Down Expand Up @@ -1394,6 +1394,50 @@ class Model(BaseModel):
assert Model.model_fields.keys() == set('abcdefg')


def test_type_on_none():
class Model(BaseModel):
a: Type[None]

Model(a=type(None))

with pytest.raises(ValidationError) as exc_info:
Model(a=None)

assert exc_info.value.errors(include_url=False) == [
{
'type': 'is_subclass_of',
'loc': ('a',),
'msg': 'Input should be a subclass of NoneType',
'input': None,
'ctx': {'class': 'NoneType'},
}
]


def test_type_on_typealias():
Float = TypeAliasType('Float', float)

class MyFloat(float): ...

adapter = TypeAdapter(Type[Float])

adapter.validate_python(float)
adapter.validate_python(MyFloat)

with pytest.raises(ValidationError) as exc_info:
adapter.validate_python(str)

assert exc_info.value.errors(include_url=False) == [
{
'type': 'is_subclass_of',
'loc': (),
'msg': 'Input should be a subclass of float',
'input': str,
'ctx': {'class': 'float'},
}
]


def test_annotated_inside_type():
class Model(BaseModel):
a: Type[Annotated[int, ...]]
Expand Down
46 changes: 44 additions & 2 deletions tests/test_types_self.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import dataclasses
import re
import typing
from typing import List, Optional, Union
from typing import List, Optional, Type, Union

import pytest
import typing_extensions
from typing_extensions import NamedTuple, TypedDict

from pydantic import BaseModel, Field, PydanticUserError, TypeAdapter, ValidationError, validate_call
from pydantic import BaseModel, Field, PydanticUserError, TypeAdapter, ValidationError, computed_field, validate_call


@pytest.fixture(
Expand Down Expand Up @@ -188,3 +188,45 @@ class A(BaseModel):
@validate_call
def foo(self: Self):
pass


def test_type_of_self(Self):
class A(BaseModel):
self_type: Type[Self]

@computed_field
def self_types1(self) -> List[Type[Self]]:
return [type(self), self.self_type]

# make sure forward refs are supported:
@computed_field
def self_types2(self) -> List[Type['Self']]:
return [type(self), self.self_type]

@computed_field
def self_types3(self) -> 'List[Type[Self]]':
return [type(self), self.self_type]

class B(A): ...

A(self_type=A)
A(self_type=B)
B(self_type=B)

a = A(self_type=B)
for prop in (a.self_types1, a.self_types2, a.self_types3):
assert prop == [A, B]

for invalid_type in (type, int, A, object):
with pytest.raises(ValidationError) as exc_info:
B(self_type=invalid_type)

assert exc_info.value.errors(include_url=False) == [
{
'type': 'is_subclass_of',
'loc': ('self_type',),
'msg': f'Input should be a subclass of {B.__qualname__}',
'input': invalid_type,
'ctx': {'class': B.__qualname__},
}
]
Loading
0