8000 Make sure Pydantic dataclasses with slots and `validate_assignment` a… · pydantic/pydantic@d174f8b · GitHub
[go: up one dir, main page]

Skip to content

Commit d174f8b

Browse files
committed
Make sure Pydantic dataclasses with slots and validate_assignment are unpickable
1 parent a3f42a4 commit d174f8b

File tree

3 files changed

+49
-12
lines changed

3 files changed

+49
-12
lines changed

pydantic/_internal/_dataclasses.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import dataclasses
66
import typing
77
import warnings
8-
from functools import partial, wraps
8+
from functools import partial
99
from typing import Any, ClassVar
1010

1111
from pydantic_core import (
@@ -178,22 +178,12 @@ def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -
178178
# We are about to set all the remaining required properties expected for this cast;
179179
# __pydantic_decorators__ and __pydantic_fields__ should already be set
180180
cls = typing.cast('type[PydanticDataclass]', cls)
181-
# debug(schema)
182181

183182
cls.__pydantic_core_schema__ = schema
184-
cls.__pydantic_validator__ = validator = create_schema_validator(
183+
cls.__pydantic_validator__ = create_schema_validator(
185184
schema, cls, cls.__module__, cls.__qualname__, 'dataclass', core_config, config_wrapper.plugin_settings
186185
)
187186
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
188-
189-
if config_wrapper.validate_assignment:
190-
191-
@wraps(cls.__setattr__)
192-
def validated_setattr(instance: Any, field: str, value: str, /) -> None:
193-
validator.validate_assignment(instance, field, value)
194-
195-
cls.__setattr__ = validated_setattr.__get__(None, cls) # type: ignore
196-
197187
cls.__pydantic_complete__ = True
198188
return True
199189

pydantic/dataclasses.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations as _annotations
44

55
import dataclasses
6+
import functools
67
import sys
78
import types
89
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, NoReturn, TypeVar, overload
@@ -264,6 +265,35 @@ def create_dataclass(cls: type[Any]) -> type[PydanticDataclass]:
264265
**kwargs,
265266
)
266267

268+
if config_wrapper.validate_assignment:
269+
270+
@functools.wraps(cls.__setattr__)
271+
def validated_setattr(instance: Any, field: str, value: str, /) -> None:
272+
type(instance).__pydantic_validator__.validate_assignment(instance, field, value)
273+
274+
cls.__setattr__ = validated_setattr.__get__(None, cls) # type: ignore
275+
276+
if slots and not hasattr(cls, '__setstate__'):
277+
# If slots is set, `pickle` (relied on by `copy.copy()`) will use
278+
# `__setattr__()` to reconstruct the dataclass. However, the custom
279+
# `__setattr__()` set above relies on `validate_assignment()`, which
280+
# in turn excepts all the field values to be already present on the
281+
# instance, resulting in attribute errors.
282+
# As such, we make use of `object.__setattr__()` instead.
283+
# Note that we do so only if `__setstate__()` isn't already set (this is the
284+
# case if on top of `slots`, `frozen` is used).
285+
286+
# Taken from `dataclasses._dataclass_get/setstate()`:
287+
def _dataclass_getstate(self: Any) -> list[Any]:
288+
return [getattr(self, f.name) for f in dataclasses.fields(self)]
289+
290+
def _dataclass_setstate(self: Any, state: list[Any]) -> None:
291+
for field, value in zip(dataclasses.fields(self), state):
292+
object.__setattr__(self, field.name, value)
293+
294+
cls.__getstate__ = _dataclass_getstate # pyright: ignore[reportAttributeAccessIssue]
295+
cls.__setstate__ = _dataclass_setstate # pyright: ignore[reportAttributeAccessIssue]
296+
267297
# This is an undocumented attribute to distinguish stdlib/Pydantic dataclasses.
268298
# It should be set as early as possible:
269299
cls.__is_pydantic_dataclass__ = True

tests/test_dataclasses.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2430,6 +2430,23 @@ class Model:
24302430
assert dc.b == 'bar'
24312431

24322432

2433+
# Must be defined at the module level to be pickable:
2434+
@pydantic.dataclasses.dataclass(slots=True, config={'validate_assignment': True})
2435+
class DataclassSlotsValidateAssignment:
2436+
a: int
2437+
2438+
2439+
@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python >= 3.10')
2440+
def test_dataclass_slots_validate_assignment():
2441+
"""https://github.com/pydantic/pydantic/issues/11768"""
2442+
2443+
m = DataclassSlotsValidateAssignment(1)
2444+
m_pickle = pickle.loads(pickle.dumps(m))
2445+
assert m_pickle.a == 1
2446+
with pytest.raises(ValidationError):
2447+
m.a = 'not_an_int'
2448+
2449+
24332450
@pytest.mark.parametrize(
24342451
'dataclass_decorator',
24352452
[

0 commit comments

Comments
 (0)
0