8000 Implement __deepcopy__ for Message (#339) · danielgtaylor/python-betterproto@74205e3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 74205e3

Browse files
authored
Implement __deepcopy__ for Message (#339)
1 parent 3f377e3 commit 74205e3

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

src/betterproto/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import typing
99
from abc import ABC
1010
from base64 import b64decode, b64encode
11+
from copy import deepcopy
1112
from datetime import datetime, timedelta, timezone
1213
from dateutil.parser import isoparse
1314
from typing import (
@@ -717,6 +718,14 @@ def __bool__(self) -> bool:
717718
for field_name in self._betterproto.meta_by_field_name
718719
)
719720

721+
def __deepcopy__(self: T, _: Any = {}) -> T:
722+
kwargs = {}
723+
for name in self._betterproto.sorted_field_names:
724+
value = self.__raw_get(name)
725+
if value is not PLACEHOLDER:
726+
kwargs[name] = deepcopy(value)
727+
return self.__class__(**kwargs) # type: ignore
728+
720729
@property
721730
def _betterproto(self) -> ProtoClassMetadata:
722731
"""

tests/test_features.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass
2+
from copy import copy, deepcopy
23
from datetime import datetime
34
from inspect import Parameter, signature
45
from typing import Dict, List, Optional
@@ -485,3 +486,22 @@ def test_service_argument__expected_parameter():
485486
do_thing_request_parameter = sig.parameters["do_thing_request"]
486487
assert do_thing_request_parameter.default is Parameter.empty
487488
assert do_thing_request_parameter.annotation == "DoThingRequest"
489+
490+
491+
def test_copyability():
492+
@dataclass
493+
class Spam(betterproto.Message):
494+
foo: bool = betterproto.bool_field(1)
495+
bar: int = betterproto.int32_field(2)
496+
baz: List[str] = betterproto.string_field(3)
497+
498+
spam = Spam(bar=12, baz=["hello"])
499+
copied = copy(spam)
500+
assert spam == copied
501+
assert spam is not copied
502+
assert spam.baz is copied.baz
503+
504+
deepcopied = deepcopy(spam)
505+
assert spam == deepcopied
506+
assert spam is not deepcopied
507+
assert spam.baz is not deepcopied.baz

0 commit comments

Comments
 (0)
0