8000 Fix parameters missing from services (#381) · arpitjain799/python-betterproto@3fd5a0d · GitHub
[go: up one dir, main page]

Skip to content

Commit 3fd5a0d

Browse files
authored
Fix parameters missing from services (danielgtaylor#381)
1 parent bc13e70 commit 3fd5a0d

File tree

9 files changed

+136
-40
lines changed

9 files changed

+136
-40
lines changed

src/betterproto/__init__.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -379,15 +379,10 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
379379
elif proto_type == TYPE_MESSAGE:
380380
if isinstance(value, datetime):
381381
# Convert the `datetime` to a timestamp message.
382-
seconds = int(value.timestamp())
383-
nanos = int(value.microsecond * 1e3)
384-
value = _Timestamp(seconds=seconds, nanos=nanos)
382+
value = _Timestamp.from_datetime(value)
385383
elif isinstance(value, timedelta):
386384
# Convert the `timedelta` to a duration message.
387-
total_ms = value // timedelta(microseconds=1)
388-
seconds = int(total_ms / 1e6)
389-
nanos = int((total_ms % 1e6) * 1e3)
390-
value = _Duration(seconds=seconds, nanos=nanos)
385+
value = _Duration.from_timedelta(value)
391386
elif wraps:
392387
if value is None:
393388
return b""
@@ -1505,6 +1500,15 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]
15051500

15061501

15071502
class _Duration(Duration):
1503+
@classmethod
1504+
def from_timedelta(
1505+
cls, delta: timedelta, *, _1_microsecond: timedelta = timedelta(microseconds=1)
1506+
) -> "_Duration":
1507+
total_ms = delta // _1_microsecond
1508+
seconds = int(total_ms / 1e6)
1509+
nanos = int((total_ms % 1e6) * 1e3)
1510+
return cls(seconds, nanos)
1511+
15081512
def to_timedelta(self) -> timedelta:
15091513
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
15101514

@@ -1518,6 +1522,12 @@ def delta_to_json(delta: timedelta) -> str:
15181522

15191523

15201524
class _Timestamp(Timestamp):
1525+
@classmethod
1526+
def from_datetime(cls, dt: datetime) -> "_Timestamp":
1527+
seconds = int(dt.timestamp())
1528+
nanos = int(dt.microsecond * 1e3)
1529+
return cls(seconds, nanos)
1530+
15211531
def to_datetime(self) -> datetime:
15221532
ts = self.seconds + (self.nanos / 1e9)
15231533
return datetime.fromtimestamp(ts, tz=timezone.utc)

src/betterproto/compile/importing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def parse_source_type_name(field_type_name: str) -> Tuple[str, str]:
4343

4444

4545
def get_type_reference(
46-
package: str, imports: set, source_type: str, unwrap: bool = True
46+
*, package: str, imports: set, source_type: str, unwrap: bool = True
4747
) -> str:
4848
"""
4949
Return a Python type name for a proto type reference. Adds the import if

src/betterproto/grpc/grpclib_client.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,22 @@
1515

1616
import grpclib.const
1717

18-
from .._types import (
19-
ST,
20-
T,
21-
)
22-
2318

2419
if TYPE_CHECKING:
2520
from grpclib.client import Channel
2621
from grpclib.metadata import Deadline
2722

23+
from .._types import (
24+
ST,
25+
IProtoMessage,
26+
Message,
27+
T,
28+
)
29+
2830

2931
Value = Union[str, bytes]
3032
MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]]
31-
MessageLike = Union[T, ST]
32-
MessageSource = Union[Iterable[ST], AsyncIterable[ST]]
33+
MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]
3334

3435

3536
class ServiceStub(ABC):
@@ -65,13 +66,13 @@ def __resolve_request_kwargs(
6566
async def _unary_unary(
6667
self,
6768
route: str,
68-
request: MessageLike,
69-
response_type: Type[T],
69+
request: "IProtoMessage",
70+
response_type: Type["T"],
7071
*,
7172
timeout: Optional[float] = None,
7273
deadline: Optional["Deadline"] = None,
7374
metadata: Optional[MetadataLike] = None,
74-
) -> T:
75+
) -> "T":
7576
"""Make a unary request and return the response."""
7677
async with self.channel.request(
7778
route,
@@ -88,13 +89,13 @@ async def _unary_unary(
8889
async def _unary_stream(
8990
self,
9091
route: str,
91-
request: MessageLike,
92-
response_type: Type[T],
92+
request: "IProtoMessage",
93+
response_type: Type["T"],
9394
*,
9495
timeout: Optional[float] = None,
9596
deadline: Optional["Deadline"] = None,
9697
metadata: Optional[MetadataLike] = None,
97-
) -> AsyncIterator[T]:
98+
) -> AsyncIterator["T"]:
9899
"""Make a unary request and return the stream response iterator."""
99100
async with self.channel.request(
100101
route,
@@ -111,13 +112,13 @@ async def _stream_unary(
111112
self,
112113
route: str,
113114
request_iterator: MessageSource,
114-
request_type: Type[ST],
115-
response_type: Type[T],
115+
request_type: Type["IProtoMessage"],
116+
response_type: Type["T"],
116117
*,
117118
timeout: Optional[float] = None,
118119
deadline: Optional["Deadline"] = None,
119120
metadata: Optional[MetadataLike] = None,
120-
) -> T:
121+
) -> "T":
121122
"""Make a stream request and return the response."""
122123
async with self.channel.request(
123124
route,
@@ -135,13 +136,13 @@ async def _stream_stream(
135136
self,
136137
route: str,
137138
request_iterator: MessageSource,
138-
request_type: Type[ST],
139-
response_type: Type[T],
139+
request_type: Type["IProtoMessage"],
140+
response_type: Type["T"],
140141
*,
141142
timeout: Optional[float] = None,
142143
deadline: Optional["Deadline"] = None,
143144
metadata: Optional[MetadataLike] = None,
144-
) -> AsyncIterator[T]:
145+
) -> AsyncIterator["T"]:
145146
"""
146147
Make a stream request and return an AsyncIterator to iterate over response
147148
messages.

src/betterproto/plugin/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ class OutputTemplate:
252252
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
253253
services: List["ServiceCompiler"] = field(default_factory=list)
254254
imports_type_checking_only: Set[str] = field(default_factory=set)
255+
output: bool = True
255256

256257
@property
257258
def package(self) -> str:
@@ -704,6 +705,7 @@ def __post_init__(self) -> None:
704705

705706
# add imports required for request arguments timeout, deadline and metadata
706707
self.output_file.typing_imports.add("Optional")
708+
self.output_file.imports_type_checking_only.add("import grpclib.server")
707709
self.output_file.imports_type_checking_only.add(
708710
"from betterproto.grpc.grpclib_client import MetadataLike"
709711
)
@@ -768,6 +770,7 @@ def py_input_message_type(self) -> str:
768770
package=self.output_file.package,
769771
imports=self.output_file.imports,
770772
source_type=self.proto_obj.input_type,
773+
unwrap=False,
771774
).strip('"')
772775

773776
@property

src/betterproto/plugin/parser.py

Lines changed: 10 additions & 8 deletions
< 1241 td data-grid-cell-id="diff-48441cfc482b1f36a706d17227283805b1ff1bcffa0b258c1d6a4f2065b9cddf-92-84-0" data-selected="false" role="gridcell" style="background-color:var(--bgColor-default);text-align:center" tabindex="-1" valign="top" class="focusable-grid-cell diff-line-number position-relative diff-line-number-neutral left-side">92
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,6 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
7474
request_data = PluginRequestCompiler(plugin_request_obj=request)
7575
# Gather output packages
7676
for proto_file in request.proto_file:
77-
if (
78-
proto_file.package == "google.protobuf"
79-
and "INCLUDE_GOOGLE" not in plugin_options
80-
):
81-
# If not INCLUDE_GOOGLE,
82-
# skip re-compiling Google's well-known types
83-
continue
84-
8577
output_package_name = proto_file.package
8678
if output_package_name not in request_data.output_packages:
8779
# Create a new output if there is no output for this package
@@ -91,6 +83,14 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
9183
# Add this input file to the output corresponding to this package
84
request_data.output_packages[output_package_name].input_files.append(proto_file)
9385

86+
if (
87+
proto_file.package == "google.protobuf"
88+
and "INCLUDE_GOOGLE" not in plugin_options
89+
):
90+
# If not INCLUDE_GOOGLE,
91+
# skip outputting Google's well-known types
92+
request_data.output_packages[output_package_name].output = False
93+
9494
# Read Messages and Enums
9595
# We need to read Messages before Services in so that we can
9696
# get the references to input/output messages for each service
@@ -113,6 +113,8 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
113113
# Generate output files
114114
output_paths: Set[pathlib.Path] = set()
115115
for output_package_name, output_package in request_data.output_packages.items():
116+
if not output_package.output:
117+
continue
116118

117119
# Add files to the response object
118120
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")

src/betterproto/templates/template.py.j2

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no
1515
{% endif %}
1616

1717
import betterproto
18+
{% if output_file.services %}
1819
from betterproto.grpc.grpclib_server import ServiceBase
20+
import grpclib
21+
{% endif %}
22+
1923
{% for i in output_file.imports|sort %}
2024
{{ i }}
2125
{% endfor %}
22-
{% if output_file.services %}
23-
import grpclib
24-
{% endif %}
2526

2627
{% if output_file.imports_type_checking_only %}
2728
from typing import TYPE_CHECKING
@@ -96,9 +97,11 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
9697
{# Client streaming: need a request iterator instead #}
9798
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
9899
{%- endif -%}
100+
,
101+
*
99102
, timeout: Optional[float] = None
100103
, deadline: Optional["Deadline"] = None
101-
, metadata: Optional["_MetadataLike"] = None
104+
, metadata: Optional["MetadataLike"] = None
102105
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
103106
{% if method.comment %}
104107
{{ method.comment }}
@@ -179,7 +182,7 @@ class {{ service.py_name }}Base(ServiceBase):
179182
{% endfor %}
180183

181184
{% for method in service.methods %}
182-
async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None:
185+
async def __rpc_{{ method.py_name }}(self, stream: "grpclib.server.Stream[{{ method.py_input_message_type }}, {{ method.py_output_message_type }}]") -> None:
183186
{% if not method.client_streaming %}
184187
request = await stream.recv_message()
185188
{% else %}

tests/inputs/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
}
1010

1111
services = {
12+
"googletypes_request",
1213
"googletypes_response",
1314
"googletypes_response_embedded",
1415
"service",
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
syntax = "proto3";
2+
3+
package googletypes_request;
4+
5+
import "google/protobuf/duration.proto";
6+
import "google/protobuf/empty.proto";
7+
import "google/protobuf/timestamp.proto";
8+
import "google/protobuf/wrappers.proto";
9+
10+
// Tests that google types can be used as params
11+
12+
service Test {
13+
rpc SendDouble (google.protobuf.DoubleValue) returns (Input);
14+
rpc SendFloat (google.protobuf.FloatValue) returns (Input);
15+
rpc SendInt64 (google.protobuf.Int64Value) returns (Input);
16+
rpc SendUInt64 (google.protobuf.UInt64Value) returns (Input);
17+
rpc SendInt32 (google.protobuf.Int32Value) returns (Input);
18+
rpc SendUInt32 (google.protobuf.UInt32Value) returns (Input);
19+
rpc SendBool (google.protobuf.BoolValue) returns (Input);
20+
rpc SendString (google.protobuf.StringValue) returns (Input);
21+
rpc SendBytes (google.protobuf.BytesValue) returns (Input);
22+
rpc SendDatetime (google.protobuf.Timestamp) returns (Input);
23+
rpc SendTimedelta (google.protobuf.Duration) returns (Input);
24+
rpc SendEmpty (google.protobuf.Empty) returns (Input);
25+
}
26+
27+
message Input {
28+
29+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from datetime import (
2+
datetime,
3+
timedelta,
4+
)
5+
from typing import (
6+
Any,
7+
Callable,
8+
)
9+
10+
import pytest
11+
12+
import betterproto.lib.google.protobuf as protobuf
13+
from tests.mocks import MockChannel
14+
from tests.output_betterproto.googletypes_request import (
15+
Input,
16+
TestStub,
17+
)
18+
19+
20+
test_cases = [
21+
(TestStub.send_double, protobuf.DoubleValue, 2.5),
22+
(TestStub.send_float, protobuf.FloatValue, 2.5),
23+
(TestStub.send_int64, protobuf.Int64Value, -64),
24+
(TestStub.send_u_int64, protobuf.UInt64Value, 64),
25+
(TestStub.send_int32, protobuf.Int32Value, -32),
26+
(TestStub.send_u_int32, protobuf.UInt32Value, 32),
27+
(TestStub.send_bool, protobuf.BoolValue, True),
28+
(TestStub.send_string, protobuf.StringValue, "string"),
29+
(TestStub.send_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]),
30+
(TestStub.send_datetime, protobuf.Timestamp, datetime(2038, 1, 19, 3, 14, 8)),
31+
(TestStub.send_timedelta, protobuf.Duration, timedelta(seconds=123456)),
32+
]
33+
34+
35+
@pytest.mark.asyncio
36+
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
37+
async def test_channel_receives_wrapped_type(
38+
service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
39+
):
40+
wrapped_value = wrapper_class()
41+
wrapped_value.value = value
42+
channel = MockChannel(responses=[Input()])
43+
service = TestStub(channel)
44+
45+
await service_method(service, wrapped_value)
46+
47+
assert channel.requests[0]["request"] == type(wrapped_value)

0 commit comments

Comments
 (0)
0