10000 Make sure Pydantic dataclasses with slots and `validate_assignment` a… · pydantic/pydantic@63c9727 · GitHub
[go: up one dir, main page]

Skip to content

Commit 63c9727

Browse files
committed
Make sure Pydantic dataclasses with slots and validate_assignment are unpickable
1 parent a3f42a4 commit 63c9727

File tree

3 files changed

+41
-1
lines changed

3 files changed

+41
-1
lines changed

pydantic/_internal/_dataclasses.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def complete_dataclass(
8585
cls: type[Any],
8686
config_wrapper: _config.ConfigWrapper,
8787
*,
88+
slots: bool,
8889
raise_errors: bool = True,
8990
ns_resolver: NsResolver | None = None,
9091
_force_build: bool = False,
@@ -98,6 +99,7 @@ def complete_dataclass(
9899
Args:
99100
cls: The class.
100101
config_wrapper: The config wrapper instance.
102+
slots: Whether slots was set on the class.
101103
raise_errors: Whether to raise errors, defaults to `True`.
102104
ns_resolver: The namespace resolver instance to use when collecting dataclass fields
103105
and during schema building.
@@ -194,6 +196,27 @@ def validated_setattr(instance: Any, field: str, value: str, /) -> None:
194196

195197
cls.__setattr__ = validated_setattr.__get__(None, cls) # type: ignore
196198

199+
if slots and not hasattr(cls, '__setstate__'):
200+
# If slots is set, `pickle` (relied on by `copy.copy()`) will use
201+
# `__setattr__()` to reconstruct the dataclass. However, the custom
202+
# `__setattr__()` set above relies on `validate_assignment()`, which
203+
# in turn excepts all the field values to be already present on the
204+
# instance, resulting in attribute errors.
205+
# As such, we make use of `object.__setattr__()` instead.
206+
# Note that we do so only if `__setstate__()` isn't already set (this is the
207+
# case if on top of `slots`, `frozen` is used).
208+
209+
# Taken from `dataclasses._dataclass_get/setstate()`:
210+
def _dataclass_getstate(self: Any) -> list[Any]:
211+
return [getattr(self, f.name) for f in dataclasses.fields(self)]
212+
213+
def _dataclass_setstate(self: Any, state: list[Any]) -> None:
214+
for field, value in zip(dataclasses.fields(self), state):
215+
object.__setattr__(self, field.name, value)
216+
217+
cls.__getstate__ = _dataclass_getstate # pyright: ignore[reportAttributeAccessIssue]
218+
cls.__setstate__ = _dataclass_setstate
219+
197220
cls.__pydantic_complete__ = True
198221
return True
199222

pydantic/dataclasses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def create_dataclass(cls: type[Any]) -> type[PydanticDataclass]:
276276
# TODO `parent_namespace` is currently None, but we could do the same thing as Pydantic models:
277277
# fetch the parent ns using `parent_frame_namespace` (if the dataclass was defined in a function),
278278
# and possibly cache it (see the `__pydantic_parent_namespace__` logic for models).
279-
_pydantic_dataclasses.complete_dataclass(cls, config_wrapper, raise_errors=False)
279+
_pydantic_dataclasses.complete_dataclass(cls, config_wrapper, slots=slots, raise_errors=False)
280280
return cls
281281

282282
return create_dataclass if _cls is None else create_dataclass(_cls)

tests/test_dataclasses.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2430,6 +2430,23 @@ class Model:
24302430
assert dc.b == 'bar'
24312431

24322432

2433+
# Must be defined at the module level to be pickable:
2434+
@pydantic.dataclasses.dataclass(slots=True, config={'validate_assignment': True})
2435+
class DataclassSlotsValidateAssignment:
2436+
a: int
2437+
2438+
2439+
@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python >= 3.10')
2440+
def test_dataclass_slots_validate_assignment():
2441+
"""https://github.com/pydantic/pydantic/issues/11768"""
2442+
2443+
m = DataclassSlotsValidateAssignment(1)
2444+
m_pickle = pickle.loads(pickle.dumps(m))
2445+
assert m_pickle.a == 1
2446+
with pytest.raises(ValidationError):
2447+
m.a = 'not_an_int'
2448+
2449+
24332450
@pytest.mark.parametrize(
24342451
'dataclass_decorator',
24352452
[

0 commit comments

Comments
 (0)
0