8000 Generate grpclib service stubs (#170) · kilimnik/python-betterproto@1d54ef8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1d54ef8

Browse files
authored
Generate grpclib service stubs (danielgtaylor#170)
1 parent 73cea12 commit 1d54ef8

File tree

6 files changed

+232
-1
lines changed

6 files changed

+232
-1
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from abc import ABC
2+
from collections import AsyncIterable
3+
from typing import Callable, Any, Dict
4+
5+
import grpclib
6+
import grpclib.server
7+
8+
9+
class ServiceBase(ABC):
10+
"""
11+
Base class for async gRPC servers.
12+
"""
13+
14+
async def _call_rpc_handler_server_stream(
15+
self,
16+
handler: Callable,
17+
stream: grpclib.server.Stream,
18+
request_kwargs: Dict[str, Any],
19+
) -> None:
20+
21+
response_iter = handler(**request_kwargs)
22+
# check if response is actually an AsyncIterator
23+
# this might be false if the method just returns without
24+
# yielding at least once
25+
# in that case, we just interpret it as an empty iterator
26+
if isinstance(response_iter, AsyncIterable):
27+
async for response_message in response_iter:
28+
await stream.send_message(response_message)
29+
else:
30+
response_iter.close()

src/betterproto/plugin/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,7 @@ class ServiceCompiler(ProtoContentBase):
553553
def __post_init__(self) -> None:
554554
# Add service to output file
555555
self.output_file.services.append(self)
556+
self.output_file.typing_imports.add("Dict")
556557
super().__post_init__() # check for unset fields
557558

558559
@property

src/betterproto/templates/template.py.j2

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no
1515
{% endif %}
1616

1717
import betterproto
18+
from betterproto.grpc.grpclib_server import ServiceBase
1819
{% if output_file.services %}
1920
import grpclib
2021
{% endif %}
@@ -82,7 +83,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
8283
Optional[{{ field.annotation }}]
8384
{%- else -%}
8485
{{ field.annotation }}
85-
{%- endif -%} =
86+
{%- endif -%} =
8687
{%- if field.py_name not in method.mutable_default_args -%}
8788
{{ field.default_value_string }}
8889
{%- else -%}
@@ -154,6 +155,89 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
154155
{% endfor %}
155156
{% endfor %}
156157

158+
{% for service in output_file.services %}
159+
class {{ service.py_name }}Base(ServiceBase):
160+
{% if service.comment %}
161+
{{ service.comment }}
162+
163+
{% endif %}
164+
165+
{% for method in service.methods %}
166+
async def {{ method.py_name }}(self
167+
{%- if not method.client_streaming -%}
168+
{%- if method.py_input_message and method.py_input_message.fields -%},
169+
{%- for field in method.py_input_message.fields -%}
170+
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
171+
Optional[{{ field.annotation }}]
172+
{%- else -%}
173+
{{ field.annotation }}
174+
{%- endif -%}
175+
{%- if not loop.last %}, {% endif -%}
176+
{%- endfor -%}
177+
{%- endif -%}
178+
{%- else -%}
179+
{# Client streaming: need a request iterator instead #}
180+
, request_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
181+
{%- endif -%}
182+
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
183+
{% if method.comment %}
184+
{{ method.comment }}
185+
186+
{% endif %}
187+
raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED)
188+
189+
{% endfor %}
190+
191+
{% for method in service.methods %}
192+
async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None:
193+
{% if not method.client_streaming %}
194+
request = await stream.recv_message()
195+
196+
request_kwargs = {
197+
{% for field in method.py_input_message.fields %}
198+
"{{ field.py_name }}": request.{{ field.py_name }},
199+
{% endfor %}
200+
}
201+
202+
{% else %}
203+
request_kwargs = {"request_iterator": stream.__aiter__()}
204+
{% endif %}
205+
206+
{% if not method.server_streaming %}
207+
response = await self.{{ method.py_name }}(**request_kwargs)
208+
await stream.send_message(response)
209+
{% else %}
210+
await self._call_rpc_handler_server_stream(
211+
self.{{ method.py_name }},
212+
stream,
213+
request_kwargs,
214+
)
215+
{% endif %}
216+
217+
{% endfor %}
218+
219+
def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
220+
return {
221+
{% for method in service.methods %}
222+
"{{ method.route }}": grpclib.const.Handler(
223+
self.__rpc_{{ method.py_name }},
224+
{% if not method.client_streaming and not method.server_streaming %}
225+
grpclib.const.Cardinality.UNARY_UNARY,
226+
{% elif not method.client_streaming and method.server_streaming %}
227+
grpclib.const.Cardinality.UNARY_STREAM,
228+
{% elif method.client_streaming and not method.server_streaming %}
229+
grpclib.const.Cardinality.STREAM_UNARY,
230+
{% else %}
231+
grpclib.const.Cardinality.STREAM_STREAM,
232+
{% endif %}
233+
{{ method.py_input_message_type }},
234+
{{ method.py_output_message_type }},
235+
),
236+
{% endfor %}
237+
}
238+
239+
{% endfor %}
240+
157241
{% for i in output_file.imports|sort %}
158242
{{ i }}
159243
{% endfor %}

tests/inputs/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@
1717
"import_service_input_message",
1818
"googletypes_service_returns_empty",
1919
"googletypes_service_returns_googletype",
20+
"example_service",
2021
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
syntax = "proto3";
2+
3+
package example_service;
4+
5+
service Test {
6+
rpc ExampleUnaryUnary(ExampleRequest) returns (ExampleResponse);
7+
rpc ExampleUnaryStream(ExampleRequest) returns (stream ExampleResponse);
8+
rpc ExampleStreamUnary(stream ExampleRequest) returns (ExampleResponse);
9+
rpc ExampleStreamStream(stream ExampleRequest) returns (stream ExampleResponse);
10+
}
11+
12+
message ExampleRequest {
13+
string example_string = 1;
14+
int64 example_integer = 2;
15+
}
16+
17+
message ExampleResponse {
18+
string example_string = 1;
19+
int64 example_integer = 2;
20+
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from typing import AsyncIterator, AsyncIterable
2+
3+
import pytest
4+
from grpclib.testing import ChannelFor
5+
6+
from tests.output_betterproto.example_service.example_service import (
7+
TestBase,
8+
TestStub,
9+
ExampleRequest,
10+
ExampleResponse,
11+
)
12+
13+
14+
class ExampleService(TestBase):
15+
async def example_unary_unary(
16+
self, example_string: str, example_integer: int
17+
) -> "ExampleResponse":
18+
return ExampleResponse(
19+
example_string=example_string,
20+
example_integer=example_integer,
21+
)
22+
23+
async def example_unary_stream(
24+
self, example_string: str, example_integer: int
25+
) -> AsyncIterator["ExampleResponse"]:
26+
response = ExampleResponse(
27+
example_string=example_string,
28+
example_integer=example_integer,
29+
)
30+
yield response 57AE
31+
yield response
32+
yield response
33+
34+
async def example_stream_unary(
35+
self, request_iterator: AsyncIterator["ExampleRequest"]
36+
) -> "ExampleResponse":
37+
async for example_request in request_iterator:
38+
return ExampleResponse(
39+
example_string=example_request.example_string,
40+
example_integer=example_request.example_integer,
41+
)
42+
43+
async def example_stream_stream(
44+
self, request_iterator: AsyncIterator["ExampleRequest"]
45+
) -> AsyncIterator["ExampleResponse"]:
46+
async for example_request in request_iterator:
47+
yield ExampleResponse(
48+
example_string=example_request.example_string,
49+
example_integer=example_request.example_integer,
50+
)
51+
52+
53+
@pytest.mark.asyncio
54+
async def test_calls_with_different_cardinalities():
55+
test_string = "test string"
56+
test_int = 42
57+
58+
async with ChannelFor([ExampleService()]) as channel:
59+
stub = TestStub(channel)
60+
61+
# unary unary
62+
response = await stub.example_unary_unary(
63+
example_string="test string",
64+
example_integer=42,
65+
)
66+
assert response.example_string == test_string
67+
assert response.example_integer == test_int
68+
69+
# unary stream
70+
async for response in stub.example_unary_stream(
71+
example_string="test string",
72+
example_integer=42,
73+
):
74+
assert response.example_string == test_string
75+
assert response.example_integer == test_int
76+
77+
# stream unary
78+
request = ExampleRequest(
79+
example_string=test_string,
80+
example_integer=42,
81+
)
82+
83+
async def request_iterator():
84+
yield request
85+
yield request
86+
yield request
87+
88+
response = await stub.example_stream_unary(request_iterator())
89+
assert response.example_string == test_string
90+
assert response.example_integer == test_int
91+
92+
# stream stream
93+
async for response in stub.example_stream_stream(request_iterator()):
94+
assert response.example_string == test_string
95+
assert response.example_integer == test_int

0 commit comments

Comments
 (0)
0