8000 Add support for Type[T] typehints when arbitrary_types_allowed==True.… · pydantic/pydantic@f08fd2f · GitHub
[go: up one dir, main page]

Skip to content

Commit f08fd2f

Browse files
timonbimonsamuelcolvin
authored andcommitted
Add support for Type[T] typehints when arbitrary_types_allowed==True. (#808)
* Add support for Type[T] typehints when arbitrary_types_allowe==True. * Add documentation. * Let black do its magic. * Ignore mypy warning - see here: python/mypy#3060 * Prettify docs. * Change Changelog. * Refactor and simplify check for Type[T]. * Black again. ^^ - Really need pre-commit hooks. * Update pydantic/validators.py Co-Authored-By: Samuel Colvin <samcolvin@gmail.com> * Rename arbitrary_class to class. * Black. * Add type hints. * Make private function public. * Add support for bare Type. * Black again. * Update docs. * CO_ct not meant for export. * Fix get_class for Python3.6 * Update error message of ClassError. * Use relative import. * Incorporate typing feedback (both versions are fine with mypy). * Move from issubclass to lenient_issubclass. * correct docs
1 parent ef894d2 commit f08fd2f

File tree

9 files changed

+201
-4
lines changed

9 files changed

+201
-4
lines changed

changes/807-timonbimon.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add support for ``Type[T]`` type hints

docs/examples/bare_type_type.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from typing import Type
2+
3+
from pydantic import BaseModel, ValidationError
4+
5+
6+
class Foo:
7+
pass
8+
9+
10+
class LenientSimpleModel(BaseModel):
11+
any_class_goes: Type
12+
13+
14+
LenientSimpleModel(any_class_goes=int)
15+
LenientSimpleModel(any_class_goes=Foo)
16+
try:
17+
LenientSimpleModel(any_class_goes=Foo())
18+
except ValidationError as e:
19+
print(e)
20+
"""
21+
1 validation error
22+
any_class_goes
23+
subclass of type expected (type=type_error.class)
24+
"""

docs/examples/type_type.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from typing import Type
2+
3+
from pydantic import BaseModel
4+
from pydantic import ValidationError
5+
6+
class Foo:
7+
pass
8+
9+
class Bar(Foo):
10+
pass
11+
12+
class Other:
13+
pass
14+
15+
class SimpleModel(BaseModel):
16+
just_subclasses: Type[Foo]
17+
18+
19+
SimpleModel(just_subclasses=Foo)
20+
SimpleModel(just_subclasses=Bar)
21+
try:
22+
SimpleModel(just_subclasses=Other)
23+
except ValidationError as e:
24+
print(e)
25+
"""
26+
1 validation error
27+
just_subclasses
28+
subclass of Foo expected (type=type_error.class)
29+
"""

docs/index.rst

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,18 @@ With proper ordering in an annotated ``Union``, you can use this to parse types
818818

819819
(This script is complete, it should run "as is")
820820

821+
Type Type
822+
............
823+
824+
Pydantic supports the use of ``Type[T]`` to specify that a field may only accept classes (not instances)
825+
that are subclasses of ``T``.
826+
827+
.. literalinclude:: examples/type_type.py
828+
829+
You may also use ``Type`` to specify that any class is allowed.
830+
831+
.. literalinclude:: examples/bare_type_type.py
832+
821833
Custom Data Types
822834
.................
823835

@@ -898,7 +910,7 @@ Options:
898910
:error_msg_templates: let's you to override default error message templates.
899911
Pass in a dictionary with keys matching the error messages you want to override (default: ``{}``)
900912
:arbitrary_types_allowed: whether to allow arbitrary user types for fields (they are validated simply by checking if the
901-
value is instance of that type). If False - RuntimeError will be raised on model declaration (default: ``False``)
913+
value is instance of that type). If ``False`` - ``RuntimeError`` will be raised on model declaration (default: ``False``)
902914
:json_encoders: customise the way types are encoded to json, see :ref:`JSON Serialisation <json_dump>` for more
903915
details.
904916
:orm_mode: allows usage of :ref:`ORM mode <orm_mode>`

pydantic/errors.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,19 @@ def __init__(self, *, expected_arbitrary_type: AnyType) -> None:
324324
super().__init__(expected_arbitrary_type=display_as_type(expected_arbitrary_type))
325325

326326

327+
class ClassError(PydanticTypeError):
328+
code = 'class'
329+
msg_template = 'a class is expected'
330+
331+
332+
class SubclassError(PydanticTypeError):
333+
code = 'subclass'
334+
msg_template = 'subclass of {expected_class} expected'
335+
336+
def __init__(self, *, expected_class: AnyType) -> None:
337+
super().__init__(expected_class=display_as_type(expected_class))
338+
339+
327340
class JsonError(PydanticValueError):
328341
msg_template = 'Invalid JSON'
329342

pydantic/fields.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,8 @@ def _populate_sub_fields(self) -> None: # noqa: C901 (ignore complexity)
248248
)
249249
self.type_ = self.type_.__args__[1] # type: ignore
250250
self.shape = SHAPE_MAPPING
251+
elif issubclass(origin, Type): # type: ignore
252+
return
251253
else:
252254
raise TypeError(f'Fields of type "{origin}" are not supported.')
253255

pydantic/typing.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,21 @@ def update_field_forward_refs(field: 'Field', globalns: Any, localns: Any) -> No
193193
if field.sub_fields:
194194
for sub_f in field.sub_fields:
195195
update_field_forward_refs(sub_f, globalns=globalns, localns=localns)
196+
197+
198+
def get_class(type_: AnyType) -> Union[None, bool, AnyType]:
199+
"""
200+
Tries to get the class of a Type[T] annotation. Returns True if Type is used
201+
without brackets. Otherwise returns None.
202+
"""
203+
try:
204+
origin = getattr(type_, '__origin__')
205+
if origin is None: # Python 3.6
206+
origin = type_
207+
if issubclass(origin, Type): # type: ignore
208+
if type_.__args__ is None or not isinstance(type_.__args__[0], type):
209+
return True
210+
return type_.__args__[0]
211+
except AttributeError:
212+
pass
213+
return None

pydantic/validators.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626

2727
from . import errors
2828
from .datetime_parse import parse_date, parse_datetime, parse_duration, parse_time
29-
from .typing import AnyCallable, AnyType, ForwardRef, display_as_type, is_callable_type, is_literal_type
30-
from .utils import almost_equal_floats, change_exception, sequence_like
29+
from .typing import AnyCallable, AnyType, ForwardRef, display_as_type, get_class, is_callable_type, is_literal_type
30+
from .utils import almost_equal_floats, change_exception, lenient_issubclass, sequence_like
3131

3232
if TYPE_CHECKING: # pragma: no cover
3333
from .fields import Field
@@ -404,6 +404,21 @@ def arbitrary_type_validator(v: Any) -> T:
404404
return arbitrary_type_validator
405405

406406

407+
def make_class_validator(type_: Type[T]) -> Callable[[Any], Type[T]]:
408+
def class_validator(v: Any) -> Type[T]:
409+
if lenient_issubclass(v, type_):
410+
return v
411+
raise errors.SubclassError(expected_class=type_)
412+
413+
return class_validator
414+
415+
416+
def any_class_validator(v: Any) -> Type[T]:
417+
if isinstance(v, type):
418+
return v
419+
raise errors.ClassError()
420+
421+
407422
def pattern_validator(v: Any) -> Pattern[str]:
408423
with change_exception(errors.PatternError, re.error):
409424
return re.compile(v)
@@ -486,6 +501,14 @@ def find_validators( # noqa: C901 (ignore complexity)
486501
yield make_literal_validator(type_)
487502
return
488503

504+
class_ = get_class(type_)
505+
if class_ is not None:
506+
if isinstance(class_, type):
507+
yield make_class_validator(class_)
508+
else:
509+
yield any_class_validator
510+
return
511+
489512
supertype = _find_supertype(type_)
490513
if supertype is not None:
491514
type_ = supertype

tests/test_main.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import Any, ClassVar, List, Mapping
2+
from typing import Any, ClassVar, List, Mapping, Type
33

44
import pytest
55

@@ -530,6 +530,81 @@ class ArbitraryTypeNotAllowedModel(BaseModel):
530530
assert exc_info.value.args[0].startswith('no validator found for')
531531

532532

533+
def test_type_type_validation_success():
534+
class ArbitraryClassAllowedModel(BaseModel):
535+
t: Type[ArbitraryType]
536+
537+
arbitrary_type_class = ArbitraryType
538+
m = ArbitraryClassAllowedModel(t=arbitrary_type_class)
539+
assert m.t == arbitrary_type_class
540+
541+
542+
def test_type_type_subclass_validation_success():
543+
class ArbitraryClassAllowedModel(BaseModel):
544+
t: Type[ArbitraryType]
545+
546+
class ArbitrarySubType(ArbitraryType):
547+
pass
548+
549+
arbitrary_type_class = ArbitrarySubType
550+
m = ArbitraryClassAllowedModel(t=arbitrary_type_class)
551+
assert m.t == arbitrary_type_class
552+
553+
554+
def test_type_type_validation_fails_for_instance():
555+
class ArbitraryClassAllowedModel(BaseModel):
556+
t: Type[ArbitraryType]
557+
558+
class C:
559+
pass
560+
561+
with pytest.raises(ValidationError) as exc_info:
562+
ArbitraryClassAllowedModel(t=C)
563+
assert exc_info.value.errors() == [
564+
{
565+
'loc': ('t',),
566+
'msg': 'subclass of ArbitraryType expected',
567+
'type': 'type_error.subclass',
568+
'ctx': {'expected_class': 'ArbitraryType'},
569+
}
570+
]
571+
572+
573+
def test_type_type_validation_fails_for_basic_type():
574+
class ArbitraryClassAllowedModel(BaseModel):
575+
t: Type[ArbitraryType]
576+
577+
with pytest.raises(ValidationError) as exc_info:
578+
ArbitraryClassAllowedModel(t=1)
579+
assert exc_info.value.errors() == [
580+
{
581+
'loc': ('t',),
582+
'msg': 'subclass of ArbitraryType expected',
583+
'type': 'type_error.subclass',
584+
'ctx': {'expected_class': 'ArbitraryType'},
585+
}
586+
]
587+
588+
589+
def test_bare_type_type_validation_success():
590+
class ArbitraryClassAllowedModel(BaseModel):
591+
t: Type
592+
593+
arbitrary_type_class = ArbitraryType
594+
m = ArbitraryClassAllowedModel(t=arbitrary_type_class)
595+
assert m.t == arbitrary_type_class
596+
597+
598+
def test_bare_type_type_validation_fails():
599+
class ArbitraryClassAllowedModel(BaseModel):
600+
t: Type
601+
602+
arbitrary_type = ArbitraryType()
603+
with pytest.raises(ValidationError) as exc_info:
604+
ArbitraryClassAllowedModel(t=arbitrary_type)
605+
assert exc_info.value.errors() == [{'loc': ('t',), 'msg': 'a class is expected', 'type': 'type_error.class'}]
606+
607+
533608
def test_annotation_field_name_shadows_attribute():
534609
with pytest.raises(NameError):
535610
# When defining a model that has an attribute with the name of a built-in attribute, an exception is raised

0 commit comments

Comments
 (0)
0