8000 Allow None for (virtual) superclasses of NoneType (#1266) · ag-python/pydantic@b749e22 · GitHub
[go: up one dir, main page]

Skip to content

Commit b749e22

Browse files
authored
Allow None for (virtual) superclasses of NoneType (pydantic#1266)
* Set allow_none to True for (virtual) superclasses of NoneType * Add full support for Hashable
1 parent 76ebdb9 commit b749e22

File tree

5 files changed

+71
-0
lines changed

5 files changed

+71
-0
lines changed

pydantic/errors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
'DateError',
6262
'TimeError',
6363
'DurationError',
64+
'HashableError',
6465
'UUIDError',
6566
'UUIDVersionError',
6667
'ArbitraryTypeError',
@@ -392,6 +393,10 @@ class DurationError(PydanticValueError):
392393
msg_template = 'invalid duration format'
393394

394395

396+
class HashableError(PydanticTypeError):
397+
msg_template = 'value is not a valid hashable'
398+
399+
395400
class UUIDError(PydanticTypeError):
396401
msg_template = 'value is not a valid uuid'
397402

pydantic/fields.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,9 @@ def _type_analysis(self) -> None: # noqa: C901 (ignore complexity)
401401
origin = getattr(self.type_, '__origin__', None)
402402
if origin is None:
403403
# field is not "typing" object eg. Union, Dict, List etc.
404+
# allow None for virtual superclasses of NoneType, e.g. Hashable
405+
if isinstance(self.type_, type) and isinstance(None, self.type_):
406+
self.allow_none = True
404407
return
405408
if origin is Callable:
406409
return

pydantic/validators.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import re
22
import sys
33
from collections import OrderedDict
4+
from collections.abc import Hashable
45
from datetime import date, datetime, time, timedelta
56
from decimal import Decimal, DecimalException
67
from enum import Enum, IntEnum
@@ -284,6 +285,13 @@ def decimal_validator(v: Any) -> Decimal:
284285
return v
285286

286287

288+
def hashable_validator(v: Any) -> Hashable:
289+
if isinstance(v, Hashable):
290+
return v
291+
292+
raise errors.HashableError()
293+
294+
287295
def ip_v4_address_validator(v: Any) -> IPv4Address:
288296
if isinstance(v, IPv4Address):
289297
return v
@@ -539,6 +547,9 @@ def find_validators( # noqa: C901 (ignore complexity)
539547
if type_ is Pattern:
540548
yield pattern_validator
541549
return
550+
if type_ is Hashable:
551+
yield hashable_validator
552+
return
542553
if is_callable_type(type_):
543554
yield callable_validator
544555
return

tests/test_dataclasses.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import dataclasses
2+
from collections.abc import Hashable
23
from datetime import datetime
34
from pathlib import Path
45
from typing import ClassVar, Dict, FrozenSet, Optional
@@ -578,3 +579,29 @@ class Child(Base):
578579

579580
Child(a=1, b=2)
580581
assert post_init_called
582+
583+
584+
def test_hashable_required():
585+
@pydantic.dataclasses.dataclass
586+
class MyDataclass:
587+
v: Hashable
588+
589+
MyDataclass(v=None)
590+
with pytest.raises(ValidationError) as exc_info:
591+
MyDataclass(v=[])
592+
assert exc_info.value.errors() == [
593+
{'loc': ('v',), 'msg': 'value is not a valid hashable', 'type': 'type_error.hashable'}
594+
]
595+
with pytest.raises(TypeError) as exc_info:
596+
MyDataclass()
597+
assert str(exc_info.value) == "__init__() missing 1 required positional argument: 'v'"
598+
599+
600+
@pytest.mark.parametrize('default', [1, None, ...])
601+
def test_hashable_optional(default):
602+
@pydantic.dataclasses.dataclass
603+
class MyDataclass:
604+
v: Hashable = default
605+
606+
MyDataclass()
607+
MyDataclass(v=None)

tests/test_edge_cases.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import sys
2+
from collections.abc import Hashable
23
from decimal import Decimal
34
from enum import Enum
45
from typing import Any, Dict, FrozenSet, Generic, List, Optional, Set, Tuple, Type, TypeVar, Union
@@ -1478,3 +1479,27 @@ def __init__(self, t1: T1, t2: T2):
14781479
class Model(BaseModel):
14791480
a: str
14801481
gen 1E0A : MyGen[str, bool]
1482+
1483+
1484+
def test_hashable_required():
1485+
class Model(BaseModel):
1486+
v: Hashable
1487+
1488+
Model(v=None)
1489+
with pytest.raises(ValidationError) as exc_info:
1490+
Model(v=[])
1491+
assert exc_info.value.errors() == [
1492+
{'loc': ('v',), 'msg': 'value is not a valid hashable', 'type': 'type_error.hashable'}
1493+
]
1494+
with pytest.raises(ValidationError) as exc_info:
1495+
Model()
1496+
assert exc_info.value.errors() == [{'loc': ('v',), 'msg': 'field required', 'type': 'value_error.missing'}]
1497+
1498+
1499+
@pytest.mark.parametrize('default', [1, None])
1500+
def test_hashable_optional(default):
1501+
class Model(BaseModel):
1502+
v: Hashable = default
1503+
1504+
Model(v=None)
1505+
Model()

0 commit comments

Comments
 (0)
0