8000 fix: issue in gcp app · jina-ai/serve@8a48521 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8a48521

Browse files
committed
fix: issue in gcp app
1 parent a3dfc9c commit 8a48521

File tree

5 files changed

+45
-9
lines changed

5 files changed

+45
-9
lines changed

jina/serve/executors/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def __init__(
393393
self._add_dynamic_batching(dynamic_batching)
394394
self._add_runtime_args(runtime_args)
395395
self.logger = JinaLogger(self.__class__.__name__, **vars(self.runtime_args))
396-
self._validate_sagemaker()
396+
self._validate_csp()
397397
self._init_instrumentation(runtime_args)
398398
self._init_monitoring()
399399
self._init_workspace = workspace
@@ -599,14 +599,14 @@ def _add_requests(self, _requests: Optional[Dict]):
599599
f'expect {typename(self)}.{func} to be a function, but receiving {typename(_func)}'
600600
)
601601

602-
def _validate_sagemaker(self):
603-
# sagemaker expects the POST /invocations endpoint to be defined.
602+
def _validate_csp(self):
603+
# csp (sagemaker/azure/gcp) expects the POST /invocations endpoint to be defined.
604604
# if it is not defined, we check if there is only one endpoint defined,
605605
# and if so, we use it as the POST /invocations endpoint, or raise an error
606606
if (
607607
not hasattr(self, 'runtime_args')
608608
or not hasattr(self.runtime_args, 'provider')
609-
or self.runtime_args.provider != ProviderType.SAGEMAKER.value
609+
or self.runtime_args.provider not in (ProviderType.SAGEMAKER.value, ProviderType.GCP.value)
610610
):
611611
return
612612

jina/serve/runtimes/asyncio.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,23 @@ def _get_server(self):
206206
cors=getattr(self.args, 'cors', None),
207207
is_cancel=self.is_cancel,
208208
)
209+
elif (
210+
hasattr(self.args, 'provider')
211+
and self.args.provider == ProviderType.GCP
212+
):
213+
from jina.serve.runtimes.servers.http import GCPHTTPServer
214+
215+
return GCPHTTPServer(
216+
name=self.args.name,
217+
runtime_args=self.args,
218+
req_handler_cls=self.req_handler_cls,
219+
proxy=getattr(self.args, 'proxy', None),
220+
uvicorn_kwargs=getattr(self.args, 'uvicorn_kwargs', None),
221+
ssl_keyfile=getattr(self.args, 'ssl_keyfile', None),
222+
ssl_certfile=getattr(self.args, 'ssl_certfile', None),
223+
cors=getattr(self.args, 'cors', None),
224+
is_cancel 8000 =self.is_cancel,
225+
)
209226
elif not hasattr(self.args, 'protocol') or (
210227
len(self.args.protocol) == 1 and self.args.protocol[0] == ProtocolType.GRPC
211228
):

jina/serve/runtimes/worker/http_gcp_app.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def get_fastapi_app(
4141
from jina.serve.runtimes.gateway.models import _to_camel_case
4242

4343
if not docarray_v2:
44-
logger.warning('Only docarray v2 is supported with Sagemaker. ')
44+
logger.warning('Only docarray v2 is supported with GCP. ')
4545
return
4646

4747
class Header(BaseModel):
@@ -129,7 +129,6 @@ async def process(body) -> output_model:
129129
raise HTTPException(status_code=499, detail=status.description)
130130
else:
131131
return {"predictions": < 8000 span class=pl-s1>resp.docs}
132-
return output_model(predictions=resp.docs)
133132

134133
@app.api_route(**app_kwargs)
135134
async def post(request: Request):
@@ -175,7 +174,7 @@ async def post(request: Request):
175174

176175
from jina.serve.runtimes.gateway.health_model import JinaHealthModel
177176

178-
# `/ping` route is required by AWS Sagemaker
177+
# `/ping` route is required by GCP
179178
@app.get(
180179
path='/ping',
181180
summary='Get the health of Jina Executor service',

jina/serve/runtimes/worker/request_handling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def _init_monitoring(
326326
if metrics_registry:
327327
with ImportExtensions(
328328
required=True,
329-
help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina',
329+
help_text='You need to install the `prometheus_client` to use the monitoring functionality of jina',
330330
):
331331
from prometheus_client import Counter, Summary
332332

tests/integration/docarray_v2/gcp/test_gcp.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,25 @@ def test_provider_gcp_pod_inference():
7070
assert resp.status_code == 200
7171
resp_json = resp.json()
7272
assert len(resp_json['predictions']) == 2
73-
print(resp_json)
7473

74+
75+
def test_provider_gcp_deployment_inference():
76+
with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')):
77+
dep_port = random_port()
78+
with Deployment(uses='config.yml', provider='gcp', port=dep_port):
79+
# Test the `GET /ping` endpoint (added by jina for gcp)
80+
resp = requests.get(f'http://localhost:{dep_port}/ping')
81+
assert resp.status_code == 200
82+
assert resp.json() == {}
83+
84+
# Test the `POST /invocations` endpoint
85+
# Note: this endpoint is not implemented in the sample executor
86+
resp = requests.post(
87+
f'http://localhost:{dep_port}/invocations',
88+
json={
89+
'instances': ["hello world", "good apple"]
90+
},
91+
)
92+
assert resp.status_code == 200
93+
resp_json = resp.json()
94+
assert len(resp_json['predictions']) == 2

0 commit comments

Comments
 (0)
0