8000 Raise `AttributeError` on attempts to access unset `oneof` fields · danielgtaylor/python-betterproto@774b3c2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 774b3c2

Browse files
committed
Raise AttributeError on attempts to access unset oneof fields
This commit modifies `Message.__getattribute__` to raise `AttributeError` whenever an attempt is made to access an unset `oneof` field. This provides several benefits over the current approach: * There is no longer any risk of `betterproto` users accidentally relying on values of unset fields. * Pattern matching with `match/case` on messages containing `oneof` groups is now supported. The following is now possible: ``` @dataclasses.dataclass(eq=Fals 8000 e, repr=False) class Test(betterproto.Message): x: int = betterproto.int32_field(1, group="g") y: str = betterproto.string_field(2, group="g") match Test(y="text"): case Test(x=v): print("x", v) case Test(y=v): print("y", v) ``` Before this commit the code above would output `x 0` instead of `y text`, but now the output is `y text` as expected. The reason this works is because an `AttributeError` in a `case` pattern does not propagate and instead simply skips the `case`. * We now have a type-checkable way to deconstruct `oneof`. When running `mypy` for the snippet above `v` has type `int` in the first `case` and type `str` in the second `case`. For versions of Python that do not support `match/case` (before 3.10) it is now possbile to use `try/except/else` blocks to achieve the same result: ``` t = Test(y="text") try: v0: int = t.x except AttributeError: v1: str = t.y # `oneof` contains `y` else: pass # `oneof` contains `x` ``` This is a breaking change.
1 parent 098989e commit 774b3c2

File tree

4 files changed

+45
-22
lines changed

4 files changed

+45
-22
lines changed

src/betterproto/__init__.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -693,8 +693,21 @@ def __repr__(self) -> str:
693693
def __getattribute__(self, name: str) -> Any:
694694
"""
695695
Lazily initialize default values to avoid infinite recursion for recursive
696-
message types
696+
message types.
697+
Raise `AttributeError` on attempts to access unset `oneof` fields.
697698
"""
699+
try:
700+
group_current = super().__getattribute__("_group_current")
701+
except AttributeError:
702+
pass
703+
else:
704+
if name not in {"__class__", "_betterproto"}:
705+
group = self._betterproto.oneof_group_by_field.get(name)
706+
if group is not None and group_current[group] != name:
707+
raise AttributeError(
708+
f"'{self.__class__.__name__}.{group}' is set to '{group_current[group]}', not '{name}'"
709+
)
710+
698711
value = super().__getattribute__(name)
699712
if value is not PLACEHOLDER:
700713
return value
@@ -761,7 +774,10 @@ def __bytes__(self) -> bytes:
761774
"""
762775
output = bytearray()
763776
for field_name, meta in self._betterproto.meta_by_field_name.items():
764-
value = getattr(self, field_name)
777+
try:
778+
value = getattr(self, field_name)
779+
except AttributeError:
780+
continue
765781

766782
if value is None:
767783
# Optional items should be skipped. This is used for the Google
@@ -775,9 +791,7 @@ def __bytes__(self) -> bytes:
775791
# Note that proto3 field presence/optional fields are put in a
776792
# synthetic single-item oneof by protoc, which helps us ensure we
777793
# send the value even if the value is the default zero value.
778-
selected_in_group = (
779-
meta.group and self._group_current[meta.group] == field_name
780-
)
794+
selected_in_group = bool(meta.group)
781795

782796
# Empty messages can still be sent on the wire if they were
783797
# set (or received empty).
@@ -1016,7 +1030,12 @@ def parse(self: T, data: bytes) -> T:
10161030
parsed.wire_type, meta, field_name, parsed.value
10171031
)
10181032

1019-
current = getattr(self, field_name)
1033+
try:
1034+
current = getattr(self, field_name)
1035+
except AttributeError:
1036+
current = self._get_field_default(field_name)
1037+
setattr(self, field_name, current)
1038+
10201039
if meta.proto_type == TYPE_MAP:
10211040
# Value represents a single key/value pair entry in the map.
10221041
current[value.key] = value.value
@@ -1077,7 +1096,10 @@ def to_dict(
10771096
defaults = self._betterproto.default_gen
10781097
for field_name, meta in self._betterproto.meta_by_field_name.items():
10791098
field_is_repeated = defaults[field_name] is list
1080-
value = getattr(self, field_name)
1099+
try:
1100+
value = getattr(self, field_name)
1101+
except AttributeError:
1102+
value = self._get_field_default(field_name)
10811103
cased_name = casing(field_name).rstrip("_") # type: ignore
10821104
if meta.proto_type == TYPE_MESSAGE:
10831105
if isinstance(value, datetime):
@@ -1209,7 +1231,7 @@ def from_dict(self: T, value: Mapping[str, Any]) -> T:
12091231

12101232
if value[key] is not None:
12111233
if meta.proto_type == TYPE_MESSAGE:
1212-
v = getattr(self, field_name)
1234+
v = self._get_field_default(field_name)
12131235
cls = self._betterproto.cls_by_field[field_name]
12141236
if isinstance(v, list):
12151237
if cls == datetime:
@@ -1486,7 +1508,6 @@ def _validate_field_groups(cls, values):
14861508
field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore
14871509

14881510
for group, field_set in group_to_one_ofs.items():
1489-
14901511
if len(field_set) == 1:
14911512
(field,) = field_set
14921513
field_name = field.name

tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ def test_bytes_are_the_same_for_oneof():
5050

5151
# None of these fields were explicitly set BUT they should not actually be null
5252
# themselves
53-
assert isinstance(message.foo, Foo)
54-
assert isinstance(message2.foo, Foo)
53+
assert not hasattr(message, "foo")
54+
assert object.__getattribute__(message, "foo") == betterproto.PLACEHOLDER
55+
assert not hasattr(message2, "foo")
56+
assert object.__getattribute__(message2, "foo") == betterproto.PLACEHOLDER
5557

5658
assert isinstance(message_reference.foo, ReferenceFoo)
5759
assert isinstance(message_reference2.foo, ReferenceFoo)

tests/inputs/oneof_enum/test_oneof_enum.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@ def test_which_one_of_returns_enum_with_default_value():
1818
get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json
1919
)
2020

21-
assert message.move == Move(
22-
x=0, y=0
23-
) # Proto3 will default this as there is no null
21+
assert not hasattr(message, "move")
22+
assert object.__getattribute__(message, "move") == betterproto.PLACEHOLDER
2423
assert message.signal == Signal.PASS
2524
assert betterproto.which_one_of(message, "action") == ("signal", Signal.PASS)
2625

@@ -33,9 +32,8 @@ def test_which_one_of_returns_enum_with_non_default_value():
3332
message.from_json(
3433
get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json
3534
)
36-
assert message.move == Move(
37-
x=0, y=0
38-
) # Proto3 will default this as there is no null
35+
assert not hasattr(message, "move")
36+
assert object.__getattribute__(message, "move") == betterproto.PLACEHOLDER
3937
assert message.signal == Signal.RESIGN
4038
assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN)
4139

@@ -44,5 +42,6 @@ def test_which_one_of_returns_second_field_when_set():
4442
message = Test()
4543
message.from_json(get_test_case_json_data("oneof_enum")[0].json)
4644
assert message.move == Move(x=2, y=3)
47-
assert message.signal == Signal.PASS
45+
assert not hasattr(message, "signal")
46+
assert object.__getattribute__(message, "signal") == betterproto.PLACEHOLDER
4847
assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3))

tests/test_features.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,17 +151,18 @@ class Foo(betterproto.Message):
151151
foo.baz = "test"
152152

153153
# Other oneof fields should now be unset
154-
assert foo.bar == 0
154+
assert not hasattr(foo, "bar")
155+
assert object.__getattribute__(foo, "bar") == betterproto.PLACEHOLDER
155156
assert betterproto.which_one_of(foo, "group1")[0] == "baz"
156157

157-
foo.sub.val = 1
158+
foo.sub = Sub(val=1)
158159
assert betterproto.serialized_on_wire(foo.sub)
159160

160161
foo.abc = "test"
161162

162163
# Group 1 shouldn't be touched, group 2 should have reset
163-
assert foo.sub.val == 0
164-
assert betterproto.serialized_on_wire(foo.sub) is False
164+
assert not hasattr(foo, "sub")
165+
assert object.__geta 49DC ttribute__(foo, "sub") == betterproto.PLACEHOLDER
165166
assert betterproto.which_one_of(foo, "group2")[0] == "abc"
166167

167168
# Zero value should always serialize for one-of

0 commit comments

Comments
 (0)
0