8000 fix: use dataclasses proxy for frozen or empty dataclasses (#4878) · pydantic/pydantic@bf4b5ce · GitHub
[go: up one dir, main page]

Skip to content

Commit bf4b5ce

Browse files
authored
fix: use dataclasses proxy for frozen or empty dataclasses (#4878)
* add tests * fix: dataclasses * chore: add change file * refactor: remove useless kwarg * test: add new test * keep old kwarg to avoid breaking change
1 parent a220f87 commit bf4b5ce

File tree

3 files changed

+112
-24
lines changed

3 files changed

+112
-24
lines changed

changes/4878-PrettyWood.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
fix: use dataclass proxy for frozen or empty dataclasses

pydantic/dataclasses.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,7 @@ class M:
3434
import sys
3535
from contextlib import contextmanager
3636
from functools import wraps
37-
from typing import (
38-
TYPE_CHECKING,
39-
Any,
40-
Callable,
41-
ClassVar,
42-
Dict,
43-
Generator,
44-
Optional,
45-
Set,
46-
Type,
47-
TypeVar,
48-
Union,
49-
overload,
50-
)
37+
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Optional, Type, TypeVar, Union, overload
5138

5239
from typing_extensions import dataclass_transform
5340

@@ -117,6 +104,7 @@ def dataclass(
117104
frozen: bool = False,
118105
config: Union[ConfigDict, Type[object], None] = None,
119106
validate_on_init: Optional[bool] = None,
107+
use_proxy: Optional[bool] = None,
120108
kw_only: bool = ...,
121109
) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
122110
...
@@ -134,6 +122,7 @@ def dataclass(
134122
frozen: bool = False,
135123
config: Union[ConfigDict, Type[object], None] = None,
136124
validate_on_init: Optional[bool] = None,
125+
use_proxy: Optional[bool] = None,
137126
kw_only: bool = ...,
138127
) -> 'DataclassClassOrWrapper':
139128
...
@@ -152,6 +141,7 @@ def dataclass(
152141
frozen: bool = False,
153142
config: Union[ConfigDict, Type[object], None] = None,
154143
validate_on_init: Optional[bool] = None,
144+
use_proxy: Optional[bool] = None,
155145
) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
156146
...
157147

@@ -168,6 +158,7 @@ def dataclass(
168158
frozen: bool = False,
169159
config: Union[ConfigDict, Type[object], None] = None,
170160
validate_on_init: Optional[bool] = None,
161+
use_proxy: Optional[bool] = None,
171162
) -> 'DataclassClassOrWrapper':
172163
...
173164

@@ -184,6 +175,7 @@ def dataclass(
184175
frozen: bool = False,
185176
config: Union[ConfigDict, Type[object], None] = None,
186177
validate_on_init: Optional[bool] = None,
178+
use_proxy: Optional[bool] = None,
187179
kw_only: bool = False,
188180
) -> Union[Callable[[Type[_T]], 'DataclassClassOrWrapper'], 'DataclassClassOrWrapper']:
189181
"""
@@ -197,7 +189,15 @@ def dataclass(
197189
def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper':
198190
import dataclasses
199191

200-
if is_builtin_dataclass(cls) and _extra_dc_args(_cls) == _extra_dc_args(_cls.__bases__[0]): # type: ignore
192+
should_use_proxy = (
193+
use_proxy
194+
if use_proxy is not None
195+
else (
196+
is_builtin_dataclass(cls)
197+
and (cls.__bases__[0] is object or set(dir(cls)) == set(dir(cls.__bases__[0])))
198+
)
199+
)
200+
if should_use_proxy:
201201
dc_cls_doc = ''
202202
dc_cls = DataclassProxy(cls)
203203
default_validate_on_init = False
@@ -437,14 +437,6 @@ def _dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value:
437437
object.__setattr__(self, name, value)
438438

439439

440-
def _extra_dc_args(cls: Type[Any]) -> Set[str]:
441-
return {
442-
x
443-
for x in dir(cls)
444-
if x not in getattr(cls, '__dataclass_fields__', {}) and not (x.startswith('__') and x.endswith('__'))
445-
}
446-
447-
448440
def is_builtin_dataclass(_cls: Type[Any]) -> bool:
449441
"""
450442
Whether a class is a stdlib dataclass
@@ -482,4 +474,4 @@ def make_dataclass_validator(dc_cls: Type['Dataclass'], config: Type[BaseConfig]
482474
and yield the validators
483475
It retrieves the parameters of the dataclass and forwards them to the newly created dataclass
484476
"""
485-
yield from _get_validators(dataclass(dc_cls, config=config, validate_on_init=False))
477+
yield from _get_validators(dataclass(dc_cls, config=config, use_proxy=True))

tests/test_dataclasses.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,3 +1513,98 @@ class Foo:
15131513
assert config.bar == 'cat'
15141514
setattr(config, 'bar', 'dog')
15151515
assert config.bar == 'dog'
1516+
1517+
1518+
def test_frozen_dataclasses():
1519+
@dataclasses.dataclass(frozen=True)
1520+
class First:
1521+
a: int
1522+
1523+
@dataclasses.dataclass(frozen=True)
1524+
class Second(First):
1525+
@property
1526+
def b(self):
1527+
return self.a
1528+
1529+
class My(BaseModel):
1530+
my: Second
1531+
1532+
assert My(my=Second(a='1')).my.b == 1
1533+
1534+
1535+
def test_empty_dataclass():
1536+
"""should be able to inherit without adding a field"""
1537+
1538+
@dataclasses.dataclass
1539+
class UnvalidatedDataclass:
1540+
a: int = 0
1541+
1542+
@pydantic.dataclasses.dataclass
1543+
class ValidatedDerivedA(UnvalidatedDataclass):
1544+
...
1545+
1546+
@pydantic.dataclasses.dataclass()
1547+
class ValidatedDerivedB(UnvalidatedDataclass):
1548+
b: int = 0
1549+
1550+
@pydantic.dataclasses.dataclass()
1551+
class ValidatedDerivedC(UnvalidatedDataclass):
1552+
...
1553+
1554+
1555+
def test_proxy_dataclass():
1556+
@dataclasses.dataclass
1557+
class Foo:
1558+
a: Optional[int] = dataclasses.field(default=42)
1559+
b: List = dataclasses.field(default_factory=list)
1560+
1561+
@dataclasses.dataclass
1562+
class Bar:
1563+
pass
1564+
1565+
@dataclasses.dataclass
1566+
class Model1:
1567+
foo: Foo
1568+
1569+
class Model2(BaseModel):
1570+
foo: Foo
1571+
1572+
m1 = Model1(foo=Foo())
1573+
m2 = Model2(foo=Foo())
1574+
1575+
assert m1.foo.a == m2.foo.a == 42
1576+
assert m1.foo.b == m2.foo.b == []
1577+
assert m1.foo.Bar() is not None
1578+
assert m2.foo.Bar() is not None
1579+
1580+
1581+
def test_proxy_dataclass_2():
1582+
@dataclasses.dataclass
1583+
class M1:
1584+
a: int
1585+
b: str = 'b'
1586+
c: float = dataclasses.field(init=False)
1587+
1588+
def __post_init__(self):
1589+
self.c = float(self.a)
1590+
1591+
@dataclasses.dataclass
1592+
class M2:
1593+
a: int
1594+
b: str = 'b'
1595+
c: float = dataclasses.field(init=False)
1596+
1597+
def __post_init__(self):
1598+
self.c = float(self.a)
1599+
1600+
@pydantic.validator('b')
1601+
def check_b(cls, v):
1602+
if not v:
1603+
raise ValueError('b should not be empty')
1604+
return v
1605+
1606+
m1 = pydantic.parse_obj_as(M1, {'a': 3})
1607+
m2 = pydantic.parse_obj_as(M2, {'a': 3})
1608+
assert m1.a == m2.a == 3
1609+
assert m1.b == m2.b == 'b'
1610+
assert m1.c == m2.c == 3.0

0 commit comments

Comments
 (0)
0